[pypy-commit] pypy improve-rbigint: More fixes to toom cock

stian noreply at buildbot.pypy.org
Sat Jul 21 18:41:07 CEST 2012


Author: stian
Branch: improve-rbigint
Changeset: r56321:423421c5c9ba
Date: 2012-06-23 06:41 +0200
http://bitbucket.org/pypy/pypy/changeset/423421c5c9ba/

Log:	More fixes to toom cock

diff --git a/pypy/rlib/rbigint.py b/pypy/rlib/rbigint.py
--- a/pypy/rlib/rbigint.py
+++ b/pypy/rlib/rbigint.py
@@ -36,8 +36,8 @@
 KARATSUBA_CUTOFF = 38
 KARATSUBA_SQUARE_CUTOFF = 2 * KARATSUBA_CUTOFF
 
-USE_TOOMCOCK = False
-TOOMCOOK_CUTOFF = 102
+USE_TOOMCOCK = False # WIP
+TOOMCOOK_CUTOFF = 3 # Smallest possible cutoff is 3. Ideal is probably around 150+
 
 # For exponentiation, use the binary left-to-right algorithm
 # unless the exponent contains more than FIVEARY_CUTOFF digits.
@@ -956,30 +956,18 @@
     """
     A helper for Karatsuba multiplication (k_mul).
     Takes a bigint "n" and an integer "size" representing the place to
-    split, and sets low and high such that abs(n) == (high << size) + low,
+    split, and sets low and high such that abs(n) == (high << (size * 2) + (mid << size) + low,
     viewing the shift as being by digits.  The sign bit is ignored, and
     the return values are >= 0.
     """
-    
-    assert size > 0
-    
-    size_n = n.numdigits()
-    shift = min(size_n, size)
-    
-    lo = rbigint(n._digits[:shift], 1)
-    if size_n >= (shift * 2):
-        mid = rbigint(n._digits[shift:shift >> 1], 1)
-        hi = rbigint(n._digits[shift >> 1:], 1)
-    else:
-        mid = rbigint(n._digits[shift:], 1)
-        hi = rbigint([NULLDIGIT] * ((shift * 3) - size_n), 1)
+    lo = rbigint(n._digits[:size], 1)
+    mid = rbigint(n._digits[size:size * 2], 1)
+    hi = rbigint(n._digits[size *2:], 1)
     lo._normalize()
     mid._normalize()
     hi._normalize()
     return hi, mid, lo
 
-# Declear a simple 2 as constants for our toom cook
-POINT2 = rbigint.fromint(2)
 def _tc_mul(a, b):
     """
     Toom Cook
@@ -988,7 +976,7 @@
     bsize = b.numdigits()
 
     # Split a & b into hi, mid and lo pieces.
-    shift = bsize >> 1
+    shift = asize // 3
     ah, am, al = _tcmul_split(a, shift)
     assert ah.sign == 1    # the split isn't degenerate
 
@@ -1006,15 +994,16 @@
     pO = al.add(ah)
     p1 = pO.add(am)
     pn1 = pO.sub(am)
-    pn2 = pn1.add(ah).mul(POINT2).sub(al)
+    pn2 = pn1.add(ah).lshift(1).sub(al)
     
     qO = bl.add(bh)
     q1 = qO.add(bm)
     qn1 = qO.sub(bm)
-    qn2 = qn1.add(bh).mul(POINT2).sub(bl)
+    qn2 = qn1.add(bh).lshift(1).sub(bl)
     
     w0 = al.mul(bl)
     winf = ah.mul(bh)
+
     w1 = p1.mul(q1)
     wn1 = pn1.mul(qn1)
     wn2 = pn2.mul(qn2)
@@ -1024,26 +1013,29 @@
     r0 = w0
     r4 = winf
     r3 = _divrem1(wn2.sub(wn1), 3)[0]
-    r1 = _divrem1(w1.sub(wn1), 2)[0]
+    r1 = w1.sub(wn1).rshift(1)
     r2 = wn1.sub(w0)
-    r3 = _divrem1(r2.sub(r3), 2)[0].add(r4.mul(POINT2))
+    r3 = _divrem1(r2.sub(r3), 2)[0].add(r4.lshift(1))
     r2 = r2.add(r1).sub(r4)
     r1 = r1.sub(r3)
     
     # Now we fit r+ r2 + r4 into the new string.
     # Now we got to add the r1 and r3 in the mid shift. This is TODO (aga, not fixed yet)
-    pointer = r0.numdigits()
-    ret._digits[:pointer] = r0._digits
+    ret._digits[:shift] = r0._digits
     
-    pointer2 = pointer + r2.numdigits()
-    ret._digits[pointer:pointer2] = r2._digits
+    ret._digits[shift:shift*2] = r2._digits
     
-    pointer3 = pointer2 + r4.numdigits()
-    ret._digits[pointer2:pointer3] = r4._digits
+    ret._digits[shift*2:(shift*2)+r4.numdigits()] = r4._digits
     
     # TODO!!!!
-    #_v_iadd(ret, shift, i, r1, r1.numdigits())
-    #_v_iadd(ret, shift >> 1, i, r3, r3.numdigits())
+    """
+    x and y are rbigints, m >= n required.  x.digits[0:n] is modified in place,
+    by adding y.digits[0:m] to it.  Carries are propagated as far as
+    x[m-1], and the remaining carry (0 or 1) is returned.
+    Python adaptation: x is addressed relative to xofs!
+    """
+    _v_iadd(ret, shift, shift + r1.numdigits(), r1, r1.numdigits())
+    _v_iadd(ret, shift * 2, shift + r3.numdigits(), r3, r3.numdigits())
 
     ret._normalize()
     return ret


More information about the pypy-commit mailing list