[pypy-svn] r14624 - in pypy/dist/pypy/objspace/std: . test
tismer at codespeak.net
tismer at codespeak.net
Wed Jul 13 17:30:40 CEST 2005
Author: tismer
Date: Wed Jul 13 17:30:39 2005
New Revision: 14624
Modified:
pypy/dist/pypy/objspace/std/longobject.py
pypy/dist/pypy/objspace/std/test/test_longobject.py
Log:
a few enhancements which I allowed myself to do on the ferry and
in Kiel at my parent's (blush)
- removed all explicit typing, because this is the beginning of all evil!
- upgraded all implementations to the latest C source in CVS.
- implemented Karatsuba multiplication
- added optimizations for a*a
- added optimized pow()
- added more comments
- a lot of new tests
this appears to be almost State of Python Art.
XXX still to do at some time: I think we should
re-adopt CPython's use of a signed length to store
sign and size of the digit array. It is more efficient
to carry explicit length than to really have to adjust
a list. Also I think the list will get optimized
away at some point, and we want an explicit length, anyway.
But this isn't urgent.
Modified: pypy/dist/pypy/objspace/std/longobject.py
==============================================================================
--- pypy/dist/pypy/objspace/std/longobject.py (original)
+++ pypy/dist/pypy/objspace/std/longobject.py Wed Jul 13 17:30:39 2005
@@ -3,60 +3,44 @@
from pypy.objspace.std.intobject import W_IntObject
from pypy.objspace.std.floatobject import W_FloatObject
from pypy.objspace.std.noneobject import W_NoneObject
-from pypy.rpython.rarithmetic import intmask, r_uint, r_ushort, r_ulong
-from pypy.rpython.rarithmetic import LONG_BIT
+from pypy.rpython.rarithmetic import LONG_BIT, LONG_MASK, intmask, r_uint
import math
-# for now, we use r_uint as a digit.
-# we may later switch to r_ushort, when it is supported by rtyper etc.
-
-Digit = r_uint # already works: r_ushort
-Twodigits = r_uint
+# after many days of debugging and testing,
+# I (chris) finally found out about overflows
+# and how to assign the correct types.
+# But in the end, this really makes no sense at all.
+# Finally I think thatr we should avoid to use anything
+# but general integers. r_uint and friends should go away!
+# Unsignedness can be completely deduced by back-propagation
+# of masking. I will change the annotator to do this.
+# remember: not typing at all is much stronger!
+#
+# my conclusion:
+# having no special types at all, but describing everything
+# in terms of operations and masks is the stronger way.
-# the following describe a plain digit
-# XXX at the moment we can't save this bit,
-# or we would need a large enough type to hold
-# the carry bits in _x_divrem
-SHIFT = (Twodigits.BITS // 2) - 1
+SHIFT = (LONG_BIT // 2) - 1
MASK = int((1 << SHIFT) - 1)
-# find the correct type for carry/borrow
-if Digit.BITS - SHIFT >= 1:
- # we have one more bit in Digit
- Carryadd = Digit
- Stwodigits = int
-else:
- # we need another Digit
- Carryadd = Twodigits
- raise ValueError, "need a large enough type for Stwodigits"
-Carrymul = Twodigits
# Debugging digit array access.
#
-# 0 == no check at all
-# 1 == check correct type
-# 2 == check for extra (ab)used bits
-CHECK_DIGITS = 2
+# False == no checking at all
+# True == check 0 <= value <= MASK
+
+CHECK_DIGITS = True
if CHECK_DIGITS:
class DigitArray(list):
- if CHECK_DIGITS == 1:
- def __setitem__(self, idx, value):
- assert type(value) is Digit
- list.__setitem__(self, idx, value)
- elif CHECK_DIGITS == 2:
- def __setitem__(self, idx, value):
- assert type(value) is Digit
- assert value <= MASK
- list.__setitem__(self, idx, value)
- else:
- raise Exception, 'CHECK_DIGITS == %d not supported' % CHECK_DIGITS
+ def __setitem__(self, idx, value):
+ assert value >=0
+ assert value <= MASK
+ list.__setitem__(self, idx, value)
else:
DigitArray = list
-# XXX some operations below return one of their input arguments
-# without checking that it's really of type long (and not a subclass).
class W_LongObject(W_Object):
"""This is a reimplementation of longs using a list of digits."""
@@ -87,7 +71,7 @@
def _normalize(self):
if len(self.digits) == 0:
self.sign = 0
- self.digits = [Digit(0)]
+ self.digits = [0]
return
i = len(self.digits) - 1
while i != 0 and self.digits[i] == 0:
@@ -99,9 +83,25 @@
registerimplementation(W_LongObject)
+USE_KARATSUBA = True # set to False for comparison
+
+# For long multiplication, use the O(N**2) school algorithm unless
+# both operands contain more than KARATSUBA_CUTOFF digits (this
+# being an internal Python long digit, in base BASE).
+
+KARATSUBA_CUTOFF = 70
+KARATSUBA_SQUARE_CUTOFF = 2 * KARATSUBA_CUTOFF
+
+# For exponentiation, use the binary left-to-right algorithm
+# unless the exponent contains more than FIVEARY_CUTOFF digits.
+# In that case, do 5 bits at a time. The potential drawback is that
+# a table of 2**5 intermediate results is computed.
+
+FIVEARY_CUTOFF = 8
+
# bool-to-long
def delegate_Bool2Long(w_bool):
- return W_LongObject(w_bool.space, [Digit(w_bool.boolval)],
+ return W_LongObject(w_bool.space, [w_bool.boolval & MASK],
int(w_bool.boolval))
# int-to-long delegation
@@ -135,21 +135,21 @@
sign = 1
ival = w_intobj.intval
else:
- return W_LongObject(space, [Digit(0)], 0)
+ return W_LongObject(space, [0], 0)
# Count the number of Python digits.
# We used to pick 5 ("big enough for anything"), but that's a
# waste of time and space given that 5*15 = 75 bits are rarely
# needed.
- t = r_uint(ival)
+ t = ival
ndigits = 0
while t:
ndigits += 1
t >>= SHIFT
- v = W_LongObject(space, [Digit(0)] * ndigits, sign)
- t = r_uint(ival)
+ v = W_LongObject(space, [0] * ndigits, sign)
+ t = ival
p = 0
while t:
- v.digits[p] = Digit(t & MASK)
+ v.digits[p] = t & MASK
t >>= SHIFT
p += 1
return v
@@ -184,16 +184,16 @@
if w_value.sign == -1:
raise OperationError(space.w_ValueError, space.wrap(
"cannot convert negative integer to unsigned int"))
- x = r_uint(0)
+ x = 0
i = len(w_value.digits) - 1
while i >= 0:
prev = x
- x = (x << SHIFT) + w_value.digits[i]
+ x = ((x << SHIFT) + w_value.digits[i]) & LONG_MASK
if (x >> SHIFT) != prev:
raise OperationError(space.w_OverflowError, space.wrap(
"long int too large to convert to unsigned int"))
i -= 1
- return x
+ return r_uint(x) # XXX r_uint should go away
def repr__Long(space, w_long):
return space.wrap(_format(w_long, 10, True))
@@ -287,7 +287,10 @@
return result
def mul__Long_Long(space, w_long1, w_long2):
- result = _x_mul(w_long1, w_long2)
+ if USE_KARATSUBA:
+ result = _k_mul(w_long1, w_long2)
+ else:
+ result = _x_mul(w_long1, w_long2)
result.sign = w_long1.sign * w_long2.sign
return result
@@ -322,9 +325,9 @@
space.w_None)
if lz is not None:
if lz.sign == 0:
- raise OperationError(space.w_ValueError,
- space.wrap("pow() 3rd argument cannot be 0"))
- result = W_LongObject(space, [Digit(1)], 1)
+ raise OperationError(space.w_ValueError, space.wrap(
+ "pow() 3rd argument cannot be 0"))
+ result = W_LongObject(space, [1], 1)
if lw.sign == 0:
if lz is not None:
result = mod__Long_Long(space, result, lz)
@@ -337,7 +340,7 @@
# Treat the most significant digit specially to reduce multiplications
while i < len(lw.digits) - 1:
j = 0
- m = Digit(1)
+ m = 1
di = lw.digits[i]
while j < SHIFT:
if di & m:
@@ -349,7 +352,7 @@
m = m << 1
j += 1
i += 1
- m = Digit(1) << (SHIFT - 1)
+ m = 1 << (SHIFT - 1)
highest_set_bit = SHIFT
j = SHIFT - 1
di = lw.digits[i]
@@ -361,7 +364,7 @@
j -= 1
assert highest_set_bit != SHIFT, "long not normalized"
j = 0
- m = Digit(1)
+ m = 1
while j <= highest_set_bit:
if di & m:
result = mul__Long_Long(space, result, temp)
@@ -375,6 +378,106 @@
result = mod__Long_Long(space, result, lz)
return result
+def _impl_long_long_pow(space, a, b, c=None):
+ """ pow(a, b, c) """
+
+ negativeOutput = False # if x<0 return negative output
+
+ # 5-ary values. If the exponent is large enough, table is
+ # precomputed so that table[i] == a**i % c for i in range(32).
+ # python translation: the table is computed when needed.
+
+ if b.sign < 0: # if exponent is negative
+ if c is not None:
+ raise OperationError(space.w_TypeError, space.wrap(
+ "pow() 2nd argument "
+ "cannot be negative when 3rd argument specified"))
+ return space.pow(space.newfloat(_AsDouble(a)),
+ space.newfloat(_AsDouble(b)),
+ space.w_None)
+
+ if c is not None:
+ # if modulus == 0:
+ # raise ValueError()
+ if c.sign == 0:
+ raise OperationError(space.w_ValueError, space.wrap(
+ "pow() 3rd argument cannot be 0"))
+
+ # if modulus < 0:
+ # negativeOutput = True
+ # modulus = -modulus
+ if c.sign < 0:
+ negativeOutput = True
+ c = W_LongObject(space, c.digits, -c.sign)
+
+ # if modulus == 1:
+ # return 0
+ if len(c.digits) == 1 and c.digits[0] == 1:
+ return W_LongObject(space, [0], 0)
+
+ # if base < 0:
+ # base = base % modulus
+ # Having the base positive just makes things easier.
+ if a.sign < 0:
+ a, temp = _l_divmod(a, c)
+ a = temp
+
+ # At this point a, b, and c are guaranteed non-negative UNLESS
+ # c is NULL, in which case a may be negative. */
+
+ z = W_LongObject(space, [1], 1)
+
+ # python adaptation: moved macros REDUCE(X) and MULT(X, Y, result)
+ # into helper function result = _help_mult(x, y, c)
+ if len(b.digits) <= FIVEARY_CUTOFF:
+ # Left-to-right binary exponentiation (HAC Algorithm 14.79)
+ # http://www.cacr.math.uwaterloo.ca/hac/about/chap14.pdf
+ i = len(b.digits) - 1
+ while i >= 0:
+ bi = b.digits[i]
+ j = 1 << (SHIFT-1)
+ while j != 0:
+ z = _help_mult(z, z, c)
+ if bi & j:
+ z = _help_mult(z, a, c)
+ j >>= 1
+ i -= 1
+ else:
+ # Left-to-right 5-ary exponentiation (HAC Algorithm 14.82)
+ # z still holds 1L
+ table = [z] * 32
+ table[0] = z;
+ for i in range(1, 32):
+ table[i] = _help_mult(table[i-1], a, c)
+ i = len(b.digits) - 1
+ while i >= 0:
+ bi = b.digits[i]
+ j = j = SHIFT - 5
+ while j >= 0:
+ index = (bi >> j) & 0x1f
+ for k in range(5):
+ z = _help_mult(z, z, c)
+ if index:
+ z = _help_mult(z, table[index], c)
+ j -= 5
+ i -= 1
+
+ if negativeOutput and z.sign != 0:
+ z = sub__Long_Long(z.space, z, c)
+ return z
+
+def _help_mult(x, y, c):
+ """
+ Multiply two values, then reduce the result:
+ result = X*Y % c. If c is NULL, skip the mod.
+ """
+ res = mul__Long_Long(x.space, x, y)
+ # Perform a modular reduction, X = X % c, but leave X alone if c
+ # is NULL.
+ if c is not None:
+ res, temp = _l_divmod(res, c)
+ res = temp
+ return res
def pow__Long_Long_Long(space, w_long1, w_long2, w_long3):
return _impl_long_long_pow(space, w_long1, w_long2, w_long3)
@@ -382,6 +485,7 @@
def pow__Long_Long_None(space, w_long1, w_long2, w_long3):
return _impl_long_long_pow(space, w_long1, w_long2, None)
+
def neg__Long(space, w_long1):
return W_LongObject(space, w_long1.digits[:], -w_long1.sign)
@@ -395,7 +499,7 @@
return space.newbool(w_long.sign != 0)
def invert__Long(space, w_long): #Implement ~x as -(x + 1)
- w_lpp = add__Long_Long(space, w_long, W_LongObject(space, [Digit(1)], 1))
+ w_lpp = add__Long_Long(space, w_long, W_LongObject(space, [1], 1))
return neg__Long(space, w_lpp)
def lshift__Long_Long(space, w_long1, w_long2):
@@ -419,21 +523,21 @@
newsize = oldsize + wordshift
if remshift:
newsize += 1
- z = W_LongObject(space, [Digit(0)] * newsize, a.sign)
+ z = W_LongObject(space, [0] * newsize, a.sign)
# not sure if we will initialize things in the future?
for i in range(wordshift):
- z.digits[i] = Digit(0)
- accum = Twodigits(0)
+ z.digits[i] = 0
+ accum = 0
i = wordshift
j = 0
while j < oldsize:
- accum |= Twodigits(a.digits[j]) << remshift
- z.digits[i] = Digit(accum & MASK)
+ accum |= a.digits[j] << remshift
+ z.digits[i] = accum & MASK
accum >>= SHIFT
i += 1
j += 1
if remshift:
- z.digits[newsize-1] = Digit(accum)
+ z.digits[newsize-1] = accum
else:
assert not accum
z._normalize()
@@ -459,13 +563,13 @@
wordshift = shiftby // SHIFT
newsize = len(a.digits) - wordshift
if newsize <= 0:
- return W_LongObject(space, [Digit(0)], 0)
+ return W_LongObject(space, [0], 0)
loshift = shiftby % SHIFT
hishift = SHIFT - loshift
- lomask = (Digit(1) << hishift) - 1
+ lomask = (1 << hishift) - 1
himask = MASK ^ lomask
- z = W_LongObject(space, [Digit(0)] * newsize, a.sign)
+ z = W_LongObject(space, [0] * newsize, a.sign)
i = 0
j = wordshift
while i < newsize:
@@ -548,10 +652,10 @@
digits = []
i = 0
while l:
- digits.append(Digit(l & MASK))
+ digits.append(l & MASK)
l = l >> SHIFT
if sign == 0:
- digits = [Digit(0)]
+ digits = [0]
return digits, sign
@@ -564,20 +668,20 @@
if size_a < size_b:
a, b = b, a
size_a, size_b = size_b, size_a
- z = W_LongObject(a.space, [Digit(0)] * (len(a.digits) + 1), 1)
+ z = W_LongObject(a.space, [0] * (len(a.digits) + 1), 1)
i = 0
- carry = Carryadd(0)
+ carry = 0
while i < size_b:
- carry += Carryadd(a.digits[i]) + b.digits[i]
- z.digits[i] = Digit(carry & MASK)
+ carry += a.digits[i] + b.digits[i]
+ z.digits[i] = carry & MASK
carry >>= SHIFT
i += 1
while i < size_a:
carry += a.digits[i]
- z.digits[i] = Digit(carry & MASK)
+ z.digits[i] = carry & MASK
carry >>= SHIFT
i += 1
- z.digits[i] = Digit(carry)
+ z.digits[i] = carry
z._normalize()
return z
@@ -586,7 +690,7 @@
size_a = len(a.digits)
size_b = len(b.digits)
sign = 1
- borrow = Carryadd(0)
+ borrow = 0
# Ensure a is the larger of the two:
if size_a < size_b:
@@ -599,24 +703,24 @@
while i >= 0 and a.digits[i] == b.digits[i]:
i -= 1
if i < 0:
- return W_LongObject(a.space, [Digit(0)], 0)
+ return W_LongObject(a.space, [0], 0)
if a.digits[i] < b.digits[i]:
sign = -1
a, b = b, a
size_a = size_b = i+1
- z = W_LongObject(a.space, [Digit(0)] * size_a, 1)
+ z = W_LongObject(a.space, [0] * size_a, 1)
i = 0
while i < size_b:
# The following assumes unsigned arithmetic
# works modulo 2**N for some N>SHIFT.
- borrow = Carryadd(a.digits[i]) - b.digits[i] - borrow
- z.digits[i] = Digit(borrow & MASK)
+ borrow = a.digits[i] - b.digits[i] - borrow
+ z.digits[i] = borrow & MASK
borrow >>= SHIFT
borrow &= 1 # Keep only one sign bit
i += 1
while i < size_a:
borrow = a.digits[i] - borrow
- z.digits[i] = Digit(borrow & MASK)
+ z.digits[i] = borrow & MASK
borrow >>= SHIFT
borrow &= 1 # Keep only one sign bit
i += 1
@@ -627,37 +731,317 @@
return z
-#Multiply the absolute values of two longs
def _x_mul(a, b):
+ """
+ Grade school multiplication, ignoring the signs.
+ Returns the absolute value of the product, or NULL if error.
+ """
+
size_a = len(a.digits)
size_b = len(b.digits)
- z = W_LongObject(a.space, [Digit(0)] * (size_a + size_b), 1)
- i = 0
- while i < size_a:
- carry = Carrymul(0)
- f = Twodigits(a.digits[i])
- j = 0
- while j < size_b:
- carry += z.digits[i + j] + b.digits[j] * f
- z.digits[i + j] = Digit(carry & MASK)
- carry >>= SHIFT
- j += 1
- while carry != 0:
- assert i + j < size_a + size_b
- carry += z.digits[i + j]
- z.digits[i + j] = Digit(carry & MASK)
+ z = W_LongObject(a.space, [0] * (size_a + size_b), 1)
+ if a == b:
+ # Efficient squaring per HAC, Algorithm 14.16:
+ # http://www.cacr.math.uwaterloo.ca/hac/about/chap14.pdf
+ # Gives slightly less than a 2x speedup when a == b,
+ # via exploiting that each entry in the multiplication
+ # pyramid appears twice (except for the size_a squares).
+ i = 0
+ while i < size_a:
+ f = a.digits[i]
+ pz = i << 1
+ pa = i + 1
+ paend = size_a
+
+ carry = z.digits[pz] + f * f
+ z.digits[pz] = carry & MASK
+ pz += 1
carry >>= SHIFT
- j += 1
- i += 1
+ assert carry <= MASK
+
+ # Now f is added in twice in each column of the
+ # pyramid it appears. Same as adding f<<1 once.
+ f <<= 1
+ while pa < paend:
+ carry += z.digits[pz] + a.digits[pa] * f
+ pa += 1
+ z.digits[pz] = carry & MASK
+ pz += 1
+ carry >>= SHIFT
+ assert carry <= (MASK << 1)
+ if carry:
+ carry += z.digits[pz]
+ z.digits[pz] = carry & MASK
+ pz += 1
+ carry >>= SHIFT
+ if carry:
+ z.digits[pz] += carry & MASK
+ assert (carry >> SHIFT) == 0
+ i += 1
+ else:
+ # a is not the same as b -- gradeschool long mult
+ i = 0
+ while i < size_a:
+ carry = 0
+ f = a.digits[i]
+ pz = i
+ pb = 0
+ pbend = size_b
+ while pb < pbend:
+ carry += z.digits[pz] + b.digits[pb] * f
+ pb += 1
+ z.digits[pz] = carry & MASK
+ pz += 1
+ carry >>= SHIFT
+ assert carry <= MASK
+ if carry:
+ z.digits[pz] += carry & MASK
+ assert (carry >> SHIFT) == 0
+ i += 1
z._normalize()
return z
+
+def _kmul_split(n, size):
+ """
+ A helper for Karatsuba multiplication (k_mul).
+ Takes a long "n" and an integer "size" representing the place to
+ split, and sets low and high such that abs(n) == (high << size) + low,
+ viewing the shift as being by digits. The sign bit is ignored, and
+ the return values are >= 0.
+ """
+ size_n = len(n.digits)
+ size_lo = min(size_n, size)
+
+ lo = W_LongObject(n.space, n.digits[:size_lo], 1)
+ hi = W_LongObject(n.space, n.digits[size_lo:], 1)
+ lo._normalize()
+ hi._normalize()
+ return hi, lo
+
+def _k_mul(a, b):
+ """
+ Karatsuba multiplication. Ignores the input signs, and returns the
+ absolute value of the product (or raises if error).
+ See Knuth Vol. 2 Chapter 4.3.3 (Pp. 294-295).
+ """
+ asize = len(a.digits)
+ bsize = len(b.digits)
+ # (ah*X+al)(bh*X+bl) = ah*bh*X*X + (ah*bl + al*bh)*X + al*bl
+ # Let k = (ah+al)*(bh+bl) = ah*bl + al*bh + ah*bh + al*bl
+ # Then the original product is
+ # ah*bh*X*X + (k - ah*bh - al*bl)*X + al*bl
+ # By picking X to be a power of 2, "*X" is just shifting, and it's
+ # been reduced to 3 multiplies on numbers half the size.
+
+ # We want to split based on the larger number; fiddle so that b
+ # is largest.
+ if asize > bsize:
+ a, b, asize, bsize = b, a, bsize, asize
+
+ # Use gradeschool math when either number is too small.
+ if a == b:
+ i = KARATSUBA_SQUARE_CUTOFF
+ else:
+ i = KARATSUBA_CUTOFF
+ if asize <= i:
+ if a.sign == 0:
+ return W_LongObject(a.space, [0], 0)
+ else:
+ return _x_mul(a, b)
+
+ # If a is small compared to b, splitting on b gives a degenerate
+ # case with ah==0, and Karatsuba may be (even much) less efficient
+ # than "grade school" then. However, we can still win, by viewing
+ # b as a string of "big digits", each of width a->ob_size. That
+ # leads to a sequence of balanced calls to k_mul.
+ if 2 * asize <= bsize:
+ return _k_lopsided_mul(a, b)
+
+ # Split a & b into hi & lo pieces.
+ shift = bsize >> 1
+ ah, al = _kmul_split(a, shift)
+ assert ah.sign == 1 # the split isn't degenerate
+
+ if a == b:
+ bh = ah
+ bl = al
+ else:
+ bh, bl = _kmul_split(b, shift)
+
+ # The plan:
+ # 1. Allocate result space (asize + bsize digits: that's always
+ # enough).
+ # 2. Compute ah*bh, and copy into result at 2*shift.
+ # 3. Compute al*bl, and copy into result at 0. Note that this
+ # can't overlap with #2.
+ # 4. Subtract al*bl from the result, starting at shift. This may
+ # underflow (borrow out of the high digit), but we don't care:
+ # we're effectively doing unsigned arithmetic mod
+ # BASE**(sizea + sizeb), and so long as the *final* result fits,
+ # borrows and carries out of the high digit can be ignored.
+ # 5. Subtract ah*bh from the result, starting at shift.
+ # 6. Compute (ah+al)*(bh+bl), and add it into the result starting
+ # at shift.
+
+ # 1. Allocate result space.
+ ret = W_LongObject(a.space, [0] * (asize + bsize), 1)
+
+ # 2. t1 <- ah*bh, and copy into high digits of result.
+ t1 = _k_mul(ah, bh)
+ assert t1.sign >= 0
+ assert 2*shift + len(t1.digits) <= len(ret.digits)
+ ret.digits[2*shift : 2*shift + len(t1.digits)] = t1.digits
+
+ # Zero-out the digits higher than the ah*bh copy. */
+ ## ignored, assuming that we initialize to zero
+ ##i = ret->ob_size - 2*shift - t1->ob_size;
+ ##if (i)
+ ## memset(ret->ob_digit + 2*shift + t1->ob_size, 0,
+ ## i * sizeof(digit));
+
+ # 3. t2 <- al*bl, and copy into the low digits.
+ t2 = _k_mul(al, bl)
+ assert t2.sign >= 0
+ assert len(t2.digits) <= 2*shift # no overlap with high digits
+ ret.digits[:len(t2.digits)] = t2.digits
+
+ # Zero out remaining digits.
+ ## ignored, assuming that we initialize to zero
+ ##i = 2*shift - t2->ob_size; /* number of uninitialized digits */
+ ##if (i)
+ ## memset(ret->ob_digit + t2->ob_size, 0, i * sizeof(digit));
+
+ # 4 & 5. Subtract ah*bh (t1) and al*bl (t2). We do al*bl first
+ # because it's fresher in cache.
+ i = len(ret.digits) - shift # # digits after shift
+ _v_isub(ret.digits, shift, i, t2.digits, len(t2.digits))
+ _v_isub(ret.digits, shift, i, t1.digits, len(t1.digits))
+ del t1, t2
+
+ # 6. t3 <- (ah+al)(bh+bl), and add into result.
+ t1 = _x_add(ah, al)
+ del ah, al
+
+ if a == b:
+ t2 = t1
+ else:
+ t2 = _x_add(bh, bl)
+ del bh, bl
+
+ t3 = _k_mul(t1, t2)
+ del t1, t2
+ assert t3.sign ==1
+
+ # Add t3. It's not obvious why we can't run out of room here.
+ # See the (*) comment after this function.
+ _v_iadd(ret.digits, shift, i, t3.digits, len(t3.digits))
+ del t3
+
+ ret._normalize()
+ return ret
+
+""" (*) Why adding t3 can't "run out of room" above.
+
+Let f(x) mean the floor of x and c(x) mean the ceiling of x. Some facts
+to start with:
+
+1. For any integer i, i = c(i/2) + f(i/2). In particular,
+ bsize = c(bsize/2) + f(bsize/2).
+2. shift = f(bsize/2)
+3. asize <= bsize
+4. Since we call k_lopsided_mul if asize*2 <= bsize, asize*2 > bsize in this
+ routine, so asize > bsize/2 >= f(bsize/2) in this routine.
+
+We allocated asize + bsize result digits, and add t3 into them at an offset
+of shift. This leaves asize+bsize-shift allocated digit positions for t3
+to fit into, = (by #1 and #2) asize + f(bsize/2) + c(bsize/2) - f(bsize/2) =
+asize + c(bsize/2) available digit positions.
+
+bh has c(bsize/2) digits, and bl at most f(size/2) digits. So bh+hl has
+at most c(bsize/2) digits + 1 bit.
+
+If asize == bsize, ah has c(bsize/2) digits, else ah has at most f(bsize/2)
+digits, and al has at most f(bsize/2) digits in any case. So ah+al has at
+most (asize == bsize ? c(bsize/2) : f(bsize/2)) digits + 1 bit.
+
+The product (ah+al)*(bh+bl) therefore has at most
+
+ c(bsize/2) + (asize == bsize ? c(bsize/2) : f(bsize/2)) digits + 2 bits
+
+and we have asize + c(bsize/2) available digit positions. We need to show
+this is always enough. An instance of c(bsize/2) cancels out in both, so
+the question reduces to whether asize digits is enough to hold
+(asize == bsize ? c(bsize/2) : f(bsize/2)) digits + 2 bits. If asize < bsize,
+then we're asking whether asize digits >= f(bsize/2) digits + 2 bits. By #4,
+asize is at least f(bsize/2)+1 digits, so this in turn reduces to whether 1
+digit is enough to hold 2 bits. This is so since SHIFT=15 >= 2. If
+asize == bsize, then we're asking whether bsize digits is enough to hold
+c(bsize/2) digits + 2 bits, or equivalently (by #1) whether f(bsize/2) digits
+is enough to hold 2 bits. This is so if bsize >= 2, which holds because
+bsize >= KARATSUBA_CUTOFF >= 2.
+
+Note that since there's always enough room for (ah+al)*(bh+bl), and that's
+clearly >= each of ah*bh and al*bl, there's always enough room to subtract
+ah*bh and al*bl too.
+"""
+
+def _k_lopsided_mul(a, b):
+ """
+ b has at least twice the digits of a, and a is big enough that Karatsuba
+ would pay off *if* the inputs had balanced sizes. View b as a sequence
+ of slices, each with a->ob_size digits, and multiply the slices by a,
+ one at a time. This gives k_mul balanced inputs to work with, and is
+ also cache-friendly (we compute one double-width slice of the result
+ at a time, then move on, never bactracking except for the helpful
+ single-width slice overlap between successive partial sums).
+ """
+ asize = len(a.digits)
+ bsize = len(b.digits)
+ # nbdone is # of b digits already multiplied
+
+ assert asize > KARATSUBA_CUTOFF
+ assert 2 * asize <= bsize
+
+ # Allocate result space, and zero it out.
+ ret = W_LongObject(a.space, [0] * (asize + bsize), 1)
+
+ # Successive slices of b are copied into bslice.
+ #bslice = W_LongObject(a.space, [0] * asize, 1)
+ # XXX we cannot pre-allocate, see comments below!
+ bslice = W_LongObject(a.space, [0], 1)
+
+ nbdone = 0;
+ while bsize > 0:
+ nbtouse = min(bsize, asize)
+
+ # Multiply the next slice of b by a.
+
+ #bslice.digits[:nbtouse] = b.digits[nbdone : nbdone + nbtouse]
+ # XXX: this would be more efficient if we adopted CPython's
+ # way to store the size, instead of resizing the list!
+ # XXX change the implementation, encoding length via the sign.
+ bslice.digits = b.digits[nbdone : nbdone + nbtouse]
+ product = _k_mul(a, bslice)
+
+ # Add into result.
+ _v_iadd(ret.digits, nbdone, len(ret.digits) - nbdone,
+ product.digits, len(product.digits))
+ del product
+
+ bsize -= nbtouse
+ nbdone += nbtouse
+
+ ret._normalize()
+ return ret
+
+
def _inplace_divrem1(pout, pin, n, size=0):
"""
Divide long pin by non-zero digit n, storing quotient
in pout, and returning the remainder. It's OK for pin == pout on entry.
"""
- rem = Twodigits(0)
+ rem = 0
assert n > 0 and n <= MASK
if not size:
size = len(pin.digits)
@@ -665,7 +1049,7 @@
while size >= 0:
rem = (rem << SHIFT) + pin.digits[size]
hi = rem // n
- pout.digits[size] = Digit(hi)
+ pout.digits[size] = hi
rem -= hi * n
size -= 1
return rem
@@ -678,24 +1062,80 @@
"""
assert n > 0 and n <= MASK
size = len(a.digits)
- z = W_LongObject(a.space, [Digit(0)] * size, 1)
+ z = W_LongObject(a.space, [0] * size, 1)
rem = _inplace_divrem1(z, a, n)
z._normalize()
return z, rem
+def _v_iadd(x, xofs, m, y, n):
+ """
+ x[0:m] and y[0:n] are digit vectors, LSD first, m >= n required. x[0:n]
+ is modified in place, by adding y to it. Carries are propagated as far as
+ x[m-1], and the remaining carry (0 or 1) is returned.
+ Python adaptation: x is addressed relative to xofs!
+ """
+ carry = 0;
+
+ assert m >= n
+ i = xofs
+ iend = xofs + n
+ while i < iend:
+ carry += x[i] + y[i-xofs]
+ x[i] = carry & MASK
+ carry >>= SHIFT
+ assert (carry & 1) == carry
+ i += 1
+ iend = xofs + m
+ while carry and i < iend:
+ carry += x[i]
+ x[i] = carry & MASK
+ carry >>= SHIFT
+ assert (carry & 1) == carry
+ i += 1
+ return carry
+
+def _v_isub(x, xofs, m, y, n):
+ """
+ x[0:m] and y[0:n] are digit vectors, LSD first, m >= n required. x[0:n]
+ is modified in place, by subtracting y from it. Borrows are propagated as
+ far as x[m-1], and the remaining borrow (0 or 1) is returned.
+ Python adaptation: x is addressed relative to xofs!
+ """
+ borrow = 0
+
+ assert m >= n
+ i = xofs
+ iend = xofs + n
+ while i < iend:
+ borrow = x[i] - y[i-xofs] - borrow
+ x[i] = borrow & MASK
+ borrow >>= SHIFT
+ borrow &= 1 # keep only 1 sign bit
+ i += 1
+ iend = xofs + m
+ while borrow and i < iend:
+ borrow = x[i] - borrow
+ x[i] = borrow & MASK
+ borrow >>= SHIFT
+ borrow &= 1
+ i += 1
+ return borrow
+
+
def _muladd1(a, n, extra):
"""Multiply by a single digit and add a single digit, ignoring the sign.
"""
size_a = len(a.digits)
- z = W_LongObject(a.space, [Digit(0)] * (size_a+1), 1)
- carry = Carrymul(extra)
+ z = W_LongObject(a.space, [0] * (size_a+1), 1)
+ carry = extra
+ assert carry & MASK == carry
i = 0
while i < size_a:
- carry += Twodigits(a.digits[i]) * n
- z.digits[i] = Digit(carry & MASK)
+ carry += a.digits[i] * n
+ z.digits[i] = carry & MASK
carry >>= SHIFT
i += 1
- z.digits[i] = Digit(carry)
+ z.digits[i] = carry
z._normalize()
return z
@@ -703,69 +1143,66 @@
def _x_divrem(v1, w1):
""" Unsigned long division with remainder -- the algorithm """
size_w = len(w1.digits)
- d = Digit(Twodigits(MASK+1) // (w1.digits[size_w-1] + 1))
- v = _muladd1(v1, d, Digit(0))
- w = _muladd1(w1, d, Digit(0))
+ d = (MASK+1) // (w1.digits[size_w-1] + 1)
+ v = _muladd1(v1, d, 0)
+ w = _muladd1(w1, d, 0)
size_v = len(v.digits)
size_w = len(w.digits)
assert size_v >= size_w and size_w > 1 # Assert checks by div()
size_a = size_v - size_w + 1
- a = W_LongObject(v.space, [Digit(0)] * size_a, 1)
+ a = W_LongObject(v.space, [0] * size_a, 1)
j = size_v
k = size_a - 1
while k >= 0:
if j >= size_v:
- vj = Digit(0)
+ vj = 0
else:
vj = v.digits[j]
- carry = Stwodigits(0) # note: this must hold two digits and a sign!
+ carry = 0
if vj == w.digits[size_w-1]:
- q = Twodigits(MASK)
+ q = MASK
else:
- q = ((Twodigits(vj) << SHIFT) + v.digits[j-1]) // w.digits[size_w-1]
+ q = ((vj << SHIFT) + v.digits[j-1]) // w.digits[size_w-1]
- # notabene!
- # this check needs a signed two digits result
- # or we get an overflow.
while (w.digits[size_w-2] * q >
((
- (Stwodigits(vj) << SHIFT) # this one dominates
- + Stwodigits(v.digits[j-1])
- - Stwodigits(q) * Stwodigits(w.digits[size_w-1])
+ (vj << SHIFT)
+ + v.digits[j-1]
+ - q * w.digits[size_w-1]
) << SHIFT)
- + Stwodigits(v.digits[j-2])):
+ + v.digits[j-2]):
q -= 1
i = 0
while i < size_w and i+k < size_v:
- z = Stwodigits(w.digits[i] * q)
+ z = w.digits[i] * q
zz = z >> SHIFT
- carry += Stwodigits(v.digits[i+k]) - z + (zz << SHIFT)
- v.digits[i+k] = Digit(carry & MASK)
+ carry += v.digits[i+k] - z + (zz << SHIFT)
+ v.digits[i+k] = carry & MASK
carry >>= SHIFT
carry -= zz
i += 1
if i+k < size_v:
- carry += Stwodigits(v.digits[i+k])
- v.digits[i+k] = Digit(0)
+ carry += v.digits[i+k]
+ v.digits[i+k] = 0
if carry == 0:
- a.digits[k] = Digit(q & MASK)
+ a.digits[k] = q & MASK
assert not q >> SHIFT
else:
assert carry == -1
q -= 1
- a.digits[k] = Digit(q & MASK)
+ a.digits[k] = q & MASK
assert not q >> SHIFT
- carry = Stwodigits(0)
+ carry = 0
i = 0
while i < size_w and i+k < size_v:
- carry += Stwodigits(v.digits[i+k]) + Stwodigits(w.digits[i])
- v.digits[i+k] = Digit(carry & MASK)
+ carry += v.digits[i+k] + w.digits[i]
+ v.digits[i+k] = carry & MASK
carry >>= SHIFT
i += 1
j -= 1
@@ -789,7 +1226,7 @@
(size_a == size_b and
a.digits[size_a-1] < b.digits[size_b-1])):
# |a| < |b|
- z = W_LongObject(a.space, [Digit(0)], 0)
+ z = W_LongObject(a.space, [0], 0)
rem = a
return z, rem
if size_b == 1:
@@ -846,12 +1283,12 @@
##def ldexp(x, exp):
## assert type(x) is float
## lb1 = LONG_BIT - 1
-## multiplier = float(Digit(1) << lb1)
+## multiplier = float(1 << lb1)
## while exp >= lb1:
## x *= multiplier
## exp -= lb1
## if exp:
-## x *= float(Digit(1) << exp)
+## x *= float(1 << exp)
## return x
# note that math.ldexp checks for overflows,
@@ -904,13 +1341,13 @@
dval = -dval
frac, expo = math.frexp(dval) # dval = frac*2**expo; 0.0 <= frac < 1.0
if expo <= 0:
- return W_LongObject(space, [Digit(0)], 0)
+ return W_LongObject(space, [0], 0)
ndig = (expo-1) // SHIFT + 1 # Number of 'digits' in result
- v = W_LongObject(space, [Digit(0)] * ndig, 1)
+ v = W_LongObject(space, [0] * ndig, 1)
frac = math.ldexp(frac, (expo-1) % SHIFT + 1)
for i in range(ndig-1, -1, -1):
- bits = int(frac)
- v.digits[i] = Digit(bits)
+ bits = int(frac) & MASK # help the future annotator?
+ v.digits[i] = bits
frac -= float(bits)
frac = math.ldexp(frac, SHIFT)
if neg:
@@ -937,7 +1374,7 @@
div, mod = _divrem(v, w)
if mod.sign * w.sign == -1:
mod = add__Long_Long(v.space, mod, w)
- one = W_LongObject(v.space, [Digit(1)], 1)
+ one = W_LongObject(v.space, [1], 1)
div = sub__Long_Long(v.space, div, one)
return div, mod
@@ -974,7 +1411,7 @@
s[p] = '0'
elif (base & (base - 1)) == 0:
# JRH: special case for power-of-2 bases
- accum = Twodigits(0)
+ accum = 0
accumbits = 0 # # of bits in accum
basebits = 1 # # of bits in base-1
i = base
@@ -985,7 +1422,7 @@
basebits += 1
for i in range(size_a):
- accum |= Twodigits(a.digits[i]) << accumbits
+ accum |= a.digits[i] << accumbits
accumbits += SHIFT
assert accumbits >= basebits
while 1:
@@ -1012,17 +1449,17 @@
size = size_a
pin = a # just for similarity to C source which uses the array
# powbase <- largest power of base that fits in a digit.
- powbase = Digit(base) # powbase == base ** power
+ powbase = base # powbase == base ** power
power = 1
while 1:
- newpow = Twodigits(powbase) * Digit(base)
+ newpow = powbase * base
if newpow >> SHIFT: # doesn't fit in a digit
break
- powbase = Digit(newpow)
+ powbase = newpow
power += 1
# Get a scratch area for repeated division.
- scratch = W_LongObject(a.space, [Digit(0)] * size, 1)
+ scratch = W_LongObject(a.space, [0] * size, 1)
# Repeatedly divide by powbase.
while 1:
@@ -1086,14 +1523,14 @@
if a.sign < 0:
a = invert__Long(a.space, a)
- maska = Digit(MASK)
+ maska = MASK
else:
- maska = Digit(0)
+ maska = 0
if b.sign < 0:
b = invert__Long(b.space, b)
- maskb = Digit(MASK)
+ maskb = MASK
else:
- maskb = Digit(0)
+ maskb = 0
negz = 0
if op == '^':
@@ -1135,7 +1572,7 @@
else:
size_z = max(size_a, size_b)
- z = W_LongObject(a.space, [Digit(0)] * size_z, 1)
+ z = W_LongObject(a.space, [0] * size_z, 1)
for i in range(size_z):
if i < size_a:
@@ -1161,17 +1598,17 @@
def _AsLong(v):
"""
Get an integer from a long int object.
- Returns -1 and sets an error condition if overflow occurs.
+ Raises OverflowError if overflow occurs.
"""
# This version by Tim Peters
i = len(v.digits) - 1
sign = v.sign
if not sign:
return 0
- x = r_uint(0)
+ x = 0
while i >= 0:
prev = x
- x = (x << SHIFT) + v.digits[i]
+ x = ((x << SHIFT) + v.digits[i]) & LONG_MASK
if (x >> SHIFT) != prev:
raise OverflowError
i -= 1
@@ -1180,7 +1617,7 @@
# trouble *unless* this is the min negative number. So,
# trouble iff sign bit set && (positive || some bit set other
# than the sign bit).
- if int(x) < 0 and (sign > 0 or (x << 1) != 0):
+ if intmask(x) < 0 and (sign > 0 or (x << 1) & LONG_MASK != 0):
raise OverflowError
return intmask(int(x) * sign)
Modified: pypy/dist/pypy/objspace/std/test/test_longobject.py
==============================================================================
--- pypy/dist/pypy/objspace/std/test/test_longobject.py (original)
+++ pypy/dist/pypy/objspace/std/test/test_longobject.py Wed Jul 13 17:30:39 2005
@@ -3,8 +3,8 @@
from random import random, randint
from pypy.objspace.std import longobject as lobj
from pypy.objspace.std.objspace import FailedToImplement
-from pypy.rpython.rarithmetic import r_uint
from pypy.interpreter.error import OperationError
+from pypy.rpython.rarithmetic import r_uint # will go away
objspacename = 'std'
@@ -40,7 +40,7 @@
assert result.longval() == x * i - y * j
def test_subzz(self):
- w_l0 = lobj.W_LongObject(self.space, [r_uint(0)])
+ w_l0 = lobj.W_LongObject(self.space, [0])
assert self.space.sub(w_l0, w_l0).longval() == 0
def test_mul(self):
@@ -50,13 +50,16 @@
f2 = lobj.W_LongObject(self.space, *lobj.args_from_long(y))
result = lobj.mul__Long_Long(self.space, f1, f2)
assert result.longval() == x * y
+ # also test a * a, it has special code
+ result = lobj.mul__Long_Long(self.space, f1, f1)
+ assert result.longval() == x * x
def test__inplace_divrem1(self):
# signs are not handled in the helpers!
x = 1238585838347L
y = 3
f1 = lobj.W_LongObject(self.space, *lobj.args_from_long(x))
- f2 = r_uint(y)
+ f2 = y
remainder = lobj._inplace_divrem1(f1, f1, f2)
assert (f1.longval(), remainder) == divmod(x, y)
@@ -65,7 +68,7 @@
x = 1238585838347L
y = 3
f1 = lobj.W_LongObject(self.space, *lobj.args_from_long(x))
- f2 = r_uint(y)
+ f2 = y
div, rem = lobj._divrem1(f1, f2)
assert (div.longval(), rem) == divmod(x, y)
@@ -74,8 +77,8 @@
y = 3
z = 42
f1 = lobj.W_LongObject(self.space, *lobj.args_from_long(x))
- f2 = r_uint(y)
- f3 = r_uint(z)
+ f2 = y
+ f3 = z
prod = lobj._muladd1(f1, f2, f3)
assert prod.longval() == x * y + z
@@ -125,6 +128,45 @@
except OperationError, e:
assert e.w_type is self.space.w_OverflowError
+ # testing Karatsuba stuff
+ def test__v_iadd(self):
+ f1 = lobj.W_LongObject(self.space, [lobj.MASK] * 10, 1)
+ f2 = lobj.W_LongObject(self.space, [1], 1)
+ carry = lobj._v_iadd(f1.digits, 1, len(f1.digits)-1, f2.digits, 1)
+ assert carry == 1
+ assert f1.longval() == lobj.MASK
+
+ def test__v_isub(self):
+ f1 = lobj.W_LongObject(self.space, [lobj.MASK] + [0] * 9 + [1], 1)
+ f2 = lobj.W_LongObject(self.space, [1], 1)
+ borrow = lobj._v_isub(f1.digits, 1, len(f1.digits)-1, f2.digits, 1)
+ assert borrow == 0
+ assert f1.longval() == (1 << lobj.SHIFT) ** 10 - 1
+
+ def test__kmul_split(self):
+ split = 5
+ diglo = [0] * split
+ dighi = [lobj.MASK] * split
+ f1 = lobj.W_LongObject(self.space, diglo + dighi, 1)
+ hi, lo = lobj._kmul_split(f1, split)
+ assert lo.digits == [0]
+ assert hi.digits == dighi
+
+ def test__k_mul(self):
+ digs= lobj.KARATSUBA_CUTOFF * 5
+ f1 = lobj.W_LongObject(self.space, [lobj.MASK] * digs, 1)
+ f2 = lobj._x_add(f1,lobj.W_LongObject(self.space, [1], 1))
+ ret = lobj._k_mul(f1, f2)
+ assert ret.longval() == f1.longval() * f2.longval()
+
+ def test__k_lopsided_mul(self):
+ digs_a = lobj.KARATSUBA_CUTOFF + 3
+ digs_b = 3 * digs_a
+ f1 = lobj.W_LongObject(self.space, [lobj.MASK] * digs_a, 1)
+ f2 = lobj.W_LongObject(self.space, [lobj.MASK] * digs_b, 1)
+ ret = lobj._k_lopsided_mul(f1, f2)
+ assert ret.longval() == f1.longval() * f2.longval()
+
def test_eq(self):
x = 5858393919192332223L
y = 585839391919233111223311112332L
@@ -157,7 +199,7 @@
u = lobj.uint_w__Long(self.space, f2)
assert u == 12332
- assert isinstance(u, r_uint)
+ assert type(u) is r_uint
def test_conversions(self):
space = self.space
@@ -170,10 +212,10 @@
assert space.is_true(space.isinstance(lobj.int__Long(space, w_lv), space.w_int))
assert space.eq_w(lobj.int__Long(space, w_lv), w_v)
- if v>=0:
+ if v >= 0:
u = lobj.uint_w__Long(space, w_lv)
assert u == v
- assert isinstance(u, r_uint)
+ assert type(u) is r_uint
else:
space.raises_w(space.w_ValueError, lobj.uint_w__Long, space, w_lv)
@@ -193,8 +235,7 @@
u = lobj.uint_w__Long(space, w_lmaxuint)
assert u == 2*sys.maxint+1
- assert isinstance(u, r_uint)
-
+
space.raises_w(space.w_ValueError, lobj.uint_w__Long, space, w_toobig_lv3)
space.raises_w(space.w_OverflowError, lobj.uint_w__Long, space, w_toobig_lv4)
@@ -229,10 +270,10 @@
assert v.longval() == x ** y
def test_normalize(self):
- f1 = lobj.W_LongObject(self.space, [lobj.r_uint(1), lobj.r_uint(0)], 1)
+ f1 = lobj.W_LongObject(self.space, [1, 0], 1)
f1._normalize()
assert len(f1.digits) == 1
- f0 = lobj.W_LongObject(self.space, [lobj.r_uint(0)], 0)
+ f0 = lobj.W_LongObject(self.space, [0], 0)
assert self.space.is_true(
self.space.eq(lobj.sub__Long_Long(self.space, f1, f1), f0))
More information about the Pypy-commit
mailing list