[Python-checkins] bpo-39576: Prevent memory error for overly optimistic precisions (GH-18581)

Stefan Krah webhook-mailer at python.org
Thu Feb 20 19:52:52 EST 2020


https://github.com/python/cpython/commit/90930e65455f60216f09d175586139242dbba260
commit: 90930e65455f60216f09d175586139242dbba260
branch: master
author: Stefan Krah <skrah at bytereef.org>
committer: GitHub <noreply at github.com>
date: 2020-02-21T01:52:47+01:00
summary:

bpo-39576: Prevent memory error for overly optimistic precisions (GH-18581)

files:
M Lib/test/test_decimal.py
M Modules/_decimal/libmpdec/mpdecimal.c
M Modules/_decimal/tests/deccheck.py

diff --git a/Lib/test/test_decimal.py b/Lib/test/test_decimal.py
index fe0cfc7b66d7e..f1abd2aecb122 100644
--- a/Lib/test/test_decimal.py
+++ b/Lib/test/test_decimal.py
@@ -5476,6 +5476,41 @@ def __abs__(self):
             self.assertEqual(Decimal.from_float(cls(101.1)),
                              Decimal.from_float(101.1))
 
+    def test_maxcontext_exact_arith(self):
+
+        # Make sure that exact operations do not raise MemoryError due
+        # to huge intermediate values when the context precision is very
+        # large.
+
+        # The following functions fill the available precision and are
+        # therefore not suitable for large precisions (by design of the
+        # specification).
+        MaxContextSkip = ['logical_invert', 'next_minus', 'next_plus',
+                          'logical_and', 'logical_or', 'logical_xor',
+                          'next_toward', 'rotate', 'shift']
+
+        Decimal = C.Decimal
+        Context = C.Context
+        localcontext = C.localcontext
+
+        # Here only some functions that are likely candidates for triggering a
+        # MemoryError are tested.  deccheck.py has an exhaustive test.
+        maxcontext = Context(prec=C.MAX_PREC, Emin=C.MIN_EMIN, Emax=C.MAX_EMAX)
+        with localcontext(maxcontext):
+            self.assertEqual(Decimal(0).exp(), 1)
+            self.assertEqual(Decimal(1).ln(), 0)
+            self.assertEqual(Decimal(1).log10(), 0)
+            self.assertEqual(Decimal(10**2).log10(), 2)
+            self.assertEqual(Decimal(10**223).log10(), 223)
+            self.assertEqual(Decimal(10**19).logb(), 19)
+            self.assertEqual(Decimal(4).sqrt(), 2)
+            self.assertEqual(Decimal("40E9").sqrt(), Decimal('2.0E+5'))
+            self.assertEqual(divmod(Decimal(10), 3), (3, 1))
+            self.assertEqual(Decimal(10) // 3, 3)
+            self.assertEqual(Decimal(4) / 2, 2)
+            self.assertEqual(Decimal(400) ** -1, Decimal('0.0025'))
+
+
 @requires_docstrings
 @unittest.skipUnless(C, "test requires C version")
 class SignatureTest(unittest.TestCase):
diff --git a/Modules/_decimal/libmpdec/mpdecimal.c b/Modules/_decimal/libmpdec/mpdecimal.c
index bfa8bb343e60c..0986edb576a10 100644
--- a/Modules/_decimal/libmpdec/mpdecimal.c
+++ b/Modules/_decimal/libmpdec/mpdecimal.c
@@ -3781,6 +3781,43 @@ mpd_qdiv(mpd_t *q, const mpd_t *a, const mpd_t *b,
          const mpd_context_t *ctx, uint32_t *status)
 {
     _mpd_qdiv(SET_IDEAL_EXP, q, a, b, ctx, status);
+
+    if (*status & MPD_Malloc_error) {
+        /* Inexact quotients (the usual case) fill the entire context precision,
+         * which can lead to malloc() failures for very high precisions. Retry
+         * the operation with a lower precision in case the result is exact.
+         *
+         * We need an upper bound for the number of digits of a_coeff / b_coeff
+         * when the result is exact.  If a_coeff' * 1 / b_coeff' is in lowest
+         * terms, then maxdigits(a_coeff') + maxdigits(1 / b_coeff') is a suitable
+         * bound.
+         *
+         * 1 / b_coeff' is exact iff b_coeff' exclusively has prime factors 2 or 5.
+         * The largest amount of digits is generated if b_coeff' is a power of 2 or
+         * a power of 5 and is less than or equal to log5(b_coeff') <= log2(b_coeff').
+         *
+         * We arrive at a total upper bound:
+         *
+         *   maxdigits(a_coeff') + maxdigits(1 / b_coeff') <=
+         *   a->digits + log2(b_coeff) =
+         *   a->digits + log10(b_coeff) / log10(2) <=
+         *   a->digits + b->digits * 4;
+         */
+        uint32_t workstatus = 0;
+        mpd_context_t workctx = *ctx;
+        workctx.prec = a->digits + b->digits * 4;
+        if (workctx.prec >= ctx->prec) {
+            return;  /* No point in retrying, keep the original error. */
+        }
+
+        _mpd_qdiv(SET_IDEAL_EXP, q, a, b, &workctx, &workstatus);
+        if (workstatus == 0) { /* The result is exact, unrounded, normal etc. */
+            *status = 0;
+            return;
+        }
+
+        mpd_seterror(q, *status, status);
+    }
 }
 
 /* Internal function. */
@@ -7702,9 +7739,9 @@ mpd_qinvroot(mpd_t *result, const mpd_t *a, const mpd_context_t *ctx,
 /* END LIBMPDEC_ONLY */
 
 /* Algorithm from decimal.py */
-void
-mpd_qsqrt(mpd_t *result, const mpd_t *a, const mpd_context_t *ctx,
-          uint32_t *status)
+static void
+_mpd_qsqrt(mpd_t *result, const mpd_t *a, const mpd_context_t *ctx,
+           uint32_t *status)
 {
     mpd_context_t maxcontext;
     MPD_NEW_STATIC(c,0,0,0,0);
@@ -7836,6 +7873,40 @@ mpd_qsqrt(mpd_t *result, const mpd_t *a, const mpd_context_t *ctx,
     goto out;
 }
 
+void
+mpd_qsqrt(mpd_t *result, const mpd_t *a, const mpd_context_t *ctx,
+          uint32_t *status)
+{
+    _mpd_qsqrt(result, a, ctx, status);
+
+    if (*status & (MPD_Malloc_error|MPD_Division_impossible)) {
+        /* The above conditions can occur at very high context precisions
+         * if intermediate values get too large. Retry the operation with
+         * a lower context precision in case the result is exact.
+         *
+         * If the result is exact, an upper bound for the number of digits
+         * is the number of digits in the input.
+         *
+         * NOTE: sqrt(40e9) = 2.0e+5 /\ digits(40e9) = digits(2.0e+5) = 2
+         */
+        uint32_t workstatus = 0;
+        mpd_context_t workctx = *ctx;
+        workctx.prec = a->digits;
+
+        if (workctx.prec >= ctx->prec) {
+            return; /* No point in repeating this, keep the original error. */
+        }
+
+        _mpd_qsqrt(result, a, &workctx, &workstatus);
+        if (workstatus == 0) {
+            *status = 0;
+            return;
+        }
+
+        mpd_seterror(result, *status, status);
+    }
+}
+
 
 /******************************************************************************/
 /*                              Base conversions                              */
diff --git a/Modules/_decimal/tests/deccheck.py b/Modules/_decimal/tests/deccheck.py
index f907531e1ffa5..5cd5db5711426 100644
--- a/Modules/_decimal/tests/deccheck.py
+++ b/Modules/_decimal/tests/deccheck.py
@@ -125,6 +125,12 @@
     'special': ('context.__reduce_ex__', 'context.create_decimal_from_float')
 }
 
+# Functions that set no context flags but whose result can differ depending
+# on prec, Emin and Emax.
+MaxContextSkip = ['is_normal', 'is_subnormal', 'logical_invert', 'next_minus',
+                  'next_plus', 'number_class', 'logical_and', 'logical_or',
+                  'logical_xor', 'next_toward', 'rotate', 'shift']
+
 # Functions that require a restricted exponent range for reasonable runtimes.
 UnaryRestricted = [
   '__ceil__', '__floor__', '__int__', '__trunc__',
@@ -344,6 +350,20 @@ def __init__(self, funcname, operands):
         self.pex = RestrictedList()      # Python exceptions for P.Decimal
         self.presults = RestrictedList() # P.Decimal results
 
+        # If the above results are exact, unrounded and not clamped, repeat
+        # the operation with a maxcontext to ensure that huge intermediate
+        # values do not cause a MemoryError.
+        self.with_maxcontext = False
+        self.maxcontext = context.c.copy()
+        self.maxcontext.prec = C.MAX_PREC
+        self.maxcontext.Emax = C.MAX_EMAX
+        self.maxcontext.Emin = C.MIN_EMIN
+        self.maxcontext.clear_flags()
+
+        self.maxop = RestrictedList()       # converted C.Decimal operands
+        self.maxex = RestrictedList()       # Python exceptions for C.Decimal
+        self.maxresults = RestrictedList()  # C.Decimal results
+
 
 # ======================================================================
 #                SkipHandler: skip known discrepancies
@@ -545,13 +565,17 @@ def function_as_string(t):
     if t.contextfunc:
         cargs = t.cop
         pargs = t.pop
+        maxargs = t.maxop
         cfunc = "c_func: %s(" % t.funcname
         pfunc = "p_func: %s(" % t.funcname
+        maxfunc = "max_func: %s(" % t.funcname
     else:
         cself, cargs = t.cop[0], t.cop[1:]
         pself, pargs = t.pop[0], t.pop[1:]
+        maxself, maxargs = t.maxop[0], t.maxop[1:]
         cfunc = "c_func: %s.%s(" % (repr(cself), t.funcname)
         pfunc = "p_func: %s.%s(" % (repr(pself), t.funcname)
+        maxfunc = "max_func: %s.%s(" % (repr(maxself), t.funcname)
 
     err = cfunc
     for arg in cargs:
@@ -565,6 +589,14 @@ def function_as_string(t):
     err = err.rstrip(", ")
     err += ")"
 
+    if t.with_maxcontext:
+        err += "\n"
+        err += maxfunc
+        for arg in maxargs:
+            err += "%s, " % repr(arg)
+        err = err.rstrip(", ")
+        err += ")"
+
     return err
 
 def raise_error(t):
@@ -577,9 +609,24 @@ def raise_error(t):
     err = "Error in %s:\n\n" % t.funcname
     err += "input operands: %s\n\n" % (t.op,)
     err += function_as_string(t)
-    err += "\n\nc_result: %s\np_result: %s\n\n" % (t.cresults, t.presults)
-    err += "c_exceptions: %s\np_exceptions: %s\n\n" % (t.cex, t.pex)
-    err += "%s\n\n" % str(t.context)
+
+    err += "\n\nc_result: %s\np_result: %s\n" % (t.cresults, t.presults)
+    if t.with_maxcontext:
+        err += "max_result: %s\n\n" % (t.maxresults)
+    else:
+        err += "\n"
+
+    err += "c_exceptions: %s\np_exceptions: %s\n" % (t.cex, t.pex)
+    if t.with_maxcontext:
+        err += "max_exceptions: %s\n\n" % t.maxex
+    else:
+        err += "\n"
+
+    err += "%s\n" % str(t.context)
+    if t.with_maxcontext:
+        err += "%s\n" % str(t.maxcontext)
+    else:
+        err += "\n"
 
     raise VerifyError(err)
 
@@ -603,6 +650,13 @@ def raise_error(t):
 #                are printed to stdout.
 # ======================================================================
 
+def all_nan(a):
+    if isinstance(a, C.Decimal):
+        return a.is_nan()
+    elif isinstance(a, tuple):
+        return all(all_nan(v) for v in a)
+    return False
+
 def convert(t, convstr=True):
     """ t is the testset. At this stage the testset contains a tuple of
         operands t.op of various types. For decimal methods the first
@@ -617,10 +671,12 @@ def convert(t, convstr=True):
     for i, op in enumerate(t.op):
 
         context.clear_status()
+        t.maxcontext.clear_flags()
 
         if op in RoundModes:
             t.cop.append(op)
             t.pop.append(op)
+            t.maxop.append(op)
 
         elif not t.contextfunc and i == 0 or \
              convstr and isinstance(op, str):
@@ -638,11 +694,25 @@ def convert(t, convstr=True):
                 p = None
                 pex = e.__class__
 
+            try:
+                C.setcontext(t.maxcontext)
+                maxop = C.Decimal(op)
+                maxex = None
+            except (TypeError, ValueError, OverflowError) as e:
+                maxop = None
+                maxex = e.__class__
+            finally:
+                C.setcontext(context.c)
+
             t.cop.append(c)
             t.cex.append(cex)
+
             t.pop.append(p)
             t.pex.append(pex)
 
+            t.maxop.append(maxop)
+            t.maxex.append(maxex)
+
             if cex is pex:
                 if str(c) != str(p) or not context.assert_eq_status():
                     raise_error(t)
@@ -652,14 +722,21 @@ def convert(t, convstr=True):
             else:
                 raise_error(t)
 
+            # The exceptions in the maxcontext operation can legitimately
+            # differ, only test that maxex implies cex:
+            if maxex is not None and cex is not maxex:
+                raise_error(t)
+
         elif isinstance(op, Context):
             t.context = op
             t.cop.append(op.c)
             t.pop.append(op.p)
+            t.maxop.append(t.maxcontext)
 
         else:
             t.cop.append(op)
             t.pop.append(op)
+            t.maxop.append(op)
 
     return 1
 
@@ -673,6 +750,7 @@ def callfuncs(t):
         t.rc and t.rp are the results of the operation.
     """
     context.clear_status()
+    t.maxcontext.clear_flags()
 
     try:
         if t.contextfunc:
@@ -700,6 +778,35 @@ def callfuncs(t):
         t.rp = None
         t.pex.append(e.__class__)
 
+    # If the above results are exact, unrounded, normal etc., repeat the
+    # operation with a maxcontext to ensure that huge intermediate values
+    # do not cause a MemoryError.
+    if (t.funcname not in MaxContextSkip and
+        not context.c.flags[C.InvalidOperation] and
+        not context.c.flags[C.Inexact] and
+        not context.c.flags[C.Rounded] and
+        not context.c.flags[C.Subnormal] and
+        not context.c.flags[C.Clamped] and
+        not context.clamp and # results are padded to context.prec if context.clamp==1.
+        not any(isinstance(v, C.Context) for v in t.cop)): # another context is used.
+        t.with_maxcontext = True
+        try:
+            if t.contextfunc:
+                maxargs = t.maxop
+                t.rmax = getattr(t.maxcontext, t.funcname)(*maxargs)
+            else:
+                maxself = t.maxop[0]
+                maxargs = t.maxop[1:]
+                try:
+                    C.setcontext(t.maxcontext)
+                    t.rmax = getattr(maxself, t.funcname)(*maxargs)
+                finally:
+                    C.setcontext(context.c)
+            t.maxex.append(None)
+        except (TypeError, ValueError, OverflowError, MemoryError) as e:
+            t.rmax = None
+            t.maxex.append(e.__class__)
+
 def verify(t, stat):
     """ t is the testset. At this stage the testset contains the following
         tuples:
@@ -714,6 +821,9 @@ def verify(t, stat):
     """
     t.cresults.append(str(t.rc))
     t.presults.append(str(t.rp))
+    if t.with_maxcontext:
+        t.maxresults.append(str(t.rmax))
+
     if isinstance(t.rc, C.Decimal) and isinstance(t.rp, P.Decimal):
         # General case: both results are Decimals.
         t.cresults.append(t.rc.to_eng_string())
@@ -725,6 +835,12 @@ def verify(t, stat):
         t.presults.append(str(t.rp.imag))
         t.presults.append(str(t.rp.real))
 
+        if t.with_maxcontext and isinstance(t.rmax, C.Decimal):
+            t.maxresults.append(t.rmax.to_eng_string())
+            t.maxresults.append(t.rmax.as_tuple())
+            t.maxresults.append(str(t.rmax.imag))
+            t.maxresults.append(str(t.rmax.real))
+
         nc = t.rc.number_class().lstrip('+-s')
         stat[nc] += 1
     else:
@@ -732,6 +848,9 @@ def verify(t, stat):
         if not isinstance(t.rc, tuple) and not isinstance(t.rp, tuple):
             if t.rc != t.rp:
                 raise_error(t)
+            if t.with_maxcontext and not isinstance(t.rmax, tuple):
+                if t.rmax != t.rc:
+                    raise_error(t)
         stat[type(t.rc).__name__] += 1
 
     # The return value lists must be equal.
@@ -744,6 +863,20 @@ def verify(t, stat):
     if not t.context.assert_eq_status():
         raise_error(t)
 
+    if t.with_maxcontext:
+        # NaN payloads etc. depend on precision and clamp.
+        if all_nan(t.rc) and all_nan(t.rmax):
+            return
+        # The return value lists must be equal.
+        if t.maxresults != t.cresults:
+            raise_error(t)
+        # The Python exception lists (TypeError, etc.) must be equal.
+        if t.maxex != t.cex:
+            raise_error(t)
+        # The context flags must be equal.
+        if t.maxcontext.flags != t.context.c.flags:
+            raise_error(t)
+
 
 # ======================================================================
 #                           Main test loops



More information about the Python-checkins mailing list