[Numpy-svn] r8713 - in trunk/numpy/lib: . tests
numpy-svn at scipy.org
numpy-svn at scipy.org
Mon Sep 13 08:34:38 EDT 2010
Author: pierregm
Date: 2010-09-13 07:34:37 -0500 (Mon, 13 Sep 2010)
New Revision: 8713
Modified:
trunk/numpy/lib/_iotools.py
trunk/numpy/lib/tests/test__iotools.py
Log:
* fixed 'flatten_dtype' to support fields w/ titles (bug #1591). Thx to Stefan vdW for the fix.
* added a unittest for flatten_dtype
Modified: trunk/numpy/lib/_iotools.py
===================================================================
--- trunk/numpy/lib/_iotools.py 2010-09-11 21:36:23 UTC (rev 8712)
+++ trunk/numpy/lib/_iotools.py 2010-09-13 12:34:37 UTC (rev 8713)
@@ -132,8 +132,8 @@
else:
types = []
for field in names:
- (typ, _) = ndtype.fields[field]
- flat_dt = flatten_dtype(typ, flatten_base)
+ info = ndtype.fields[field]
+ flat_dt = flatten_dtype(info[0], flatten_base)
types.extend(flat_dt)
return types
Modified: trunk/numpy/lib/tests/test__iotools.py
===================================================================
--- trunk/numpy/lib/tests/test__iotools.py 2010-09-11 21:36:23 UTC (rev 8712)
+++ trunk/numpy/lib/tests/test__iotools.py 2010-09-13 12:34:37 UTC (rev 8713)
@@ -10,8 +10,8 @@
import time
import numpy as np
-from numpy.lib._iotools import LineSplitter, NameValidator, StringConverter,\
- has_nested_fields, easy_dtype
+from numpy.lib._iotools import LineSplitter, NameValidator, StringConverter, \
+ has_nested_fields, easy_dtype, flatten_dtype
from numpy.testing import *
from numpy.compat import asbytes, asbytes_nested
@@ -37,10 +37,10 @@
def test_tab_delimiter(self):
"Test tab delimiter"
- strg= asbytes(" 1\t 2\t 3\t 4\t 5 6")
+ strg = asbytes(" 1\t 2\t 3\t 4\t 5 6")
test = LineSplitter(asbytes('\t'))(strg)
assert_equal(test, asbytes_nested(['1', '2', '3', '4', '5 6']))
- strg= asbytes(" 1 2\t 3 4\t 5 6")
+ strg = asbytes(" 1 2\t 3 4\t 5 6")
test = LineSplitter(asbytes('\t'))(strg)
assert_equal(test, asbytes_nested(['1 2', '3 4', '5 6']))
@@ -70,11 +70,11 @@
def test_variable_fixed_width(self):
strg = asbytes(" 1 3 4 5 6# test")
- test = LineSplitter((3,6,6,3))(strg)
+ test = LineSplitter((3, 6, 6, 3))(strg)
assert_equal(test, asbytes_nested(['1', '3', '4 5', '6']))
#
strg = asbytes(" 1 3 4 5 6# test")
- test = LineSplitter((6,6,9))(strg)
+ test = LineSplitter((6, 6, 9))(strg)
assert_equal(test, asbytes_nested(['1', '3 4', '5 6']))
@@ -97,7 +97,7 @@
def test_excludelist(self):
"Test excludelist"
names = ['dates', 'data', 'Other Data', 'mask']
- validator = NameValidator(excludelist = ['dates', 'data', 'mask'])
+ validator = NameValidator(excludelist=['dates', 'data', 'mask'])
test = validator.validate(names)
assert_equal(test, ['dates_', 'data_', 'Other_Data', 'mask_'])
#
@@ -117,7 +117,7 @@
"Test validate nb names"
namelist = ('a', 'b', 'c')
validator = NameValidator()
- assert_equal(validator(namelist, nbfields=1), ('a', ))
+ assert_equal(validator(namelist, nbfields=1), ('a',))
assert_equal(validator(namelist, nbfields=5, defaultfmt="g%i"),
['a', 'b', 'c', 'g0', 'g1'])
#
@@ -159,7 +159,7 @@
converter.upgrade(asbytes('0j'))
assert_equal(converter._status, 3)
converter.upgrade(asbytes('a'))
- assert_equal(converter._status, len(converter._mapper)-1)
+ assert_equal(converter._status, len(converter._mapper) - 1)
#
def test_missing(self):
"Tests the use of missing values."
@@ -178,7 +178,7 @@
def test_upgrademapper(self):
"Tests updatemapper"
dateparser = _bytes_to_date
- StringConverter.upgrade_mapper(dateparser, date(2000,1,1))
+ StringConverter.upgrade_mapper(dateparser, date(2000, 1, 1))
convert = StringConverter(dateparser, date(2000, 1, 1))
test = convert(asbytes('2001-01-01'))
assert_equal(test, date(2001, 01, 01))
@@ -196,7 +196,7 @@
def test_keep_default(self):
"Make sure we don't lose an explicit default"
converter = StringConverter(None, missing_values=asbytes(''),
- default=-999)
+ default= -999)
converter.upgrade(asbytes('3.14159265'))
assert_equal(converter.default, -999)
assert_equal(converter.type, np.dtype(float))
@@ -287,3 +287,25 @@
assert_equal(easy_dtype(ndtype, names=['', '', ''], defaultfmt="f%02i"),
np.dtype([(_, float) for _ in ('f00', 'f01', 'f02')]))
+
+ def test_flatten_dtype(self):
+ "Testing flatten_dtype"
+ # Standard dtype
+ dt = np.dtype([("a", "f8"), ("b", "f8")])
+ dt_flat = flatten_dtype(dt)
+ assert_equal(dt_flat, [float, float])
+ # Recursive dtype
+ dt = np.dtype([("a", [("aa", '|S1'), ("ab", '|S2')]), ("b", int)])
+ dt_flat = flatten_dtype(dt)
+ assert_equal(dt_flat, [np.dtype('|S1'), np.dtype('|S2'), int])
+ # dtype with shaped fields
+ dt = np.dtype([("a", (float, 2)), ("b", (int, 3))])
+ dt_flat = flatten_dtype(dt)
+ assert_equal(dt_flat, [float, int])
+ dt_flat = flatten_dtype(dt, True)
+ assert_equal(dt_flat, [float] * 2 + [int] * 3)
+ # dtype w/ titles
+ dt = np.dtype([(("a", "A"), "f8"), (("b", "B"), "f8")])
+ dt_flat = flatten_dtype(dt)
+ assert_equal(dt_flat, [float, float])
+
More information about the Numpy-svn
mailing list