[pypy-commit] pypy improve-rbigint: Fixes. And reintroduce the jit stuff

stian noreply at buildbot.pypy.org
Sat Jul 21 18:41:57 CEST 2012


Author: stian
Branch: improve-rbigint
Changeset: r56365:3ac65815b0d4
Date: 2012-07-13 23:34 +0200
http://bitbucket.org/pypy/pypy/changeset/3ac65815b0d4/

Log:	Fixes. And reintroduce the jit stuff

diff --git a/pypy/rlib/rbigint.py b/pypy/rlib/rbigint.py
--- a/pypy/rlib/rbigint.py
+++ b/pypy/rlib/rbigint.py
@@ -141,11 +141,11 @@
 class rbigint(object):
     """This is a reimplementation of longs using a list of digits."""
 
-    def __init__(self, digits=[NULLDIGIT], sign=0):
+    def __init__(self, digits=[NULLDIGIT], sign=0, size=0):
         _check_digits(digits)
         make_sure_not_resized(digits)
         self._digits = digits
-        self.size = len(digits)
+        self.size = size or len(digits)
         self.sign = sign
 
     def digit(self, x):
@@ -165,7 +165,7 @@
     udigit._always_inline_ = True
     udigit._annonforceargs_ = [None, r_uint]
     def setdigit(self, x, val):
-        val = _mask_digit(val)
+        val = val & MASK
         assert val >= 0
         self._digits[x] = _store_digit(val)
     setdigit._annspecialcase_ = 'specialize:argtype(2)'
@@ -199,10 +199,10 @@
         if SHIFT >= 63:
             carry = ival >> SHIFT
             if carry:
-                return rbigint([_store_digit(_mask_digit(ival)),
-                    _store_digit(_mask_digit(carry))], sign)
+                return rbigint([_store_digit(ival & MASK),
+                    _store_digit(carry & MASK)], sign)
             else:
-                return rbigint([_store_digit(_mask_digit(ival))], sign)
+                return rbigint([_store_digit(ival & MASK)], sign)
             
         t = ival
         ndigits = 0
@@ -220,7 +220,6 @@
         return v
 
     @staticmethod
-    #@jit.elidable
     def frombool(b):
         # This function is marked as pure, so you must not call it and
         # then modify the result.
@@ -296,6 +295,7 @@
             raise OverflowError
         return intmask(intmask(x) * sign)
 
+    @jit.elidable
     def tolonglong(self):
         return _AsLongLong(self)
 
@@ -307,6 +307,7 @@
             raise ValueError("cannot convert negative integer to unsigned int")
         return self._touint_helper()
 
+    @jit.elidable
     def _touint_helper(self):
         x = r_uint(0)
         i = self.numdigits() - 1
@@ -319,14 +320,17 @@
             i -= 1
         return x
 
+    @jit.elidable
     def toulonglong(self):
         if self.sign == -1:
             raise ValueError("cannot convert negative integer to unsigned int")
         return _AsULonglong_ignore_sign(self)
 
+    @jit.elidable
     def uintmask(self):
         return _AsUInt_mask(self)
 
+    @jit.elidable
     def ulonglongmask(self):
         """Return r_ulonglong(self), truncating."""
         return _AsULonglong_mask(self)
@@ -335,21 +339,21 @@
     def tofloat(self):
         return _AsDouble(self)
 
-    #@jit.elidable
+    @jit.elidable
     def format(self, digits, prefix='', suffix=''):
         # 'digits' is a string whose length is the base to use,
         # and where each character is the corresponding digit.
         return _format(self, digits, prefix, suffix)
 
-    #@jit.elidable
+    @jit.elidable
     def repr(self):
         return _format(self, BASE10, '', 'L')
 
-    #@jit.elidable
+    @jit.elidable
     def str(self):
         return _format(self, BASE10)
 
-    #@jit.elidable
+    @jit.elidable
     def eq(self, other):
         if (self.sign != other.sign or
             self.numdigits() != other.numdigits()):
