[Numpy-discussion] Efficient square distance computation
matthew.brett at gmail.com
Tue Oct 8 15:44:11 EDT 2013
On Tue, Oct 8, 2013 at 4:38 AM, Ke Sun <sunk.cs at gmail.com> wrote:
> On Tue, Oct 08, 2013 at 01:49:14AM -0700, Matthew Brett wrote:
>> On Tue, Oct 8, 2013 at 1:06 AM, Ke Sun <sunk.cs at gmail.com> wrote:
>> > Dear all,
>> > I have written the following function to compute the square distances of a large
>> > matrix (each sample a row). It compute row by row and print the overall progress.
>> > The progress output is important and I didn't use matrix multiplication.
>> > I give as input a 70,000x800 matrix. The output should be a 70,000x70,000
>> > matrix. The program runs really slow (16 hours for 1/3 progress). And it eats
>> > 36G memory (fortunately I have enough).
>> That is very slow.
>> As a matter of interest - why didn't you use matrix multiplication?
> Because it will cost hours and I want to see the progress and
> know how far it goes. Another concern is to save memory and
> compute one sample at a time.
>> On a machine I had access to it took about 20 minutes.
> How? I am using matrix multiplication (the same code as
> http://stackoverflow.com/a/4856692) and it runs for around 18 hours.
I wonder if you are running into disk swap - the code there does
involve a large temporary array.
I believe the appended version of the code is correct, and I think it
is also memory efficient.
On a fast machine with lots of memory, it ran in about 5 minutes.
It's using EPD, which might be using multiple cores for the matrix
Does the code also work for you in reasonable time?
Another suggestion I saw which only calculates the unique values (say
lower diagonal) is scipy.spatial.distance
To a first pass that seems to be slower than the matrix multiply.
import numpy as np
"""Squared pairwise distances between all columns of X."""
B = np.dot(X.T, X)
q = np.diag(B)[:, None].copy() # copy necessary?
B *= -2
B += q
B += q.T
M = 70000
N = 800
A = np.random.normal(size=(M, N))
start = datetime.datetime.now()
dists = pdista(A.T)
elapsed = datetime.datetime.now() - start
More information about the NumPy-Discussion