Challenge: optimizing isqrt

Christian Gollwitzer auriocus at
Sat Nov 1 09:02:12 CET 2014

Hi Steven,

let me start by answering from reverse:
 > Q3: What is the largest value of n beyond which you can never use the 
 > optimization?

A3: There is no such value, besides the upper limit of floats (DBL_MAX~ 

P3: If you feed a perfect square into the floating point square root 
algorithm, with a mantissa of the root of length smaller than the 
bitwidth of your float, it will always come out perfectly. I.e., 
computing sqrt(25) in FP math is no different from sqrt(25*2**200):

 >>> 25*2**200
 >>> x=int(math.sqrt(25*2**200))
 >>> x
 >>> x*x

Am 01.11.14 02:29, schrieb Steven D'Aprano:
> There is an algorithm for calculating the integer square root of any
> positive integer using only integer operations:
> def isqrt(n):
>      if n < 0: raise ValueError
>      if n == 0:
>          return 0
>      bits = n.bit_length()
>      a, b = divmod(bits, 2)
>      x = 2**(a+b)
>      while True:
>          y = (x + n//x)//2
>          if y >= x:
>              return x
>          x = y

> Q2: For values above M, is there a way of identifying which values of n are
> okay to use the optimized version?

A2: Do it in a different way.

Your above algorithm is obviously doing Heron- or Newton-Raphson 
iterations, so the same as with floating point math. The first line 
before the while loop computes some approximation to sqrt(n). Instead of 
doing bit shuffling, you could compute this by FP math and get closer to 
the desired result, unless the integer is too large to be represented by 
FP. Now, the terminating condition seems to rely on the fact that the 
initial estimate x>=sqrt(n), but I don't really understand it. My guess 
is that if you do x=int(sqrt(n)), then do the first iteration, then swap 
x and y such that x>y, then enter the loop, you would simply start with 
a better estimate in case that the significant bits can be represented 
by the float.

So this is my try, but not thoroughly tested:

def isqrt(n):
	if n < 0: raise ValueError
	if n == 0:
		return 0
	bits = n.bit_length()
	# the highest exponent in 64bit IEEE is 1023
	if n>2**1022:
		a, b = divmod(bits, 2)
		x = 2**(a+b)
		if x<y:
			x,y = (y,x)

	while True:
		y = (x + n//x)//2
		if y >= x:
			return x
		x = y


More information about the Python-list mailing list