[pypy-commit] pypy numpy-refactor: some fixes for __getitem__
fijal
noreply at buildbot.pypy.org
Thu Sep 6 19:37:11 CEST 2012
Author: Maciej Fijalkowski <fijall at gmail.com>
Branch: numpy-refactor
Changeset: r57185:7ce8d4dbc076
Date: 2012-09-06 19:36 +0200
http://bitbucket.org/pypy/pypy/changeset/7ce8d4dbc076/
Log: some fixes for __getitem__
diff --git a/pypy/module/micronumpy/arrayimpl/concrete.py b/pypy/module/micronumpy/arrayimpl/concrete.py
--- a/pypy/module/micronumpy/arrayimpl/concrete.py
+++ b/pypy/module/micronumpy/arrayimpl/concrete.py
@@ -31,7 +31,6 @@
def next_skip_x(self, x):
self.offset += self.skip * x
- self.index += x
def done(self):
return self.offset >= self.size
diff --git a/pypy/module/micronumpy/interp_flatiter.py b/pypy/module/micronumpy/interp_flatiter.py
--- a/pypy/module/micronumpy/interp_flatiter.py
+++ b/pypy/module/micronumpy/interp_flatiter.py
@@ -1,15 +1,23 @@
+from pypy.module.micronumpy.base import W_NDimArray
+from pypy.module.micronumpy import loop
+from pypy.module.micronumpy.strides import to_coords
from pypy.interpreter.baseobjspace import Wrappable
from pypy.interpreter.error import OperationError
-from pypy.interpreter.typedef import TypeDef, interp2app
-from pypy.rlib import jit
+from pypy.interpreter.typedef import TypeDef, interp2app, GetSetProperty
class W_FlatIterator(Wrappable):
def __init__(self, arr):
self.base = arr
- self.iter = arr.create_iter()
+ self.reset()
+
+ def reset(self):
+ self.iter = self.base.create_iter()
self.index = 0
+ def descr_len(self, space):
+ return space.wrap(self.base.get_size())
+
def descr_next(self, space):
if self.iter.done():
raise OperationError(space.w_StopIteration, space.w_None)
@@ -18,44 +26,47 @@
self.index += 1
return w_res
- @jit.unroll_safe
+ def descr_index(self, space):
+ return space.wrap(self.index)
+
+ def descr_coords(self, space):
+ coords, step, lngth = to_coords(space, self.base.get_shape(),
+ self.base.get_size(), self.base.get_order(),
+ space.wrap(self.index))
+ return space.newtuple([space.wrap(c) for c in coords])
+
def descr_getitem(self, space, w_idx):
if not (space.isinstance_w(w_idx, space.w_int) or
space.isinstance_w(w_idx, space.w_slice)):
raise OperationError(space.w_IndexError,
space.wrap('unsupported iterator index'))
+ self.reset()
base = self.base
start, stop, step, length = space.decode_index4(w_idx, base.get_size())
# setslice would have been better, but flat[u:v] for arbitrary
# shapes of array a cannot be represented as a[x1:x2, y1:y2]
base_iter = base.create_iter()
- xxx
- return base.getitem(basei.offset)
- base_iter = ViewIterator(base.start, base.strides,
- base.backstrides, base.shape)
- shapelen = len(base.shape)
- basei = basei.next_skip_x(shapelen, start)
- res = W_NDimArray([lngth], base.dtype, base.order)
- ri = res.create_iter()
- while not ri.done():
- flat_get_driver.jit_merge_point(shapelen=shapelen,
- base=base,
- basei=basei,
- step=step,
- res=res,
- ri=ri)
- w_val = base.getitem(basei.offset)
- res.setitem(ri.offset, w_val)
- basei = basei.next_skip_x(shapelen, step)
- ri = ri.next(shapelen)
- return res
+ base_iter.next_skip_x(start)
+ if length == 1:
+ return base_iter.getitem()
+ res = W_NDimArray.from_shape([length], base.get_dtype(),
+ base.get_order())
+ return loop.flatiter_getitem(res, base_iter, step)
def descr_iter(self):
return self
+ def descr_base(self, space):
+ return space.wrap(self.base)
+
W_FlatIterator.typedef = TypeDef(
'flatiter',
- __iter__ = interp2app(W_FlatIterator.descr_iter),
+ __iter__ = interp2app(W_FlatIterator.descr_iter),
+ __getitem__ = interp2app(W_FlatIterator.descr_getitem),
+ __len__ = interp2app(W_FlatIterator.descr_len),
next = interp2app(W_FlatIterator.descr_next),
+ base = GetSetProperty(W_FlatIterator.descr_base),
+ index = GetSetProperty(W_FlatIterator.descr_index),
+ coords = GetSetProperty(W_FlatIterator.descr_coords),
)
diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -40,6 +40,9 @@
def get_dtype(self):
return self.implementation.dtype
+ def get_order(self):
+ return self.implementation.order
+
def descr_get_dtype(self, space):
return self.implementation.dtype
diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py
--- a/pypy/module/micronumpy/loop.py
+++ b/pypy/module/micronumpy/loop.py
@@ -186,3 +186,11 @@
arr_iter.next()
index_iter.next()
value_iter.next()
+
+def flatiter_getitem(res, base_iter, step):
+ ri = res.create_iter()
+ while not ri.done():
+ ri.setitem(base_iter.getitem())
+ base_iter.next_skip_x(step)
+ ri.next()
+ return res
diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -1814,6 +1814,8 @@
b.next()
b.next()
b.next()
+ assert b.index == 3
+ assert b.coords == (0, 3)
assert b[3] == 3
assert (b[::3] == [0, 3, 6, 9]).all()
assert (b[2::5] == [2, 7]).all()
@@ -1822,7 +1824,7 @@
raises(IndexError, "b[-11]")
raises(IndexError, 'b[0, 1]')
assert b.index == 0
- assert b.coords == (0,0)
+ assert b.coords == (0, 0)
def test_flatiter_setitem(self):
from _numpypy import arange, array
More information about the pypy-commit
mailing list