[Python-checkins] bpo-45876: Improve accuracy for stdev() and pstdev() in statistics (GH-29736)

rhettinger webhook-mailer at python.org
Sat Nov 27 00:55:17 EST 2021


https://github.com/python/cpython/commit/af9ee57b96cb872df6574e36027cc753417605f9
commit: af9ee57b96cb872df6574e36027cc753417605f9
branch: main
author: Raymond Hettinger <rhettinger at users.noreply.github.com>
committer: rhettinger <rhettinger at users.noreply.github.com>
date: 2021-11-26T22:54:50-07:00
summary:

bpo-45876: Improve accuracy for stdev() and pstdev() in statistics (GH-29736)

* Inlined code from variance functions

* Added helper functions for the float square root of a fraction

* Call helper functions

* Add blurb

* Fix over-specified test

* Add a test for the _sqrt_frac() helper function

* Increase the tested range

* Add type hints to the internal function.

* Fix test for correct rounding

* Simplify ⌊√(n/m)⌋ calculation

Co-authored-by: Mark Dickinson <dickinsm at gmail.com>

* Add comment and beef-up tests

* Test for zero denominator

* Add algorithmic references

* Add test for the _isqrt_frac_rto() helper function.

* Compute the 109 instead of hard-wiring it

* Stronger test for _isqrt_frac_rto()

* Bigger range

* Bigger range

* Replace float() call with int/int division to be parallel with the other code path.

* Factor out division. Update proof link. Remove internal type declaration

Co-authored-by: Mark Dickinson <dickinsm at gmail.com>

files:
A Misc/NEWS.d/next/Library/2021-11-23-15-36-56.bpo-45876.NO8Yaj.rst
M Lib/statistics.py
M Lib/test/test_statistics.py

diff --git a/Lib/statistics.py b/Lib/statistics.py
index 5c3f77df1549d..cf8eaa0a61e62 100644
--- a/Lib/statistics.py
+++ b/Lib/statistics.py
@@ -130,6 +130,7 @@
 import math
 import numbers
 import random
+import sys
 
 from fractions import Fraction
 from decimal import Decimal
@@ -304,6 +305,27 @@ def _fail_neg(values, errmsg='negative value'):
             raise StatisticsError(errmsg)
         yield x
 
