[pypy-commit] pypy default: Merge improve-rbigint. This branch improves the performance on most long operations and use 64bit storage and __int128 for wide digits on systems where it is available.

Stian Andreassen noreply at buildbot.pypy.org
Wed Jul 25 02:42:50 CEST 2012


Author: Stian Andreassen
Branch: 
Changeset: r56444:8a78c6bf2abb
Date: 2012-07-25 02:37 +0200
http://bitbucket.org/pypy/pypy/changeset/8a78c6bf2abb/

Log:	Merge improve-rbigint. This branch improves the performance on most
	long operations and use 64bit storage and __int128 for wide digits
	on systems where it is available.

	Special cases for power of two mod, division, multiplication.
	Improvements to pow (see
	pypy/translator/goal/targetbigintbenchmark.py for some runs on my
	system), mark operations as elidable and various other tweaks.

	Overall, it makes things run faster than CPython if the script
	doesn't heavily rely on division.

diff --git a/pypy/module/sys/system.py b/pypy/module/sys/system.py
--- a/pypy/module/sys/system.py
+++ b/pypy/module/sys/system.py
@@ -48,8 +48,8 @@
 
 def get_long_info(space):
     #assert rbigint.SHIFT == 31
-    bits_per_digit = 31 #rbigint.SHIFT
-    sizeof_digit = rffi.sizeof(rffi.ULONG)
+    bits_per_digit = rbigint.SHIFT
+    sizeof_digit = rffi.sizeof(rbigint.STORE_TYPE)
     info_w = [
         space.wrap(bits_per_digit),
         space.wrap(sizeof_digit),
diff --git a/pypy/rlib/rarithmetic.py b/pypy/rlib/rarithmetic.py
--- a/pypy/rlib/rarithmetic.py
+++ b/pypy/rlib/rarithmetic.py
@@ -115,21 +115,16 @@
         n -= 2*LONG_TEST
     return int(n)
 
-if LONG_BIT >= 64:
-    def longlongmask(n):
-        assert isinstance(n, (int, long))
-        return int(n)
-else:
-    def longlongmask(n):
-        """
-        NOT_RPYTHON
-        """
-        assert isinstance(n, (int, long))
-        n = long(n)
-        n &= LONGLONG_MASK
-        if n >= LONGLONG_TEST:
-            n -= 2*LONGLONG_TEST
-        return r_longlong(n)
+def longlongmask(n):
+    """
+    NOT_RPYTHON
+    """
+    assert isinstance(n, (int, long))
+    n = long(n)
+    n &= LONGLONG_MASK
+    if n >= LONGLONG_TEST:
+        n -= 2*LONGLONG_TEST
+    return r_longlong(n)
 
 def longlonglongmask(n):
     # Assume longlonglong doesn't overflow. This is perfectly fine for rbigint.
diff --git a/pypy/rlib/rbigint.py b/pypy/rlib/rbigint.py
--- a/pypy/rlib/rbigint.py
+++ b/pypy/rlib/rbigint.py
@@ -23,7 +23,10 @@
     SHIFT = 63
     BASE = long(1 << SHIFT)
     UDIGIT_TYPE = r_ulonglong
-    UDIGIT_MASK = longlongmask
+    if LONG_BIT >= 64:
+        UDIGIT_MASK = intmask
+    else:
+        UDIGIT_MASK = longlongmask
     LONG_TYPE = rffi.__INT128
     if LONG_BIT > SHIFT:
         STORE_TYPE = lltype.Signed
@@ -41,7 +44,7 @@
     LONG_TYPE = rffi.LONGLONG
 
 MASK = BASE - 1
-FLOAT_MULTIPLIER = float(1 << LONG_BIT) # Because it works.
+FLOAT_MULTIPLIER = float(1 << SHIFT)
 
 # Debugging digit array access.
 #
@@ -137,7 +140,7 @@
     udigit._always_inline_ = True
 
     def setdigit(self, x, val):
-        val = val & MASK
+        val = _mask_digit(val)
         assert val >= 0
         self._digits[x] = _store_digit(val)
     setdigit._annspecialcase_ = 'specialize:argtype(2)'
@@ -448,8 +451,8 @@
                 
             if asize <= i:
                 result = _x_mul(a, b)
-            elif 2 * asize <= bsize:
-                result = _k_lopsided_mul(a, b)
+                """elif 2 * asize <= bsize:
+                    result = _k_lopsided_mul(a, b)"""
             else:
                 result = _k_mul(a, b)
         else:
@@ -465,10 +468,10 @@
 
     @jit.elidable
     def floordiv(self, other):
-        if other.numdigits() == 1 and other.sign == 1:
+        if self.sign == 1 and other.numdigits() == 1 and other.sign == 1:
             digit = other.digit(0)
             if digit == 1:
-                return rbigint(self._digits[:], other.sign * self.sign, self.size)
+                return rbigint(self._digits[:], 1, self.size)
             elif digit and digit & (digit - 1) == 0:
                 return self.rshift(ptwotable[digit])
             
@@ -476,10 +479,8 @@
         if mod.sign * other.sign == -1:
             if div.sign == 0:
                 return ONENEGATIVERBIGINT
-            if div.sign == 1:
-                _v_isub(div, 0, div.numdigits(), ONERBIGINT, 1)
-            else:
-                _v_iadd(div, 0, div.numdigits(), ONERBIGINT, 1)
+            div = div.sub(ONERBIGINT)
+            
         return div
 
     def div(self, other):
@@ -545,10 +546,7 @@
             mod = mod.add(w)
             if div.sign == 0:
                 return ONENEGATIVERBIGINT, mod
-            if div.sign == 1:
-                _v_isub(div, 0, div.numdigits(), ONERBIGINT, 1)
-            else:
-                _v_iadd(div, 0, div.numdigits(), ONERBIGINT, 1)
+            div = div.sub(ONERBIGINT)
         return div, mod
 
     @jit.elidable
@@ -567,11 +565,6 @@
             # XXX failed to implement
             raise ValueError("bigint pow() too negative")
         
-        if b.sign == 0:
-            return ONERBIGINT
-        elif a.sign == 0:
-            return NULLRBIGINT
-        
         size_b = b.numdigits()
         
         if c is not None:
@@ -589,14 +582,17 @@
             #     return 0
             if c.numdigits() == 1 and c._digits[0] == ONEDIGIT:
                 return NULLRBIGINT
-
+   
             # if base < 0:
             #     base = base % modulus
             # Having the base positive just makes things easier.
             if a.sign < 0:
                 a = a.mod(c)
-                
             
+        elif b.sign == 0:
+            return ONERBIGINT
+        elif a.sign == 0:
+            return NULLRBIGINT
         elif size_b == 1:
             if b._digits[0] == NULLDIGIT:
                 return ONERBIGINT if a.sign == 1 else ONENEGATIVERBIGINT
@@ -689,12 +685,12 @@
         return z
 
     def neg(self):
-        return rbigint(self._digits[:], -self.sign, self.size)
+        return rbigint(self._digits, -self.sign, self.size)
 
     def abs(self):
         if self.sign != -1:
             return self
-        return rbigint(self._digits[:], abs(self.sign), self.size)
+        return rbigint(self._digits, 1, self.size)
 
     def invert(self): #Implement ~x as -(x + 1)
         if self.sign == 0:
@@ -703,16 +699,6 @@
         ret = self.add(ONERBIGINT)
         ret.sign = -ret.sign
         return ret
-
-    def inplace_invert(self): # Used by rshift and bitwise to prevent a double allocation.
-        if self.sign == 0:
-            return ONENEGATIVERBIGINT
-        if self.sign == 1:
-            _v_iadd(self, 0, self.numdigits(), ONERBIGINT, 1)
-        else:
-             _v_isub(self, 0, self.numdigits(), ONERBIGINT, 1)
-        self.sign = -self.sign
-        return self
         
     @jit.elidable    
     def lshift(self, int_other):
@@ -726,6 +712,9 @@
         remshift  = int_other - wordshift * SHIFT
 
         if not remshift:
+            # So we can avoid problems with eq, AND avoid the need for normalize.
+            if self.sign == 0:
+                return self
             return rbigint([NULLDIGIT] * wordshift + self._digits, self.sign, self.size + wordshift)
         
         oldsize = self.numdigits()
@@ -777,7 +766,7 @@
         if self.sign == -1 and not dont_invert:
             a1 = self.invert()
             a2 = a1.rshift(int_other)
-            return a2.inplace_invert()
+            return a2.invert()
 
         wordshift = int_other // SHIFT
         newsize = self.numdigits() - wordshift
@@ -846,7 +835,7 @@
 
     def _normalize(self):
         i = self.numdigits()
-        # i is always >= 1
+
         while i > 1 and self._digits[i - 1] == NULLDIGIT:
             i -= 1
         assert i > 0
@@ -855,7 +844,7 @@
         if self.numdigits() == 1 and self._digits[0] == NULLDIGIT:
             self.sign = 0
             self._digits = [NULLDIGIT]
-            
+
     _normalize._always_inline_ = True
     
     @jit.elidable
@@ -878,8 +867,9 @@
         return bits
 
     def __repr__(self):
-        return "<rbigint digits=%s, sign=%s, %s>" % (self._digits,
-                                                     self.sign, self.str())
+        return "<rbigint digits=%s, sign=%s, size=%d, len=%d, %s>" % (self._digits,
+                                            self.sign, self.size, len(self._digits),
+                                            self.str())
 
 ONERBIGINT = rbigint([ONEDIGIT], 1, 1)
 ONENEGATIVERBIGINT = rbigint([ONEDIGIT], -1, 1)
@@ -1218,8 +1208,9 @@
     size_n = n.numdigits()
     size_lo = min(size_n, size)
 
-    lo = rbigint(n._digits[:size_lo], 1)
-    hi = rbigint(n._digits[size_lo:], 1)
+    # We use "or" her to avoid having a check where list can be empty in _normalize.
+    lo = rbigint(n._digits[:size_lo] or [NULLDIGIT], 1)
+    hi = rbigint(n._digits[size_lo:n.size] or [NULLDIGIT], 1)
     lo._normalize()
     hi._normalize()
     return hi, lo
@@ -1243,7 +1234,10 @@
     # Split a & b into hi & lo pieces.
     shift = bsize >> 1
     ah, al = _kmul_split(a, shift)
-    assert ah.sign == 1    # the split isn't degenerate
+    if ah.sign == 0:
+        # This may happen now that _k_lopsided_mul ain't catching it.
+        return _x_mul(a, b)
+    #assert ah.sign == 1    # the split isn't degenerate
 
     if a is b:
         bh = ah
@@ -1271,6 +1265,7 @@
 
     # 2. t1 <- ah*bh, and copy into high digits of result.
     t1 = ah.mul(bh)
+
     assert t1.sign >= 0
     assert 2*shift + t1.numdigits() <= ret.numdigits()
     ret._digits[2*shift : 2*shift + t1.numdigits()] = t1._digits
@@ -1364,6 +1359,8 @@
 """
 
 def _k_lopsided_mul(a, b):
+    # Not in use anymore, only account for like 1% performance. Perhaps if we
+    # Got rid of the extra list allocation this would be more effective.
     """
     b has at least twice the digits of a, and a is big enough that Karatsuba
     would pay off *if* the inputs had balanced sizes.  View b as a sequence
@@ -1579,30 +1576,27 @@
     wm2 = w.widedigit(abs(size_w-2))
     j = size_v
     k = size_a - 1
+    carry = _widen_digit(0)
     while k >= 0:
-        assert j >= 2
+        assert j > 1
         if j >= size_v:
             vj = 0
         else:
             vj = v.widedigit(j)
-            
-        carry = 0
-        vj1 = v.widedigit(abs(j-1))
         
         if vj == wm1:
             q = MASK
-            r = 0
         else:
-            vv = ((vj << SHIFT) | vj1)
-            q = vv // wm1
-            r = _widen_digit(vv) - wm1 * q
-        
-        vj2 = v.widedigit(abs(j-2))
-        while wm2 * q > ((r << SHIFT) | vj2):
+            q = ((vj << SHIFT) + v.widedigit(abs(j-1))) // wm1
+
+        while (wm2 * q >
+                ((
+                    (vj << SHIFT)
+                    + v.widedigit(abs(j-1))
+                    - q * wm1
+                                ) << SHIFT)
+                + v.widedigit(abs(j-2))):
             q -= 1
-            r += wm1
-            if r > MASK:
-                break
         i = 0
         while i < size_w and i+k < size_v:
             z = w.widedigit(i) * q
@@ -1635,6 +1629,7 @@
                 i += 1
         j -= 1
         k -= 1
+        carry = 0
 
     a._normalize()
     _inplace_divrem1(v, v, d, size_v)
@@ -2223,7 +2218,7 @@
     if negz == 0:
         return z
     
-    return z.inplace_invert()
+    return z.invert()
 _bitwise._annspecialcase_ = "specialize:arg(1)"
 
 
diff --git a/pypy/rlib/test/test_rbigint.py b/pypy/rlib/test/test_rbigint.py
--- a/pypy/rlib/test/test_rbigint.py
+++ b/pypy/rlib/test/test_rbigint.py
@@ -442,6 +442,12 @@
                     res2 = getattr(operator, mod)(x, y)
                     assert res1 == res2
 
+    def test_mul_eq_shift(self):
+        p2 = rbigint.fromlong(1).lshift(63)
+        f1 = rbigint.fromlong(0).lshift(63)
+        f2 = rbigint.fromlong(0).mul(p2)
+        assert f1.eq(f2)
+            
     def test_tostring(self):
         z = rbigint.fromlong(0)
         assert z.str() == '0'
diff --git a/pypy/translator/goal/targetbigintbenchmark.py b/pypy/translator/goal/targetbigintbenchmark.py
--- a/pypy/translator/goal/targetbigintbenchmark.py
+++ b/pypy/translator/goal/targetbigintbenchmark.py
@@ -35,24 +35,24 @@
         Sum:  142.686547
         
         Pypy with improvements:
-        mod by 2:  0.003079
-        mod by 10000:  3.148599
-        mod by 1024 (power of two):  0.009572
-        Div huge number by 2**128: 2.202237
-        rshift: 2.240624
-        lshift: 1.405393
-        Floordiv by 2: 1.562338
-        Floordiv by 3 (not power of two): 4.197440
-        2**500000: 0.033737
-        (2**N)**5000000 (power of two): 0.046997
-        10000 ** BIGNUM % 100 1.321710
-        i = i * i: 3.929341
-        n**10000 (not power of two): 6.215907
-        Power of two ** power of two: 0.014209
-        v = v * power of two 3.506702
-        v = v * v 6.253210
-        v = v + v 2.772122
-        Sum:  38.863216
+        mod by 2:  0.006321
+        mod by 10000:  3.143117
+        mod by 1024 (power of two):  0.009611
+        Div huge number by 2**128: 2.138351
+        rshift: 2.247337
+        lshift: 1.334369
+        Floordiv by 2: 1.555604
+        Floordiv by 3 (not power of two): 4.275014
+        2**500000: 0.033836
+        (2**N)**5000000 (power of two): 0.049600
+        10000 ** BIGNUM % 100 1.326477
+        i = i * i: 3.924958
+        n**10000 (not power of two): 6.335759
+        Power of two ** power of two: 0.013380
+        v = v * power of two 3.497662
+        v = v * v 6.359251
+        v = v + v 2.785971
+        Sum:  39.036619
 
         With SUPPORT_INT128 set to False
         mod by 2:  0.004103


More information about the pypy-commit mailing list