[pypy-commit] pypy numpy-multidim-shards: add shape_agreement broadcast test, fix for it

mattip noreply at buildbot.pypy.org
Sun Nov 20 06:07:24 CET 2011


Author: mattip
Branch: numpy-multidim-shards
Changeset: r49562:cf0782e42b72
Date: 2011-11-20 07:04 +0200
http://bitbucket.org/pypy/pypy/changeset/cf0782e42b72/

Log:	add shape_agreement broadcast test, fix for it

diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -203,8 +203,16 @@
     def __init__(self, arr, res_shape):
         self.indices = [0] * len(res_shape)
         self.offset  = arr.start
-        self.shards  = [s for s in arr.shards]  # Is there a better way to make a copy in rpython?
-        self.backshards = [s for s in arr.backshards]  # Is there a better way to make a copy in rpython?
+        #shards are 0 where original shape==1
+        self.shards = []
+        self.backshards = []
+        for i in range(len(arr.shape)):
+            if arr.shape[i]==1:
+                self.shards.append(0)
+                self.backshards.append(0)
+            else:
+                self.shards.append(arr.shards[i])
+                self.backshards.append(arr.backshards[i])
         self.shape_len = len(res_shape)
         self.res_shape = res_shape
         for i in range(self.shape_len - len(arr.shape)):
diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -859,15 +859,25 @@
         assert c.all()
 
     def test_broadcast_setslice(self):
-        from numpypy import zeros, ones, array
+        from numpypy import zeros, ones
         a = zeros((100, 100))
         b = ones(100)
         a[:, :] = b
         assert a[13, 15] == 1
+
+    def test_broadcast_shape_agreement(self):
+        from numpypy import zeros, array
         a = zeros((3, 1, 3))
         b = array(((10, 11, 12), (20, 21, 22), (30, 31, 32)))
         c = ((a + b) == [b, b, b])
         assert c.all()
+        a = array((((10,11,12), ), ((20, 21, 22), ), ((30,31,32), )))
+        assert(a.shape == (3, 1, 3))
+        d = zeros((3, 3))
+        c = ((a + d) == [b, b, b])
+        c = ((a + d) == array([[[10., 11., 12.]]*3, [[20.,21.,22.]]*3, [[30.,31.,32.]]*3]))
+        assert c.all()
+        
 
 class AppTestSupport(object):
     def setup_class(cls):


More information about the pypy-commit mailing list