[pypy-commit] pypy numpy-refactor: fix setitem_filter with scalar
bdkearns
noreply at buildbot.pypy.org
Thu Feb 27 01:59:51 CET 2014
Author: Brian Kearns <bdkearns at gmail.com>
Branch: numpy-refactor
Changeset: r69477:afe7292cc2d3
Date: 2014-02-26 16:54 -0500
http://bitbucket.org/pypy/pypy/changeset/afe7292cc2d3/
Log: fix setitem_filter with scalar
diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -144,7 +144,7 @@
"cannot assign %d input values to "
"the %d output values where the mask is true" %
(val.get_size(), size)))
- loop.setitem_filter(space, self, idx, val, size)
+ loop.setitem_filter(space, self, idx, val)
def _prepare_array_index(self, space, w_index):
if isinstance(w_index, W_NDimArray):
diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py
--- a/pypy/module/micronumpy/loop.py
+++ b/pypy/module/micronumpy/loop.py
@@ -380,14 +380,17 @@
'index_dtype'],
reds = 'auto')
-def setitem_filter(space, arr, index, value, size):
+def setitem_filter(space, arr, index, value):
arr_iter = arr.create_iter()
shapelen = len(arr.get_shape())
if shapelen > 1 and len(index.get_shape()) < 2:
index_iter = index.create_iter(arr.get_shape(), backward_broadcast=True)
else:
index_iter = index.create_iter()
- value_iter = value.create_iter([size])
+ if value.get_size() == 1:
+ value_iter = value.create_iter(arr.get_shape())
+ else:
+ value_iter = value.create_iter()
index_dtype = index.get_dtype()
arr_dtype = arr.get_dtype()
while not index_iter.done():
diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -2281,6 +2281,12 @@
assert (a[b] == a).all()
a[b] = 1.
assert (a == [[1., 1., 1.]]).all()
+ a[b] = np.array(2.)
+ assert (a == [[2., 2., 2.]]).all()
+ a[b] = np.array([3.])
+ assert (a == [[3., 3., 3.]]).all()
+ a[b] = np.array([[4.]])
+ assert (a == [[4., 4., 4.]]).all()
def test_ellipsis_indexing(self):
import numpy as np
More information about the pypy-commit
mailing list