[Numpy-svn] r5506 - in trunk/numpy/ma: . tests

numpy-svn at scipy.org numpy-svn at scipy.org
Tue Jul 22 16:52:51 EDT 2008


Author: pierregm
Date: 2008-07-22 15:52:48 -0500 (Tue, 22 Jul 2008)
New Revision: 5506

Modified:
   trunk/numpy/ma/core.py
   trunk/numpy/ma/tests/test_core.py
Log:
* force the domain to the shape of the other element in DomainedBinaryOperation (bugfix 857)

Modified: trunk/numpy/ma/core.py
===================================================================
--- trunk/numpy/ma/core.py	2008-07-22 20:37:12 UTC (rev 5505)
+++ trunk/numpy/ma/core.py	2008-07-22 20:52:48 UTC (rev 5506)
@@ -639,7 +639,10 @@
         if t.any(None):
             mb = mask_or(mb, t)
             # The following line controls the domain filling
-            d2 = np.where(t,self.filly,d2)
+            if t.size == d2.size:
+                d2 = np.where(t,self.filly,d2)
+            else:
+                d2 = np.where(np.resize(t, d2.shape),self.filly, d2)
         m = mask_or(ma, mb)
         if (not m.ndim) and m:
             return masked

Modified: trunk/numpy/ma/tests/test_core.py
===================================================================
--- trunk/numpy/ma/tests/test_core.py	2008-07-22 20:37:12 UTC (rev 5505)
+++ trunk/numpy/ma/tests/test_core.py	2008-07-22 20:52:48 UTC (rev 5506)
@@ -492,6 +492,25 @@
             assert_equal(np.multiply(x,y), multiply(xm, ym))
             assert_equal(np.divide(x,y), divide(xm, ym))
 
+    def test_divide_on_different_shapes(self):
+        x = arange(6, dtype=float)
+        x.shape = (2,3)
+        y = arange(3, dtype=float)
+        #
+        z = x/y
+        assert_equal(z, [[-1.,1.,1.], [-1.,4.,2.5]])
+        assert_equal(z.mask, [[1,0,0],[1,0,0]])
+        #
+        z = x/y[None,:]
+        assert_equal(z, [[-1.,1.,1.], [-1.,4.,2.5]])
+        assert_equal(z.mask, [[1,0,0],[1,0,0]])
+        #
+        y = arange(2, dtype=float)
+        z = x/y[:,None]
+        assert_equal(z, [[-1.,-1.,-1.], [3.,4.,5.]])
+        assert_equal(z.mask, [[1,1,1],[0,0,0]])
+        
+
     def test_mixed_arithmetic(self):
         "Tests mixed arithmetics."
         na = np.array([1])




More information about the Numpy-svn mailing list