[Numpy-discussion] Efficient square distance computation

Jaime Fernández del Río jaime.frio at gmail.com
Tue Oct 8 09:42:04 EDT 2013


On Tue, Oct 8, 2013 at 4:38 AM, Ke Sun <sunk.cs at gmail.com> wrote:

> On Tue, Oct 08, 2013 at 01:49:14AM -0700, Matthew Brett wrote:
> > Hi,
> >
> > On Tue, Oct 8, 2013 at 1:06 AM, Ke Sun <sunk.cs at gmail.com> wrote:
> > > Dear all,
> > >
> > > I have written the following function to compute the square distances
> of a large
> > > matrix (each sample a row). It compute row by row and print the
> overall progress.
> > > The progress output is important and I didn't use matrix
> multiplication.
> > >
> > > I give as input a 70,000x800 matrix. The output should be a
> 70,000x70,000
> > > matrix. The program runs really slow (16 hours for 1/3 progress). And
> it eats
> > > 36G memory (fortunately I have enough).
> >
> > That is very slow.
> >
> > As a matter of interest - why didn't you use matrix multiplication?
> Because it will cost hours and I want to see the progress and
> know how far it goes. Another concern is to save memory and
> compute one sample at a time.
>

Another option you may want to consider is to do your calculation in
chunks, not one item at a time, e.g.:

rows, cols = 70000, 800

data = np.random.rand(rows, cols)

chunks = 100

chunk_len = rows // chunks


out = np.empty((rows, rows))

for j in xrange(0, rows, chunk_len):

    chunk_j = data[j: j + chunk_len]

        for k in xrange(j, rows, chunk_len):

            chunk_k = data[k: k + chunk_len]

            out[j: j + chunk_len,

                k: k + chunk_len] = np.dot(chunk_j, chunk_k.T)

            if j != k:

                out[k: k + chunk_len,

                    j: j + chunk_len] = out[j: j + chunk_len,

                                            k: k + chunk_len].T

q = np.diag(out)

out *= -2

out += q

out += q[:, np.newaxis]

This way you can still gauge progress, use mostly the fast, efficient
vectorized approach and probably offset the (relatively small amount of)
Python looping by not calculating most of the symmetrical items.

Jaime
-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://mail.python.org/pipermail/numpy-discussion/attachments/20131008/5ff706fc/attachment.html>


More information about the NumPy-Discussion mailing list