[Numpy-discussion] Vectorizing code, for loops, and all that
Tim Hochberg
tim.hochberg at ieee.org
Mon Oct 2 20:17:46 EDT 2006
Tim Hochberg wrote:
> Travis Oliphant wrote:
>
>> Albert Strasheim wrote:
>>
>>
>>
>>> In [571]: x1 = N.random.randn(2000,39)
>>>
>>> In [572]: y1 = N.random.randn(64,39)
>>>
>>> In [574]: %timeit z1=x1[...,N.newaxis,...]-y1 10 loops, best of 3: 703 ms
>>> per loop
>>>
>>> In [575]: z1.shape
>>> Out[575]: (2000, 64, 39)
>>>
>>> As far as I can figure, this operation is doing 2000*64*39 subtractions.
>>> Doing this straight up yields the following:
>>>
>>> In [576]: x2 = N.random.randn(2000,64,39)
>>>
>>> In [577]: y2 = N.random.randn(2000,64,39)
>>>
>>> In [578]: %timeit z2 = x2-y2
>>> 10 loops, best of 3: 108 ms per loop
>>>
>>> Does anybody have any ideas on why this is so much faster? Hopefully I
>>> didn't mess up somewhere...
>>>
>>>
>>>
[SNIP]
>
>
> I just spent a while playing with this, and assuming I've correctly
> translated your original intent I've come up with two alternative,
> looping versions that run, respectively 2 and 3 times faster. I've a
> feeling that kmean3, the fastest one, still has a little more room to be
> sped up, but I'm out of time now. Code is below
>
> -tim
>
>
>
One more iterations, this time a little algorithmic improvement, and
were up to 4x as fast as the original code. Here we take advantage of
the fact that the term di**2 is constant across the axis we are
minimizing on, computation of ci**2 can be hoisted out of the loop and
the -2*di*ci term can be rejiggered to di*ci by appropriate rescaling.
This reduces the computation in the inner loop from a subtract and a
multiply to just a multiply.
def kmean4(data):
nclusters = 64
naxes = data.shape[-1]
code = data[:nclusters]
transdata = data.transpose().copy()
totals = N.empty([nclusters, len(data)], float)
code2 = (code**2).sum(-1)
code2 *= -0.5
totals[:] = code2[:, N.newaxis]
for cluster, tot in zip(code, totals):
for di, ci in zip(transdata, cluster):
tot += di*ci
return totals.argmax(axis=0)
[SNIP CODE]
More information about the NumPy-Discussion
mailing list