[pypy-commit] pypy numpy-indexing-by-arrays: Initial (unoptimized) impementation of indexing by boolean vectors.

snus_mumrik noreply at buildbot.pypy.org
Sat Sep 10 18:27:23 CEST 2011


Author: Ilya Osadchiy <osadchiy.ilya at gmail.com>
Branch: numpy-indexing-by-arrays
Changeset: r47197:fc54fc827233
Date: 2011-09-10 19:25 +0300
http://bitbucket.org/pypy/pypy/changeset/fc54fc827233/

Log:	Initial (unoptimized) impementation of indexing by boolean vectors.

diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -230,8 +230,12 @@
             bool_dtype = space.fromcache(interp_dtype.W_BoolDtype)
             int_dtype = space.fromcache(interp_dtype.W_Int64Dtype)
             if w_idx.find_dtype() is bool_dtype:
-                # TODO: indexing by bool array
-                raise NotImplementedError("sorry, not yet implemented")
+                # Indexing by boolean array
+                new_sig = signature.Signature.find_sig([
+                    IndexedByBoolArray.signature, self.signature
+                ])                
+                res = IndexedByBoolArray(new_sig, bool_dtype, self, w_idx)
+                return space.wrap(res)
             else:
                 # Indexing by array
 
@@ -470,6 +474,54 @@
         val = self.source.eval(idx).convert_to(self.res_dtype)
         return val
 
+class IndexedByBoolArray(VirtualArray):
+    """
+    Intermediate class for performing indexing of array by another array
+    """
+    # TODO: override "compute" to optimize (?)
+    signature = signature.BaseSignature()
+    def __init__(self, signature, bool_dtype, source, index):
+        VirtualArray.__init__(self, signature, source.find_dtype())
+        self.source = source
+        self.index = index
+        self.bool_dtype = bool_dtype
+        self.size = -1
+        self.cur_idx = 0
+
+    def _del_sources(self):
+        self.source = None
+        self.index = None
+
+    def _find_size(self):
+        # Finding size may be long, so we store the result for reuse.
+        if self.size != -1:
+            return self.size
+        # TODO: avoid index.get_concrete by using "sum" (reduce with "add")
+        idxs = self.index.get_concrete()
+        s = 0
+        i = 0
+        while i < self.index.find_size():
+            idx_val = self.bool_dtype.unbox(idxs.eval(i).convert_to(self.bool_dtype))
+            assert(isinstance(idx_val, bool))
+            if idx_val is True:
+                s += 1
+            i += 1
+        self.size = s
+        return self.size
+
+    def _eval(self, i):
+        if i == 0:
+            self.cur_idx = 0
+        while True:
+            idx_val = self.bool_dtype.unbox(self.index.eval(self.cur_idx).convert_to(self.bool_dtype))
+            assert(isinstance(idx_val, bool))
+            if idx_val is True:
+                break
+            self.cur_idx += 1
+        val = self.source.eval(self.cur_idx).convert_to(self.res_dtype)
+        self.cur_idx += 1
+        return val
+
 class ViewArray(BaseArray):
     """
     Class for representing views of arrays, they will reflect changes of parent
diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -133,6 +133,22 @@
         for i in xrange(6):
             assert a_by_list[i] == range(5)[idx_list[i]]
 
+    def test_index_by_bool_array(self):
+        from numpy import array, dtype
+        a = array(range(5))
+        ind = array([False, True, False, True, False])
+        assert ind.dtype is dtype(bool)
+        # get length before actual calculation
+        b0 = a[ind]
+        assert len(b0) == 2
+        assert b0[0] == 1
+        assert b0[1] == 3
+        # get length after actual calculation
+        b1 = a[ind]
+        assert b1[0] == 1
+        assert b1[1] == 3
+        assert len(b1) == 2
+
     def test_setitem(self):
         from numpy import array
         a = array(range(5))


More information about the pypy-commit mailing list