[pypy-commit] pypy numpy-refactor: some fixes for __getitem__

fijal noreply at buildbot.pypy.org
Thu Sep 6 19:37:11 CEST 2012


Author: Maciej Fijalkowski <fijall at gmail.com>
Branch: numpy-refactor
Changeset: r57185:7ce8d4dbc076
Date: 2012-09-06 19:36 +0200
http://bitbucket.org/pypy/pypy/changeset/7ce8d4dbc076/

Log:	some fixes for __getitem__

diff --git a/pypy/module/micronumpy/arrayimpl/concrete.py b/pypy/module/micronumpy/arrayimpl/concrete.py
--- a/pypy/module/micronumpy/arrayimpl/concrete.py
+++ b/pypy/module/micronumpy/arrayimpl/concrete.py
@@ -31,7 +31,6 @@
 
     def next_skip_x(self, x):
         self.offset += self.skip * x
-        self.index += x
 
     def done(self):
         return self.offset >= self.size
diff --git a/pypy/module/micronumpy/interp_flatiter.py b/pypy/module/micronumpy/interp_flatiter.py
--- a/pypy/module/micronumpy/interp_flatiter.py
+++ b/pypy/module/micronumpy/interp_flatiter.py
@@ -1,15 +1,23 @@
 
+from pypy.module.micronumpy.base import W_NDimArray
+from pypy.module.micronumpy import loop
+from pypy.module.micronumpy.strides import to_coords
 from pypy.interpreter.baseobjspace import Wrappable
 from pypy.interpreter.error import OperationError
-from pypy.interpreter.typedef import TypeDef, interp2app
-from pypy.rlib import jit
+from pypy.interpreter.typedef import TypeDef, interp2app, GetSetProperty
 
 class W_FlatIterator(Wrappable):
     def __init__(self, arr):
         self.base = arr
-        self.iter = arr.create_iter()
+        self.reset()
+
+    def reset(self):
+        self.iter = self.base.create_iter()
         self.index = 0
 
+    def descr_len(self, space):
+        return space.wrap(self.base.get_size())
+
     def descr_next(self, space):
         if self.iter.done():
             raise OperationError(space.w_StopIteration, space.w_None)
@@ -18,44 +26,47 @@
         self.index += 1
         return w_res
 
-    @jit.unroll_safe
+    def descr_index(self, space):
+        return space.wrap(self.index)
+
+    def descr_coords(self, space):
+        coords, step, lngth = to_coords(space, self.base.get_shape(),
+                            self.base.get_size(), self.base.get_order(),
+                            space.wrap(self.index))
+        return space.newtuple([space.wrap(c) for c in coords])
+
     def descr_getitem(self, space, w_idx):
         if not (space.isinstance_w(w_idx, space.w_int) or
             space.isinstance_w(w_idx, space.w_slice)):
             raise OperationError(space.w_IndexError,
                                  space.wrap('unsupported iterator index'))
+        self.reset()
         base = self.base
         start, stop, step, length = space.decode_index4(w_idx, base.get_size())
         # setslice would have been better, but flat[u:v] for arbitrary
         # shapes of array a cannot be represented as a[x1:x2, y1:y2]
         base_iter = base.create_iter()
-        xxx
-        return base.getitem(basei.offset)
-        base_iter = ViewIterator(base.start, base.strides,
-                             base.backstrides, base.shape)
-        shapelen = len(base.shape)
-        basei = basei.next_skip_x(shapelen, start)
-        res = W_NDimArray([lngth], base.dtype, base.order)
-        ri = res.create_iter()
-        while not ri.done():
-            flat_get_driver.jit_merge_point(shapelen=shapelen,
-                                             base=base,
-                                             basei=basei,
-                                             step=step,
-                                             res=res,
-                                             ri=ri)
-            w_val = base.getitem(basei.offset)
-            res.setitem(ri.offset, w_val)
-            basei = basei.next_skip_x(shapelen, step)
-            ri = ri.next(shapelen)
-        return res
+        base_iter.next_skip_x(start)
+        if length == 1:
+            return base_iter.getitem()
+        res = W_NDimArray.from_shape([length], base.get_dtype(),
+                                     base.get_order())
+        return loop.flatiter_getitem(res, base_iter, step)
 
     def descr_iter(self):
         return self
 
+    def descr_base(self, space):
+        return space.wrap(self.base)
+
 W_FlatIterator.typedef = TypeDef(
     'flatiter',
-    __iter__ = interp2app(W_FlatIterator.descr_iter),    
+    __iter__ = interp2app(W_FlatIterator.descr_iter),
+    __getitem__ = interp2app(W_FlatIterator.descr_getitem),
+    __len__ = interp2app(W_FlatIterator.descr_len),
 
     next = interp2app(W_FlatIterator.descr_next),
+    base = GetSetProperty(W_FlatIterator.descr_base),
+    index = GetSetProperty(W_FlatIterator.descr_index),
+    coords = GetSetProperty(W_FlatIterator.descr_coords),
 )
diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -40,6 +40,9 @@
     def get_dtype(self):
         return self.implementation.dtype
 
+    def get_order(self):
+        return self.implementation.order
+
     def descr_get_dtype(self, space):
         return self.implementation.dtype
 
diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py
--- a/pypy/module/micronumpy/loop.py
+++ b/pypy/module/micronumpy/loop.py
@@ -186,3 +186,11 @@
         arr_iter.next()
         index_iter.next()
         value_iter.next()
+
+def flatiter_getitem(res, base_iter, step):
+    ri = res.create_iter()
+    while not ri.done():
+        ri.setitem(base_iter.getitem())
+        base_iter.next_skip_x(step)
+        ri.next()
+    return res
diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -1814,6 +1814,8 @@
         b.next()
         b.next()
         b.next()
+        assert b.index == 3
+        assert b.coords == (0, 3)
         assert b[3] == 3
         assert (b[::3] == [0, 3, 6, 9]).all()
         assert (b[2::5] == [2, 7]).all()
@@ -1822,7 +1824,7 @@
         raises(IndexError, "b[-11]")
         raises(IndexError, 'b[0, 1]')
         assert b.index == 0
-        assert b.coords == (0,0)
+        assert b.coords == (0, 0)
 
     def test_flatiter_setitem(self):
         from _numpypy import arange, array


More information about the pypy-commit mailing list