@@ -365,7 +369,7 @@
     def ne(self, other):
         return not self.eq(other)
 
-    #@jit.elidable
+    @jit.elidable
     def lt(self, other):
         if self.sign > other.sign:
             return False
@@ -413,7 +417,7 @@
     def hash(self):
         return _hash(self)
 
-    #@jit.elidable
+    @jit.elidable
     def add(self, other):
         if self.sign == 0:
             return other
@@ -426,12 +430,12 @@
         result.sign *= other.sign
         return result
 
-    #@jit.elidable
+    @jit.elidable
     def sub(self, other):
         if other.sign == 0:
             return self
         if self.sign == 0:
-            return rbigint(other._digits, -other.sign)
+            return rbigint(other._digits[:], -other.sign, other.size)
         if self.sign == other.sign:
             result = _x_sub(self, other)
         else:
@@ -439,7 +443,7 @@
         result.sign *= self.sign
         return result
 
-    #@jit.elidable
+    @jit.elidable
     def mul(self, b):
         asize = self.numdigits()
         bsize = b.numdigits()
@@ -456,7 +460,7 @@
             if a._digits[0] == NULLDIGIT:
                 return rbigint()
             elif a._digits[0] == ONEDIGIT:
-                return rbigint(b._digits, a.sign * b.sign)
+                return rbigint(b._digits[:], a.sign * b.sign, b.size)
             elif bsize == 1:
                 result = rbigint([NULLDIGIT] * 2, a.sign * b.sign)
                 carry = b.widedigit(0) * a.widedigit(0)
@@ -464,6 +468,7 @@
                 carry >>= SHIFT
                 if carry:
                     result.setdigit(1, carry)
+                result._normalize()
                 return result
                 
             result =  _x_mul(a, b, a.digit(0))
@@ -487,12 +492,12 @@
         result.sign = a.sign * b.sign
         return result
 
-    #@jit.elidable
+    @jit.elidable
     def truediv(self, other):
         div = _bigint_true_divide(self, other)
         return div
 
-    #@jit.elidable
+    @jit.elidable
     def floordiv(self, other):
         if other.numdigits() == 1 and other.sign == 1:
             digit = other.digit(0)
@@ -506,11 +511,10 @@
             div = div.sub(ONERBIGINT)
         return div
 
-    #@jit.elidable
     def div(self, other):
         return self.floordiv(other)
 
-    #@jit.elidable
+    @jit.elidable
     def mod(self, other):
         if self.sign == 0:
             return NULLRBIGINT
@@ -549,7 +553,7 @@
             mod = mod.add(other)
         return mod
 
