[pypy-commit] pypy numppy-flatitter: cleanup merge, add failing test

mattip noreply at buildbot.pypy.org
Thu Jan 26 22:47:15 CET 2012


Author: mattip
Branch: numppy-flatitter
Changeset: r51828:88475cf6b4ab
Date: 2012-01-19 09:51 +0200
http://bitbucket.org/pypy/pypy/changeset/88475cf6b4ab/

Log:	cleanup merge, add failing test

diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -4,7 +4,7 @@
 from pypy.interpreter.typedef import TypeDef, GetSetProperty
 from pypy.module.micronumpy import interp_ufuncs, interp_dtype, signature,\
      interp_boxes
-from pypy.module.micronumpy.strides import calculate_slice_strides
+from pypy.module.micronumpy.strides import calculate_slice_strides, to_coords
 from pypy.rlib import jit
 from pypy.rpython.lltypesystem import lltype, rffi
 from pypy.tool.sourcetools import func_with_new_name
@@ -1449,40 +1449,6 @@
     tolist = interp2app(BaseArray.descr_tolist),
 )
 
-#TODO:Move all this to another file after fijal finishes reorganization
- at jit.unroll_safe
-def _to_coords(space, arr, w_item_or_slice):
-    '''Returns a start coord, step, and length.
-    '''
-    start = lngth = step = 0
-    if space.isinstance_w(w_item_or_slice, space.w_int):
-        start = space.int_w(w_item_or_slice)
-        if start < 0:
-            start += arr.size
-        lngth = 1
-        step = 1
-    elif space.isinstance_w(w_item_or_slice, space.w_slice):
-        start, stop, step, lngth = space.decode_index4(w_item_or_slice,arr.size)
-    else:
-        raise OperationError(space.w_IndexError,
-                             space.wrap('unsupported iterator index'))
-    coords = []
-    i = start
-    if arr.order =='C':
-        for s in range(len(arr.shape) -1, -1, -1):
-            coords.insert(0, i % arr.shape[s])
-            i /= arr.shape[s]
-    else:
-        raise NotImplementedError
-        #untested code. Erase?
-        for s in range(len(arr.shape)):
-            coords.append(i % arr.shape[s])
-            i /= arr.shape[s]
-    if i != 0:
-        raise OperationError(space.w_IndexError, space.wrap("invalid index"))
-
-    return coords, start, step, lngth
-
 
 class W_FlatIterator(ViewArray):
 
@@ -1511,26 +1477,30 @@
         return space.wrap(self.index)
 
     def descr_coords(self, space):
-        coords, start, step, lngth = _to_coords(space, self.base, space.wrap(self.index))
+        coords, step, lngth = to_coords(space, self.base.shape, 
+                            self.base.size, self.base.order, 
+                            space.wrap(self.index))
         return space.newtuple([space.wrap(c) for c in coords])
 
     def descr_getitem(self, space, w_idx):
-        coords, start, step, lngth = _to_coords(space, self.base, w_idx)
-        if lngth <2:
-            w_coords = space.newtuple([space.wrap(c) for c in coords])
-            return self.base.descr_getitem(space, w_coords)
-
-        res = W_NDimArray(lngth, [lngth], self.base.dtype,
-                                    self.base.order)
+        if not (space.isinstance_w(w_idx, space.w_int) or
+            space.isinstance_w(w_idx, space.w_slice)):
+            raise OperationError(space.w_IndexError,
+                                 space.wrap('unsupported iterator index'))
+        start, stop, step, lngth = space.decode_index4(w_idx, self.base.size)
         # setslice would have been better, but flat[u:v] for arbitrary
         # shapes of array a cannot be represented as a[x1:x2, y1:y2]
         basei = ViewIterator(self.base.start, self.base.strides,
                                self.base.backstrides,self.base.shape)
         shapelen = len(self.base.shape)
         basei = basei.next_skip_x(shapelen, start)
+        if lngth <2:
+            return self.base.getitem(basei.offset)
         ri = ArrayIterator(lngth)
+        res = W_NDimArray(lngth, [lngth], self.base.dtype,
+                                    self.base.order)
         while not ri.done():
-            # TODO: add a jit_merge_point
+            # TODO: add a jit_merge_point?
             w_val = self.base.getitem(basei.offset)
             res.setitem(ri.offset,w_val)
             basei = basei.next_skip_x(shapelen, step)
@@ -1538,7 +1508,11 @@
         return res
 
     def descr_setitem(self, space, w_idx, w_value):
-        coords, start, step, lngth = _to_coords(space, self.base, w_idx)
+        if not (space.isinstance_w(w_idx, space.w_int) or
+            space.isinstance_w(w_idx, space.w_slice)):
+            raise OperationError(space.w_IndexError,
+                                 space.wrap('unsupported iterator index'))
+        start, stop, step, lngth = space.decode_index4(w_idx, self.base.size)
         arr = convert_to_array(space, w_value)
         ai = 0
         basei = ViewIterator(self.base.start, self.base.strides,
@@ -1546,7 +1520,7 @@
         shapelen = len(self.base.shape)
         basei = basei.next_skip_x(shapelen, start)
         for i in range(lngth):
-            # TODO: add jit_merge_point
+            # TODO: add jit_merge_point?
             v = arr.getitem(ai).convert_to(self.base.dtype)
             self.base.setitem(basei.offset, v)
             # need to repeat input values until all assignments are done
diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -1391,6 +1391,7 @@
         assert b[-2] == 8
         raises(IndexError, "b[11]")
         raises(IndexError, "b[-11]")
+        raises(IndexError, 'b[0, 1]')
         assert b.index == 3
         assert b.coords == (0,3)
 
diff --git a/pypy/module/micronumpy/test/test_zjit.py b/pypy/module/micronumpy/test/test_zjit.py
--- a/pypy/module/micronumpy/test/test_zjit.py
+++ b/pypy/module/micronumpy/test/test_zjit.py
@@ -12,7 +12,7 @@
 from pypy.module.micronumpy.compile import (FakeSpace,
     IntObject, Parser, InterpreterState)
 from pypy.module.micronumpy.interp_numarray import (W_NDimArray,
-     BaseArray)
+     BaseArray, W_FlatIterator)
 from pypy.rlib.nonconst import NonConstant
 
 
@@ -50,6 +50,8 @@
             if not len(interp.results):
                 raise Exception("need results")
             w_res = interp.results[-1]
+            if isinstance(w_res, W_FlatIterator):
+                w_res = w_res.next()
             if isinstance(w_res, BaseArray):
                 concr = w_res.get_concrete_or_scalar()
                 sig = concr.find_sig()
@@ -369,7 +371,28 @@
                                 'setinteriorfield_raw': 1, 'int_add': 2,
                                 'int_ge': 1, 'guard_false': 1, 'jump': 1,
                                 'arraylen_gc': 1})
+    def define_flat_iter():
+        return '''
+        a = |30|
+        a -> flat
+        '''
 
+    def test_flat_iter(self):
+        result = self.run("flat_iter")
+        assert result == 0
+
+    def define_flat_getitem():
+        return '''
+        a = |30|
+        b = a -> flat
+        b -> 6
+        '''
+
+    def test_flat_getitem(self):
+        result = self.run("flat_getitem")
+        assert result == 4
+        self.check_trace_count(1)
+        self.check_simple_loop({})
 
 class TestNumpyOld(LLJitMixin):
     def setup_class(cls):


More information about the pypy-commit mailing list