[Numpy-discussion] nditer gurus: is there a more efficient way to do this?

Dave Hirschfeld dave.hirschfeld at gmail.com
Sun Feb 10 09:25:00 EST 2013


I have two NxMx3 arrays and I want to reduce over the last dimension of the 
first array by selecting those elements corresponding to the index of the 
maximum value of each 3-vector of the second array to give an NxM result.

Hopefully that makes sense? If not hopefully the example below will shed some 
light.

Can anyone think of a more efficient way to do this than my 1st attempt? I 
thought that nditer might be the solution; I got it to work (mostly by trial & 
error) but found that it's ~50x slower for this task! Is this not a good usecase 
for nditer or am I doing something wrong?


In [42]: value_array = np.outer(ones(1000), arange(3)).reshape(20,50,3)
    ...: index_array = randn(*value_array.shape)

In [43]: indices = index_array.reshape(-1,3).argmax(axis=1)
    ...: result = value_array.reshape(-1,3)[np.arange(indices.size), indices]
    ...: result = result.reshape(value_array.shape[0:-1])

In [44]: it = np.nditer([value_array, index_array, None],
    ...:     flags=['reduce_ok', 'external_loop','buffered', 'delay_bufalloc'],
    ...:     op_flags=[['readonly'],['readonly'],['readwrite', 'allocate']],
    ...:     op_axes=[[0,1,2], [0,1,2], [0,1,-1]])
    ...: it.reset()
    ...: for values, index_values, out in it:
    ...:     out[...] = values[index_values.argmax()]
    ...: #

In [45]: allclose(result, it.operands[2])
Out[45]: True

In [46]: %%timeit
    ...: indices = index_array.reshape(-1,3).argmax(axis=1)
    ...: result = value_array.reshape(-1,3)[np.arange(indices.size), indices]
    ...: result = result.reshape(value_array.shape[0:-1])
    ...: 
10000 loops, best of 3: 113 µs per loop

In [47]: %%timeit
    ...: it = np.nditer([value_array, index_array, None],
    ...:     flags=['reduce_ok', 'external_loop','buffered', 'delay_bufalloc'],
    ...:     op_flags=[['readonly'],['readonly'],['readwrite', 'allocate']],
    ...:     op_axes=[[0,1,2], [0,1,2], [0,1,-1]])
    ...: it.reset()
    ...: for values, index_values, out in it:
    ...:     out[...] = values[index_values.argmax()]
    ...: #
    ...: 
100 loops, best of 3: 5.26 ms per loop

In [48]: 

Thanks,
Dave




More information about the NumPy-Discussion mailing list