[Python-checkins] r56656 - python/branches/decimal-branch/Lib/decimal.py

facundo.batista python-checkins at python.org
Thu Aug 2 04:17:55 CEST 2007


Author: facundo.batista
Date: Thu Aug  2 04:17:54 2007
New Revision: 56656

Modified:
   python/branches/decimal-branch/Lib/decimal.py
Log:

sqrt() function was written brand new, now is around 10x
faster and does not have some subtle errors. Also 
_fixexponents() was, well, fixed for small details.
Thanks Mark Dickinson.

Modified: python/branches/decimal-branch/Lib/decimal.py
==============================================================================
--- python/branches/decimal-branch/Lib/decimal.py	(original)
+++ python/branches/decimal-branch/Lib/decimal.py	Thu Aug  2 04:17:54 2007
@@ -1559,46 +1559,58 @@
         """Fix the exponents and return a copy with the exponent in bounds.
         Only call if known to not be a special value.
         """
-        folddown = context._clamp
-        Emin = context.Emin
+
         ans = self
-        ans_adjusted = ans.adjusted()
-        if ans_adjusted < Emin:
+        # deal with zeros first
+        if not ans:
             Etiny = context.Etiny()
             if ans._exp < Etiny:
-                if not ans:
+                ans = Decimal(self)
+                ans._exp = Etiny
+                context._raise_error(Clamped)
+            else:
+                if context._clamp:
+                    exp_max = context.Etop()
+                else:
+                    exp_max = context.Emax
+                if ans._exp > exp_max:
                     ans = Decimal(self)
-                    ans._exp = Etiny
-                    context._raise_error(Clamped)
-                    return ans
-                ans = ans._rescale(Etiny, context=context)
-                # It isn't zero, and exp < Emin => subnormal
-                context._raise_error(Subnormal)
-                if not ans:
+                    ans._exp = exp_max
                     context._raise_error(Clamped)
-                if context.flags[Inexact]:
-                    context._raise_error(Underflow)
-            else:
-                if ans:
-                    # Only raise subnormal if non-zero.
-                    context._raise_error(Subnormal)
-        else:
+            return ans
+
+        # self is nonzero; if adjusted exponent is > Emax, overflow
+        ans_adjusted = ans.adjusted()
+        if ans_adjusted > context.Emax:
+            context._raise_error(Inexact)
+            context._raise_error(Rounded)
+            c = context._raise_error(Overflow, 'above Emax', ans._sign)
+            return c
+
+        # Now check for subnormal results, and for the need to fold
+        # down.  These two conditions are *not* mutually
+        # exclusive---when the precision is large and Emin and Emax
+        # are small it's quite possible to have Emin > Etop.
+        # (Actually, the specification requires Emin and Emax to be at
+        # least 5*precision, so this shouldn't happen, but it never
+        # hurts to be careful.)
+        if context._clamp:
             Etop = context.Etop()
-            if folddown and ans._exp > Etop:
+            if ans._exp > Etop:
                 context._raise_error(Clamped)
                 ans = ans._rescale(Etop, context=context)
-            else:
-                Emax = context.Emax
-                if ans_adjusted > Emax:
+
+        if ans_adjusted < context.Emin:
+            context._raise_error(Subnormal)
+            Etiny = context.Etiny()
+            if ans._exp < Etiny:
+                ans_before_rescale = ans
+                ans = ans._rescale(Etiny, context=context)
+                if ans != ans_before_rescale:
+                    context._raise_error(Underflow)
                     if not ans:
-                        ans = Decimal(self)
-                        ans._exp = Emax
                         context._raise_error(Clamped)
-                        return ans
-                    context._raise_error(Inexact)
-                    context._raise_error(Rounded)
-                    c = context._raise_error(Overflow, 'above Emax', ans._sign)
-                    return c
+
         return ans
 
     def _round(self, prec=None, rounding=None, context=None, forceExp=None, fromQuantize=False):
@@ -2068,11 +2080,7 @@
     to_integral = to_integral_value
 
     def sqrt(self, context=None):
-        """Return the square root of self.
-
-        Uses a converging algorithm (Xn+1 = 0.5*(Xn + self / Xn))
-        Should quadratically approach the right answer.
-        """
+        """Return the square root of self."""
         if self._is_special:
             ans = self._check_nans(context=context)
             if ans:
@@ -2082,16 +2090,9 @@
                 return Decimal(self)
 
         if not self:
