[pypy-svn] r14624 - in pypy/dist/pypy/objspace/std: . test

tismer at codespeak.net tismer at codespeak.net
Wed Jul 13 17:30:40 CEST 2005


Author: tismer
Date: Wed Jul 13 17:30:39 2005
New Revision: 14624

Modified:
   pypy/dist/pypy/objspace/std/longobject.py
   pypy/dist/pypy/objspace/std/test/test_longobject.py
Log:
a few enhancements which I allowed myself to do on the ferry and
in Kiel at my parent's  (blush)

- removed all explicit typing, because this is the beginning of all evil!
- upgraded all implementations to the latest C source in CVS.
- implemented Karatsuba multiplication
- added optimizations for a*a
- added optimized pow()
- added more comments
- a lot of new tests

this appears to be almost State of Python Art.

XXX still to do at some time: I think we should
re-adopt CPython's use of a signed length to store
sign and size of the digit array. It is more efficient
to carry explicit length than to really have to adjust
a list. Also I think the list will get optimized
away at some point, and we want an explicit length, anyway.
But this isn't urgent.


Modified: pypy/dist/pypy/objspace/std/longobject.py
==============================================================================
--- pypy/dist/pypy/objspace/std/longobject.py	(original)
+++ pypy/dist/pypy/objspace/std/longobject.py	Wed Jul 13 17:30:39 2005
@@ -3,60 +3,44 @@
 from pypy.objspace.std.intobject import W_IntObject
 from pypy.objspace.std.floatobject import W_FloatObject
 from pypy.objspace.std.noneobject import W_NoneObject
-from pypy.rpython.rarithmetic import intmask, r_uint, r_ushort, r_ulong
-from pypy.rpython.rarithmetic import LONG_BIT
+from pypy.rpython.rarithmetic import LONG_BIT, LONG_MASK, intmask, r_uint
 
 import math
 
-# for now, we use r_uint as a digit.
-# we may later switch to r_ushort, when it is supported by rtyper etc.
-
-Digit = r_uint # already works: r_ushort
-Twodigits = r_uint
+# after many days of debugging and testing,
+# I (chris) finally found out about overflows
+# and how to assign the correct types.
+# But in the end, this really makes no sense at all.
+# Finally I think thatr we should avoid to use anything
+# but general integers. r_uint and friends should go away!
+# Unsignedness can be completely deduced by back-propagation
+# of masking. I will change the annotator to do this.
+# remember: not typing at all is much stronger!
+#
+# my conclusion:
+# having no special types at all, but describing everything
+# in terms of operations and masks is the stronger way.
 
