[pypy-commit] pypy numpy-multidim: creation from sequences

fijal noreply at buildbot.pypy.org
Thu Oct 27 19:39:30 CEST 2011


Author: Maciej Fijalkowski <fijall at gmail.com>
Branch: numpy-multidim
Changeset: r48531:d73fb983a4bb
Date: 2011-10-27 19:38 +0200
http://bitbucket.org/pypy/pypy/changeset/d73fb983a4bb/

Log:	creation from sequences

diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -16,6 +16,42 @@
                                                        'dtype'])
 slice_driver = jit.JitDriver(greens=['signature'], reds=['i', 'self', 'source'])
 
+def _find_dtype(space, w_iterable):
+    stack = [w_iterable]
+    w_dtype = None
+    while stack:
+        w_next = stack.pop()
+        if space.issequence_w(w_next):
+            for w_item in space.listview(w_next):
+                stack.append(w_item)
+        else:
+            w_dtype = interp_ufuncs.find_dtype_for_scalar(space, w_item, w_dtype)
+            if w_dtype is space.fromcache(interp_dtype.W_Float64Dtype):
+                return w_dtype
+    if w_dtype is None:
+        return space.w_None
+    return w_dtype
+
+def _find_shape_and_elems(space, w_iterable):
+    shape = [space.len_w(w_iterable)]
+    batch = space.listview(w_iterable)
+    while True:
+        new_batch = []
+        if not space.issequence_w(batch[0]):
+            for elem in batch:
+                if space.issequence_w(elem):
+                    raise OperationError(space.w_ValueError, space.wrap(
+                        "setting an array element with a sequence"))
+            return shape, batch
+        size = space.len_w(batch[0])
+        for w_elem in batch:
+            if not space.issequence_w(w_elem) or space.len_w(w_elem) != size:
+                raise OperationError(space.w_ValueError, space.wrap(
+                    "setting an array element with a sequence"))
+            new_batch += space.listview(w_elem)
+        shape.append(size)
+        batch = new_batch
+
 class BaseArray(Wrappable):
     _attrs_ = ["invalidates", "signature"]
 
@@ -36,24 +72,18 @@
         self.invalidates.append(other)
 
     def descr__new__(space, w_subtype, w_size_or_iterable, w_dtype=None):
-        l = space.listview(w_size_or_iterable)
+        # find scalar
         if space.is_w(w_dtype, space.w_None):
-            w_dtype = None
-            for w_item in l:
-                w_dtype = interp_ufuncs.find_dtype_for_scalar(space, w_item, w_dtype)
-                if w_dtype is space.fromcache(interp_dtype.W_Float64Dtype):
-                    break
-            if w_dtype is None:
-                w_dtype = space.w_None
-
+            w_dtype = _find_dtype(space, w_size_or_iterable)
         dtype = space.interp_w(interp_dtype.W_Dtype,
             space.call_function(space.gettypefor(interp_dtype.W_Dtype), w_dtype)
         )
-        arr = NDimArray(len(l), [len(l)], dtype=dtype)
+        shape, elems_w = _find_shape_and_elems(space, w_size_or_iterable)
+        size = len(elems_w)
+        arr = NDimArray(size, shape, dtype=dtype)
         i = 0
-        for w_elem in l:
+        for i, w_elem in enumerate(elems_w):
             dtype.setitem_w(space, arr.storage, i, w_elem)
-            i += 1
         return arr
 
     def _unaryop_impl(ufunc_name):
diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -49,6 +49,9 @@
         from numpy import array
         a = array(range(5))
         assert a[3] == 3
+        a = array(1)
+        assert a[0] == 1
+        assert a.shape == ()
 
     def test_repr(self):
         from numpy import array, zeros
@@ -660,6 +663,17 @@
         assert a[0][1][1] == 13
         assert a[1][2][1] == 15
 
+    def test_init_2(self):
+        import numpy
+        raises(ValueError, numpy.array, [[1], 2])
+        raises(ValueError, numpy.array, [[1, 2], [3]])
+        raises(ValueError, numpy.array, [[[1, 2], [3, 4], 5]])
+        raises(ValueError, numpy.array, [[[1, 2], [3, 4], [5]]])
+        a = numpy.array([[1, 2], [4, 5]])
+        assert a[0, 1] == a[0][1] == 2
+        a = numpy.array(([[[1, 2], [3, 4], [5, 6]]]))
+        assert a[0, 1] == [3, 4]
+
     def test_setitem_slice(self):
         import numpy
         a = numpy.zeros((3, 4))


More information about the pypy-commit mailing list