-    #@jit.elidable
+    @jit.elidable
     def divmod(v, w):
         """
         The / and % operators are now defined in terms of divmod().
@@ -573,7 +577,7 @@
             div = div.sub(ONERBIGINT)
         return div, mod
 
-    #@jit.elidable
+    @jit.elidable
     def pow(a, b, c=None):
         negativeOutput = False  # if x<0 return negative output
 
@@ -711,12 +715,12 @@
         return z
 
     def neg(self):
-        return rbigint(self._digits, -self.sign)
+        return rbigint(self._digits[:], -self.sign, self.size)
 
     def abs(self):
         if self.sign != -1:
             return self
-        return rbigint(self._digits, abs(self.sign))
+        return rbigint(self._digits[:], abs(self.sign), self.size)
 
     def invert(self): #Implement ~x as -(x + 1)
         if self.sign == 0:
@@ -726,7 +730,17 @@
         ret.sign = -ret.sign
         return ret
 
-    #@jit.elidable
+    def inplace_invert(self): # Used by rshift and bitwise to prevent a double allocation.
+        if self.sign == 0:
+            return ONENEGATIVERBIGINT
+        if self.sign == 1:
+            _v_iadd(self, 0, self.numdigits(), ONERBIGINT, 1)
+        else:
+             _v_isub(self, 0, self.numdigits(), ONERBIGINT, 1)
+        self.sign = -self.sign
+        return self
+        
+    @jit.elidable    
     def lshift(self, int_other):
         if int_other < 0:
             raise ValueError("negative shift count")
@@ -738,7 +752,9 @@
         remshift  = int_other - wordshift * SHIFT
 
         if not remshift:
-            return rbigint([NULLDIGIT] * wordshift + self._digits, self.sign)
+            ret = rbigint([NULLDIGIT] * wordshift + self._digits, self.sign)
+            ret._normalize()
+            return ret
         
         oldsize = self.numdigits()
         newsize = oldsize + wordshift + 1
@@ -760,7 +776,7 @@
         return z
     lshift._always_inline_ = True # It's so fast that it's always benefitial.
     
-    #@jit.elidable
+    @jit.elidable
     def lqshift(self, int_other):
         " A quicker one with much less checks, int_other is valid and for the most part constant."
         assert int_other > 0
@@ -780,7 +796,7 @@
         return z
     lqshift._always_inline_ = True # It's so fast that it's always benefitial.
     
-    #@jit.elidable
+    @jit.elidable
     def rshift(self, int_other, dont_invert=False):
         if int_other < 0:
             raise ValueError("negative shift count")
@@ -789,7 +805,7 @@
         if self.sign == -1 and not dont_invert:
             a1 = self.invert()
             a2 = a1.rshift(int_other)
-            return a2.invert()
+            return a2.inplace_invert()
 
         wordshift = int_other // SHIFT
         newsize = self.numdigits() - wordshift
@@ -807,7 +823,7 @@
         while i < newsize:
             newdigit = (self.udigit(wordshift) >> loshift) #& lomask
             if i+1 < newsize:
-                newdigit |= (self.udigit(wordshift+1) << hishift) #& himask
+                newdigit += (self.udigit(wordshift+1) << hishift) #& himask
             z.setdigit(i, newdigit)
             i += 1
             wordshift += 1
@@ -815,15 +831,15 @@
         return z
     rshift._always_inline_ = True # It's so fast that it's always benefitial.
     
-    #@jit.elidable
+    @jit.elidable
     def and_(self, other):
         return _bitwise(self, '&', other)
 
-    #@jit.elidable
+    @jit.elidable
     def xor(self, other):
         return _bitwise(self, '^', other)
 
-    #@jit.elidable
+    @jit.elidable
     def or_(self, other):
         return _bitwise(self, '|', other)
 
@@ -836,7 +852,7 @@
     def hex(self):
         return _format(self, BASE16, '0x', 'L')
 
-    #@jit.elidable
+    @jit.elidable
     def log(self, base):
         # base is supposed to be positive or 0.0, which means we use e
         if base == 10.0:
@@ -875,6 +891,7 @@
             
     _normalize._always_inline_ = True
     
+    @jit.elidable
     def bit_length(self):
         i = self.numdigits()
         if i == 1 and self._digits[0] == NULLDIGIT:
@@ -1044,6 +1061,7 @@
         borrow >>= SHIFT
         borrow &= 1
         i += 1
+        
     assert borrow == 0
     z._normalize()
     return z
@@ -1091,7 +1109,11 @@
                 z.setdigit(pz, carry)
                 pz += 1
                 carry >>= SHIFT
-                #assert carry <= (_widen_digit(MASK) << 1)
+            if carry:
+                carry += z.widedigit(pz)
+                z.setdigit(pz, carry)
+                pz += 1
+                carry >>= SHIFT
             if carry:
                 z.setdigit(pz, z.widedigit(pz) + carry)
             assert (carry >> SHIFT) == 0
@@ -1442,7 +1464,7 @@
         pout.setdigit(size, hi)
         rem -= hi * n
         size -= 1
-    return _mask_digit(rem)
+    return rem & MASK
 
 def _divrem1(a, n):
     """
@@ -1649,6 +1671,7 @@
 
     a._normalize()
     _inplace_divrem1(v, v, d, size_v)
