[Numpy-discussion] Extracting values from one array corresponding to argmax elements in another array
Friedrich Romstedt
friedrichromstedt at gmail.com
Fri Apr 9 15:19:13 EDT 2010
2010/4/5 Ken Basye <kbasye1 at jhu.edu>:
> I have two arrays, A and B, with the same shape. I want to find the
> highest values in A along some axis, then extract the corresponding
> values from B.
Maybe:
def select(A, B, axis):
# Extract incomplete index tuples:
argmax = a.argmax(axis = axis)
# Create the selection list to be handed over to B.__getitem__() ...
advanced_index = []
# Decompose the index tuples.
for dimi in xrange(argmax.shape[1]):
advanced_index.append(argmax[:, dimi])
# Insert the missing dimension.
advanced_index.insert(axis, numpy.arange(0, B.shape[axis]))
# Perform advanced (integer) selection ...
return B[advanced_index]
>>> a
array([[[0, 1],
[2, 3]],
[[4, 5],
[6, 7]]])
>>> select(a, a, 0)
array([3, 7])
>>> select(a, a, 1)
array([5, 7])
>>> select(a, a, 2)
array([6, 7])
It seems to work.
Friedrich
More information about the NumPy-Discussion
mailing list