-# the following describe a plain digit
-# XXX at the moment we can't save this bit,
-# or we would need a large enough type to hold
-# the carry bits in _x_divrem
-SHIFT = (Twodigits.BITS // 2) - 1
+SHIFT = (LONG_BIT // 2) - 1
 MASK = int((1 << SHIFT) - 1)
 
-# find the correct type for carry/borrow
-if Digit.BITS - SHIFT >= 1:
-    # we have one more bit in Digit
-    Carryadd = Digit
-    Stwodigits = int
-else:
-    # we need another Digit
-    Carryadd = Twodigits
-    raise ValueError, "need a large enough type for Stwodigits"
-Carrymul = Twodigits
 
 # Debugging digit array access.
 #
-# 0 == no check at all
-# 1 == check correct type
-# 2 == check for extra (ab)used bits
-CHECK_DIGITS = 2
+# False == no checking at all
+# True == check 0 <= value <= MASK
+
+CHECK_DIGITS = True
 
 if CHECK_DIGITS:
     class DigitArray(list):
-        if CHECK_DIGITS == 1:
-            def __setitem__(self, idx, value):
-                assert type(value) is Digit
-                list.__setitem__(self, idx, value)
-        elif CHECK_DIGITS == 2:
-            def __setitem__(self, idx, value):
-                assert type(value) is Digit
-                assert value <= MASK
-                list.__setitem__(self, idx, value)
-        else:
-            raise Exception, 'CHECK_DIGITS == %d not supported' % CHECK_DIGITS
+        def __setitem__(self, idx, value):
+            assert value >=0
+            assert value <= MASK
+            list.__setitem__(self, idx, value)
 else:
     DigitArray = list
 
-# XXX some operations below return one of their input arguments
-#     without checking that it's really of type long (and not a subclass).
 
 class W_LongObject(W_Object):
     """This is a reimplementation of longs using a list of digits."""
@@ -87,7 +71,7 @@
     def _normalize(self):
         if len(self.digits) == 0:
             self.sign = 0
-            self.digits = [Digit(0)]
+            self.digits = [0]
             return
         i = len(self.digits) - 1
         while i != 0 and self.digits[i] == 0:
@@ -99,9 +83,25 @@
 
 registerimplementation(W_LongObject)
 
+USE_KARATSUBA = True # set to False for comparison
+
+# For long multiplication, use the O(N**2) school algorithm unless
+# both operands contain more than KARATSUBA_CUTOFF digits (this
+# being an internal Python long digit, in base BASE).
+
+KARATSUBA_CUTOFF = 70
+KARATSUBA_SQUARE_CUTOFF = 2 * KARATSUBA_CUTOFF
+
+# For exponentiation, use the binary left-to-right algorithm
+# unless the exponent contains more than FIVEARY_CUTOFF digits.
+# In that case, do 5 bits at a time.  The potential drawback is that
+# a table of 2**5 intermediate results is computed.
+
+FIVEARY_CUTOFF = 8
+
 # bool-to-long
 def delegate_Bool2Long(w_bool):
-    return W_LongObject(w_bool.space, [Digit(w_bool.boolval)],
+    return W_LongObject(w_bool.space, [w_bool.boolval & MASK],
                         int(w_bool.boolval))
 
 # int-to-long delegation
@@ -135,21 +135,21 @@
         sign = 1
         ival = w_intobj.intval
     else:
-        return W_LongObject(space, [Digit(0)], 0)
+        return W_LongObject(space, [0], 0)
     # Count the number of Python digits.
     # We used to pick 5 ("big enough for anything"), but that's a
     # waste of time and space given that 5*15 = 75 bits are rarely
     # needed.
-    t = r_uint(ival)
+    t = ival
     ndigits = 0
     while t:
         ndigits += 1
         t >>= SHIFT
-    v = W_LongObject(space, [Digit(0)] * ndigits, sign)
-    t = r_uint(ival)
+    v = W_LongObject(space, [0] * ndigits, sign)
+    t = ival
     p = 0
     while t:
-        v.digits[p] = Digit(t & MASK)
+        v.digits[p] = t & MASK
         t >>= SHIFT
         p += 1
     return v
@@ -184,16 +184,16 @@
     if w_value.sign == -1:
         raise OperationError(space.w_ValueError, space.wrap(
             "cannot convert negative integer to unsigned int"))
-    x = r_uint(0)
+    x = 0
     i = len(w_value.digits) - 1
     while i >= 0:
         prev = x
-        x = (x << SHIFT) + w_value.digits[i]
+        x = ((x << SHIFT) + w_value.digits[i]) & LONG_MASK
         if (x >> SHIFT) != prev:
             raise OperationError(space.w_OverflowError, space.wrap(
                 "long int too large to convert to unsigned int"))
         i -= 1
-    return x
+    return r_uint(x) # XXX r_uint should go away
 
 def repr__Long(space, w_long):
     return space.wrap(_format(w_long, 10, True))
@@ -287,7 +287,10 @@
     return result
 
 def mul__Long_Long(space, w_long1, w_long2):
-    result = _x_mul(w_long1, w_long2)
+    if USE_KARATSUBA:
+        result = _k_mul(w_long1, w_long2)
+    else:
+        result = _x_mul(w_long1, w_long2)
     result.sign = w_long1.sign * w_long2.sign
     return result
 
@@ -322,9 +325,9 @@
                          space.w_None)
     if lz is not None:
         if lz.sign == 0:
-            raise OperationError(space.w_ValueError,
-                                    space.wrap("pow() 3rd argument cannot be 0"))
-    result = W_LongObject(space, [Digit(1)], 1)
+            raise OperationError(space.w_ValueError, space.wrap(
+                "pow() 3rd argument cannot be 0"))
+    result = W_LongObject(space, [1], 1)
     if lw.sign == 0:
         if lz is not None:
             result = mod__Long_Long(space, result, lz)
@@ -337,7 +340,7 @@
     # Treat the most significant digit specially to reduce multiplications
     while i < len(lw.digits) - 1:
         j = 0
-        m = Digit(1)
+        m = 1
         di = lw.digits[i]
         while j < SHIFT:
             if di & m:
@@ -349,7 +352,7 @@
             m = m << 1
             j += 1
         i += 1
-    m = Digit(1) << (SHIFT - 1)
+    m = 1 << (SHIFT - 1)
     highest_set_bit = SHIFT
     j = SHIFT - 1
     di = lw.digits[i]
@@ -361,7 +364,7 @@
         j -= 1
     assert highest_set_bit != SHIFT, "long not normalized"
     j = 0
-    m = Digit(1)
+    m = 1
     while j <= highest_set_bit:
         if di & m:
             result = mul__Long_Long(space, result, temp)
@@ -375,6 +378,106 @@
         result = mod__Long_Long(space, result, lz)
     return result
 
+def _impl_long_long_pow(space, a, b, c=None):
+    """ pow(a, b, c) """
+
+    negativeOutput = False  # if x<0 return negative output
+
+    # 5-ary values.  If the exponent is large enough, table is
+    # precomputed so that table[i] == a**i % c for i in range(32).
+    # python translation: the table is computed when needed.
+
+    if b.sign < 0:  # if exponent is negative
+        if c is not None:
+            raise OperationError(space.w_TypeError, space.wrap(
+                "pow() 2nd argument "
+                "cannot be negative when 3rd argument specified"))
+        return space.pow(space.newfloat(_AsDouble(a)),
+                         space.newfloat(_AsDouble(b)),
+                         space.w_None)
+
+    if c is not None:
+        # if modulus == 0:
+        #     raise ValueError()
+        if c.sign == 0:
+            raise OperationError(space.w_ValueError, space.wrap(
+                "pow() 3rd argument cannot be 0"))
+
+        # if modulus < 0:
+        #     negativeOutput = True
+        #     modulus = -modulus
+        if c.sign < 0:
+            negativeOutput = True
+            c = W_LongObject(space, c.digits, -c.sign)
+
+        # if modulus == 1:
+        #     return 0
+        if len(c.digits) == 1 and c.digits[0] == 1:
+            return W_LongObject(space, [0], 0)
+
+        # if base < 0:
+        #     base = base % modulus
+        # Having the base positive just makes things easier.
+        if a.sign < 0:
+            a, temp = _l_divmod(a, c)
+            a = temp
+
+    # At this point a, b, and c are guaranteed non-negative UNLESS
+    # c is NULL, in which case a may be negative. */
+
+    z = W_LongObject(space, [1], 1)
+
+    # python adaptation: moved macros REDUCE(X) and MULT(X, Y, result)
+    # into helper function result = _help_mult(x, y, c)
+    if len(b.digits) <= FIVEARY_CUTOFF:
+        # Left-to-right binary exponentiation (HAC Algorithm 14.79)
+        # http://www.cacr.math.uwaterloo.ca/hac/about/chap14.pdf
+        i = len(b.digits) - 1
+        while i >= 0:
+            bi = b.digits[i]
+            j = 1 << (SHIFT-1)
+            while j != 0:
+                z = _help_mult(z, z, c)
+                if bi & j:
+                    z = _help_mult(z, a, c)
+                j >>= 1
+            i -= 1
+    else:
+        # Left-to-right 5-ary exponentiation (HAC Algorithm 14.82)
+        # z still holds 1L
+        table = [z] * 32
+        table[0] = z;
+        for i in range(1, 32):
+            table[i] = _help_mult(table[i-1], a, c)
+        i = len(b.digits) - 1
+        while i >= 0:
+            bi = b.digits[i]
+            j = j = SHIFT - 5
+            while j >= 0:
+                index = (bi >> j) & 0x1f
+                for k in range(5):
+                    z = _help_mult(z, z, c)
+                if index:
+                    z = _help_mult(z, table[index], c)
+                j -= 5
+            i -= 1
+
+    if negativeOutput and z.sign != 0:
+        z = sub__Long_Long(z.space, z, c)
+    return z
+
+def _help_mult(x, y, c):
+    """
+    Multiply two values, then reduce the result:
+    result = X*Y % c.  If c is NULL, skip the mod.
+    """
+    res = mul__Long_Long(x.space, x, y)
+    # Perform a modular reduction, X = X % c, but leave X alone if c
+    # is NULL.
+    if c is not None:
+        res, temp = _l_divmod(res, c)
+        res = temp
+    return res
 
 def pow__Long_Long_Long(space, w_long1, w_long2, w_long3):
     return _impl_long_long_pow(space, w_long1, w_long2, w_long3)
@@ -382,6 +485,7 @@
 def pow__Long_Long_None(space, w_long1, w_long2, w_long3):
     return _impl_long_long_pow(space, w_long1, w_long2, None)
 
+
 def neg__Long(space, w_long1):
     return W_LongObject(space, w_long1.digits[:], -w_long1.sign)
 
@@ -395,7 +499,7 @@
     return space.newbool(w_long.sign != 0)
 
 def invert__Long(space, w_long): #Implement ~x as -(x + 1)
-    w_lpp = add__Long_Long(space, w_long, W_LongObject(space, [Digit(1)], 1))
+    w_lpp = add__Long_Long(space, w_long, W_LongObject(space, [1], 1))
     return neg__Long(space, w_lpp)
 
 def lshift__Long_Long(space, w_long1, w_long2):
@@ -419,21 +523,21 @@
     newsize = oldsize + wordshift
     if remshift:
         newsize += 1
-    z = W_LongObject(space, [Digit(0)] * newsize, a.sign)
+    z = W_LongObject(space, [0] * newsize, a.sign)
     # not sure if we will initialize things in the future?
     for i in range(wordshift):
-        z.digits[i] = Digit(0)
-    accum = Twodigits(0)
+        z.digits[i] = 0
+    accum = 0
     i = wordshift
     j = 0
     while j < oldsize:
-        accum |= Twodigits(a.digits[j]) << remshift
-        z.digits[i] = Digit(accum & MASK)
+        accum |= a.digits[j] << remshift
+        z.digits[i] = accum & MASK
         accum >>= SHIFT
         i += 1
         j += 1
     if remshift:
-        z.digits[newsize-1] = Digit(accum)
+        z.digits[newsize-1] = accum
     else:
         assert not accum
     z._normalize()
@@ -459,13 +563,13 @@
     wordshift = shiftby // SHIFT
     newsize = len(a.digits) - wordshift
     if newsize <= 0:
-        return W_LongObject(space, [Digit(0)], 0)
+        return W_LongObject(space, [0], 0)
 
     loshift = shiftby % SHIFT
     hishift = SHIFT - loshift
-    lomask = (Digit(1) << hishift) - 1
+    lomask = (1 << hishift) - 1
     himask = MASK ^ lomask
-    z = W_LongObject(space, [Digit(0)] * newsize, a.sign)
+    z = W_LongObject(space, [0] * newsize, a.sign)
     i = 0
     j = wordshift
     while i < newsize:
@@ -548,10 +652,10 @@
     digits = []
     i = 0
     while l:
-        digits.append(Digit(l & MASK))
+        digits.append(l & MASK)
         l = l >> SHIFT
     if sign == 0:
-        digits = [Digit(0)]
+        digits = [0]
     return digits, sign
 
 
@@ -564,20 +668,20 @@
     if size_a < size_b:
         a, b = b, a
         size_a, size_b = size_b, size_a
-    z = W_LongObject(a.space, [Digit(0)] * (len(a.digits) + 1), 1)
+    z = W_LongObject(a.space, [0] * (len(a.digits) + 1), 1)
     i = 0
-    carry = Carryadd(0)
+    carry = 0
     while i < size_b:
-        carry += Carryadd(a.digits[i]) + b.digits[i]
-        z.digits[i] = Digit(carry & MASK)
+        carry += a.digits[i] + b.digits[i]
+        z.digits[i] = carry & MASK
         carry >>= SHIFT
         i += 1
     while i < size_a:
         carry += a.digits[i]
-        z.digits[i] = Digit(carry & MASK)
+        z.digits[i] = carry & MASK
         carry >>= SHIFT
         i += 1
-    z.digits[i] = Digit(carry)
+    z.digits[i] = carry
     z._normalize()
     return z
 
@@ -586,7 +690,7 @@
     size_a = len(a.digits)
     size_b = len(b.digits)
     sign = 1
-    borrow = Carryadd(0)
+    borrow = 0
 
     # Ensure a is the larger of the two:
     if size_a < size_b:
@@ -599,24 +703,24 @@
         while i >= 0 and a.digits[i] == b.digits[i]:
             i -= 1
         if i < 0:
-            return W_LongObject(a.space, [Digit(0)], 0)
+            return W_LongObject(a.space, [0], 0)
         if a.digits[i] < b.digits[i]:
             sign = -1
             a, b = b, a
         size_a = size_b = i+1
-    z = W_LongObject(a.space, [Digit(0)] * size_a, 1)
+    z = W_LongObject(a.space, [0] * size_a, 1)
     i = 0
     while i < size_b:
         # The following assumes unsigned arithmetic
         # works modulo 2**N for some N>SHIFT.
-        borrow = Carryadd(a.digits[i]) - b.digits[i] - borrow
-        z.digits[i] = Digit(borrow & MASK)
+        borrow = a.digits[i] - b.digits[i] - borrow
+        z.digits[i] = borrow & MASK
         borrow >>= SHIFT
         borrow &= 1 # Keep only one sign bit
         i += 1
     while i < size_a:
         borrow = a.digits[i] - borrow
-        z.digits[i] = Digit(borrow & MASK)
+        z.digits[i] = borrow & MASK
         borrow >>= SHIFT
         borrow &= 1 # Keep only one sign bit
         i += 1
@@ -627,37 +731,317 @@
     return z
 
 
-#Multiply the absolute values of two longs
 def _x_mul(a, b):
+    """
+    Grade school multiplication, ignoring the signs.
+    Returns the absolute value of the product, or NULL if error.
+    """
+
     size_a = len(a.digits)
     size_b = len(b.digits)
-    z = W_LongObject(a.space, [Digit(0)] * (size_a + size_b), 1)
-    i = 0
-    while i < size_a:
-        carry = Carrymul(0)
-        f = Twodigits(a.digits[i])
-        j = 0
-        while j < size_b:
-            carry += z.digits[i + j] + b.digits[j] * f
-            z.digits[i + j] = Digit(carry & MASK)
-            carry >>= SHIFT
-            j += 1
-        while carry != 0:
-            assert i + j < size_a + size_b
-            carry += z.digits[i + j]
-            z.digits[i + j] = Digit(carry & MASK)
+    z = W_LongObject(a.space, [0] * (size_a + size_b), 1)
+    if a == b:
+        # Efficient squaring per HAC, Algorithm 14.16:
+        # http://www.cacr.math.uwaterloo.ca/hac/about/chap14.pdf
+        # Gives slightly less than a 2x speedup when a == b,
+        # via exploiting that each entry in the multiplication
+        # pyramid appears twice (except for the size_a squares).
+        i = 0
+        while i < size_a:
+            f = a.digits[i]
+            pz = i << 1
+            pa = i + 1
+            paend = size_a
+
+            carry = z.digits[pz] + f * f
+            z.digits[pz] = carry & MASK
+            pz += 1
             carry >>= SHIFT
-            j += 1
-        i += 1
+            assert carry <= MASK
+
+            # Now f is added in twice in each column of the
+            # pyramid it appears.  Same as adding f<<1 once.
+            f <<= 1
+            while pa < paend:
+                carry += z.digits[pz] + a.digits[pa] * f
+                pa += 1
+                z.digits[pz] = carry & MASK
+                pz += 1
+                carry >>= SHIFT
+                assert carry <= (MASK << 1)
+            if carry:
+                carry += z.digits[pz]
+                z.digits[pz] = carry & MASK
+                pz += 1
+                carry >>= SHIFT
+            if carry:
+                z.digits[pz] += carry & MASK
+            assert (carry >> SHIFT) == 0
+            i += 1
+    else:
+        # a is not the same as b -- gradeschool long mult
+        i = 0
+        while i < size_a:
+            carry = 0
+            f = a.digits[i]
+            pz = i
+            pb = 0
+            pbend = size_b
+            while pb < pbend:
+                carry += z.digits[pz] + b.digits[pb] * f
+                pb += 1
+                z.digits[pz] = carry & MASK
+                pz += 1
+                carry >>= SHIFT
+                assert carry <= MASK
+            if carry:
+                z.digits[pz] += carry & MASK
+            assert (carry >> SHIFT) == 0
+            i += 1
     z._normalize()
     return z
 
+
+def _kmul_split(n, size):
+    """
+    A helper for Karatsuba multiplication (k_mul).
+    Takes a long "n" and an integer "size" representing the place to
+    split, and sets low and high such that abs(n) == (high << size) + low,
+    viewing the shift as being by digits.  The sign bit is ignored, and
+    the return values are >= 0.
+    """
+    size_n = len(n.digits)
+    size_lo = min(size_n, size)
+
+    lo = W_LongObject(n.space, n.digits[:size_lo], 1)
+    hi = W_LongObject(n.space, n.digits[size_lo:], 1)
+    lo._normalize()
+    hi._normalize()
+    return hi, lo
+
+def _k_mul(a, b):
+    """
+    Karatsuba multiplication.  Ignores the input signs, and returns the
+    absolute value of the product (or raises if error).
+    See Knuth Vol. 2 Chapter 4.3.3 (Pp. 294-295).
+    """
+    asize = len(a.digits)
+    bsize = len(b.digits)
+    # (ah*X+al)(bh*X+bl) = ah*bh*X*X + (ah*bl + al*bh)*X + al*bl
+    # Let k = (ah+al)*(bh+bl) = ah*bl + al*bh  + ah*bh + al*bl
+    # Then the original product is
+    #     ah*bh*X*X + (k - ah*bh - al*bl)*X + al*bl
+    # By picking X to be a power of 2, "*X" is just shifting, and it's
+    # been reduced to 3 multiplies on numbers half the size.
+
+    # We want to split based on the larger number; fiddle so that b
+    # is largest.
+    if asize > bsize:
+        a, b, asize, bsize = b, a, bsize, asize
+
+    # Use gradeschool math when either number is too small.
+    if a == b:
+        i = KARATSUBA_SQUARE_CUTOFF
+    else:
+        i = KARATSUBA_CUTOFF
+    if asize <= i:
+        if a.sign == 0:
+            return W_LongObject(a.space, [0], 0)
+        else:
+            return _x_mul(a, b)
+
+    # If a is small compared to b, splitting on b gives a degenerate
+    # case with ah==0, and Karatsuba may be (even much) less efficient
+    # than "grade school" then.  However, we can still win, by viewing
+    # b as a string of "big digits", each of width a->ob_size.  That
+    # leads to a sequence of balanced calls to k_mul.
+    if 2 * asize <= bsize:
+        return _k_lopsided_mul(a, b)
+
+    # Split a & b into hi & lo pieces.
+    shift = bsize >> 1
+    ah, al = _kmul_split(a, shift)
+    assert ah.sign == 1    # the split isn't degenerate
+
+    if a == b:
+        bh = ah
+        bl = al
+    else:
+        bh, bl = _kmul_split(b, shift)
+
+    # The plan:
+    # 1. Allocate result space (asize + bsize digits:  that's always
+    #    enough).
+    # 2. Compute ah*bh, and copy into result at 2*shift.
+    # 3. Compute al*bl, and copy into result at 0.  Note that this
+    #    can't overlap with #2.
+    # 4. Subtract al*bl from the result, starting at shift.  This may
+    #    underflow (borrow out of the high digit), but we don't care:
+    #    we're effectively doing unsigned arithmetic mod
+    #    BASE**(sizea + sizeb), and so long as the *final* result fits,
+    #    borrows and carries out of the high digit can be ignored.
+    # 5. Subtract ah*bh from the result, starting at shift.
+    # 6. Compute (ah+al)*(bh+bl), and add it into the result starting
+    #    at shift.
+
+    # 1. Allocate result space.
+    ret = W_LongObject(a.space, [0] * (asize + bsize), 1)
+
+    # 2. t1 <- ah*bh, and copy into high digits of result.
+    t1 = _k_mul(ah, bh)
+    assert t1.sign >= 0
+    assert 2*shift + len(t1.digits) <= len(ret.digits)
+    ret.digits[2*shift : 2*shift + len(t1.digits)] = t1.digits
+
+    # Zero-out the digits higher than the ah*bh copy. */
+    ## ignored, assuming that we initialize to zero
+    ##i = ret->ob_size - 2*shift - t1->ob_size;
+    ##if (i)
+    ##    memset(ret->ob_digit + 2*shift + t1->ob_size, 0,
+    ##           i * sizeof(digit));
+
+    # 3. t2 <- al*bl, and copy into the low digits.
+    t2 = _k_mul(al, bl)
+    assert t2.sign >= 0
+    assert len(t2.digits) <= 2*shift # no overlap with high digits
+    ret.digits[:len(t2.digits)] = t2.digits
+
+    # Zero out remaining digits.
+    ## ignored, assuming that we initialize to zero
+    ##i = 2*shift - t2->ob_size;  /* number of uninitialized digits */
+    ##if (i)
+    ##    memset(ret->ob_digit + t2->ob_size, 0, i * sizeof(digit));
+
+    # 4 & 5. Subtract ah*bh (t1) and al*bl (t2).  We do al*bl first
+    # because it's fresher in cache.
+    i = len(ret.digits) - shift  # # digits after shift
+    _v_isub(ret.digits, shift, i, t2.digits, len(t2.digits))
+    _v_isub(ret.digits, shift, i, t1.digits, len(t1.digits))
+    del t1, t2
+
+    # 6. t3 <- (ah+al)(bh+bl), and add into result.
+    t1 = _x_add(ah, al)
+    del ah, al
+
+    if a == b:
+        t2 = t1
+    else:
+        t2 = _x_add(bh, bl)
+    del bh, bl
+
+    t3 = _k_mul(t1, t2)
+    del t1, t2
+    assert t3.sign ==1
+
+    # Add t3.  It's not obvious why we can't run out of room here.
+    # See the (*) comment after this function.
+    _v_iadd(ret.digits, shift, i, t3.digits, len(t3.digits))
+    del t3
+
+    ret._normalize()
+    return ret
+
+""" (*) Why adding t3 can't "run out of room" above.
+
+Let f(x) mean the floor of x and c(x) mean the ceiling of x.  Some facts
+to start with:
+
+1. For any integer i, i = c(i/2) + f(i/2).  In particular,
+   bsize = c(bsize/2) + f(bsize/2).
+2. shift = f(bsize/2)
+3. asize <= bsize
+4. Since we call k_lopsided_mul if asize*2 <= bsize, asize*2 > bsize in this
+   routine, so asize > bsize/2 >= f(bsize/2) in this routine.
+
+We allocated asize + bsize result digits, and add t3 into them at an offset
+of shift.  This leaves asize+bsize-shift allocated digit positions for t3
+to fit into, = (by #1 and #2) asize + f(bsize/2) + c(bsize/2) - f(bsize/2) =
+asize + c(bsize/2) available digit positions.
+
+bh has c(bsize/2) digits, and bl at most f(size/2) digits.  So bh+hl has
+at most c(bsize/2) digits + 1 bit.
+
+If asize == bsize, ah has c(bsize/2) digits, else ah has at most f(bsize/2)
+digits, and al has at most f(bsize/2) digits in any case.  So ah+al has at
+most (asize == bsize ? c(bsize/2) : f(bsize/2)) digits + 1 bit.
+
+The product (ah+al)*(bh+bl) therefore has at most
+
+    c(bsize/2) + (asize == bsize ? c(bsize/2) : f(bsize/2)) digits + 2 bits
+
+and we have asize + c(bsize/2) available digit positions.  We need to show
+this is always enough.  An instance of c(bsize/2) cancels out in both, so
+the question reduces to whether asize digits is enough to hold
+(asize == bsize ? c(bsize/2) : f(bsize/2)) digits + 2 bits.  If asize < bsize,
+then we're asking whether asize digits >= f(bsize/2) digits + 2 bits.  By #4,
+asize is at least f(bsize/2)+1 digits, so this in turn reduces to whether 1
+digit is enough to hold 2 bits.  This is so since SHIFT=15 >= 2.  If
+asize == bsize, then we're asking whether bsize digits is enough to hold
+c(bsize/2) digits + 2 bits, or equivalently (by #1) whether f(bsize/2) digits
+is enough to hold 2 bits.  This is so if bsize >= 2, which holds because
+bsize >= KARATSUBA_CUTOFF >= 2.
+
+Note that since there's always enough room for (ah+al)*(bh+bl), and that's
+clearly >= each of ah*bh and al*bl, there's always enough room to subtract
+ah*bh and al*bl too.
+"""
+
+def _k_lopsided_mul(a, b):
+    """
+    b has at least twice the digits of a, and a is big enough that Karatsuba
+    would pay off *if* the inputs had balanced sizes.  View b as a sequence
+    of slices, each with a->ob_size digits, and multiply the slices by a,
+    one at a time.  This gives k_mul balanced inputs to work with, and is
+    also cache-friendly (we compute one double-width slice of the result
+    at a time, then move on, never bactracking except for the helpful
+    single-width slice overlap between successive partial sums).
+    """
+    asize = len(a.digits)
+    bsize = len(b.digits)
+    # nbdone is # of b digits already multiplied
+
+    assert asize > KARATSUBA_CUTOFF
+    assert 2 * asize <= bsize
+
+    # Allocate result space, and zero it out.
+    ret = W_LongObject(a.space, [0] * (asize + bsize), 1)
+
+    # Successive slices of b are copied into bslice.
+    #bslice = W_LongObject(a.space, [0] * asize, 1)
+    # XXX we cannot pre-allocate, see comments below!
+    bslice = W_LongObject(a.space, [0], 1)
+
+    nbdone = 0;
+    while bsize > 0:
+        nbtouse = min(bsize, asize)
+
+        # Multiply the next slice of b by a.
+
+        #bslice.digits[:nbtouse] = b.digits[nbdone : nbdone + nbtouse]
+        # XXX: this would be more efficient if we adopted CPython's
+        # way to store the size, instead of resizing the list!
+        # XXX change the implementation, encoding length via the sign.
+        bslice.digits = b.digits[nbdone : nbdone + nbtouse]
+        product = _k_mul(a, bslice)
+
+        # Add into result.
+        _v_iadd(ret.digits, nbdone, len(ret.digits) - nbdone,
+                 product.digits, len(product.digits))
+        del product
+
+        bsize -= nbtouse
+        nbdone += nbtouse
+
+    ret._normalize()
+    return ret
+
+
 def _inplace_divrem1(pout, pin, n, size=0):
     """
     Divide long pin by non-zero digit n, storing quotient
     in pout, and returning the remainder. It's OK for pin == pout on entry.
     """
-    rem = Twodigits(0)
+    rem = 0
     assert n > 0 and n <= MASK
     if not size:
         size = len(pin.digits)
@@ -665,7 +1049,7 @@
     while size >= 0:
         rem = (rem << SHIFT) + pin.digits[size]
         hi = rem // n
-        pout.digits[size] = Digit(hi)
+        pout.digits[size] = hi
         rem -= hi * n
         size -= 1
     return rem
@@ -678,24 +1062,80 @@
     """
     assert n > 0 and n <= MASK
     size = len(a.digits)
-    z = W_LongObject(a.space, [Digit(0)] * size, 1)
+    z = W_LongObject(a.space, [0] * size, 1)
     rem = _inplace_divrem1(z, a, n)
     z._normalize()
     return z, rem
 
+def _v_iadd(x, xofs, m, y, n):
+    """
+    x[0:m] and y[0:n] are digit vectors, LSD first, m >= n required.  x[0:n]
+    is modified in place, by adding y to it.  Carries are propagated as far as
+    x[m-1], and the remaining carry (0 or 1) is returned.
+    Python adaptation: x is addressed relative to xofs!
+    """
+    carry = 0;
+
+    assert m >= n
+    i = xofs
+    iend = xofs + n
+    while i < iend:
+        carry += x[i] + y[i-xofs]
+        x[i] = carry & MASK
+        carry >>= SHIFT
+        assert (carry & 1) == carry
+        i += 1
+    iend = xofs + m
+    while carry and i < iend:
+        carry += x[i]
+        x[i] = carry & MASK
+        carry >>= SHIFT
+        assert (carry & 1) == carry
+        i += 1
+    return carry
+
+def _v_isub(x, xofs, m, y, n):
+    """
+    x[0:m] and y[0:n] are digit vectors, LSD first, m >= n required.  x[0:n]
+    is modified in place, by subtracting y from it.  Borrows are propagated as
+    far as x[m-1], and the remaining borrow (0 or 1) is returned.
+    Python adaptation: x is addressed relative to xofs!
+    """
+    borrow = 0
+
+    assert m >= n
+    i = xofs
+    iend = xofs + n
+    while i < iend:
+        borrow = x[i] - y[i-xofs] - borrow
+        x[i] = borrow & MASK
+        borrow >>= SHIFT
+        borrow &= 1    # keep only 1 sign bit
+        i += 1
+    iend = xofs + m
+    while borrow and i < iend:
+        borrow = x[i] - borrow
+        x[i] = borrow & MASK
+        borrow >>= SHIFT
+        borrow &= 1
+        i += 1
+    return borrow
+
+
 def _muladd1(a, n, extra):
     """Multiply by a single digit and add a single digit, ignoring the sign.
     """
     size_a = len(a.digits)
-    z = W_LongObject(a.space, [Digit(0)] * (size_a+1), 1)
-    carry = Carrymul(extra)
+    z = W_LongObject(a.space, [0] * (size_a+1), 1)
+    carry = extra
+    assert carry & MASK == carry
     i = 0
     while i < size_a:
-        carry += Twodigits(a.digits[i]) * n
-        z.digits[i] = Digit(carry & MASK)
+        carry += a.digits[i] * n
+        z.digits[i] = carry & MASK
         carry >>= SHIFT
         i += 1
-    z.digits[i] = Digit(carry)
+    z.digits[i] = carry
     z._normalize()
     return z
 
@@ -703,69 +1143,66 @@
 def _x_divrem(v1, w1):
     """ Unsigned long division with remainder -- the algorithm """
     size_w = len(w1.digits)
-    d = Digit(Twodigits(MASK+1) // (w1.digits[size_w-1] + 1))
-    v = _muladd1(v1, d, Digit(0))
-    w = _muladd1(w1, d, Digit(0))
+    d = (MASK+1) // (w1.digits[size_w-1] + 1)
+    v = _muladd1(v1, d, 0)
+    w = _muladd1(w1, d, 0)
     size_v = len(v.digits)
     size_w = len(w.digits)
     assert size_v >= size_w and size_w > 1 # Assert checks by div()
 
     size_a = size_v - size_w + 1
-    a = W_LongObject(v.space, [Digit(0)] * size_a, 1)
+    a = W_LongObject(v.space, [0] * size_a, 1)
 
     j = size_v
     k = size_a - 1
     while k >= 0:
         if j >= size_v:
-            vj = Digit(0)
+            vj = 0
         else:
             vj = v.digits[j]
-        carry = Stwodigits(0) # note: this must hold two digits and a sign!
+        carry = 0
 
         if vj == w.digits[size_w-1]:
-            q = Twodigits(MASK)
+            q = MASK
         else:
-            q = ((Twodigits(vj) << SHIFT) + v.digits[j-1]) // w.digits[size_w-1]
+            q = ((vj << SHIFT) + v.digits[j-1]) // w.digits[size_w-1]
 
-        # notabene!
-        # this check needs a signed two digits result
-        # or we get an overflow.
         while (w.digits[size_w-2] * q >
                 ((
-                    (Stwodigits(vj) << SHIFT) # this one dominates
-                    + Stwodigits(v.digits[j-1])
-                    - Stwodigits(q) * Stwodigits(w.digits[size_w-1])
+                    (vj << SHIFT)
+                    + v.digits[j-1]
+                    - q * w.digits[size_w-1]
                                 ) << SHIFT)
-                + Stwodigits(v.digits[j-2])):
+                + v.digits[j-2]):
             q -= 1
         i = 0
         while i < size_w and i+k < size_v:
-            z = Stwodigits(w.digits[i] * q)
+            z = w.digits[i] * q
             zz = z >> SHIFT
-            carry += Stwodigits(v.digits[i+k]) - z + (zz << SHIFT)
-            v.digits[i+k] = Digit(carry & MASK)
+            carry += v.digits[i+k] - z + (zz << SHIFT)
+            v.digits[i+k] = carry & MASK
             carry >>= SHIFT
             carry -= zz
             i += 1
 
         if i+k < size_v:
-            carry += Stwodigits(v.digits[i+k])
-            v.digits[i+k] = Digit(0)
+            carry += v.digits[i+k]
+            v.digits[i+k] = 0
 
         if carry == 0:
-            a.digits[k] = Digit(q & MASK)
+            a.digits[k] = q & MASK
             assert not q >> SHIFT
         else:
             assert carry == -1
             q -= 1
-            a.digits[k] = Digit(q & MASK)
+            a.digits[k] = q & MASK
             assert not q >> SHIFT
 
-            carry = Stwodigits(0)
+            carry = 0
             i = 0
             while i < size_w and i+k < size_v:
-                carry += Stwodigits(v.digits[i+k]) + Stwodigits(w.digits[i])
-                v.digits[i+k] = Digit(carry & MASK)
+                carry += v.digits[i+k] + w.digits[i]
+                v.digits[i+k] = carry & MASK
                 carry >>= SHIFT
                 i += 1
         j -= 1
@@ -789,7 +1226,7 @@
         (size_a == size_b and
          a.digits[size_a-1] < b.digits[size_b-1])):
         # |a| < |b|
-        z = W_LongObject(a.space, [Digit(0)], 0)
+        z = W_LongObject(a.space, [0], 0)
         rem = a
         return z, rem
     if size_b == 1:
@@ -846,12 +1283,12 @@
 ##def ldexp(x, exp):
 ##    assert type(x) is float
 ##    lb1 = LONG_BIT - 1
-##    multiplier = float(Digit(1) << lb1)
+##    multiplier = float(1 << lb1)
 ##    while exp >= lb1:
 ##        x *= multiplier
 ##        exp -= lb1
 ##    if exp:
-##        x *= float(Digit(1) << exp)
+##        x *= float(1 << exp)
 ##    return x
 
 # note that math.ldexp checks for overflows,
@@ -904,13 +1341,13 @@
         dval = -dval
     frac, expo = math.frexp(dval) # dval = frac*2**expo; 0.0 <= frac < 1.0
     if expo <= 0:
-        return W_LongObject(space, [Digit(0)], 0)
+        return W_LongObject(space, [0], 0)
     ndig = (expo-1) // SHIFT + 1 # Number of 'digits' in result
-    v = W_LongObject(space, [Digit(0)] * ndig, 1)
+    v = W_LongObject(space, [0] * ndig, 1)
     frac = math.ldexp(frac, (expo-1) % SHIFT + 1)
     for i in range(ndig-1, -1, -1):
-        bits = int(frac)
-        v.digits[i] = Digit(bits)
+        bits = int(frac) & MASK # help the future annotator?
+        v.digits[i] = bits
         frac -= float(bits)
         frac = math.ldexp(frac, SHIFT)
     if neg:
@@ -937,7 +1374,7 @@
     div, mod = _divrem(v, w)
     if mod.sign * w.sign == -1:
         mod = add__Long_Long(v.space, mod, w)
-        one = W_LongObject(v.space, [Digit(1)], 1)
+        one = W_LongObject(v.space, [1], 1)
         div = sub__Long_Long(v.space, div, one)
     return div, mod
 
@@ -974,7 +1411,7 @@
         s[p] = '0'
     elif (base & (base - 1)) == 0:
         # JRH: special case for power-of-2 bases
-        accum = Twodigits(0)
+        accum = 0
         accumbits = 0  # # of bits in accum 
         basebits = 1   # # of bits in base-1
         i = base
@@ -985,7 +1422,7 @@
             basebits += 1
 
         for i in range(size_a):
-            accum |= Twodigits(a.digits[i]) << accumbits
+            accum |= a.digits[i] << accumbits
             accumbits += SHIFT
             assert accumbits >= basebits
             while 1:
@@ -1012,17 +1449,17 @@
         size = size_a
         pin = a # just for similarity to C source which uses the array
         # powbase <- largest power of base that fits in a digit.
-        powbase = Digit(base)  # powbase == base ** power
+        powbase = base  # powbase == base ** power
         power = 1
         while 1:
-            newpow = Twodigits(powbase) * Digit(base)
+            newpow = powbase * base
             if newpow >> SHIFT:  # doesn't fit in a digit
                 break
-            powbase = Digit(newpow)
+            powbase = newpow
             power += 1
 
         # Get a scratch area for repeated division.
-        scratch = W_LongObject(a.space, [Digit(0)] * size, 1)
+        scratch = W_LongObject(a.space, [0] * size, 1)
 
         # Repeatedly divide by powbase.
         while 1:
@@ -1086,14 +1523,14 @@
 
     if a.sign < 0:
         a = invert__Long(a.space, a)
-        maska = Digit(MASK)
+        maska = MASK
     else:
-        maska = Digit(0)
+        maska = 0
     if b.sign < 0:
         b = invert__Long(b.space, b)
-        maskb = Digit(MASK)
+        maskb = MASK
     else:
-        maskb = Digit(0)
+        maskb = 0
 
     negz = 0
     if op == '^':
@@ -1135,7 +1572,7 @@
     else:
         size_z = max(size_a, size_b)
 
-    z = W_LongObject(a.space, [Digit(0)] * size_z, 1)
+    z = W_LongObject(a.space, [0] * size_z, 1)
 
     for i in range(size_z):
         if i < size_a:
@@ -1161,17 +1598,17 @@
 def _AsLong(v):
     """
     Get an integer from a long int object.
-    Returns -1 and sets an error condition if overflow occurs.
+    Raises OverflowError if overflow occurs.
     """
     # This version by Tim Peters
     i = len(v.digits) - 1
     sign = v.sign
     if not sign:
         return 0
-    x = r_uint(0)
+    x = 0
     while i >= 0:
         prev = x
-        x = (x << SHIFT) + v.digits[i]
+        x = ((x << SHIFT) + v.digits[i]) & LONG_MASK
         if (x >> SHIFT) != prev:
             raise OverflowError
         i -= 1
@@ -1180,7 +1617,7 @@
     # trouble *unless* this is the min negative number.  So,
     # trouble iff sign bit set && (positive || some bit set other
     # than the sign bit).
-    if int(x) < 0 and (sign > 0 or (x << 1) != 0):
+    if intmask(x) < 0 and (sign > 0 or (x << 1) & LONG_MASK != 0):
             raise OverflowError
     return intmask(int(x) * sign)
 

Modified: pypy/dist/pypy/objspace/std/test/test_longobject.py
==============================================================================
--- pypy/dist/pypy/objspace/std/test/test_longobject.py	(original)
+++ pypy/dist/pypy/objspace/std/test/test_longobject.py	Wed Jul 13 17:30:39 2005
@@ -3,8 +3,8 @@
 from random import random, randint
 from pypy.objspace.std import longobject as lobj
 from pypy.objspace.std.objspace import FailedToImplement
-from pypy.rpython.rarithmetic import r_uint
 from pypy.interpreter.error import OperationError
+from pypy.rpython.rarithmetic import r_uint # will go away
 
 objspacename = 'std'
 
@@ -40,7 +40,7 @@
                 assert result.longval() == x * i - y * j
 
     def test_subzz(self):
-        w_l0 = lobj.W_LongObject(self.space, [r_uint(0)])
+        w_l0 = lobj.W_LongObject(self.space, [0])
         assert self.space.sub(w_l0, w_l0).longval() == 0
 
     def test_mul(self):
@@ -50,13 +50,16 @@
         f2 = lobj.W_LongObject(self.space, *lobj.args_from_long(y))
         result = lobj.mul__Long_Long(self.space, f1, f2)
         assert result.longval() == x * y
+        # also test a * a, it has special code
+        result = lobj.mul__Long_Long(self.space, f1, f1)
+        assert result.longval() == x * x
 
     def test__inplace_divrem1(self):
         # signs are not handled in the helpers!
         x = 1238585838347L
         y = 3
         f1 = lobj.W_LongObject(self.space, *lobj.args_from_long(x))
-        f2 = r_uint(y)
+        f2 = y
         remainder = lobj._inplace_divrem1(f1, f1, f2)
         assert (f1.longval(), remainder) == divmod(x, y)
 
@@ -65,7 +68,7 @@
         x = 1238585838347L
         y = 3
         f1 = lobj.W_LongObject(self.space, *lobj.args_from_long(x))
-        f2 = r_uint(y)
+        f2 = y
         div, rem = lobj._divrem1(f1, f2)
         assert (div.longval(), rem) == divmod(x, y)
 
@@ -74,8 +77,8 @@
         y = 3
         z = 42
         f1 = lobj.W_LongObject(self.space, *lobj.args_from_long(x))
-        f2 = r_uint(y)
-        f3 = r_uint(z)
+        f2 = y
+        f3 = z
         prod = lobj._muladd1(f1, f2, f3)
         assert prod.longval() == x * y + z
 
@@ -125,6 +128,45 @@
         except OperationError, e:
             assert e.w_type is self.space.w_OverflowError
 
+    # testing Karatsuba stuff
+    def test__v_iadd(self):
+        f1 = lobj.W_LongObject(self.space, [lobj.MASK] * 10, 1)
+        f2 = lobj.W_LongObject(self.space, [1], 1)
+        carry = lobj._v_iadd(f1.digits, 1, len(f1.digits)-1, f2.digits, 1)
+        assert carry == 1
+        assert f1.longval() == lobj.MASK
+
+    def test__v_isub(self):
+        f1 = lobj.W_LongObject(self.space, [lobj.MASK] + [0] * 9 + [1], 1)
+        f2 = lobj.W_LongObject(self.space, [1], 1)
+        borrow = lobj._v_isub(f1.digits, 1, len(f1.digits)-1, f2.digits, 1)
+        assert borrow == 0
+        assert f1.longval() == (1 << lobj.SHIFT) ** 10 - 1
+
+    def test__kmul_split(self):
+        split = 5
+        diglo = [0] * split
+        dighi = [lobj.MASK] * split
+        f1 = lobj.W_LongObject(self.space, diglo + dighi, 1)
+        hi, lo = lobj._kmul_split(f1, split)
+        assert lo.digits == [0]
+        assert hi.digits == dighi
+
+    def test__k_mul(self):
+        digs= lobj.KARATSUBA_CUTOFF * 5
+        f1 = lobj.W_LongObject(self.space, [lobj.MASK] * digs, 1)
+        f2 = lobj._x_add(f1,lobj.W_LongObject(self.space, [1], 1))
+        ret = lobj._k_mul(f1, f2)
+        assert ret.longval() == f1.longval() * f2.longval()
+
+    def test__k_lopsided_mul(self):
+        digs_a = lobj.KARATSUBA_CUTOFF + 3
+        digs_b = 3 * digs_a
+        f1 = lobj.W_LongObject(self.space, [lobj.MASK] * digs_a, 1)
+        f2 = lobj.W_LongObject(self.space, [lobj.MASK] * digs_b, 1)
+        ret = lobj._k_lopsided_mul(f1, f2)
+        assert ret.longval() == f1.longval() * f2.longval()
+
     def test_eq(self):
         x = 5858393919192332223L
         y = 585839391919233111223311112332L
@@ -157,7 +199,7 @@
 
         u = lobj.uint_w__Long(self.space, f2)
         assert u == 12332
-        assert isinstance(u, r_uint)
+        assert type(u) is r_uint
 
     def test_conversions(self):
         space = self.space
@@ -170,10 +212,10 @@
                 assert space.is_true(space.isinstance(lobj.int__Long(space, w_lv), space.w_int))            
                 assert space.eq_w(lobj.int__Long(space, w_lv), w_v)
 
-                if v>=0:
+                if v >= 0:
                     u = lobj.uint_w__Long(space, w_lv)
                     assert u == v
-                    assert isinstance(u, r_uint)
+                    assert type(u) is r_uint
                 else:
                     space.raises_w(space.w_ValueError, lobj.uint_w__Long, space, w_lv)
 
@@ -193,8 +235,7 @@
 
         u = lobj.uint_w__Long(space, w_lmaxuint)
         assert u == 2*sys.maxint+1
-        assert isinstance(u, r_uint)
-        
+
         space.raises_w(space.w_ValueError, lobj.uint_w__Long, space, w_toobig_lv3)       
         space.raises_w(space.w_OverflowError, lobj.uint_w__Long, space, w_toobig_lv4)
 
@@ -229,10 +270,10 @@
         assert v.longval() == x ** y
 
     def test_normalize(self):
-        f1 = lobj.W_LongObject(self.space, [lobj.r_uint(1), lobj.r_uint(0)], 1)
+        f1 = lobj.W_LongObject(self.space, [1, 0], 1)
         f1._normalize()
         assert len(f1.digits) == 1
-        f0 = lobj.W_LongObject(self.space, [lobj.r_uint(0)], 0)
+        f0 = lobj.W_LongObject(self.space, [0], 0)
         assert self.space.is_true(
             self.space.eq(lobj.sub__Long_Long(self.space, f1, f1), f0))
 



More information about the Pypy-commit mailing list