[pypy-commit] pypy numpy-back-to-applevel: clean up scalar reshape and ravel

fijal noreply at buildbot.pypy.org
Fri Jan 27 21:08:20 CET 2012


Author: Maciej Fijalkowski <fijall at gmail.com>
Branch: numpy-back-to-applevel
Changeset: r51887:891a2ea64919
Date: 2012-01-27 22:07 +0200
http://bitbucket.org/pypy/pypy/changeset/891a2ea64919/

Log:	clean up scalar reshape and ravel

diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -274,7 +274,8 @@
 
     def descr_flatten(self, space, w_order=None):
         if isinstance(self, Scalar):
-            return self.copy(space)
+            # scalars have no storage
+            return self.descr_reshape(space, [space.wrap([1])])
         concr = self.get_concrete()
         w_res = concr.descr_ravel(space, w_order)
         if w_res.storage == concr.storage:
@@ -479,8 +480,11 @@
             w_shape = args_w[0]
         else:
             w_shape = space.newtuple(args_w)
+        new_shape = get_shape_from_iterable(space, self.size, w_shape)
+        return self.reshape(space, new_shape)
+        
+    def reshape(self, space, new_shape):
         concrete = self.get_concrete()
-        new_shape = get_shape_from_iterable(space, concrete.size, w_shape)
         # Since we got to here, prod(new_shape) == self.size
         new_strides = calc_new_strides(new_shape, concrete.shape,
                                      concrete.strides, concrete.order)
@@ -693,6 +697,11 @@
     def get_concrete_or_scalar(self):
         return self
 
+    def reshape(self, space, new_shape):
+        size = support.product(new_shape)
+        res = W_NDimArray(size, new_shape, self.dtype, 'C')
+        res.setitem(0, self.value)
+        return res
 
 class VirtualArray(BaseArray):
     """
diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -469,6 +469,13 @@
         y = z.reshape(4, 3, 8)
         assert y.shape == (4, 3, 8)
 
+    def test_scalar_reshape(self):
+        from numpypy import array
+        a = array(3)
+        assert a.reshape([1, 1]).shape == (1, 1)
+        assert a.reshape([1]).shape == (1,)
+        raises(ValueError, "a.reshape(3)")
+
     def test_add(self):
         from _numpypy import array
         a = array(range(5))
@@ -1104,6 +1111,7 @@
     def test_flatten(self):
         from _numpypy import array
 
+        assert array(3).flatten().shape == (1,)
         a = array([[1, 2], [3, 4]])
         b = a.flatten()
         c = a.ravel()


More information about the pypy-commit mailing list