[pypy-commit] pypy numpy-indexing-by-arrays-2: progress on arr[arr_of_bools] = arr

fijal noreply at buildbot.pypy.org
Tue Jan 17 12:16:18 CET 2012


Author: Maciej Fijalkowski <fijall at gmail.com>
Branch: numpy-indexing-by-arrays-2
Changeset: r51391:fe9bca1da2b6
Date: 2012-01-17 13:15 +0200
http://bitbucket.org/pypy/pypy/changeset/fe9bca1da2b6/

Log:	progress on arr[arr_of_bools] = arr

diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -10,7 +10,7 @@
 from pypy.tool.sourcetools import func_with_new_name
 from pypy.rlib.rstring import StringBuilder
 from pypy.module.micronumpy.interp_iter import ArrayIterator, OneDimIterator,\
-     SkipLastAxisIterator, Chunk
+     SkipLastAxisIterator, Chunk, ViewIterator
 
 numpy_driver = jit.JitDriver(
     greens=['shapelen', 'sig'],
@@ -518,7 +518,7 @@
         res = W_NDimArray(size, [size], self.find_dtype())
         ri = ArrayIterator(size)
         shapelen = len(self.shape)
-        argi = ArrayIterator(concr.size)
+        argi = concr.create_iter()
         sig = self.find_sig()
         frame = sig.create_frame(self)
         v = None
@@ -536,6 +536,18 @@
             frame.next(shapelen)
         return res
 
+    def setitem_filter(self, space, idx, val):
+        arr = SliceArray(self.shape, self.dtype, self, val)
+        shapelen = len(arr.shape)
+        sig = arr.find_sig()
+        frame = sig.create_frame(arr)
+        idxi = idx.create_iter()
+        while not frame.done():
+            if idx.dtype.getitem_bool(idx.storage, idxi.offset):
+                sig.eval(frame, arr)
+            idxi = idxi.next(shapelen)
+            frame.next(shapelen)
+
     def descr_getitem(self, space, w_idx):
         if (isinstance(w_idx, BaseArray) and w_idx.shape == self.shape and
             w_idx.find_dtype().is_bool_type()):
@@ -549,6 +561,11 @@
 
     def descr_setitem(self, space, w_idx, w_value):
         self.invalidated()
+        if (isinstance(w_idx, BaseArray) and w_idx.shape == self.shape and
+            w_idx.find_dtype().is_bool_type()):
+            return self.get_concrete().setitem_filter(space,
+                                                      w_idx.get_concrete(),
+                                             convert_to_array(space, w_value))
         if self._single_item_result(space, w_idx):
             concrete = self.get_concrete()
             item = concrete._index_of_single_item(space, w_idx)
@@ -1135,6 +1152,10 @@
                                parent)
         self.start = start
 
+    def create_iter(self):
+        return ViewIterator(self.start, self.strides, self.backstrides,
+                            self.shape)
+
     def setshape(self, space, new_shape):
         if len(self.shape) < 1:
             return
@@ -1181,6 +1202,9 @@
         self.shape = new_shape
         self.calc_strides(new_shape)
 
+    def create_iter(self):
+        return ArrayIterator(self.size)
+
     def create_sig(self):
         return signature.ArraySignature(self.dtype)
 
@@ -1235,6 +1259,7 @@
     arr = W_NDimArray(size, shape[:], dtype=dtype, order=order)
     shapelen = len(shape)
     arr_iter = ArrayIterator(arr.size)
+    # XXX we might want to have a jitdriver here
     for i in range(len(elems_w)):
         w_elem = elems_w[i]
         dtype.setitem(arr.storage, arr_iter.offset,
diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -1320,6 +1320,15 @@
         assert (a[a > 3] == [4, 5, 6, 7, 8, 9]).all()
         assert (a[a & 1 == 1] == [1, 3, 5, 7, 9]).all()
 
+    def test_array_indexing_bool_setitem(self):
+        from _numpypy import arange, array
+        a = arange(6)
+        a[a > 3] = 15
+        assert (a == [0, 1, 2, 3, 15, 15]).all()
+        a = arange(6).reshape(3, 2)
+        a[a & 1 == 1] = array([8, 9, 10])
+        assert (a == [[0, 8], [3, 9], [5, 10]]).all()
+
 class AppTestSupport(BaseNumpyAppTest):
     def setup_class(cls):
         import struct


More information about the pypy-commit mailing list