Hello,
I notices a bug in ndarray.argmax which prevent from getting the argmax
from any axis but the last one.
I join a patch to correct this.
Also, here is a small python code to test the behaviour of argmax I
implemented :
==8<====8<====8<====8<====8<====8<====8<====8<===8<===
from numpy import array, random, all
a = random.normal( 0, 1, ( 4,5,6,7,8 ) )
for i in xrange( a.ndim ):
amax = a.max( i )
aargmax = a.argmax( i )
axes = range( a.ndim )
axes.remove( i )
assert all( amax == aargmax.choose( *a.transpose( i, *axes ) ) )
==8<====8<====8<====8<====8<====8<====8<====8<===8<===
Pierre
diff numpy-0.9.6/numpy/core/src/multiarraymodule.c numpy-0.9.6.mod/numpy/core/src/multiarraymodule.c
1952a1953,1955
> If orign > ap->nd, then we cannot "swap it back"
> as the dimension does not exist anymore. It means
> the axis must be put back at the end of the array.
1956c1959,1979
< (op) = (PyAO *)PyArray_SwapAxes((ap), axis, orign); \
---
> int nb_dims = (ap)->nd; \
> if (orign > nb_dims-1 ) { \
> PyArray_Dims dims; \
> int i; \
> dims.ptr = ( intp* )malloc( sizeof( intp )*nb_dims );\
> dims.len = nb_dims; \
> for(i = 0 ; i < axis ; ++i) \
> { \
> dims.ptr[i] = i; \
> } \
> for(i = axis ; i < nb_dims-1 ; ++i) \
> { \
> dims.ptr[i] = i+1; \
> } \
> dims.ptr[nb_dims-1] = axis; \
> (op) = (PyAO *)PyArray_Transpose((ap), &dims ); \
> } \
> else \
> { \
> (op) = (PyAO *)PyArray_SwapAxes((ap), axis, orign); \
> } \