[Scipy-svn] r2761 - trunk/Lib/sandbox/maskedarray

scipy-svn at scipy.org scipy-svn at scipy.org
Sun Feb 25 22:21:42 EST 2007


Author: pierregm
Date: 2007-02-25 21:21:38 -0600 (Sun, 25 Feb 2007)
New Revision: 2761

Modified:
   trunk/Lib/sandbox/maskedarray/extras.py
   trunk/Lib/sandbox/maskedarray/testutils.py
Log:
testutils : prevent chararrays to be transformed to float + force elementwise comparison for object/string ndarrays

Modified: trunk/Lib/sandbox/maskedarray/extras.py
===================================================================
--- trunk/Lib/sandbox/maskedarray/extras.py	2007-02-26 02:54:59 UTC (rev 2760)
+++ trunk/Lib/sandbox/maskedarray/extras.py	2007-02-26 03:21:38 UTC (rev 2761)
@@ -41,7 +41,7 @@
 
 #...............................................................................
 def issequence(seq):
-    """Returns True if the argumnet is a sequence (ndarray, list or tuple)."""
+    """Returns True if the argument is a sequence (ndarray, list or tuple)."""
     if isinstance(seq, ndarray):
         return True
     elif isinstance(seq, tuple):

Modified: trunk/Lib/sandbox/maskedarray/testutils.py
===================================================================
--- trunk/Lib/sandbox/maskedarray/testutils.py	2007-02-26 02:54:59 UTC (rev 2760)
+++ trunk/Lib/sandbox/maskedarray/testutils.py	2007-02-26 03:21:38 UTC (rev 2761)
@@ -39,10 +39,19 @@
     y = filled(masked_array(d2, copy=False, mask=m), 1).astype(float_)
     d = N.less_equal(umath.absolute(x-y), atol + rtol * umath.absolute(y))
     return d.ravel()
-#............................
+#................................................
+def _assert_equal_on_sequences(actual, desired, err_msg=''):
+    "Asserts the equality of two non-array sequences."
+    assert_equal(len(actual),len(desired),err_msg)
+    for k in range(len(desired)):
+        assert_equal(actual[k], desired[k], 'item=%r\n%s' % (k,err_msg))
+    return
+    
+
 def assert_equal(actual,desired,err_msg=''):
     """Asserts that two items are equal.
     """
+    # Case #1: dictionary .....
     if isinstance(desired, dict):
         assert isinstance(actual, dict), repr(type(actual))
         assert_equal(len(actual),len(desired),err_msg)
@@ -50,15 +59,21 @@
             assert actual.has_key(k), repr(k)
             assert_equal(actual[k], desired[k], 'key=%r\n%s' % (k,err_msg))
         return
+    # Case #2: lists .....
     if isinstance(desired, (list,tuple)) and isinstance(actual, (list,tuple)):
-        assert_equal(len(actual),len(desired),err_msg)
-        for k in range(len(desired)):
-            assert_equal(actual[k], desired[k], 'item=%r\n%s' % (k,err_msg))
+        return _assert_equal_on_sequences(actual, desired, err_msg='')
+    if not (isinstance(actual, ndarray) or isinstance(desired, ndarray)):
+        msg = build_err_msg([actual, desired], err_msg,)
+        assert desired == actual, msg
         return
-    if isinstance(actual, ndarray) or isinstance(desired, ndarray):
+    # Case #4. arrays or equivalent
+    actual = N.array(actual, copy=False, subok=True)
+    desired = N.array(desired, copy=False, subok=True)
+    if actual.dtype.char in "OS" and desired.dtype.char in "OS":
+        return _assert_equal_on_sequences(actual.tolist(), 
+                                          desired.tolist(), 
+                                          err_msg='')
         return assert_array_equal(actual, desired, err_msg)
-    msg = build_err_msg([actual, desired], err_msg,)
-    assert desired == actual, msg
 #.............................
 def fail_if_equal(actual,desired,err_msg='',):
     """Raises an assertion error if two items are equal.
@@ -100,13 +115,13 @@
     
     x = masked_array(xf, copy=False, mask=m).filled(fill_value)
     y = masked_array(yf, copy=False, mask=m).filled(fill_value)
-    if (x.dtype.char != "O"):
+    if (x.dtype.char != "O") and (x.dtype.char != "S"):
         x = x.astype(float_)
         if isinstance(x, N.ndarray) and x.size > 1:
             x[N.isnan(x)] = 0
         elif N.isnan(x):
             x = 0
-    if (y.dtype.char != "O"):
+    if (y.dtype.char != "O") and (y.dtype.char != "S"):
         y = y.astype(float_)
         if isinstance(y, N.ndarray) and y.size > 1:
             y[N.isnan(y)] = 0
@@ -162,11 +177,13 @@
     """Checks the elementwise equality of two masked arrays, up to a given 
     number of decimals."""
     def compare(x, y):
+        "Returns the result of the loose comparison between x and y)."
         return approx(x,y)
     assert_array_compare(compare, x, y, err_msg=err_msg, 
                          header='Arrays are not almost equal')
 #............................
 def assert_array_less(x, y, err_msg=''):
+    "Checks that x is smaller than y elementwise."
     assert_array_compare(less, x, y, err_msg=err_msg,
                          header='Arrays are not less-ordered')
 #............................




More information about the Scipy-svn mailing list