+def _isqrt_frac_rto(n: int, m: int) -> float:
+    """Square root of n/m, rounded to the nearest integer using round-to-odd."""
+    # Reference: https://www.lri.fr/~melquion/doc/05-imacs17_1-expose.pdf
+    a = math.isqrt(n // m)
+    return a | (a*a*m != n)
+
+# For 53 bit precision floats, the _sqrt_frac() shift is 109.
+_sqrt_shift: int = 2 * sys.float_info.mant_dig + 3
+
+def _sqrt_frac(n: int, m: int) -> float:
+    """Square root of n/m as a float, correctly rounded."""
+    # See principle and proof sketch at: https://bugs.python.org/msg407078
+    q = (n.bit_length() - m.bit_length() - _sqrt_shift) // 2
+    if q >= 0:
+        numerator = _isqrt_frac_rto(n, m << 2 * q) << q
+        denominator = 1
+    else:
+        numerator = _isqrt_frac_rto(n << -2 * q, m)
+        denominator = 1 << -q
+    return numerator / denominator   # Convert to float
+
 
 # === Measures of central tendency (averages) ===
 
@@ -837,14 +859,17 @@ def stdev(data, xbar=None):
     1.0810874155219827
 
     """
-    # Fixme: Despite the exact sum of squared deviations, some inaccuracy
-    # remain because there are two rounding steps.  The first occurs in
-    # the _convert() step for variance(), the second occurs in math.sqrt().
-    var = variance(data, xbar)
-    try:
+    if iter(data) is data:
+        data = list(data)
+    n = len(data)
+    if n < 2:
+        raise StatisticsError('stdev requires at least two data points')
+    T, ss = _ss(data, xbar)
+    mss = ss / (n - 1)
+    if hasattr(T, 'sqrt'):
+        var = _convert(mss, T)
         return var.sqrt()
-    except AttributeError:
-        return math.sqrt(var)
+    return _sqrt_frac(mss.numerator, mss.denominator)
 
 
 def pstdev(data, mu=None):
@@ -856,14 +881,17 @@ def pstdev(data, mu=None):
     0.986893273527251
 
     """
-    # Fixme: Despite the exact sum of squared deviations, some inaccuracy
-    # remain because there are two rounding steps.  The first occurs in
-    # the _convert() step for pvariance(), the second occurs in math.sqrt().
-    var = pvariance(data, mu)
-    try:
+    if iter(data) is data:
+        data = list(data)
+    n = len(data)
+    if n < 1:
+        raise StatisticsError('pstdev requires at least one data point')
+    T, ss = _ss(data, mu)
+    mss = ss / n
+    if hasattr(T, 'sqrt'):
+        var = _convert(mss, T)
         return var.sqrt()
-    except AttributeError:
-        return math.sqrt(var)
+    return _sqrt_frac(mss.numerator, mss.denominator)
 
 
 # === Statistics for relations between two inputs ===
diff --git a/Lib/test/test_statistics.py b/Lib/test/test_statistics.py
index c0e427d9355f2..771a03e707ee0 100644
--- a/Lib/test/test_statistics.py
+++ b/Lib/test/test_statistics.py
@@ -9,13 +9,14 @@
 import copy
 import decimal
 import doctest
+import itertools
 import math
 import pickle
 import random
 import sys
 import unittest
 from test import support
-from test.support import import_helper
+from test.support import import_helper, requires_IEEE_754
 
 from decimal import Decimal
 from fractions import Fraction
@@ -2161,6 +2162,66 @@ def test_center_not_at_mean(self):
         self.assertEqual(self.func(data), 2.5)
         self.assertEqual(self.func(data, mu=0.5), 6.5)
 
+class TestSqrtHelpers(unittest.TestCase):
+
+    def test_isqrt_frac_rto(self):
+        for n, m in itertools.product(range(100), range(1, 1000)):
+            r = statistics._isqrt_frac_rto(n, m)
+            self.assertIsInstance(r, int)
+            if r*r*m == n:
+                # Root is exact
+                continue
+            # Inexact, so the root should be odd
+            self.assertEqual(r&1, 1)
+            # Verify correct rounding
+            self.assertTrue(m * (r - 1)**2 < n < m * (r + 1)**2)
+
+    @requires_IEEE_754
+    def test_sqrt_frac(self):
+
+        def is_root_correctly_rounded(x: Fraction, root: float) -> bool:
+            if not x:
+                return root == 0.0
+
+            # Extract adjacent representable floats
+            r_up: float = math.nextafter(root, math.inf)
+            r_down: float = math.nextafter(root, -math.inf)
+            assert r_down < root < r_up
+
+            # Convert to fractions for exact arithmetic
+            frac_root: Fraction = Fraction(root)
+            half_way_up: Fraction = (frac_root + Fraction(r_up)) / 2
+            half_way_down: Fraction = (frac_root + Fraction(r_down)) / 2
+
+            # Check a closed interval.
+            # Does not test for a midpoint rounding rule.
+            return half_way_down ** 2 <= x <= half_way_up ** 2
+
+        randrange = random.randrange
+
+        for i in range(60_000):
+            numerator: int = randrange(10 ** randrange(50))
+            denonimator: int = randrange(10 ** randrange(50)) + 1
+            with self.subTest(numerator=numerator, denonimator=denonimator):
+                x: Fraction = Fraction(numerator, denonimator)
+                root: float = statistics._sqrt_frac(numerator, denonimator)
+                self.assertTrue(is_root_correctly_rounded(x, root))
+
+        # Verify that corner cases and error handling match math.sqrt()
+        self.assertEqual(statistics._sqrt_frac(0, 1), 0.0)
+        with self.assertRaises(ValueError):
+            statistics._sqrt_frac(-1, 1)
+        with self.assertRaises(ValueError):
+            statistics._sqrt_frac(1, -1)
+
+        # Error handling for zero denominator matches that for Fraction(1, 0)
+        with self.assertRaises(ZeroDivisionError):
+            statistics._sqrt_frac(1, 0)
+
+        # The result is well defined if both inputs are negative
+        self.assertAlmostEqual(statistics._sqrt_frac(-2, -1), math.sqrt(2.0))
+
+
 class TestStdev(VarianceStdevMixin, NumericTestCase):
     # Tests for sample standard deviation.
     def setUp(self):
@@ -2175,7 +2236,7 @@ def test_compare_to_variance(self):
         # Test that stdev is, in fact, the square root of variance.
         data = [random.uniform(-2, 9) for _ in range(1000)]
         expected = math.sqrt(statistics.variance(data))
-        self.assertEqual(self.func(data), expected)
+        self.assertAlmostEqual(self.func(data), expected)
 
     def test_center_not_at_mean(self):
         data = (1.0, 2.0)
diff --git a/Misc/NEWS.d/next/Library/2021-11-23-15-36-56.bpo-45876.NO8Yaj.rst b/Misc/NEWS.d/next/Library/2021-11-23-15-36-56.bpo-45876.NO8Yaj.rst
new file mode 100644
index 0000000000000..889ed6ce3ffb2
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2021-11-23-15-36-56.bpo-45876.NO8Yaj.rst
@@ -0,0 +1,2 @@
+Improve the accuracy of stdev() and pstdev() in the statistics module.  When
+the inputs are floats or fractions, the output is a correctly rounded float



More information about the Python-checkins mailing list