[pypy-commit] pypy improve-rbigint: Fix one test, fix so a few tests no longer fails (divrem fails for some reason, I don't understand why). Optimize mod() and fix issue with lshift and fix translation (for some reason the last commit failed today, but worked last night hehe)

stian noreply at buildbot.pypy.org
Sat Jul 21 18:41:41 CEST 2012


Author: stian
Branch: improve-rbigint
Changeset: r56352:0abcf5b8aaba
Date: 2012-07-06 20:01 +0200
http://bitbucket.org/pypy/pypy/changeset/0abcf5b8aaba/

Log:	Fix one test, fix so a few tests no longer fails (divrem fails for
	some reason, I don't understand why). Optimize mod() and fix issue
	with lshift and fix translation (for some reason the last commit
	failed today, but worked last night hehe)

diff --git a/pypy/rlib/rbigint.py b/pypy/rlib/rbigint.py
--- a/pypy/rlib/rbigint.py
+++ b/pypy/rlib/rbigint.py
@@ -151,23 +151,24 @@
         """Return the x'th digit, as an int."""
         return self._digits[x]
     digit._always_inline_ = True
-    
+    digit._annonforceargs_ = [None, r_uint] # These are necessary because x can't always be proven non negative, no matter how hard we try.
     def widedigit(self, x):
         """Return the x'th digit, as a long long int if needed
         to have enough room to contain two digits."""
         return _widen_digit(self._digits[x])
     widedigit._always_inline_ = True
-    
+    widedigit._annonforceargs_ = [None, r_uint]
     def udigit(self, x):
         """Return the x'th digit, as an unsigned int."""
         return _load_unsigned_digit(self._digits[x])
     udigit._always_inline_ = True
-    
+    udigit._annonforceargs_ = [None, r_uint]
     def setdigit(self, x, val):
         val = _mask_digit(val)
         assert val >= 0
         self._digits[x] = _store_digit(val)
     setdigit._annspecialcase_ = 'specialize:argtype(2)'
+    digit._annonforceargs_ = [None, r_uint, None]
     setdigit._always_inline_ = True
 
     def numdigits(self):
@@ -450,23 +451,21 @@
         if a.sign == 0 or b.sign == 0:
             return rbigint()
         
-        
         if asize == 1:
-            digit = a.widedigit(0)
-            if digit == 0:
+            if a._digits[0] == NULLDIGIT:
                 return rbigint()
-            elif digit == 1:
+            elif b._digits[0] == ONEDIGIT:
                 return rbigint(b._digits, a.sign * b.sign)
             elif bsize == 1:
                 result = rbigint([NULLDIGIT] * 2, a.sign * b.sign)
-                carry = b.widedigit(0) * digit
+                carry = b.widedigit(0) * a.widedigit(0)
                 result.setdigit(0, carry)
                 carry >>= SHIFT
                 if carry:
                     result.setdigit(1, carry)
                 return result
                 
-            result =  _x_mul(a, b, digit)
+            result =  _x_mul(a, b, a.digit(0))
         elif USE_TOOMCOCK and asize >= TOOMCOOK_CUTOFF:
             result = _tc_mul(a, b)
         elif USE_KARATSUBA:
@@ -512,7 +511,21 @@
 
     @jit.elidable
     def mod(self, other):
-        div, mod = _divrem(self, other)
+        if other.numdigits() == 1:
+            # Faster.
+            i = 0
+            mod = 0
+            b = other.digit(0) * other.sign
+            while i < self.numdigits():
+                digit = self.digit(i) * self.sign
+                if digit:
+                    mod <<= SHIFT
+                    mod = (mod + digit) % b
+                
+                i += 1
+            mod = rbigint.fromint(mod)
+        else:        
+            div, mod = _divrem(self, other)
         if mod.sign * other.sign == -1:
             mod = mod.add(other)
         return mod
@@ -577,7 +590,7 @@
 
             # if modulus == 1:
             #     return 0
-            if c.numdigits() == 1 and c.digit(0) == 1:
+            if c.numdigits() == 1 and c._digits[0] == ONEDIGIT:
                 return NULLRBIGINT
 
             # if base < 0:
@@ -588,13 +601,13 @@
                 
             
         elif size_b == 1:
-            digit = b.digit(0)
-            if digit == 0:
+            if b._digits[0] == NULLDIGIT:
                 return ONERBIGINT if a.sign == 1 else ONENEGATIVERBIGINT
-            elif digit == 1:
+            elif b._digits[0] == ONEDIGIT:
                 return a
             elif a.numdigits() == 1:
                 adigit = a.digit(0)
+                digit = b.digit(0)
                 if adigit == 1:
                     if a.sign == -1 and digit % 2:
                         return ONENEGATIVERBIGINT
@@ -612,7 +625,7 @@
         
         # python adaptation: moved macros REDUCE(X) and MULT(X, Y, result)
         # into helper function result = _help_mult(x, y, c)
-        if True: #not c or size_b <= FIVEARY_CUTOFF:
+        if not c or size_b <= FIVEARY_CUTOFF:
             # Left-to-right binary exponentiation (HAC Algorithm 14.79)
             # http://www.cacr.math.uwaterloo.ca/hac/about/chap14.pdf
             size_b -= 1
@@ -627,7 +640,6 @@
                 size_b -= 1
                 
         else:
-            # XXX: Not working with int128! Yet
             # Left-to-right 5-ary exponentiation (HAC Algorithm 14.82)
             # This is only useful in the case where c != None.
             # z still holds 1L
@@ -662,7 +674,7 @@
                         break # Done
                         
                     size_b -= 1
-
+                    assert size_b >= 0
                     bi = b.udigit(size_b)
                     index = ((accum << (-j)) | (bi >> (j+SHIFT))) & 0x1f
                     accum = bi
@@ -706,11 +718,12 @@
         wordshift = int_other // SHIFT
         remshift  = int_other - wordshift * SHIFT
 
-        oldsize = self.numdigits()
         if not remshift:
             return rbigint([NULLDIGIT] * wordshift + self._digits, self.sign)
-            
-        z = rbigint([NULLDIGIT] * (oldsize + wordshift + 1), self.sign)
+        
+        oldsize = self.numdigits()
+        newsize = oldsize + wordshift + 1
+        z = rbigint([NULLDIGIT] * newsize, self.sign)
         accum = _widen_digit(0)
         i = wordshift
         j = 0
@@ -720,8 +733,10 @@
             accum >>= SHIFT
             i += 1
             j += 1
-            
-        z.setdigit(oldsize, accum)
+        
+        newsize -= 1
+        assert newsize >= 0
+        z.setdigit(newsize, accum)
 
         z._positivenormalize()
         return z
@@ -830,31 +845,31 @@
             self._digits = [NULLDIGIT]
             return
         
-        while i > 1 and self.digit(i - 1) == 0:
+        while i > 1 and self._digits[i - 1] == NULLDIGIT:
             i -= 1
         assert i > 0
         if i != c:
             self._digits = self._digits[:i]
-        if self.numdigits() == 1 and self.digit(0) == 0:
+        if self.numdigits() == 1 and self._digits[0] == NULLDIGIT:
             self.sign = 0
             
-    _normalize._always_inline_ = True
+    #_normalize._always_inline_ = True
     
     def _positivenormalize(self):
         """ This function assumes numdigits > 0. Good for shifts and such """
         i = c = self.numdigits()
-        while i > 1 and self.digit(i - 1) == 0:
+        while i > 1 and self._digits[i - 1] == NULLDIGIT:
             i -= 1
         assert i > 0
         if i != c:
             self._digits = self._digits[:i]
-        if self.numdigits() == 1 and self.digit(0) == 0:
+        if self.numdigits() == 1 and self._digits[0] == NULLDIGIT:
             self.sign = 0
     _positivenormalize._always_inline_ = True
     
     def bit_length(self):
         i = self.numdigits()
-        if i == 1 and self.digit(0) == 0:
+        if i == 1 and self._digits[0] == NULLDIGIT:
             return 0
         msd = self.digit(i - 1)
         msd_bits = 0
@@ -1047,12 +1062,11 @@
         # via exploiting that each entry in the multiplication
         # pyramid appears twice (except for the size_a squares).
         z = rbigint([NULLDIGIT] * (size_a + size_b), 1)
-        i = _load_unsigned_digit(0)
+        i = 0
         while i < size_a:
             f = a.widedigit(i)
             pz = i << 1
             pa = i + 1
-            paend = size_a
 
             carry = z.widedigit(pz) + f * f
             z.setdigit(pz, carry)
@@ -1063,7 +1077,7 @@
             # Now f is added in twice in each column of the
             # pyramid it appears.  Same as adding f<<1 once.
             f <<= 1
-            while pa < paend:
+            while pa < size_a:
                 carry += z.widedigit(pz) + a.widedigit(pa) * f
                 pa += 1
                 z.setdigit(pz, carry)
@@ -1075,8 +1089,8 @@
                 z.setdigit(pz, carry)
                 pz += 1
                 carry >>= SHIFT
-            if carry:
-                z.setdigit(pz, z.widedigit(pz) + carry)
+                if carry:
+                    z.setdigit(pz, z.widedigit(pz) + carry)
             assert (carry >> SHIFT) == 0
             i += 1
         z._positivenormalize()
@@ -1087,7 +1101,7 @@
 
     z = rbigint([NULLDIGIT] * (size_a + size_b), 1)
     # gradeschool long mult
-    i = _load_unsigned_digit(0)
+    i = 0
     while i < size_a:
         carry = 0
         f = a.widedigit(i)
@@ -1101,6 +1115,7 @@
             carry >>= SHIFT
             assert carry <= MASK
         if carry:
+            assert pz >= 0
             z.setdigit(pz, z.widedigit(pz) + carry)
         assert (carry >> SHIFT) == 0
         i += 1
@@ -1550,7 +1565,7 @@
     w = _muladd1(w1, d)
     size_v = v1.numdigits()
     size_w = w1.numdigits()
-    assert size_v >= size_w and size_w >= 1 # (Assert checks by div()
+    assert size_v >= size_w and size_w > 1 # (Assert checks by div()
 
     """v = rbigint([NULLDIGIT] * (size_v + 1))
     w = rbigint([NULLDIGIT] * (size_w))
@@ -1565,12 +1580,13 @@
         
     size_a = size_v - size_w + 1
     a = rbigint([NULLDIGIT] * size_a, 1)
-
-    wm1 = w.widedigit(abs(size_w-1))
-    wm2 = w.widedigit(abs(size_w-2))
-    j = _load_unsigned_digit(size_v)
+    assert size_w >= 2
+    wm1 = w.widedigit(size_w-1)
+    wm2 = w.widedigit(size_w-2)
+    j = size_v
     k = size_a - 1
     while k >= 0:
+        assert j >= 2
         if j >= size_v:
             vj = 0
         else:
@@ -2099,7 +2115,7 @@
             ntostore = power
             rem = _inplace_divrem1(scratch, pin, powbase, size)
             pin = scratch  # no need to use a again
-            if pin.digit(size - 1) == 0:
+            if pin._digits[size - 1] == NULLDIGIT:
                 size -= 1
 
             # Break rem into digits.
diff --git a/pypy/rlib/test/test_rbigint.py b/pypy/rlib/test/test_rbigint.py
--- a/pypy/rlib/test/test_rbigint.py
+++ b/pypy/rlib/test/test_rbigint.py
@@ -360,7 +360,7 @@
                       for i in (10L, 5L, 0L)]
         py.test.raises(ValueError, f1.pow, f2, f3)
         #
-        MAX = 1E40
+        MAX = 1E20
         x = long(random() * MAX) + 1
         y = long(random() * MAX) + 1
         z = long(random() * MAX) + 1
@@ -521,9 +521,9 @@
     def test__x_divrem(self):
         x = 12345678901234567890L
         for i in range(100):
-            y = long(randint(0, 1 << 30))
-            y <<= 30
-            y += randint(0, 1 << 30)
+            y = long(randint(0, 1 << 60))
+            y <<= 60
+            y += randint(0, 1 << 60)
             f1 = rbigint.fromlong(x)
             f2 = rbigint.fromlong(y)
             div, rem = lobj._x_divrem(f1, f2)
@@ -532,9 +532,9 @@
     def test__divrem(self):
         x = 12345678901234567890L
         for i in range(100):
-            y = long(randint(0, 1 << 30))
-            y <<= 30
-            y += randint(0, 1 << 30)
+            y = long(randint(0, 1 << 60))
+            y <<= 60
+            y += randint(0, 1 << 60)
             for sx, sy in (1, 1), (1, -1), (-1, -1), (-1, 1):
                 sx *= x
                 sy *= y
diff --git a/pypy/rpython/lltypesystem/rlist.py b/pypy/rpython/lltypesystem/rlist.py
--- a/pypy/rpython/lltypesystem/rlist.py
+++ b/pypy/rpython/lltypesystem/rlist.py
@@ -303,12 +303,12 @@
     return l.items
 
 def ll_getitem_fast(l, index):
-    #ll_assert(index < l.length, "getitem out of bounds")
+    ll_assert(index < l.length, "getitem out of bounds")
     return l.ll_items()[index]
 ll_getitem_fast.oopspec = 'list.getitem(l, index)'
 
 def ll_setitem_fast(l, index, item):
-    #ll_assert(index < l.length, "setitem out of bounds")
+    ll_assert(index < l.length, "setitem out of bounds")
     l.ll_items()[index] = item
 ll_setitem_fast.oopspec = 'list.setitem(l, index, item)'
 
@@ -316,7 +316,7 @@
 
 @typeMethod
 def ll_fixed_newlist(LIST, length):
-    #ll_assert(length >= 0, "negative fixed list length")
+    ll_assert(length >= 0, "negative fixed list length")
     l = malloc(LIST, length)
     return l
 ll_fixed_newlist.oopspec = 'newlist(length)'
@@ -333,12 +333,12 @@
     return l
 
 def ll_fixed_getitem_fast(l, index):
-    #ll_assert(index < len(l), "fixed getitem out of bounds")
+    ll_assert(index < len(l), "fixed getitem out of bounds")
     return l[index]
 ll_fixed_getitem_fast.oopspec = 'list.getitem(l, index)'
 
 def ll_fixed_setitem_fast(l, index, item):
-    #ll_assert(index < len(l), "fixed setitem out of bounds")
+    ll_assert(index < len(l), "fixed setitem out of bounds")
     l[index] = item
 ll_fixed_setitem_fast.oopspec = 'list.setitem(l, index, item)'
 


More information about the pypy-commit mailing list