bpo-39576: Prevent memory error for overly optimistic precisions (GH-18581)
https://github.com/python/cpython/commit/90930e65455f60216f09d175586139242db... commit: 90930e65455f60216f09d175586139242dbba260 branch: master author: Stefan Krah <skrah@bytereef.org> committer: GitHub <noreply@github.com> date: 2020-02-21T01:52:47+01:00 summary: bpo-39576: Prevent memory error for overly optimistic precisions (GH-18581) files: M Lib/test/test_decimal.py M Modules/_decimal/libmpdec/mpdecimal.c M Modules/_decimal/tests/deccheck.py diff --git a/Lib/test/test_decimal.py b/Lib/test/test_decimal.py index fe0cfc7b66d7e..f1abd2aecb122 100644 --- a/Lib/test/test_decimal.py +++ b/Lib/test/test_decimal.py @@ -5476,6 +5476,41 @@ def __abs__(self): self.assertEqual(Decimal.from_float(cls(101.1)), Decimal.from_float(101.1)) + def test_maxcontext_exact_arith(self): + + # Make sure that exact operations do not raise MemoryError due + # to huge intermediate values when the context precision is very + # large. + + # The following functions fill the available precision and are + # therefore not suitable for large precisions (by design of the + # specification). + MaxContextSkip = ['logical_invert', 'next_minus', 'next_plus', + 'logical_and', 'logical_or', 'logical_xor', + 'next_toward', 'rotate', 'shift'] + + Decimal = C.Decimal + Context = C.Context + localcontext = C.localcontext + + # Here only some functions that are likely candidates for triggering a + # MemoryError are tested. deccheck.py has an exhaustive test. + maxcontext = Context(prec=C.MAX_PREC, Emin=C.MIN_EMIN, Emax=C.MAX_EMAX) + with localcontext(maxcontext): + self.assertEqual(Decimal(0).exp(), 1) + self.assertEqual(Decimal(1).ln(), 0) + self.assertEqual(Decimal(1).log10(), 0) + self.assertEqual(Decimal(10**2).log10(), 2) + self.assertEqual(Decimal(10**223).log10(), 223) + self.assertEqual(Decimal(10**19).logb(), 19) + self.assertEqual(Decimal(4).sqrt(), 2) + self.assertEqual(Decimal("40E9").sqrt(), Decimal('2.0E+5')) + self.assertEqual(divmod(Decimal(10), 3), (3, 1)) + self.assertEqual(Decimal(10) // 3, 3) + self.assertEqual(Decimal(4) / 2, 2) + self.assertEqual(Decimal(400) ** -1, Decimal('0.0025')) + + @requires_docstrings @unittest.skipUnless(C, "test requires C version") class SignatureTest(unittest.TestCase): diff --git a/Modules/_decimal/libmpdec/mpdecimal.c b/Modules/_decimal/libmpdec/mpdecimal.c index bfa8bb343e60c..0986edb576a10 100644 --- a/Modules/_decimal/libmpdec/mpdecimal.c +++ b/Modules/_decimal/libmpdec/mpdecimal.c @@ -3781,6 +3781,43 @@ mpd_qdiv(mpd_t *q, const mpd_t *a, const mpd_t *b, const mpd_context_t *ctx, uint32_t *status) { _mpd_qdiv(SET_IDEAL_EXP, q, a, b, ctx, status); + + if (*status & MPD_Malloc_error) { + /* Inexact quotients (the usual case) fill the entire context precision, + * which can lead to malloc() failures for very high precisions. Retry + * the operation with a lower precision in case the result is exact. + * + * We need an upper bound for the number of digits of a_coeff / b_coeff + * when the result is exact. If a_coeff' * 1 / b_coeff' is in lowest + * terms, then maxdigits(a_coeff') + maxdigits(1 / b_coeff') is a suitable + * bound. + * + * 1 / b_coeff' is exact iff b_coeff' exclusively has prime factors 2 or 5. + * The largest amount of digits is generated if b_coeff' is a power of 2 or + * a power of 5 and is less than or equal to log5(b_coeff') <= log2(b_coeff'). + * + * We arrive at a total upper bound: + * + * maxdigits(a_coeff') + maxdigits(1 / b_coeff') <= + * a->digits + log2(b_coeff) = + * a->digits + log10(b_coeff) / log10(2) <= + * a->digits + b->digits * 4; + */ + uint32_t workstatus = 0; + mpd_context_t workctx = *ctx; + workctx.prec = a->digits + b->digits * 4; + if (workctx.prec >= ctx->prec) { + return; /* No point in retrying, keep the original error. */ + } + + _mpd_qdiv(SET_IDEAL_EXP, q, a, b, &workctx, &workstatus); + if (workstatus == 0) { /* The result is exact, unrounded, normal etc. */ + *status = 0; + return; + } + + mpd_seterror(q, *status, status); + } } /* Internal function. */ @@ -7702,9 +7739,9 @@ mpd_qinvroot(mpd_t *result, const mpd_t *a, const mpd_context_t *ctx, /* END LIBMPDEC_ONLY */ /* Algorithm from decimal.py */ -void -mpd_qsqrt(mpd_t *result, const mpd_t *a, const mpd_context_t *ctx, - uint32_t *status) +static void +_mpd_qsqrt(mpd_t *result, const mpd_t *a, const mpd_context_t *ctx, + uint32_t *status) { mpd_context_t maxcontext; MPD_NEW_STATIC(c,0,0,0,0); @@ -7836,6 +7873,40 @@ mpd_qsqrt(mpd_t *result, const mpd_t *a, const mpd_context_t *ctx, goto out; } +void +mpd_qsqrt(mpd_t *result, const mpd_t *a, const mpd_context_t *ctx, + uint32_t *status) +{ + _mpd_qsqrt(result, a, ctx, status); + + if (*status & (MPD_Malloc_error|MPD_Division_impossible)) { + /* The above conditions can occur at very high context precisions + * if intermediate values get too large. Retry the operation with + * a lower context precision in case the result is exact. + * + * If the result is exact, an upper bound for the number of digits + * is the number of digits in the input. + * + * NOTE: sqrt(40e9) = 2.0e+5 /\ digits(40e9) = digits(2.0e+5) = 2 + */ + uint32_t workstatus = 0; + mpd_context_t workctx = *ctx; + workctx.prec = a->digits; + + if (workctx.prec >= ctx->prec) { + return; /* No point in repeating this, keep the original error. */ + } + + _mpd_qsqrt(result, a, &workctx, &workstatus); + if (workstatus == 0) { + *status = 0; + return; + } + + mpd_seterror(result, *status, status); + } +} + /******************************************************************************/ /* Base conversions */ diff --git a/Modules/_decimal/tests/deccheck.py b/Modules/_decimal/tests/deccheck.py index f907531e1ffa5..5cd5db5711426 100644 --- a/Modules/_decimal/tests/deccheck.py +++ b/Modules/_decimal/tests/deccheck.py @@ -125,6 +125,12 @@ 'special': ('context.__reduce_ex__', 'context.create_decimal_from_float') } +# Functions that set no context flags but whose result can differ depending +# on prec, Emin and Emax. +MaxContextSkip = ['is_normal', 'is_subnormal', 'logical_invert', 'next_minus', + 'next_plus', 'number_class', 'logical_and', 'logical_or', + 'logical_xor', 'next_toward', 'rotate', 'shift'] + # Functions that require a restricted exponent range for reasonable runtimes. UnaryRestricted = [ '__ceil__', '__floor__', '__int__', '__trunc__', @@ -344,6 +350,20 @@ def __init__(self, funcname, operands): self.pex = RestrictedList() # Python exceptions for P.Decimal self.presults = RestrictedList() # P.Decimal results + # If the above results are exact, unrounded and not clamped, repeat + # the operation with a maxcontext to ensure that huge intermediate + # values do not cause a MemoryError. + self.with_maxcontext = False + self.maxcontext = context.c.copy() + self.maxcontext.prec = C.MAX_PREC + self.maxcontext.Emax = C.MAX_EMAX + self.maxcontext.Emin = C.MIN_EMIN + self.maxcontext.clear_flags() + + self.maxop = RestrictedList() # converted C.Decimal operands + self.maxex = RestrictedList() # Python exceptions for C.Decimal + self.maxresults = RestrictedList() # C.Decimal results + # ====================================================================== # SkipHandler: skip known discrepancies @@ -545,13 +565,17 @@ def function_as_string(t): if t.contextfunc: cargs = t.cop pargs = t.pop + maxargs = t.maxop cfunc = "c_func: %s(" % t.funcname pfunc = "p_func: %s(" % t.funcname + maxfunc = "max_func: %s(" % t.funcname else: cself, cargs = t.cop[0], t.cop[1:] pself, pargs = t.pop[0], t.pop[1:] + maxself, maxargs = t.maxop[0], t.maxop[1:] cfunc = "c_func: %s.%s(" % (repr(cself), t.funcname) pfunc = "p_func: %s.%s(" % (repr(pself), t.funcname) + maxfunc = "max_func: %s.%s(" % (repr(maxself), t.funcname) err = cfunc for arg in cargs: @@ -565,6 +589,14 @@ def function_as_string(t): err = err.rstrip(", ") err += ")" + if t.with_maxcontext: + err += "\n" + err += maxfunc + for arg in maxargs: + err += "%s, " % repr(arg) + err = err.rstrip(", ") + err += ")" + return err def raise_error(t): @@ -577,9 +609,24 @@ def raise_error(t): err = "Error in %s:\n\n" % t.funcname err += "input operands: %s\n\n" % (t.op,) err += function_as_string(t) - err += "\n\nc_result: %s\np_result: %s\n\n" % (t.cresults, t.presults) - err += "c_exceptions: %s\np_exceptions: %s\n\n" % (t.cex, t.pex) - err += "%s\n\n" % str(t.context) + + err += "\n\nc_result: %s\np_result: %s\n" % (t.cresults, t.presults) + if t.with_maxcontext: + err += "max_result: %s\n\n" % (t.maxresults) + else: + err += "\n" + + err += "c_exceptions: %s\np_exceptions: %s\n" % (t.cex, t.pex) + if t.with_maxcontext: + err += "max_exceptions: %s\n\n" % t.maxex + else: + err += "\n" + + err += "%s\n" % str(t.context) + if t.with_maxcontext: + err += "%s\n" % str(t.maxcontext) + else: + err += "\n" raise VerifyError(err) @@ -603,6 +650,13 @@ def raise_error(t): # are printed to stdout. # ====================================================================== +def all_nan(a): + if isinstance(a, C.Decimal): + return a.is_nan() + elif isinstance(a, tuple): + return all(all_nan(v) for v in a) + return False + def convert(t, convstr=True): """ t is the testset. At this stage the testset contains a tuple of operands t.op of various types. For decimal methods the first @@ -617,10 +671,12 @@ def convert(t, convstr=True): for i, op in enumerate(t.op): context.clear_status() + t.maxcontext.clear_flags() if op in RoundModes: t.cop.append(op) t.pop.append(op) + t.maxop.append(op) elif not t.contextfunc and i == 0 or \ convstr and isinstance(op, str): @@ -638,11 +694,25 @@ def convert(t, convstr=True): p = None pex = e.__class__ + try: + C.setcontext(t.maxcontext) + maxop = C.Decimal(op) + maxex = None + except (TypeError, ValueError, OverflowError) as e: + maxop = None + maxex = e.__class__ + finally: + C.setcontext(context.c) + t.cop.append(c) t.cex.append(cex) + t.pop.append(p) t.pex.append(pex) + t.maxop.append(maxop) + t.maxex.append(maxex) + if cex is pex: if str(c) != str(p) or not context.assert_eq_status(): raise_error(t) @@ -652,14 +722,21 @@ def convert(t, convstr=True): else: raise_error(t) + # The exceptions in the maxcontext operation can legitimately + # differ, only test that maxex implies cex: + if maxex is not None and cex is not maxex: + raise_error(t) + elif isinstance(op, Context): t.context = op t.cop.append(op.c) t.pop.append(op.p) + t.maxop.append(t.maxcontext) else: t.cop.append(op) t.pop.append(op) + t.maxop.append(op) return 1 @@ -673,6 +750,7 @@ def callfuncs(t): t.rc and t.rp are the results of the operation. """ context.clear_status() + t.maxcontext.clear_flags() try: if t.contextfunc: @@ -700,6 +778,35 @@ def callfuncs(t): t.rp = None t.pex.append(e.__class__) + # If the above results are exact, unrounded, normal etc., repeat the + # operation with a maxcontext to ensure that huge intermediate values + # do not cause a MemoryError. + if (t.funcname not in MaxContextSkip and + not context.c.flags[C.InvalidOperation] and + not context.c.flags[C.Inexact] and + not context.c.flags[C.Rounded] and + not context.c.flags[C.Subnormal] and + not context.c.flags[C.Clamped] and + not context.clamp and # results are padded to context.prec if context.clamp==1. + not any(isinstance(v, C.Context) for v in t.cop)): # another context is used. + t.with_maxcontext = True + try: + if t.contextfunc: + maxargs = t.maxop + t.rmax = getattr(t.maxcontext, t.funcname)(*maxargs) + else: + maxself = t.maxop[0] + maxargs = t.maxop[1:] + try: + C.setcontext(t.maxcontext) + t.rmax = getattr(maxself, t.funcname)(*maxargs) + finally: + C.setcontext(context.c) + t.maxex.append(None) + except (TypeError, ValueError, OverflowError, MemoryError) as e: + t.rmax = None + t.maxex.append(e.__class__) + def verify(t, stat): """ t is the testset. At this stage the testset contains the following tuples: @@ -714,6 +821,9 @@ def verify(t, stat): """ t.cresults.append(str(t.rc)) t.presults.append(str(t.rp)) + if t.with_maxcontext: + t.maxresults.append(str(t.rmax)) + if isinstance(t.rc, C.Decimal) and isinstance(t.rp, P.Decimal): # General case: both results are Decimals. t.cresults.append(t.rc.to_eng_string()) @@ -725,6 +835,12 @@ def verify(t, stat): t.presults.append(str(t.rp.imag)) t.presults.append(str(t.rp.real)) + if t.with_maxcontext and isinstance(t.rmax, C.Decimal): + t.maxresults.append(t.rmax.to_eng_string()) + t.maxresults.append(t.rmax.as_tuple()) + t.maxresults.append(str(t.rmax.imag)) + t.maxresults.append(str(t.rmax.real)) + nc = t.rc.number_class().lstrip('+-s') stat[nc] += 1 else: @@ -732,6 +848,9 @@ def verify(t, stat): if not isinstance(t.rc, tuple) and not isinstance(t.rp, tuple): if t.rc != t.rp: raise_error(t) + if t.with_maxcontext and not isinstance(t.rmax, tuple): + if t.rmax != t.rc: + raise_error(t) stat[type(t.rc).__name__] += 1 # The return value lists must be equal. @@ -744,6 +863,20 @@ def verify(t, stat): if not t.context.assert_eq_status(): raise_error(t) + if t.with_maxcontext: + # NaN payloads etc. depend on precision and clamp. + if all_nan(t.rc) and all_nan(t.rmax): + return + # The return value lists must be equal. + if t.maxresults != t.cresults: + raise_error(t) + # The Python exception lists (TypeError, etc.) must be equal. + if t.maxex != t.cex: + raise_error(t) + # The context flags must be equal. + if t.maxcontext.flags != t.context.c.flags: + raise_error(t) + # ====================================================================== # Main test loops
participants (1)
-
Stefan Krah