[Numpy-svn] r5156 - in trunk/numpy/ma: . tests

numpy-svn at scipy.org numpy-svn at scipy.org
Mon May 12 11:48:30 EDT 2008


Author: pierregm
Date: 2008-05-12 10:48:27 -0500 (Mon, 12 May 2008)
New Revision: 5156

Modified:
   trunk/numpy/ma/core.py
   trunk/numpy/ma/tests/test_core.py
Log:
core : power : use the quick-and-dirty approach: compute everything and mask afterwards
	 : MaskedArray._update_from(obj) : ensure that _baseclass is a ndarray if obj wasn't one already
	 : introduced clip in the namespace, just for convenience

Modified: trunk/numpy/ma/core.py
===================================================================
--- trunk/numpy/ma/core.py	2008-05-11 03:08:06 UTC (rev 5155)
+++ trunk/numpy/ma/core.py	2008-05-12 15:48:27 UTC (rev 5156)
@@ -26,8 +26,8 @@
            'arctanh', 'argmax', 'argmin', 'argsort', 'around',
            'array', 'asarray','asanyarray',
            'bitwise_and', 'bitwise_or', 'bitwise_xor',
-           'ceil', 'choose', 'common_fill_value', 'compress', 'compressed',
-           'concatenate', 'conjugate', 'cos', 'cosh', 'count',
+           'ceil', 'choose', 'clip', 'common_fill_value', 'compress',
+           'compressed', 'concatenate', 'conjugate', 'cos', 'cosh', 'count',
            'default_fill_value', 'diagonal', 'divide', 'dump', 'dumps',
            'empty', 'empty_like', 'equal', 'exp',
            'fabs', 'fmod', 'filled', 'floor', 'floor_divide','fix_invalid',
@@ -1226,7 +1226,7 @@
     def _update_from(self, obj):
         """Copies some attributes of obj to self.
         """  
-        if obj is not None:
+        if obj is not None and isinstance(obj,ndarray):
             _baseclass = type(obj)
         else:
             _baseclass = ndarray
@@ -2845,23 +2845,45 @@
     """
     if third is not None:
         raise MAError, "3-argument power not supported."
+    # Get the masks
     ma = getmask(a)
     mb = getmask(b)
     m = mask_or(ma, mb)
+    # Get the rawdata
     fa = getdata(a)
     fb = getdata(b)
-    if fb.dtype.char in typecodes["Integer"]:
-        return masked_array(umath.power(fa, fb), m)
-    m = mask_or(m, (fa < 0) & (fb != fb.astype(int))) 
-    if m is nomask:
-        return masked_array(umath.power(fa, fb))
+    # Get the type of the result (so that we preserve subclasses)
+    if isinstance(a,MaskedArray):
+        basetype = type(a)
     else:
-        fa = fa.copy()
-        if m.all():
-            fa.flat = 1
-        else: 
-            numpy.putmask(fa,m,1)
-        return masked_array(umath.power(fa, fb), m)
+        basetype = MaskedArray
+    # Get the result and view it as a (subclass of) MaskedArray
+    result = umath.power(fa,fb).view(basetype)
+    # Retrieve some extra attributes if needed
+    result._update_from(a)
+    # Find where we're in trouble w/ NaNs and Infs
+    invalid = numpy.logical_not(numpy.isfinite(result.view(ndarray)))
+    # Add the initial mask
+    if m is not nomask:
+        result._mask = m
+    # Fix the invalid parts
+    if invalid.any():
+        result[invalid] = masked
+        result._data[invalid] = result.fill_value
+    return result
+    
+#    if fb.dtype.char in typecodes["Integer"]:
+#        return masked_array(umath.power(fa, fb), m)
+#    m = mask_or(m, (fa < 0) & (fb != fb.astype(int))) 
+#    if m is nomask:
+#        return masked_array(umath.power(fa, fb))
+#    else:
+#        fa = fa.copy()
+#        if m.all():
+#            fa.flat = 1
+#        else: 
+#            numpy.putmask(fa,m,1)
+#        return masked_array(umath.power(fa, fb), m)
 
 #..............................................................................
 def argsort(a, axis=None, kind='quicksort', order=None, fill_value=None):
@@ -3373,6 +3395,7 @@
 fromfunction = _convert2ma('fromfunction')
 identity = _convert2ma('identity')
 indices = numpy.indices
+clip = numpy.clip
 
 ###############################################################################
 

Modified: trunk/numpy/ma/tests/test_core.py
===================================================================
--- trunk/numpy/ma/tests/test_core.py	2008-05-11 03:08:06 UTC (rev 5155)
+++ trunk/numpy/ma/tests/test_core.py	2008-05-12 15:48:27 UTC (rev 5156)
@@ -1571,7 +1571,11 @@
         b = array([0.5,2.,0.5,2.,1.], mask=[0,0,0,0,1])
         y = power(x,b)
         assert_almost_equal(y, [0, 1.21, 1.04880884817, 1.21, 0.] )
-        assert_equal(y._mask, [1,0,0,0,1])        
+        assert_equal(y._mask, [1,0,0,0,1])
+        b.mask = nomask
+        y = power(x,b)
+        assert_equal(y._mask, [1,0,0,0,1])
+         
 
 
 ###############################################################################




More information about the Numpy-svn mailing list