[Python-checkins] python/dist/src/Objects longobject.c,1.161,1.162

tim_one at users.sourceforge.net tim_one at users.sourceforge.net
Mon Aug 30 00:16:53 CEST 2004


Update of /cvsroot/python/python/dist/src/Objects
In directory sc8-pr-cvs1.sourceforge.net:/tmp/cvs-serv18253/Objects

Modified Files:
	longobject.c 
Log Message:
SF patch 936813: fast modular exponentiation

This checkin is adapted from part 1 (of 3) of Trevor Perrin's patch set.

x_mul()
  - sped a little by optimizing the C
  - sped a lot (~2X) if it's doing a square; note that long_pow() squares
    often
k_mul()
  - more cache-friendly now if it's doing a square
KARATSUBA_CUTOFF
  - boosted; gradeschool mult is quicker now, and it may have been too low
    for many platforms anyway
KARATSUBA_SQUARE_CUTOFF
  - new
  - since x_mul is a lot faster at squaring now, the point at which
    Karatsuba pays for squaring is much higher than for general mult


Index: longobject.c
===================================================================
RCS file: /cvsroot/python/python/dist/src/Objects/longobject.c,v
retrieving revision 1.161
retrieving revision 1.162
diff -u -d -r1.161 -r1.162
--- longobject.c	28 Jun 2003 20:04:25 -0000	1.161
+++ longobject.c	29 Aug 2004 22:16:50 -0000	1.162
@@ -12,7 +12,8 @@
  * both operands contain more than KARATSUBA_CUTOFF digits (this
  * being an internal Python long digit, in base BASE).
  */
-#define KARATSUBA_CUTOFF 35
+#define KARATSUBA_CUTOFF 70
+#define KARATSUBA_SQUARE_CUTOFF (2 * KARATSUBA_CUTOFF)
 
 #define ABS(x) ((x) < 0 ? -(x) : (x))
 
@@ -1717,26 +1718,72 @@
 		return NULL;
 
 	memset(z->ob_digit, 0, z->ob_size * sizeof(digit));
-	for (i = 0; i < size_a; ++i) {
-		twodigits carry = 0;
-		twodigits f = a->ob_digit[i];
-		int j;
-		digit *pz = z->ob_digit + i;
+	if (a == b) {
+		/* Efficient squaring per HAC, Algorithm 14.16:
+		 * http://www.cacr.math.uwaterloo.ca/hac/about/chap14.pdf
+		 * Gives slightly less than a 2x speedup when a == b,
+		 * via exploiting that each entry in the multiplication
+		 * pyramid appears twice (except for the size_a squares).
+		 */
+		for (i = 0; i < size_a; ++i) {
+			twodigits carry;
+			twodigits f = a->ob_digit[i];
+			digit *pz = z->ob_digit + (i << 1);
+			digit *pa = a->ob_digit + i + 1;
+			digit *paend = a->ob_digit + size_a;
 
-		SIGCHECK({
-			Py_DECREF(z);
-			return NULL;
-		})
-		for (j = 0; j < size_b; ++j) {
-			carry += *pz + b->ob_digit[j] * f;
-			*pz++ = (digit) (carry & MASK);
+			SIGCHECK({
+				Py_DECREF(z);
+				return NULL;
+			})
+
+			carry = *pz + f * f;
+			*pz++ = (digit)(carry & MASK);
 			carry >>= SHIFT;
+			assert(carry <= MASK);
+
+			/* Now f is added in twice in each column of the
+			 * pyramid it appears.  Same as adding f<<1 once.
+			 */
+			f <<= 1;
+			while (pa < paend) {
+				carry += *pz + *pa++ * f;
+				*pz++ = (digit)(carry & MASK);
+				carry >>= SHIFT;
+				assert(carry <= (MASK << 1));
+			}
+			if (carry) {
+				carry += *pz;
+				*pz++ = (digit)(carry & MASK);
+				carry >>= SHIFT;
+			}
+			if (carry)
+				*pz += (digit)(carry & MASK);
+			assert((carry >> SHIFT) == 0);
 		}
-		for (; carry != 0; ++j) {
-			assert(i+j < z->ob_size);
-			carry += *pz;
-			*pz++ = (digit) (carry & MASK);
-			carry >>= SHIFT;
+	}
+	else {	/* a is not the same as b -- gradeschool long mult */
+		for (i = 0; i < size_a; ++i) {
+			twodigits carry = 0;
+			twodigits f = a->ob_digit[i];
+			digit *pz = z->ob_digit + i;
+			digit *pb = b->ob_digit;
+			digit *pbend = b->ob_digit + size_b;
+
+			SIGCHECK({
+				Py_DECREF(z);
+				return NULL;
+			})
+
+			while (pb < pbend) {
+				carry += *pz + *pb++ * f;
+				*pz++ = (digit)(carry & MASK);
+				carry >>= SHIFT;
+				assert(carry <= MASK);
+			}
+			if (carry)
+				*pz += (digit)(carry & MASK);
+			assert((carry >> SHIFT) == 0);
 		}
 	}
 	return long_normalize(z);
@@ -1816,7 +1863,8 @@
 	}
 
 	/* Use gradeschool math when either number is too small. */
-	if (asize <= KARATSUBA_CUTOFF) {
+	i = a == b ? KARATSUBA_SQUARE_CUTOFF : KARATSUBA_CUTOFF;
+	if (asize <= i) {
 		if (asize == 0)
 			return _PyLong_New(0);
 		else
@@ -1837,7 +1885,13 @@
 	if (kmul_split(a, shift, &ah, &al) < 0) goto fail;
 	assert(ah->ob_size > 0);	/* the split isn't degenerate */
 
-	if (kmul_split(b, shift, &bh, &bl) < 0) goto fail;
+	if (a == b) {
+		bh = ah;
+		bl = al;
+		Py_INCREF(bh);
+		Py_INCREF(bl);
+	}
+	else if (kmul_split(b, shift, &bh, &bl) < 0) goto fail;
 
 	/* The plan:
 	 * 1. Allocate result space (asize + bsize digits:  that's always
@@ -1906,7 +1960,11 @@
 	Py_DECREF(al);
 	ah = al = NULL;
 
-	if ((t2 = x_add(bh, bl)) == NULL) {
+	if (a == b) {
+		t2 = t1;
+		Py_INCREF(t2);
+	}
+	else if ((t2 = x_add(bh, bl)) == NULL) {
 		Py_DECREF(t1);
 		goto fail;
 	}



More information about the Python-checkins mailing list