[Numpy-discussion] nditer gurus: is there a more efficient way to do this?
Dave Hirschfeld
dave.hirschfeld at gmail.com
Sun Feb 10 09:25:00 EST 2013
I have two NxMx3 arrays and I want to reduce over the last dimension of the
first array by selecting those elements corresponding to the index of the
maximum value of each 3-vector of the second array to give an NxM result.
Hopefully that makes sense? If not hopefully the example below will shed some
light.
Can anyone think of a more efficient way to do this than my 1st attempt? I
thought that nditer might be the solution; I got it to work (mostly by trial &
error) but found that it's ~50x slower for this task! Is this not a good usecase
for nditer or am I doing something wrong?
In [42]: value_array = np.outer(ones(1000), arange(3)).reshape(20,50,3)
...: index_array = randn(*value_array.shape)
In [43]: indices = index_array.reshape(-1,3).argmax(axis=1)
...: result = value_array.reshape(-1,3)[np.arange(indices.size), indices]
...: result = result.reshape(value_array.shape[0:-1])
In [44]: it = np.nditer([value_array, index_array, None],
...: flags=['reduce_ok', 'external_loop','buffered', 'delay_bufalloc'],
...: op_flags=[['readonly'],['readonly'],['readwrite', 'allocate']],
...: op_axes=[[0,1,2], [0,1,2], [0,1,-1]])
...: it.reset()
...: for values, index_values, out in it:
...: out[...] = values[index_values.argmax()]
...: #
In [45]: allclose(result, it.operands[2])
Out[45]: True
In [46]: %%timeit
...: indices = index_array.reshape(-1,3).argmax(axis=1)
...: result = value_array.reshape(-1,3)[np.arange(indices.size), indices]
...: result = result.reshape(value_array.shape[0:-1])
...:
10000 loops, best of 3: 113 µs per loop
In [47]: %%timeit
...: it = np.nditer([value_array, index_array, None],
...: flags=['reduce_ok', 'external_loop','buffered', 'delay_bufalloc'],
...: op_flags=[['readonly'],['readonly'],['readwrite', 'allocate']],
...: op_axes=[[0,1,2], [0,1,2], [0,1,-1]])
...: it.reset()
...: for values, index_values, out in it:
...: out[...] = values[index_values.argmax()]
...: #
...:
100 loops, best of 3: 5.26 ms per loop
In [48]:
Thanks,
Dave
More information about the NumPy-Discussion
mailing list