[Numpy-svn] r8713 - in trunk/numpy/lib: . tests

numpy-svn at scipy.org numpy-svn at scipy.org
Mon Sep 13 08:34:38 EDT 2010


Author: pierregm
Date: 2010-09-13 07:34:37 -0500 (Mon, 13 Sep 2010)
New Revision: 8713

Modified:
   trunk/numpy/lib/_iotools.py
   trunk/numpy/lib/tests/test__iotools.py
Log:
* fixed 'flatten_dtype' to support fields w/ titles (bug #1591). Thx to Stefan vdW for the fix.
* added a unittest for flatten_dtype

Modified: trunk/numpy/lib/_iotools.py
===================================================================
--- trunk/numpy/lib/_iotools.py	2010-09-11 21:36:23 UTC (rev 8712)
+++ trunk/numpy/lib/_iotools.py	2010-09-13 12:34:37 UTC (rev 8713)
@@ -132,8 +132,8 @@
     else:
         types = []
         for field in names:
-            (typ, _) = ndtype.fields[field]
-            flat_dt = flatten_dtype(typ, flatten_base)
+            info = ndtype.fields[field]
+            flat_dt = flatten_dtype(info[0], flatten_base)
             types.extend(flat_dt)
         return types
 

Modified: trunk/numpy/lib/tests/test__iotools.py
===================================================================
--- trunk/numpy/lib/tests/test__iotools.py	2010-09-11 21:36:23 UTC (rev 8712)
+++ trunk/numpy/lib/tests/test__iotools.py	2010-09-13 12:34:37 UTC (rev 8713)
@@ -10,8 +10,8 @@
 import time
 
 import numpy as np
-from numpy.lib._iotools import LineSplitter, NameValidator, StringConverter,\
-                               has_nested_fields, easy_dtype
+from numpy.lib._iotools import LineSplitter, NameValidator, StringConverter, \
+                               has_nested_fields, easy_dtype, flatten_dtype
 from numpy.testing import *
 
 from numpy.compat import asbytes, asbytes_nested
@@ -37,10 +37,10 @@
 
     def test_tab_delimiter(self):
         "Test tab delimiter"
-        strg= asbytes(" 1\t 2\t 3\t 4\t 5  6")
+        strg = asbytes(" 1\t 2\t 3\t 4\t 5  6")
         test = LineSplitter(asbytes('\t'))(strg)
         assert_equal(test, asbytes_nested(['1', '2', '3', '4', '5  6']))
-        strg= asbytes(" 1  2\t 3  4\t 5  6")
+        strg = asbytes(" 1  2\t 3  4\t 5  6")
         test = LineSplitter(asbytes('\t'))(strg)
         assert_equal(test, asbytes_nested(['1  2', '3  4', '5  6']))
 
@@ -70,11 +70,11 @@
 
     def test_variable_fixed_width(self):
         strg = asbytes("  1     3  4  5  6# test")
-        test = LineSplitter((3,6,6,3))(strg)
+        test = LineSplitter((3, 6, 6, 3))(strg)
         assert_equal(test, asbytes_nested(['1', '3', '4  5', '6']))
         #
         strg = asbytes("  1     3  4  5  6# test")
-        test = LineSplitter((6,6,9))(strg)
+        test = LineSplitter((6, 6, 9))(strg)
         assert_equal(test, asbytes_nested(['1', '3  4', '5  6']))
 
 
@@ -97,7 +97,7 @@
     def test_excludelist(self):
         "Test excludelist"
         names = ['dates', 'data', 'Other Data', 'mask']
-        validator = NameValidator(excludelist = ['dates', 'data', 'mask'])
+        validator = NameValidator(excludelist=['dates', 'data', 'mask'])
         test = validator.validate(names)
         assert_equal(test, ['dates_', 'data_', 'Other_Data', 'mask_'])
     #
@@ -117,7 +117,7 @@
         "Test validate nb names"
         namelist = ('a', 'b', 'c')
         validator = NameValidator()
-        assert_equal(validator(namelist, nbfields=1), ('a', ))
+        assert_equal(validator(namelist, nbfields=1), ('a',))
         assert_equal(validator(namelist, nbfields=5, defaultfmt="g%i"),
                      ['a', 'b', 'c', 'g0', 'g1'])
     #
@@ -159,7 +159,7 @@
         converter.upgrade(asbytes('0j'))
         assert_equal(converter._status, 3)
         converter.upgrade(asbytes('a'))
-        assert_equal(converter._status, len(converter._mapper)-1)
+        assert_equal(converter._status, len(converter._mapper) - 1)
     #
     def test_missing(self):
         "Tests the use of missing values."
@@ -178,7 +178,7 @@
     def test_upgrademapper(self):
         "Tests updatemapper"
         dateparser = _bytes_to_date
-        StringConverter.upgrade_mapper(dateparser, date(2000,1,1))
+        StringConverter.upgrade_mapper(dateparser, date(2000, 1, 1))
         convert = StringConverter(dateparser, date(2000, 1, 1))
         test = convert(asbytes('2001-01-01'))
         assert_equal(test, date(2001, 01, 01))
@@ -196,7 +196,7 @@
     def test_keep_default(self):
         "Make sure we don't lose an explicit default"
         converter = StringConverter(None, missing_values=asbytes(''),
-                                    default=-999)
+                                    default= -999)
         converter.upgrade(asbytes('3.14159265'))
         assert_equal(converter.default, -999)
         assert_equal(converter.type, np.dtype(float))
@@ -287,3 +287,25 @@
         assert_equal(easy_dtype(ndtype, names=['', '', ''], defaultfmt="f%02i"),
                      np.dtype([(_, float) for _ in ('f00', 'f01', 'f02')]))
 
+
+    def test_flatten_dtype(self):
+        "Testing flatten_dtype"
+        # Standard dtype
+        dt = np.dtype([("a", "f8"), ("b", "f8")])
+        dt_flat = flatten_dtype(dt)
+        assert_equal(dt_flat, [float, float])
+        # Recursive dtype
+        dt = np.dtype([("a", [("aa", '|S1'), ("ab", '|S2')]), ("b", int)])
+        dt_flat = flatten_dtype(dt)
+        assert_equal(dt_flat, [np.dtype('|S1'), np.dtype('|S2'), int])
+        # dtype with shaped fields
+        dt = np.dtype([("a", (float, 2)), ("b", (int, 3))])
+        dt_flat = flatten_dtype(dt)
+        assert_equal(dt_flat, [float, int])
+        dt_flat = flatten_dtype(dt, True)
+        assert_equal(dt_flat, [float] * 2 + [int] * 3)
+        # dtype w/ titles
+        dt = np.dtype([(("a", "A"), "f8"), (("b", "B"), "f8")])
+        dt_flat = flatten_dtype(dt)
+        assert_equal(dt_flat, [float, float])
+




More information about the Numpy-svn mailing list