[pypy-commit] pypy newindex: hack in support for [None] to create new index

MichaelBlume noreply at buildbot.pypy.org
Sat Mar 17 09:38:29 CET 2012


Author: Mike Blume <mike at loggly.com>
Branch: newindex
Changeset: r53752:4805ca9b8d4c
Date: 2012-03-14 15:04 -0700
http://bitbucket.org/pypy/pypy/changeset/4805ca9b8d4c/

Log:	hack in support for [None] to create new index

diff --git a/pypy/module/micronumpy/interp_iter.py b/pypy/module/micronumpy/interp_iter.py
--- a/pypy/module/micronumpy/interp_iter.py
+++ b/pypy/module/micronumpy/interp_iter.py
@@ -50,6 +50,7 @@
 # structures to describe slicing
 
 class Chunk(object):
+    ind_step = 1
     def __init__(self, start, stop, step, lgt):
         self.start = start
         self.stop = stop
@@ -64,6 +65,16 @@
         return 'Chunk(%d, %d, %d, %d)' % (self.start, self.stop, self.step,
                                           self.lgt)
 
+class NewIndexChunk(Chunk):
+    start = 0
+    stop = 1
+    step = 1
+    lgt = 1
+    ind_step = 0
+
+    def __init__(self):
+        pass
+
 class BaseTransform(object):
     pass
 
diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -7,10 +7,10 @@
 from pypy.module.micronumpy.appbridge import get_appbridge_cache
 from pypy.module.micronumpy.dot import multidim_dot, match_dot_shapes
 from pypy.module.micronumpy.interp_iter import (ArrayIterator,
-    SkipLastAxisIterator, Chunk, ViewIterator)
+    SkipLastAxisIterator, Chunk, NewIndexChunk, ViewIterator)
 from pypy.module.micronumpy.strides import (calculate_slice_strides,
     shape_agreement, find_shape_and_elems, get_shape_from_iterable,
-    calc_new_strides, to_coords)
+    calc_new_strides, to_coords, enumerate_chunks)
 from pypy.rlib import jit
 from pypy.rlib.rstring import StringBuilder
 from pypy.rpython.lltypesystem import lltype, rffi
@@ -321,6 +321,13 @@
         is a list of scalars that match the size of shape
         """
         shape_len = len(self.shape)
+        if space.isinstance_w(w_idx, space.w_tuple):
+            for w_item in space.fixedview(w_idx):
+                if (space.isinstance_w(w_item, space.w_slice) or
+                    space.isinstance_w(w_item, space.w_NoneType)):
+                    return False
+        elif space.isinstance_w(w_idx, space.w_NoneType):
+            return False
         if shape_len == 0:
             raise OperationError(space.w_IndexError, space.wrap(
                 "0-d arrays can't be indexed"))
@@ -336,20 +343,25 @@
         if lgt > shape_len:
             raise OperationError(space.w_IndexError,
                                  space.wrap("invalid index"))
-        if lgt < shape_len:
-            return False
-        for w_item in space.fixedview(w_idx):
-            if space.isinstance_w(w_item, space.w_slice):
-                return False
-        return True
+        return lgt == shape_len
 
     @jit.unroll_safe
     def _prepare_slice_args(self, space, w_idx):
         if (space.isinstance_w(w_idx, space.w_int) or
             space.isinstance_w(w_idx, space.w_slice)):
             return [Chunk(*space.decode_index4(w_idx, self.shape[0]))]
-        return [Chunk(*space.decode_index4(w_item, self.shape[i])) for i, w_item in
-                enumerate(space.fixedview(w_idx))]
+        elif space.isinstance_w(w_idx, space.w_NoneType):
+            return [NewIndexChunk()]
+        result = []
+        i = 0
+        for w_item in space.fixedview(w_idx):
+            if space.isinstance_w(w_item, space.w_NoneType):
+                result.append(NewIndexChunk())
+            else:
+                result.append(Chunk(*space.decode_index4(w_item,
+                                                         self.shape[i])))
+                i += 1
+        return result
 
     def count_all_true(self, arr):
         sig = arr.find_sig()
@@ -443,7 +455,7 @@
     def create_slice(self, chunks):
         shape = []
         i = -1
-        for i, chunk in enumerate(chunks):
+        for i, chunk in enumerate_chunks(chunks):
             chunk.extend_shape(shape)
         s = i + 1
         assert s >= 0
diff --git a/pypy/module/micronumpy/strides.py b/pypy/module/micronumpy/strides.py
--- a/pypy/module/micronumpy/strides.py
+++ b/pypy/module/micronumpy/strides.py
@@ -1,6 +1,14 @@
 from pypy.rlib import jit
 from pypy.interpreter.error import OperationError
 
+def enumerate_chunks(chunks):
+    result = []
+    i = -1
+    for chunk in chunks:
+        i += chunk.ind_step
+        result.append((i, chunk))
+    return result
+
 @jit.look_inside_iff(lambda shape, start, strides, backstrides, chunks:
     jit.isconstant(len(chunks))
 )
@@ -10,7 +18,7 @@
     rstart = start
     rshape = []
     i = -1
-    for i, chunk in enumerate(chunks):
+    for i, chunk in enumerate_chunks(chunks):
         if chunk.step != 0:
             rstrides.append(strides[i] * chunk.step)
             rbackstrides.append(strides[i] * (chunk.lgt - 1) * chunk.step)


More information about the pypy-commit mailing list