-            # exponent = self._exp / 2, using round_down.
-            # if self._exp < 0:
-            #    exp = (self._exp+1) // 2
-            # else:
-            exp = (self._exp) // 2
-            if self._sign == 1:
-                # sqrt(-0) = -0
-                return Decimal( (1, (0,), exp))
-            else:
-                return Decimal( (0, (0,), exp))
+            # exponent = self._exp // 2.  sqrt(-0) = -0
+            ans = Decimal((self._sign, (0,), self._exp // 2))
+            return ans._fix(context)
 
         if context is None:
             context = getcontext()
@@ -2099,94 +2100,83 @@
         if self._sign == 1:
             return context._raise_error(InvalidOperation, 'sqrt(-x), x > 0')
 
-        tmp = Decimal(self)
-
-        expadd = tmp._exp // 2
-        if tmp._exp & 1:
-            tmp._int += (0,)
-            tmp._exp = 0
-        else:
-            tmp._exp = 0
-
-        context = context._shallow_copy()
-        flags = context._ignore_all_flags()
-        firstprec = context.prec
-        context.prec = 3
-        if tmp.adjusted() & 1 == 0:
-            ans = Decimal( (0, (8,1,9), tmp.adjusted()  - 2) )
-            ans = ans.__add__(tmp.__mul__(Decimal((0, (2,5,9), -2)),
-                                          context=context), context=context)
-            ans._exp -= 1 + tmp.adjusted() // 2
-        else:
-            ans = Decimal( (0, (2,5,9), tmp._exp + len(tmp._int)- 3) )
-            ans = ans.__add__(tmp.__mul__(Decimal((0, (8,1,9), -3)),
-                                          context=context), context=context)
-            ans._exp -= 1 + tmp.adjusted()  // 2
-
-        # ans is now a linear approximation.
-        Emax, Emin = context.Emax, context.Emin
-        context.Emax, context.Emin = DefaultContext.Emax, DefaultContext.Emin
-
-        half = Decimal('0.5')
-
-        maxp = firstprec + 2
-        rounding = context._set_rounding(ROUND_HALF_EVEN)
-        while 1:
-            context.prec = min(2*context.prec - 2, maxp)
-            ans = half.__mul__(ans.__add__(tmp.__div__(ans, context=context),
-                                           context=context), context=context)
-            if context.prec == maxp:
+        # At this point self represents a positive number.  Let p be
+        # the desired precision and express self in the form c*100**e
+        # with c a positive real number and e an integer, c and e
+        # being chosen so that 100**(p-1) <= c < 100**p.  Then the
+        # (exact) square root of self is sqrt(c)*10**e, and 10**(p-1)
+        # <= sqrt(c) < 10**p, so the closest representable Decimal at
+        # precision p is n*10**e where n = round_half_even(sqrt(c)),
+        # the closest integer to sqrt(c) with the even integer chosen
+        # in the case of a tie.
+        #
+        # To ensure correct rounding in all cases, we use the
+        # following trick: we compute the square root to an extra
+        # place (precision p+1 instead of precision p), rounding down.
+        # Then, if the result is inexact and its last digit is 0 or 5,
+        # we increase the last digit to 1 or 6 respectively; if it's
+        # exact we leave the last digit alone.  Now the final round to
+        # p places (or fewer in the case of underflow) will round
+        # correctly and raise the appropriate flags.
+
+        # use an extra digit of precision
+        prec = context.prec+1
+
+        # write argument in the form c*100**e where e = self._exp//2
+        # is the 'ideal' exponent, to be used if the square root is
+        # exactly representable.  l is the number of 'digits' of c in
+        # base 100, so that 100**(l-1) <= c < 100**l.
+        op = _WorkRep(self)
+        e = op.exp >> 1
+        if op.exp & 1:
+            c = op.int * 10
+            l = (len(self._int) >> 1) + 1
+        else:
+            c = op.int
+            l = len(self._int)+1 >> 1
+
+        # rescale so that c has exactly prec base 100 'digits'
+        shift = prec-l
+        if shift >= 0:
+            c *= 100**shift
+            exact = True
+        else:
+            c, remainder = divmod(c, 100**-shift)
+            exact = not remainder
+        e -= shift
+
+        # find n = floor(sqrt(c)) using Newton's method
+        n = 10**prec
+        while True:
+            q = c//n
+            if n <= q:
                 break
+            else:
+                n = n + q >> 1
+        exact = exact and n*n == c
 
-        # Round to the answer's precision-- the only error can be 1 ulp.
-        context.prec = firstprec
-        prevexp = ans.adjusted()
-        ans = ans._round(context=context)
-
-        # Now, check if the other last digits are better.
-        context.prec = firstprec + 1
-        # In case we rounded up another digit and we should actually go lower.
-        if prevexp != ans.adjusted():
-            ans._int += (0,)
-            ans._exp -= 1
-
-
-        lower = ans.__sub__(Decimal((0, (5,), ans._exp-1)), context=context)
-        context._set_rounding(ROUND_UP)
-        if lower.__mul__(lower, context=context) > (tmp):
-            ans = ans.__sub__(Decimal((0, (1,), ans._exp)), context=context)
-
+        if exact:
+            # result is exact; rescale to use ideal exponent e
+            if shift >= 0:
+                # assert n % 10**shift == 0
+                n //= 10**shift
+            else:
+                n *= 10**-shift
+            e += shift
         else:
-            upper = ans.__add__(Decimal((0, (5,), ans._exp-1)),context=context)
-            context._set_rounding(ROUND_DOWN)
-            if upper.__mul__(upper, context=context) < tmp:
-                ans = ans.__add__(Decimal((0, (1,), ans._exp)),context=context)
+            # result is not exact; fix last digit as described above
+            if n % 5 == 0:
+                n += 1
 
-        ans._exp += expadd
+        ans = Decimal((0, map(int, str(n)), e))
 
-        context.prec = firstprec
-        context.rounding = rounding
+        # round, and fit to current context
+        context = context._shallow_copy()
+        rounding = context._set_rounding(ROUND_HALF_EVEN)
         ans = ans._fix(context)
+        context.rounding = rounding
 
-        rounding = context._set_rounding_decision(NEVER_ROUND)
-        if not ans.__mul__(ans, context=context) == self:
-            # Only rounded/inexact if here.
-            context._regard_flags(flags)
-            context._raise_error(Rounded)
-            context._raise_error(Inexact)
-        else:
-            # Exact answer, so let's set the exponent right.
-            # if self._exp < 0:
-            #    exp = (self._exp +1)// 2
-            # else:
-            exp = self._exp // 2
-            context.prec += ans._exp - exp
-            ans = ans._rescale(exp, context=context)
-            context.prec = firstprec
-            context._regard_flags(flags)
-        context.Emax, context.Emin = Emax, Emin
-
-        return ans._fix(context)
+        return ans
 
     def max(self, other, context=None):
         """Returns the larger value.


More information about the Python-checkins mailing list