[Numpy-svn] r8519 - in branches/1.5.x/numpy/core: src/multiarray tests
numpy-svn at scipy.org
numpy-svn at scipy.org
Sat Jul 24 06:42:20 EDT 2010
Author: ptvirtan
Date: 2010-07-24 05:42:19 -0500 (Sat, 24 Jul 2010)
New Revision: 8519
Modified:
branches/1.5.x/numpy/core/src/multiarray/arraytypes.c.src
branches/1.5.x/numpy/core/tests/test_multiarray.py
Log:
BUG: core: fix argmax and argmin NaN handling to conform with max/min (#1429)
This makes `argmax` and `argmix` treat NaN as a maximal element.
Effectively, this causes propagation of NaNs, which is consistent with
the current behavior of amax & amin.
(cherry picked from commit r8509)
Modified: branches/1.5.x/numpy/core/src/multiarray/arraytypes.c.src
===================================================================
--- branches/1.5.x/numpy/core/src/multiarray/arraytypes.c.src 2010-07-24 10:41:57 UTC (rev 8518)
+++ branches/1.5.x/numpy/core/src/multiarray/arraytypes.c.src 2010-07-24 10:42:19 UTC (rev 8519)
@@ -2260,6 +2260,8 @@
* #type = Bool, byte, ubyte, short, ushort, int, uint, long, ulong,
* longlong, ulonglong, float, double, longdouble,
* float, double, longdouble#
+ * #isfloat = 0*11, 1*6#
+ * #iscomplex = 0*14, 1*3#
* #incr= ip++*14, ip+=2*3#
*/
static int
@@ -2267,14 +2269,54 @@
{
intp i;
@type@ mp = *ip;
+#if @iscomplex@
+ @type@ mp_im = ip[1];
+#endif
*max_ind = 0;
+
+#if @isfloat@
+ if (npy_isnan(mp)) {
+ /* nan encountered; it's maximal */
+ return 0;
+ }
+#endif
+#if @iscomplex@
+ if (npy_isnan(mp_im)) {
+ /* nan encountered; it's maximal */
+ return 0;
+ }
+#endif
+
for (i = 1; i < n; i++) {
@incr@;
- if (*ip > mp) {
+ /*
+ * Propagate nans, similarly as max() and min()
+ */
+#if @iscomplex@
+ /* Lexical order for complex numbers */
+ if ((ip[0] > mp) || ((ip[0] == mp) && (ip[1] > mp_im))
+ || npy_isnan(ip[0]) || npy_isnan(ip[1])) {
+ mp = ip[0];
+ mp_im = ip[1];
+ *max_ind = i;
+ if (npy_isnan(mp) || npy_isnan(mp_im)) {
+ /* nan encountered, it's maximal */
+ break;
+ }
+ }
+#else
+ if (!(*ip <= mp)) { /* negated, for correct nan handling */
mp = *ip;
*max_ind = i;
+#if @isfloat@
+ if (npy_isnan(mp)) {
+ /* nan encountered, it's maximal */
+ break;
+ }
+#endif
}
+#endif
}
return 0;
}
Modified: branches/1.5.x/numpy/core/tests/test_multiarray.py
===================================================================
--- branches/1.5.x/numpy/core/tests/test_multiarray.py 2010-07-24 10:41:57 UTC (rev 8518)
+++ branches/1.5.x/numpy/core/tests/test_multiarray.py 2010-07-24 10:42:19 UTC (rev 8519)
@@ -671,6 +671,27 @@
class TestArgmax(TestCase):
+
+ nan_arr = [
+ ([0, 1, 2, 3, np.nan], 4),
+ ([0, 1, 2, np.nan, 3], 3),
+ ([np.nan, 0, 1, 2, 3], 0),
+ ([np.nan, 0, np.nan, 2, 3], 0),
+ ([0, 1, 2, 3, complex(0,np.nan)], 4),
+ ([0, 1, 2, 3, complex(np.nan,0)], 4),
+ ([0, 1, 2, complex(np.nan,0), 3], 3),
+ ([0, 1, 2, complex(0,np.nan), 3], 3),
+ ([complex(0,np.nan), 0, 1, 2, 3], 0),
+ ([complex(np.nan, np.nan), 0, 1, 2, 3], 0),
+ ([complex(np.nan, 0), complex(np.nan, 2), complex(np.nan, 1)], 0),
+ ([complex(np.nan, np.nan), complex(np.nan, 2), complex(np.nan, 1)], 0),
+ ([complex(np.nan, 0), complex(np.nan, 2), complex(np.nan, np.nan)], 0),
+
+ ([complex(0, 0), complex(0, 2), complex(0, 1)], 1),
+ ([complex(1, 0), complex(0, 2), complex(0, 1)], 0),
+ ([complex(1, 0), complex(0, 2), complex(1, 1)], 2),
+ ]
+
def test_all(self):
a = np.random.normal(0,1,(4,5,6,7,8))
for i in xrange(a.ndim):
@@ -680,6 +701,12 @@
axes.remove(i)
assert all(amax == aargmax.choose(*a.transpose(i,*axes)))
+ def test_combinations(self):
+ for arr, pos in self.nan_arr:
+ assert_equal(np.argmax(arr), pos, err_msg="%r"%arr)
+ assert_equal(arr[np.argmax(arr)], np.max(arr), err_msg="%r"%arr)
+
+
class TestMinMax(TestCase):
def test_scalar(self):
assert_raises(ValueError, np.amax, 1, 1)
More information about the Numpy-svn
mailing list