[pypy-commit] pypy numpy-multidim-shards: make ones support multidim arrays. Also write a passing test

fijal noreply at buildbot.pypy.org
Sun Nov 20 11:13:11 CET 2011


Author: Maciej Fijalkowski <fijall at gmail.com>
Branch: numpy-multidim-shards
Changeset: r49570:29fe0349fa99
Date: 2011-11-20 12:12 +0200
http://bitbucket.org/pypy/pypy/changeset/29fe0349fa99/

Log:	make ones support multidim arrays. Also write a passing test

diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -1049,13 +1049,21 @@
             shape.append(item)
     return space.wrap(NDimArray(size, shape[:], dtype=dtype))
 
- at unwrap_spec(size=int)
-def ones(space, size, w_dtype=None):
+def ones(space, w_size, w_dtype=None):
     dtype = space.interp_w(interp_dtype.W_Dtype,
         space.call_function(space.gettypefor(interp_dtype.W_Dtype), w_dtype)
     )
-
-    arr = NDimArray(size, [size], dtype=dtype)
+    if space.isinstance_w(w_size, space.w_int):
+        size = space.int_w(w_size)
+        shape = [size]
+    else:
+        size = 1
+        shape = []
+        for w_item in space.fixedview(w_size):
+            item = space.int_w(w_item)
+            size *= item
+            shape.append(item)
+    arr = NDimArray(size, shape[:], dtype=dtype)
     one = dtype.adapt_val(1)
     arr.dtype.fill(arr.storage, one, 0, size)
     return space.wrap(arr)
diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -851,6 +851,11 @@
         c = b + b
         assert c[1][1] == 12
 
+    def test_multidim_ones(self):
+        from numpypy import ones
+        a = ones((1, 2, 3))
+        assert a[0, 1, 2] == 1.0
+
     def test_broadcast_ufunc(self):
         from numpypy import array
         a = array([[1, 2], [3, 4], [5, 6]])
@@ -877,7 +882,13 @@
         c = ((a + d) == [b, b, b])
         c = ((a + d) == array([[[10., 11., 12.]]*3, [[20.,21.,22.]]*3, [[30.,31.,32.]]*3]))
         assert c.all()
-        
+
+    def test_broadcast_call2(self):
+        from numpypy import zeros, ones
+        a = zeros((4, 1, 5))
+        b = ones((4, 3, 5))
+        b[:] = (a + a)
+        assert (b == zeros((4, 3, 5))).all()
 
 class AppTestSupport(object):
     def setup_class(cls):


More information about the pypy-commit mailing list