+    v._normalize()
     return a, v
 
     """
@@ -2221,6 +2244,7 @@
             digb = b.digit(i) ^ maskb
         else:
             digb = maskb
+            
         if op == '&':
             z.setdigit(i, diga & digb)
         elif op == '|':
@@ -2231,7 +2255,8 @@
     z._normalize()
     if negz == 0:
         return z
-    return z.invert()
+    
+    return z.inplace_invert()
 _bitwise._annspecialcase_ = "specialize:arg(1)"
 
 
diff --git a/pypy/rlib/test/test_rbigint.py b/pypy/rlib/test/test_rbigint.py
--- a/pypy/rlib/test/test_rbigint.py
+++ b/pypy/rlib/test/test_rbigint.py
@@ -407,7 +407,7 @@
     def test_normalize(self):
         f1 = bigint([1, 0], 1)
         f1._normalize()
-        assert len(f1._digits) == 1
+        assert f1.size == 1
         f0 = bigint([0], 0)
         assert f1.sub(f1).eq(f0)
 
diff --git a/pypy/rpython/lltypesystem/ll2ctypes.py b/pypy/rpython/lltypesystem/ll2ctypes.py
--- a/pypy/rpython/lltypesystem/ll2ctypes.py
+++ b/pypy/rpython/lltypesystem/ll2ctypes.py
@@ -133,6 +133,7 @@
         rffi.LONGLONG:   ctypes.c_longlong,
         rffi.ULONGLONG:  ctypes.c_ulonglong,
         rffi.SIZE_T:     ctypes.c_size_t,
+        rffi.__INT128:   ctypes.c_longlong, # XXX: Not right at all. But for some reason, It started by while doing JIT compile after a merge with default. Can't extend ctypes, because thats a python standard, right?
         lltype.Bool:     getattr(ctypes, "c_bool", ctypes.c_byte),
         llmemory.Address:  ctypes.c_void_p,
         llmemory.GCREF:    ctypes.c_void_p,
diff --git a/pypy/translator/goal/targetbigintbenchmark.py b/pypy/translator/goal/targetbigintbenchmark.py
--- a/pypy/translator/goal/targetbigintbenchmark.py
+++ b/pypy/translator/goal/targetbigintbenchmark.py
@@ -35,24 +35,24 @@
         Sum:  142.686547
         
         Pypy with improvements:
-        mod by 2:  0.005516
-        mod by 10000:  3.650751
-        mod by 1024 (power of two):  0.011492
-        Div huge number by 2**128: 2.148300
-        rshift: 2.333236
-        lshift: 1.355453
-        Floordiv by 2: 1.604574
-        Floordiv by 3 (not power of two): 4.155219
-        2**500000: 0.033960
-        (2**N)**5000000 (power of two): 0.046241
-        10000 ** BIGNUM % 100 1.963261
-        i = i * i: 3.906100
-        n**10000 (not power of two): 5.994802
-        Power of two ** power of two: 0.013270
-        v = v * power of two 3.481778
-        v = v * v 6.348381
-        v = v + v 2.782792
-        Sum:  39.835126
+        mod by 2:  0.007256
+        mod by 10000:  3.175842
+        mod by 1024 (power of two):  0.011571
+        Div huge number by 2**128: 2.187273
+        rshift: 2.319537
+        lshift: 1.488359
+        Floordiv by 2: 1.513284
+        Floordiv by 3 (not power of two): 4.210322
+        2**500000: 0.033903
+        (2**N)**5000000 (power of two): 0.052366
+        10000 ** BIGNUM % 100 2.032749
+        i = i * i: 4.609749
+        n**10000 (not power of two): 6.266791
+        Power of two ** power of two: 0.013294
+        v = v * power of two 4.107085
+        v = v * v 6.384141
+        v = v + v 2.820538
+        Sum:  41.234060
 
         A pure python form of those tests where also run
         Improved pypy           | Pypy                  | CPython 2.7.3


More information about the pypy-commit mailing list