[Numpy-svn] r8714 - in trunk/numpy/ma: . tests

numpy-svn at scipy.org numpy-svn at scipy.org
Mon Sep 13 11:43:27 EDT 2010


Author: pierregm
Date: 2010-09-13 10:43:27 -0500 (Mon, 13 Sep 2010)
New Revision: 8714

Modified:
   trunk/numpy/ma/core.py
   trunk/numpy/ma/tests/test_core.py
Log:
* ma.core._print_templates: switched the keys 'short' and 'long' to 'short_std' and 'long_std' respectively (bug #1586)
* Fixed incorrect broadcasting in ma.power (bug #1606)


Modified: trunk/numpy/ma/core.py
===================================================================
--- trunk/numpy/ma/core.py	2010-09-13 12:34:37 UTC (rev 8713)
+++ trunk/numpy/ma/core.py	2010-09-13 15:43:27 UTC (rev 8714)
@@ -2296,14 +2296,14 @@
             np.putmask(curdata, curmask, printopt)
     return
 
-_print_templates = dict(long="""\
+_print_templates = dict(long_std="""\
 masked_%(name)s(data =
  %(data)s,
        %(nlen)s mask =
  %(mask)s,
  %(nlen)s fill_value = %(fill)s)
 """,
-                        short="""\
+                        short_std="""\
 masked_%(name)s(data = %(data)s,
        %(nlen)s mask = %(mask)s,
 %(nlen)s  fill_value = %(fill)s)
@@ -3574,8 +3574,8 @@
                 return _print_templates['short_flx'] % parameters
             return  _print_templates['long_flx'] % parameters
         elif n <= 1:
-            return _print_templates['short'] % parameters
-        return _print_templates['long'] % parameters
+            return _print_templates['short_std'] % parameters
+        return _print_templates['long_std'] % parameters
 
 
     def __eq__(self, other):
@@ -5972,7 +5972,7 @@
 ids = _frommethod('ids')
 maximum = _maximum_operation()
 mean = _frommethod('mean')
-minimum = _minimum_operation ()
+minimum = _minimum_operation()
 nonzero = _frommethod('nonzero')
 prod = _frommethod('prod')
 product = _frommethod('prod')
@@ -6040,8 +6040,7 @@
     if m is not nomask:
         if not (result.ndim):
             return masked
-        m |= invalid
-        result._mask = m
+        result._mask = np.logical_or(m, invalid)
     # Fix the invalid parts
     if invalid.any():
         if not result.ndim:

Modified: trunk/numpy/ma/tests/test_core.py
===================================================================
--- trunk/numpy/ma/tests/test_core.py	2010-09-13 12:34:37 UTC (rev 8713)
+++ trunk/numpy/ma/tests/test_core.py	2010-09-13 15:43:27 UTC (rev 8714)
@@ -2975,7 +2975,40 @@
         assert_almost_equal(x, y)
         assert_almost_equal(x._data, y._data)
 
+    def test_power_w_broadcasting(self):
+        "Test power w/ broadcasting"
+        a2 = np.array([[1., 2., 3.], [4., 5., 6.]])
+        a2m = array(a2, mask=[[1, 0, 0], [0, 0, 1]])
+        b1 = np.array([2, 4, 3])
+        b1m = array(b1, mask=[0, 1, 0])
+        b2 = np.array([b1, b1])
+        b2m = array(b2, mask=[[0, 1, 0], [0, 1, 0]])
+        #
+        ctrl = array([[1 ** 2, 2 ** 4, 3 ** 3], [4 ** 2, 5 ** 4, 6 ** 3]],
+                mask=[[1, 1, 0], [0, 1, 1]])
+        # No broadcasting, base & exp w/ mask
+        test = a2m ** b2m
+        assert_equal(test, ctrl)
+        assert_equal(test.mask, ctrl.mask)
+        # No broadcasting, base w/ mask, exp w/o mask
+        test = a2m ** b2
+        assert_equal(test, ctrl)
+        assert_equal(test.mask, a2m.mask)
+        # No broadcasting, base w/o mask, exp w/ mask
+        test = a2 ** b2m
+        assert_equal(test, ctrl)
+        assert_equal(test.mask, b2m.mask)
+        #
+        ctrl = array([[2 ** 2, 4 ** 4, 3 ** 3], [2 ** 2, 4 ** 4, 3 ** 3]],
+                mask=[[0, 1, 0], [0, 1, 0]])
+        test = b1 ** b2m
+        assert_equal(test, ctrl)
+        assert_equal(test.mask, ctrl.mask)
+        test = b2m ** b1
+        assert_equal(test, ctrl)
+        assert_equal(test.mask, ctrl.mask)
 
+
     def test_where(self):
         "Test the where function"
         x = np.array([1., 1., 1., -2., pi / 2.0, 4., 5., -10., 10., 1., 2., 3.])




More information about the Numpy-svn mailing list