[pypy-commit] pypy numpy-refactor: fix setitem_filter with scalar

bdkearns noreply at buildbot.pypy.org
Thu Feb 27 01:59:51 CET 2014


Author: Brian Kearns <bdkearns at gmail.com>
Branch: numpy-refactor
Changeset: r69477:afe7292cc2d3
Date: 2014-02-26 16:54 -0500
http://bitbucket.org/pypy/pypy/changeset/afe7292cc2d3/

Log:	fix setitem_filter with scalar

diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -144,7 +144,7 @@
                 "cannot assign %d input values to "
                 "the %d output values where the mask is true" %
                 (val.get_size(), size)))
-        loop.setitem_filter(space, self, idx, val, size)
+        loop.setitem_filter(space, self, idx, val)
 
     def _prepare_array_index(self, space, w_index):
         if isinstance(w_index, W_NDimArray):
diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py
--- a/pypy/module/micronumpy/loop.py
+++ b/pypy/module/micronumpy/loop.py
@@ -380,14 +380,17 @@
                                                 'index_dtype'],
                                       reds = 'auto')
 
-def setitem_filter(space, arr, index, value, size):
+def setitem_filter(space, arr, index, value):
     arr_iter = arr.create_iter()
     shapelen = len(arr.get_shape())
     if shapelen > 1 and len(index.get_shape()) < 2:
         index_iter = index.create_iter(arr.get_shape(), backward_broadcast=True)
     else:
         index_iter = index.create_iter()
-    value_iter = value.create_iter([size])
+    if value.get_size() == 1:
+        value_iter = value.create_iter(arr.get_shape())
+    else:
+        value_iter = value.create_iter()
     index_dtype = index.get_dtype()
     arr_dtype = arr.get_dtype()
     while not index_iter.done():
diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -2281,6 +2281,12 @@
         assert (a[b] == a).all()
         a[b] = 1.
         assert (a == [[1., 1., 1.]]).all()
+        a[b] = np.array(2.)
+        assert (a == [[2., 2., 2.]]).all()
+        a[b] = np.array([3.])
+        assert (a == [[3., 3., 3.]]).all()
+        a[b] = np.array([[4.]])
+        assert (a == [[4., 4., 4.]]).all()
 
     def test_ellipsis_indexing(self):
         import numpy as np


More information about the pypy-commit mailing list