[pypy-commit] pypy newindex: hack in support for [None] to create new index
MichaelBlume
noreply at buildbot.pypy.org
Sat Mar 17 09:38:29 CET 2012
Author: Mike Blume <mike at loggly.com>
Branch: newindex
Changeset: r53752:4805ca9b8d4c
Date: 2012-03-14 15:04 -0700
http://bitbucket.org/pypy/pypy/changeset/4805ca9b8d4c/
Log: hack in support for [None] to create new index
diff --git a/pypy/module/micronumpy/interp_iter.py b/pypy/module/micronumpy/interp_iter.py
--- a/pypy/module/micronumpy/interp_iter.py
+++ b/pypy/module/micronumpy/interp_iter.py
@@ -50,6 +50,7 @@
# structures to describe slicing
class Chunk(object):
+ ind_step = 1
def __init__(self, start, stop, step, lgt):
self.start = start
self.stop = stop
@@ -64,6 +65,16 @@
return 'Chunk(%d, %d, %d, %d)' % (self.start, self.stop, self.step,
self.lgt)
+class NewIndexChunk(Chunk):
+ start = 0
+ stop = 1
+ step = 1
+ lgt = 1
+ ind_step = 0
+
+ def __init__(self):
+ pass
+
class BaseTransform(object):
pass
diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -7,10 +7,10 @@
from pypy.module.micronumpy.appbridge import get_appbridge_cache
from pypy.module.micronumpy.dot import multidim_dot, match_dot_shapes
from pypy.module.micronumpy.interp_iter import (ArrayIterator,
- SkipLastAxisIterator, Chunk, ViewIterator)
+ SkipLastAxisIterator, Chunk, NewIndexChunk, ViewIterator)
from pypy.module.micronumpy.strides import (calculate_slice_strides,
shape_agreement, find_shape_and_elems, get_shape_from_iterable,
- calc_new_strides, to_coords)
+ calc_new_strides, to_coords, enumerate_chunks)
from pypy.rlib import jit
from pypy.rlib.rstring import StringBuilder
from pypy.rpython.lltypesystem import lltype, rffi
@@ -321,6 +321,13 @@
is a list of scalars that match the size of shape
"""
shape_len = len(self.shape)
+ if space.isinstance_w(w_idx, space.w_tuple):
+ for w_item in space.fixedview(w_idx):
+ if (space.isinstance_w(w_item, space.w_slice) or
+ space.isinstance_w(w_item, space.w_NoneType)):
+ return False
+ elif space.isinstance_w(w_idx, space.w_NoneType):
+ return False
if shape_len == 0:
raise OperationError(space.w_IndexError, space.wrap(
"0-d arrays can't be indexed"))
@@ -336,20 +343,25 @@
if lgt > shape_len:
raise OperationError(space.w_IndexError,
space.wrap("invalid index"))
- if lgt < shape_len:
- return False
- for w_item in space.fixedview(w_idx):
- if space.isinstance_w(w_item, space.w_slice):
- return False
- return True
+ return lgt == shape_len
@jit.unroll_safe
def _prepare_slice_args(self, space, w_idx):
if (space.isinstance_w(w_idx, space.w_int) or
space.isinstance_w(w_idx, space.w_slice)):
return [Chunk(*space.decode_index4(w_idx, self.shape[0]))]
- return [Chunk(*space.decode_index4(w_item, self.shape[i])) for i, w_item in
- enumerate(space.fixedview(w_idx))]
+ elif space.isinstance_w(w_idx, space.w_NoneType):
+ return [NewIndexChunk()]
+ result = []
+ i = 0
+ for w_item in space.fixedview(w_idx):
+ if space.isinstance_w(w_item, space.w_NoneType):
+ result.append(NewIndexChunk())
+ else:
+ result.append(Chunk(*space.decode_index4(w_item,
+ self.shape[i])))
+ i += 1
+ return result
def count_all_true(self, arr):
sig = arr.find_sig()
@@ -443,7 +455,7 @@
def create_slice(self, chunks):
shape = []
i = -1
- for i, chunk in enumerate(chunks):
+ for i, chunk in enumerate_chunks(chunks):
chunk.extend_shape(shape)
s = i + 1
assert s >= 0
diff --git a/pypy/module/micronumpy/strides.py b/pypy/module/micronumpy/strides.py
--- a/pypy/module/micronumpy/strides.py
+++ b/pypy/module/micronumpy/strides.py
@@ -1,6 +1,14 @@
from pypy.rlib import jit
from pypy.interpreter.error import OperationError
+def enumerate_chunks(chunks):
+ result = []
+ i = -1
+ for chunk in chunks:
+ i += chunk.ind_step
+ result.append((i, chunk))
+ return result
+
@jit.look_inside_iff(lambda shape, start, strides, backstrides, chunks:
jit.isconstant(len(chunks))
)
@@ -10,7 +18,7 @@
rstart = start
rshape = []
i = -1
- for i, chunk in enumerate(chunks):
+ for i, chunk in enumerate_chunks(chunks):
if chunk.step != 0:
rstrides.append(strides[i] * chunk.step)
rbackstrides.append(strides[i] * (chunk.lgt - 1) * chunk.step)
More information about the pypy-commit
mailing list