[Scipy-svn] r2761 - trunk/Lib/sandbox/maskedarray
scipy-svn at scipy.org
scipy-svn at scipy.org
Sun Feb 25 22:21:42 EST 2007
Author: pierregm
Date: 2007-02-25 21:21:38 -0600 (Sun, 25 Feb 2007)
New Revision: 2761
Modified:
trunk/Lib/sandbox/maskedarray/extras.py
trunk/Lib/sandbox/maskedarray/testutils.py
Log:
testutils : prevent chararrays to be transformed to float + force elementwise comparison for object/string ndarrays
Modified: trunk/Lib/sandbox/maskedarray/extras.py
===================================================================
--- trunk/Lib/sandbox/maskedarray/extras.py 2007-02-26 02:54:59 UTC (rev 2760)
+++ trunk/Lib/sandbox/maskedarray/extras.py 2007-02-26 03:21:38 UTC (rev 2761)
@@ -41,7 +41,7 @@
#...............................................................................
def issequence(seq):
- """Returns True if the argumnet is a sequence (ndarray, list or tuple)."""
+ """Returns True if the argument is a sequence (ndarray, list or tuple)."""
if isinstance(seq, ndarray):
return True
elif isinstance(seq, tuple):
Modified: trunk/Lib/sandbox/maskedarray/testutils.py
===================================================================
--- trunk/Lib/sandbox/maskedarray/testutils.py 2007-02-26 02:54:59 UTC (rev 2760)
+++ trunk/Lib/sandbox/maskedarray/testutils.py 2007-02-26 03:21:38 UTC (rev 2761)
@@ -39,10 +39,19 @@
y = filled(masked_array(d2, copy=False, mask=m), 1).astype(float_)
d = N.less_equal(umath.absolute(x-y), atol + rtol * umath.absolute(y))
return d.ravel()
-#............................
+#................................................
+def _assert_equal_on_sequences(actual, desired, err_msg=''):
+ "Asserts the equality of two non-array sequences."
+ assert_equal(len(actual),len(desired),err_msg)
+ for k in range(len(desired)):
+ assert_equal(actual[k], desired[k], 'item=%r\n%s' % (k,err_msg))
+ return
+
+
def assert_equal(actual,desired,err_msg=''):
"""Asserts that two items are equal.
"""
+ # Case #1: dictionary .....
if isinstance(desired, dict):
assert isinstance(actual, dict), repr(type(actual))
assert_equal(len(actual),len(desired),err_msg)
@@ -50,15 +59,21 @@
assert actual.has_key(k), repr(k)
assert_equal(actual[k], desired[k], 'key=%r\n%s' % (k,err_msg))
return
+ # Case #2: lists .....
if isinstance(desired, (list,tuple)) and isinstance(actual, (list,tuple)):
- assert_equal(len(actual),len(desired),err_msg)
- for k in range(len(desired)):
- assert_equal(actual[k], desired[k], 'item=%r\n%s' % (k,err_msg))
+ return _assert_equal_on_sequences(actual, desired, err_msg='')
+ if not (isinstance(actual, ndarray) or isinstance(desired, ndarray)):
+ msg = build_err_msg([actual, desired], err_msg,)
+ assert desired == actual, msg
return
- if isinstance(actual, ndarray) or isinstance(desired, ndarray):
+ # Case #4. arrays or equivalent
+ actual = N.array(actual, copy=False, subok=True)
+ desired = N.array(desired, copy=False, subok=True)
+ if actual.dtype.char in "OS" and desired.dtype.char in "OS":
+ return _assert_equal_on_sequences(actual.tolist(),
+ desired.tolist(),
+ err_msg='')
return assert_array_equal(actual, desired, err_msg)
- msg = build_err_msg([actual, desired], err_msg,)
- assert desired == actual, msg
#.............................
def fail_if_equal(actual,desired,err_msg='',):
"""Raises an assertion error if two items are equal.
@@ -100,13 +115,13 @@
x = masked_array(xf, copy=False, mask=m).filled(fill_value)
y = masked_array(yf, copy=False, mask=m).filled(fill_value)
- if (x.dtype.char != "O"):
+ if (x.dtype.char != "O") and (x.dtype.char != "S"):
x = x.astype(float_)
if isinstance(x, N.ndarray) and x.size > 1:
x[N.isnan(x)] = 0
elif N.isnan(x):
x = 0
- if (y.dtype.char != "O"):
+ if (y.dtype.char != "O") and (y.dtype.char != "S"):
y = y.astype(float_)
if isinstance(y, N.ndarray) and y.size > 1:
y[N.isnan(y)] = 0
@@ -162,11 +177,13 @@
"""Checks the elementwise equality of two masked arrays, up to a given
number of decimals."""
def compare(x, y):
+ "Returns the result of the loose comparison between x and y)."
return approx(x,y)
assert_array_compare(compare, x, y, err_msg=err_msg,
header='Arrays are not almost equal')
#............................
def assert_array_less(x, y, err_msg=''):
+ "Checks that x is smaller than y elementwise."
assert_array_compare(less, x, y, err_msg=err_msg,
header='Arrays are not less-ordered')
#............................
More information about the Scipy-svn
mailing list