[Numpy-svn] r6194 - in trunk/numpy/ma: . tests
numpy-svn at scipy.org
numpy-svn at scipy.org
Tue Dec 23 18:43:48 EST 2008
Author: pierregm
Date: 2008-12-23 17:43:43 -0600 (Tue, 23 Dec 2008)
New Revision: 6194
Modified:
trunk/numpy/ma/core.py
trunk/numpy/ma/tests/test_core.py
trunk/numpy/ma/tests/test_mrecords.py
trunk/numpy/ma/testutils.py
Log:
testutils:
* assert_equal : use assert_equal_array on records
* assert_array_compare : prevent the common mask to be back-propagated to the initial input arrays.
* assert_equal_array : use operator.__eq__ instead of ma.equal
* assert_equal_less: use operator.__less__ instead of ma.less
core:
* Fixed _check_fill_value for nested flexible types
* Add a ndtype option to _make_mask_descr
* Fixed mask_or for nested flexible types
* Fixed the printing of masked arrays w/ flexible types.
Modified: trunk/numpy/ma/core.py
===================================================================
--- trunk/numpy/ma/core.py 2008-12-23 09:02:15 UTC (rev 6193)
+++ trunk/numpy/ma/core.py 2008-12-23 23:43:43 UTC (rev 6194)
@@ -217,6 +217,28 @@
raise TypeError(errmsg)
+def _recursive_set_default_fill_value(dtypedescr):
+ deflist = []
+ for currentdescr in dtypedescr:
+ currenttype = currentdescr[1]
+ if isinstance(currenttype, list):
+ deflist.append(tuple(_recursive_set_default_fill_value(currenttype)))
+ else:
+ deflist.append(default_fill_value(np.dtype(currenttype)))
+ return tuple(deflist)
+
+def _recursive_set_fill_value(fillvalue, dtypedescr):
+ fillvalue = np.resize(fillvalue, len(dtypedescr))
+ output_value = []
+ for (fval, descr) in zip(fillvalue, dtypedescr):
+ cdtype = descr[1]
+ if isinstance(cdtype, list):
+ output_value.append(tuple(_recursive_set_fill_value(fval, cdtype)))
+ else:
+ output_value.append(np.array(fval, dtype=cdtype).item())
+ return tuple(output_value)
+
+
def _check_fill_value(fill_value, ndtype):
"""
Private function validating the given `fill_value` for the given dtype.
@@ -233,10 +255,9 @@
fields = ndtype.fields
if fill_value is None:
if fields:
- fdtype = [(_[0], _[1]) for _ in ndtype.descr]
- fill_value = np.array(tuple([default_fill_value(fields[n][0])
- for n in ndtype.names]),
- dtype=fdtype)
+ descr = ndtype.descr
+ fill_value = np.array(_recursive_set_default_fill_value(descr),
+ dtype=ndtype)
else:
fill_value = default_fill_value(ndtype)
elif fields:
@@ -248,10 +269,9 @@
err_msg = "Unable to transform %s to dtype %s"
raise ValueError(err_msg % (fill_value, fdtype))
else:
- fval = np.resize(fill_value, len(ndtype.descr))
- fill_value = [np.asarray(f).astype(desc[1]).item()
- for (f, desc) in zip(fval, ndtype.descr)]
- fill_value = np.array(tuple(fill_value), copy=False, dtype=fdtype)
+ descr = ndtype.descr
+ fill_value = np.array(_recursive_set_fill_value(fill_value, descr),
+ dtype=ndtype)
else:
if isinstance(fill_value, basestring) and (ndtype.char not in 'SV'):
fill_value = default_fill_value(ndtype)
@@ -831,35 +851,35 @@
#####--------------------------------------------------------------------------
#---- --- Mask creation functions ---
#####--------------------------------------------------------------------------
+def _recursive_make_descr(datatype, newtype=bool_):
+ "Private function allowing recursion in make_descr."
+ # Do we have some name fields ?
+ if datatype.names:
+ descr = []
+ for name in datatype.names:
+ field = datatype.fields[name]
+ if len(field) == 3:
+ # Prepend the title to the name
+ name = (field[-1], name)
+ descr.append((name, _recursive_make_descr(field[0], newtype)))
+ return descr
+ # Is this some kind of composite a la (np.float,2)
+ elif datatype.subdtype:
+ mdescr = list(datatype.subdtype)
+ mdescr[0] = newtype
+ return tuple(mdescr)
+ else:
+ return newtype
def make_mask_descr(ndtype):
"""Constructs a dtype description list from a given dtype.
Each field is set to a bool.
"""
- def _make_descr(datatype):
- "Private function allowing recursion."
- # Do we have some name fields ?
- if datatype.names:
- descr = []
- for name in datatype.names:
- field = datatype.fields[name]
- if len(field) == 3:
- # Prepend the title to the name
- name = (field[-1], name)
- descr.append((name, _make_descr(field[0])))
- return descr
- # Is this some kind of composite a la (np.float,2)
- elif datatype.subdtype:
- mdescr = list(datatype.subdtype)
- mdescr[0] = np.dtype(bool)
- return tuple(mdescr)
- else:
- return np.bool
# Make sure we do have a dtype
if not isinstance(ndtype, np.dtype):
ndtype = np.dtype(ndtype)
- return np.dtype(_make_descr(ndtype))
+ return np.dtype(_recursive_make_descr(ndtype, np.bool))
def get_mask(a):
"""Return the mask of a, if any, or nomask.
@@ -988,7 +1008,17 @@
ValueError
If m1 and m2 have different flexible dtypes.
- """
+ """
+ def _recursive_mask_or(m1, m2, newmask):
+ names = m1.dtype.names
+ for name in names:
+ current1 = m1[name]
+ if current1.dtype.names:
+ _recursive_mask_or(current1, m2[name], newmask[name])
+ else:
+ umath.logical_or(current1, m2[name], newmask[name])
+ return
+ #
if (m1 is nomask) or (m1 is False):
dtype = getattr(m2, 'dtype', MaskType)
return make_mask(m2, copy=copy, shrink=shrink, dtype=dtype)
@@ -1002,8 +1032,7 @@
raise ValueError("Incompatible dtypes '%s'<>'%s'" % (dtype1, dtype2))
if dtype1.names:
newmask = np.empty_like(m1)
- for n in dtype1.names:
- newmask[n] = umath.logical_or(m1[n], m2[n])
+ _recursive_mask_or(m1, m2, newmask)
return newmask
return make_mask(umath.logical_or(m1, m2), copy=copy, shrink=shrink)
@@ -1291,6 +1320,22 @@
#if you single index into a masked location you get this object.
masked_print_option = _MaskedPrintOption('--')
+
+def _recursive_printoption(result, mask, printopt):
+ """
+ Puts printoptions in result where mask is True.
+ Private function allowing for recursion
+ """
+ names = result.dtype.names
+ for name in names:
+ (curdata, curmask) = (result[name], mask[name])
+ if curdata.dtype.names:
+ _recursive_printoption(curdata, curmask, printopt)
+ else:
+ np.putmask(curdata, curmask, printopt)
+ return
+
+
#####--------------------------------------------------------------------------
#---- --- MaskedArray class ---
#####--------------------------------------------------------------------------
@@ -2184,13 +2229,9 @@
res = self._data.astype("|O8")
res[m] = f
else:
- rdtype = [list(_) for _ in self.dtype.descr]
- for r in rdtype:
- r[1] = '|O8'
- rdtype = [tuple(_) for _ in rdtype]
+ rdtype = _recursive_make_descr(self.dtype, "|O8")
res = self._data.astype(rdtype)
- for field in names:
- np.putmask(res[field], m[field], f)
+ _recursive_printoption(res, m, f)
else:
res = self.filled(self.fill_value)
return str(res)
Modified: trunk/numpy/ma/tests/test_core.py
===================================================================
--- trunk/numpy/ma/tests/test_core.py 2008-12-23 09:02:15 UTC (rev 6193)
+++ trunk/numpy/ma/tests/test_core.py 2008-12-23 23:43:43 UTC (rev 6194)
@@ -483,6 +483,16 @@
y._optinfo['info'] = '!!!'
assert_equal(x._optinfo['info'], '???')
+
+ def test_fancy_printoptions(self):
+ "Test printing a masked array w/ fancy dtype."
+ fancydtype = np.dtype([('x', int), ('y', [('t', int), ('s', float)])])
+ test = array([(1, (2, 3.0)), (4, (5, 6.0))],
+ mask=[(1, (0, 1)), (0, (1, 0))],
+ dtype=fancydtype)
+ control = "[(--, (2, --)) (4, (--, 6.0))]"
+ assert_equal(str(test), control)
+
#------------------------------------------------------------------------------
class TestMaskedArrayArithmetic(TestCase):
@@ -1049,19 +1059,19 @@
# The shape shouldn't matter
ndtype = [('f0', float, (2, 2))]
control = np.array((default_fill_value(0.),),
- dtype=[('f0',float)])
+ dtype=[('f0',float)]).astype(ndtype)
assert_equal(_check_fill_value(None, ndtype), control)
- control = np.array((0,), dtype=[('f0',float)])
+ control = np.array((0,), dtype=[('f0',float)]).astype(ndtype)
assert_equal(_check_fill_value(0, ndtype), control)
#
ndtype = np.dtype("int, (2,3)float, float")
control = np.array((default_fill_value(0),
default_fill_value(0.),
default_fill_value(0.),),
- dtype="int, float, float")
+ dtype="int, float, float").astype(ndtype)
test = _check_fill_value(None, ndtype)
assert_equal(test, control)
- control = np.array((0,0,0), dtype="int, float, float")
+ control = np.array((0,0,0), dtype="int, float, float").astype(ndtype)
assert_equal(_check_fill_value(0, ndtype), control)
#------------------------------------------------------------------------------
@@ -1912,8 +1922,8 @@
dtype=ndtype)
data[[0,1,2,-1]] = masked
record = data.torecords()
- assert_equal(record['_data'], data._data)
- assert_equal(record['_mask'], data._mask)
+ assert_equal_records(record['_data'], data._data)
+ assert_equal_records(record['_mask'], data._mask)
#------------------------------------------------------------------------------
@@ -2531,6 +2541,12 @@
test = mask_or(mask, other)
except ValueError:
pass
+ # Using nested arrays
+ dtype = [('a', np.bool), ('b', [('ba', np.bool), ('bb', np.bool)])]
+ amask = np.array([(0, (1, 0)), (0, (1, 0))], dtype=dtype)
+ bmask = np.array([(1, (0, 1)), (0, (0, 0))], dtype=dtype)
+ cntrl = np.array([(1, (1, 1)), (0, (1, 0))], dtype=dtype)
+ assert_equal(mask_or(amask, bmask), cntrl)
def test_flatten_mask(self):
Modified: trunk/numpy/ma/tests/test_mrecords.py
===================================================================
--- trunk/numpy/ma/tests/test_mrecords.py 2008-12-23 09:02:15 UTC (rev 6193)
+++ trunk/numpy/ma/tests/test_mrecords.py 2008-12-23 23:43:43 UTC (rev 6194)
@@ -334,8 +334,8 @@
mult[0] = masked
mult[1] = (1, 1, 1)
mult.filled(0)
- assert_equal(mult.filled(0),
- np.array([(0,0,0),(1,1,1)], dtype=mult.dtype))
+ assert_equal_records(mult.filled(0),
+ np.array([(0,0,0),(1,1,1)], dtype=mult.dtype))
class TestView(TestCase):
Modified: trunk/numpy/ma/testutils.py
===================================================================
--- trunk/numpy/ma/testutils.py 2008-12-23 09:02:15 UTC (rev 6193)
+++ trunk/numpy/ma/testutils.py 2008-12-23 23:43:43 UTC (rev 6194)
@@ -110,14 +110,14 @@
return _assert_equal_on_sequences(actual.tolist(),
desired.tolist(),
err_msg='')
- elif actual_dtype.char in "OV" and desired_dtype.char in "OV":
- if (actual_dtype != desired_dtype) and actual_dtype:
- msg = build_err_msg([actual_dtype, desired_dtype],
- err_msg, header='', names=('actual', 'desired'))
- raise ValueError(msg)
- return _assert_equal_on_sequences(actual.tolist(),
- desired.tolist(),
- err_msg='')
+# elif actual_dtype.char in "OV" and desired_dtype.char in "OV":
+# if (actual_dtype != desired_dtype) and actual_dtype:
+# msg = build_err_msg([actual_dtype, desired_dtype],
+# err_msg, header='', names=('actual', 'desired'))
+# raise ValueError(msg)
+# return _assert_equal_on_sequences(actual.tolist(),
+# desired.tolist(),
+# err_msg='')
return assert_array_equal(actual, desired, err_msg)
@@ -171,15 +171,14 @@
# yf = filled(y)
# Allocate a common mask and refill
m = mask_or(getmask(x), getmask(y))
- x = masked_array(x, copy=False, mask=m, subok=False)
- y = masked_array(y, copy=False, mask=m, subok=False)
+ x = masked_array(x, copy=False, mask=m, keep_mask=False, subok=False)
+ y = masked_array(y, copy=False, mask=m, keep_mask=False, subok=False)
if ((x is masked) and not (y is masked)) or \
((y is masked) and not (x is masked)):
msg = build_err_msg([x, y], err_msg=err_msg, verbose=verbose,
header=header, names=('x', 'y'))
raise ValueError(msg)
# OK, now run the basic tests on filled versions
- comparison = getattr(np, comparison.__name__, lambda x,y: True)
return utils.assert_array_compare(comparison,
x.filled(fill_value),
y.filled(fill_value),
@@ -189,7 +188,8 @@
def assert_array_equal(x, y, err_msg='', verbose=True):
"""Checks the elementwise equality of two masked arrays."""
- assert_array_compare(equal, x, y, err_msg=err_msg, verbose=verbose,
+ assert_array_compare(operator.__eq__, x, y,
+ err_msg=err_msg, verbose=verbose,
header='Arrays are not equal')
@@ -223,7 +223,8 @@
def assert_array_less(x, y, err_msg='', verbose=True):
"Checks that x is smaller than y elementwise."
- assert_array_compare(less, x, y, err_msg=err_msg, verbose=verbose,
+ assert_array_compare(operator.__lt__, x, y,
+ err_msg=err_msg, verbose=verbose,
header='Arrays are not less-ordered')
More information about the Numpy-svn
mailing list