[Numpy-discussion] Distance Matrix speed

Alan G Isaac aisaac at american.edu
Mon Jun 19 00:30:12 EDT 2006


On Sun, 18 Jun 2006, Tim Hochberg apparently wrote: 

> Alan G Isaac wrote: 

>> On Sun, 18 Jun 2006, Sebastian Beca apparently wrote: 

>>> def dist(): 
>>> d = zeros([N, C], dtype=float) 
>>> if N < C: for i in range(N): 
>>> xy = A[i] - B d[i,:] = sqrt(sum(xy**2, axis=1)) 
>>> return d 
>>> else: 
>>> for j in range(C): 
>>> xy = A - B[j] d[:,j] = sqrt(sum(xy**2, axis=1)) 
>>> return d 

>> But that is 50% slower than Johannes's version: 

>> def dist_loehner1(): 
>>        d = A[:, newaxis, :] - B[newaxis, :, :] 
>>        d = sqrt((d**2).sum(axis=2)) 
>> 	return d 

> Are you sure about that? I just ran it through timeit, using Sebastian's 
> array sizes and I get Sebastian's version being 150% faster. This 
> could well be cache size dependant, so may vary from box to box, but I'd 
> expect Sebastian's current version to scale better in general. 

No, I'm not sure.
Script attached bottom.
Most recent output follows:
for reasons I have not determined,
it doesn't match my previous runs ...
Alan

>>> execfile(r'c:\temp\temp.py')
dist_beca :       3.042277
dist_loehner1:    3.170026


#################################
#THE SCRIPT
import sys
sys.path.append("c:\\temp")
import numpy
from numpy import *
import timeit


K = 10
C = 2500
N = 3 # One could switch around C and N now.
A = numpy.random.random( [N, K] )
B = numpy.random.random( [C, K] )

# beca
def dist_beca():
    d = zeros([N, C], dtype=float)
    if N < C:
        for i in range(N):
            xy = A[i] - B
            d[i,:] = sqrt(sum(xy**2, axis=1))
        return d
    else:
        for j in range(C):
            xy = A - B[j]
            d[:,j] = sqrt(sum(xy**2, axis=1))
    return d

#loehnert
def dist_loehner1():
	# drawback: memory usage temporarily doubled
	# solution see below
	d = A[:, newaxis, :] - B[newaxis, :, :]
	# written as 3 expressions for more clarity
	d = sqrt((d**2).sum(axis=2))
	return d


if __name__ == "__main__":
	t1 = timeit.Timer('dist_beca()', 'from temp import dist_beca').timeit(100)
	t8 = timeit.Timer('dist_loehner1()', 'from temp import dist_loehner1').timeit(100)
	fmt="%-10s:\t"+"%10.6f"
	print fmt%('dist_beca', t1)
	print fmt%('dist_loehner1', t8)






More information about the NumPy-Discussion mailing list