[Numpy-discussion] Extracting values from one array corresponding to argmax elements in another array

Friedrich Romstedt friedrichromstedt at gmail.com
Fri Apr 9 15:19:13 EDT 2010


2010/4/5 Ken Basye <kbasye1 at jhu.edu>:
>  I have two arrays, A and B, with the same shape.  I want to find the
> highest values in A along some axis, then extract the corresponding
> values from B.

Maybe:

def select(A, B, axis):
	# Extract incomplete index tuples:
	argmax = a.argmax(axis = axis)
	
	# Create the selection list to be handed over to B.__getitem__() ...	
	
	advanced_index = []
	
	# Decompose the index tuples.
	for dimi in xrange(argmax.shape[1]):
		advanced_index.append(argmax[:, dimi])
	
	# Insert the missing dimension.
	advanced_index.insert(axis, numpy.arange(0, B.shape[axis]))
	
	# Perform advanced (integer) selection ...
	
	return B[advanced_index]

>>> a
array([[[0, 1],
        [2, 3]],

       [[4, 5],
        [6, 7]]])
>>> select(a, a, 0)
array([3, 7])
>>> select(a, a, 1)
array([5, 7])
>>> select(a, a, 2)
array([6, 7])

It seems to work.

Friedrich



More information about the NumPy-Discussion mailing list