[pypy-commit] pypy numpy-searchsorted: merge default into branch

mattip noreply at buildbot.pypy.org
Sat Apr 19 23:53:33 CEST 2014


Author: Matti Picus <matti.picus at gmail.com>
Branch: numpy-searchsorted
Changeset: r70788:e0560bcc6840
Date: 2014-04-19 23:56 +0300
http://bitbucket.org/pypy/pypy/changeset/e0560bcc6840/

Log:	merge default into branch

diff too long, truncating to 2000 out of 2685 lines

diff --git a/pypy/doc/stm.rst b/pypy/doc/stm.rst
--- a/pypy/doc/stm.rst
+++ b/pypy/doc/stm.rst
@@ -40,7 +40,7 @@
 ``pypy-stm`` project is to improve what is so far the state-of-the-art
 for using multiple CPUs, which for cases where separate processes don't
 work is done by writing explicitly multi-threaded programs.  Instead,
-``pypy-stm`` is flushing forward an approach to *hide* the threads, as
+``pypy-stm`` is pushing forward an approach to *hide* the threads, as
 described below in `atomic sections`_.
 
 
diff --git a/pypy/doc/whatsnew-head.rst b/pypy/doc/whatsnew-head.rst
--- a/pypy/doc/whatsnew-head.rst
+++ b/pypy/doc/whatsnew-head.rst
@@ -140,3 +140,6 @@
 
 .. branch: numpypy-nditer
 Implement the core of nditer, without many of the fancy flags (external_loop, buffered)
+
+.. branch: numpy-speed
+Separate iterator from its state so jit can optimize better
diff --git a/pypy/module/micronumpy/concrete.py b/pypy/module/micronumpy/concrete.py
--- a/pypy/module/micronumpy/concrete.py
+++ b/pypy/module/micronumpy/concrete.py
@@ -284,9 +284,11 @@
                                             self.get_backstrides(),
                                             self.get_shape(), shape,
                                             backward_broadcast)
-            return ArrayIter(self, support.product(shape), shape, r[0], r[1])
-        return ArrayIter(self, self.get_size(), self.shape,
-                         self.strides, self.backstrides)
+            i = ArrayIter(self, support.product(shape), shape, r[0], r[1])
+        else:
+            i = ArrayIter(self, self.get_size(), self.shape,
+                          self.strides, self.backstrides)
+        return i, i.reset()
 
     def swapaxes(self, space, orig_arr, axis1, axis2):
         shape = self.get_shape()[:]
diff --git a/pypy/module/micronumpy/ctors.py b/pypy/module/micronumpy/ctors.py
--- a/pypy/module/micronumpy/ctors.py
+++ b/pypy/module/micronumpy/ctors.py
@@ -2,7 +2,7 @@
 from pypy.interpreter.gateway import unwrap_spec, WrappedDefault
 from rpython.rlib.rstring import strip_spaces
 from rpython.rtyper.lltypesystem import lltype, rffi
-from pypy.module.micronumpy import descriptor, loop, ufuncs
+from pypy.module.micronumpy import descriptor, loop
 from pypy.module.micronumpy.base import W_NDimArray, convert_to_array
 from pypy.module.micronumpy.converters import shape_converter
 
@@ -156,10 +156,10 @@
             "string is smaller than requested size"))
 
     a = W_NDimArray.from_shape(space, [num_items], dtype=dtype)
-    ai = a.create_iter()
+    ai, state = a.create_iter()
     for val in items:
-        ai.setitem(val)
-        ai.next()
+        ai.setitem(state, val)
+        state = ai.next(state)
 
     return space.wrap(a)
 
diff --git a/pypy/module/micronumpy/flatiter.py b/pypy/module/micronumpy/flatiter.py
--- a/pypy/module/micronumpy/flatiter.py
+++ b/pypy/module/micronumpy/flatiter.py
@@ -32,24 +32,23 @@
         self.reset()
 
     def reset(self):
-        self.iter = self.base.create_iter()
+        self.iter, self.state = self.base.create_iter()
 
     def descr_len(self, space):
-        return space.wrap(self.base.get_size())
+        return space.wrap(self.iter.size)
 
     def descr_next(self, space):
-        if self.iter.done():
+        if self.iter.done(self.state):
             raise OperationError(space.w_StopIteration, space.w_None)
-        w_res = self.iter.getitem()
-        self.iter.next()
+        w_res = self.iter.getitem(self.state)
+        self.state = self.iter.next(self.state)
         return w_res
 
     def descr_index(self, space):
-        return space.wrap(self.iter.index)
+        return space.wrap(self.state.index)
 
     def descr_coords(self, space):
-        coords = self.base.to_coords(space, space.wrap(self.iter.index))
-        return space.newtuple([space.wrap(c) for c in coords])
+        return space.newtuple([space.wrap(c) for c in self.state.indices])
 
     def descr_getitem(self, space, w_idx):
         if not (space.isinstance_w(w_idx, space.w_int) or
@@ -58,13 +57,13 @@
         self.reset()
         base = self.base
         start, stop, step, length = space.decode_index4(w_idx, base.get_size())
-        base_iter = base.create_iter()
-        base_iter.next_skip_x(start)
+        base_iter, base_state = base.create_iter()
+        base_state = base_iter.next_skip_x(base_state, start)
         if length == 1:
-            return base_iter.getitem()
+            return base_iter.getitem(base_state)
         res = W_NDimArray.from_shape(space, [length], base.get_dtype(),
                                      base.get_order(), w_instance=base)
-        return loop.flatiter_getitem(res, base_iter, step)
+        return loop.flatiter_getitem(res, base_iter, base_state, step)
 
     def descr_setitem(self, space, w_idx, w_value):
         if not (space.isinstance_w(w_idx, space.w_int) or
diff --git a/pypy/module/micronumpy/iterators.py b/pypy/module/micronumpy/iterators.py
--- a/pypy/module/micronumpy/iterators.py
+++ b/pypy/module/micronumpy/iterators.py
@@ -42,7 +42,6 @@
 """
 from rpython.rlib import jit
 from pypy.module.micronumpy import support
-from pypy.module.micronumpy.strides import calc_strides
 from pypy.module.micronumpy.base import W_NDimArray
 
 
@@ -52,19 +51,20 @@
         self.shapelen = len(shape)
         self.indexes = [0] * len(shape)
         self._done = False
-        self.idx_w = [None] * len(idx_w)
+        self.idx_w_i = [None] * len(idx_w)
+        self.idx_w_s = [None] * len(idx_w)
         for i, w_idx in enumerate(idx_w):
             if isinstance(w_idx, W_NDimArray):
-                self.idx_w[i] = w_idx.create_iter(shape)
+                self.idx_w_i[i], self.idx_w_s[i] = w_idx.create_iter(shape)
 
     def done(self):
         return self._done
 
     @jit.unroll_safe
     def next(self):
-        for w_idx in self.idx_w:
-            if w_idx is not None:
-                w_idx.next()
+        for i, idx_w_i in enumerate(self.idx_w_i):
+            if idx_w_i is not None:
+                self.idx_w_s[i] = idx_w_i.next(self.idx_w_s[i])
         for i in range(self.shapelen - 1, -1, -1):
             if self.indexes[i] < self.shape[i] - 1:
                 self.indexes[i] += 1
@@ -79,6 +79,16 @@
         return [space.wrap(self.indexes[i]) for i in range(shapelen)]
 
 
+class IterState(object):
+    _immutable_fields_ = ['iterator', 'index', 'indices[*]', 'offset']
+
+    def __init__(self, iterator, index, indices, offset):
+        self.iterator = iterator
+        self.index = index
+        self.indices = indices
+        self.offset = offset
+
+
 class ArrayIter(object):
     _immutable_fields_ = ['array', 'size', 'ndim_m1', 'shape_m1[*]',
                           'strides[*]', 'backstrides[*]']
@@ -91,90 +101,66 @@
         self.shape_m1 = [s - 1 for s in shape]
         self.strides = strides
         self.backstrides = backstrides
-        self.reset()
 
     def reset(self):
-        self.index = 0
-        self.indices = [0] * len(self.shape_m1)
-        self.offset = self.array.start
+        return IterState(self, 0, [0] * len(self.shape_m1), self.array.start)
 
     @jit.unroll_safe
-    def next(self):
-        self.index += 1
+    def next(self, state):
+        assert state.iterator is self
+        index = state.index + 1
+        indices = state.indices
+        offset = state.offset
         for i in xrange(self.ndim_m1, -1, -1):
-            idx = self.indices[i]
+            idx = indices[i]
             if idx < self.shape_m1[i]:
-                self.indices[i] = idx + 1
-                self.offset += self.strides[i]
+                indices[i] = idx + 1
+                offset += self.strides[i]
                 break
             else:
-                self.indices[i] = 0
-                self.offset -= self.backstrides[i]
+                indices[i] = 0
+                offset -= self.backstrides[i]
+        return IterState(self, index, indices, offset)
 
     @jit.unroll_safe
-    def next_skip_x(self, step):
+    def next_skip_x(self, state, step):
+        assert state.iterator is self
         assert step >= 0
         if step == 0:
-            return
-        self.index += step
+            return state
+        index = state.index + step
+        indices = state.indices
+        offset = state.offset
         for i in xrange(self.ndim_m1, -1, -1):
-            idx = self.indices[i]
+            idx = indices[i]
             if idx < (self.shape_m1[i] + 1) - step:
-                self.indices[i] = idx + step
-                self.offset += self.strides[i] * step
+                indices[i] = idx + step
+                offset += self.strides[i] * step
                 break
             else:
-                rem_step = (self.indices[i] + step) // (self.shape_m1[i] + 1)
+                rem_step = (idx + step) // (self.shape_m1[i] + 1)
                 cur_step = step - rem_step * (self.shape_m1[i] + 1)
-                self.indices[i] += cur_step
-                self.offset += self.strides[i] * cur_step
+                indices[i] = idx + cur_step
+                offset += self.strides[i] * cur_step
                 step = rem_step
                 assert step > 0
+        return IterState(self, index, indices, offset)
 
-    def done(self):
-        return self.index >= self.size
+    def done(self, state):
+        assert state.iterator is self
+        return state.index >= self.size
 
-    def getitem(self):
-        return self.array.getitem(self.offset)
+    def getitem(self, state):
+        assert state.iterator is self
+        return self.array.getitem(state.offset)
 
-    def getitem_bool(self):
-        return self.array.getitem_bool(self.offset)
+    def getitem_bool(self, state):
+        assert state.iterator is self
+        return self.array.getitem_bool(state.offset)
 
-    def setitem(self, elem):
-        self.array.setitem(self.offset, elem)
-
-
-class SliceIterator(ArrayIter):
-    def __init__(self, arr, strides, backstrides, shape, order="C",
-                    backward=False, dtype=None):
-        if dtype is None:
-            dtype = arr.implementation.dtype
-        self.dtype = dtype
-        self.arr = arr
-        if backward:
-            self.slicesize = shape[0]
-            self.gap = [support.product(shape[1:]) * dtype.elsize]
-            strides = strides[1:]
-            backstrides = backstrides[1:]
-            shape = shape[1:]
-            strides.reverse()
-            backstrides.reverse()
-            shape.reverse()
-            size = support.product(shape)
-        else:
-            shape = [support.product(shape)]
-            strides, backstrides = calc_strides(shape, dtype, order)
-            size = 1
-            self.slicesize = support.product(shape)
-            self.gap = strides
-
-        ArrayIter.__init__(self, arr.implementation, size, shape, strides, backstrides)
-
-    def getslice(self):
-        from pypy.module.micronumpy.concrete import SliceArray
-        retVal = SliceArray(self.offset, self.gap, self.backstrides,
-        [self.slicesize], self.arr.implementation, self.arr, self.dtype)
-        return retVal
+    def setitem(self, state, elem):
+        assert state.iterator is self
+        self.array.setitem(state.offset, elem)
 
 
 def AxisIter(array, shape, axis, cumulative):
diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py
--- a/pypy/module/micronumpy/loop.py
+++ b/pypy/module/micronumpy/loop.py
@@ -12,11 +12,10 @@
     AllButAxisIter
 
 
-call2_driver = jit.JitDriver(name='numpy_call2',
-                             greens = ['shapelen', 'func', 'calc_dtype',
-                                       'res_dtype'],
-                             reds = ['shape', 'w_lhs', 'w_rhs', 'out',
-                                     'left_iter', 'right_iter', 'out_iter'])
+call2_driver = jit.JitDriver(
+    name='numpy_call2',
+    greens=['shapelen', 'func', 'calc_dtype', 'res_dtype'],
+    reds='auto')
 
 def call2(space, shape, func, calc_dtype, res_dtype, w_lhs, w_rhs, out):
     # handle array_priority
@@ -46,47 +45,40 @@
     if out is None:
         out = W_NDimArray.from_shape(space, shape, res_dtype,
                                      w_instance=lhs_for_subtype)
-    left_iter = w_lhs.create_iter(shape)
-    right_iter = w_rhs.create_iter(shape)
-    out_iter = out.create_iter(shape)
+    left_iter, left_state = w_lhs.create_iter(shape)
+    right_iter, right_state = w_rhs.create_iter(shape)
+    out_iter, out_state = out.create_iter(shape)
     shapelen = len(shape)
-    while not out_iter.done():
+    while not out_iter.done(out_state):
         call2_driver.jit_merge_point(shapelen=shapelen, func=func,
-                                     calc_dtype=calc_dtype, res_dtype=res_dtype,
-                                     shape=shape, w_lhs=w_lhs, w_rhs=w_rhs,
-                                     out=out,
-                                     left_iter=left_iter, right_iter=right_iter,
-                                     out_iter=out_iter)
-        w_left = left_iter.getitem().convert_to(space, calc_dtype)
-        w_right = right_iter.getitem().convert_to(space, calc_dtype)
-        out_iter.setitem(func(calc_dtype, w_left, w_right).convert_to(
+                                     calc_dtype=calc_dtype, res_dtype=res_dtype)
+        w_left = left_iter.getitem(left_state).convert_to(space, calc_dtype)
+        w_right = right_iter.getitem(right_state).convert_to(space, calc_dtype)
+        out_iter.setitem(out_state, func(calc_dtype, w_left, w_right).convert_to(
             space, res_dtype))
-        left_iter.next()
-        right_iter.next()
-        out_iter.next()
+        left_state = left_iter.next(left_state)
+        right_state = right_iter.next(right_state)
+        out_state = out_iter.next(out_state)
     return out
 
-call1_driver = jit.JitDriver(name='numpy_call1',
-                             greens = ['shapelen', 'func', 'calc_dtype',
-                                       'res_dtype'],
-                             reds = ['shape', 'w_obj', 'out', 'obj_iter',
-                                     'out_iter'])
+call1_driver = jit.JitDriver(
+    name='numpy_call1',
+    greens=['shapelen', 'func', 'calc_dtype', 'res_dtype'],
+    reds='auto')
 
 def call1(space, shape, func, calc_dtype, res_dtype, w_obj, out):
     if out is None:
         out = W_NDimArray.from_shape(space, shape, res_dtype, w_instance=w_obj)
-    obj_iter = w_obj.create_iter(shape)
-    out_iter = out.create_iter(shape)
+    obj_iter, obj_state = w_obj.create_iter(shape)
+    out_iter, out_state = out.create_iter(shape)
     shapelen = len(shape)
-    while not out_iter.done():
+    while not out_iter.done(out_state):
         call1_driver.jit_merge_point(shapelen=shapelen, func=func,
-                                     calc_dtype=calc_dtype, res_dtype=res_dtype,
-                                     shape=shape, w_obj=w_obj, out=out,
-                                     obj_iter=obj_iter, out_iter=out_iter)
-        elem = obj_iter.getitem().convert_to(space, calc_dtype)
-        out_iter.setitem(func(calc_dtype, elem).convert_to(space, res_dtype))
-        out_iter.next()
-        obj_iter.next()
+                                     calc_dtype=calc_dtype, res_dtype=res_dtype)
+        elem = obj_iter.getitem(obj_state).convert_to(space, calc_dtype)
+        out_iter.setitem(out_state, func(calc_dtype, elem).convert_to(space, res_dtype))
+        out_state = out_iter.next(out_state)
+        obj_state = obj_iter.next(obj_state)
     return out
 
 setslice_driver = jit.JitDriver(name='numpy_setslice',
@@ -96,18 +88,20 @@
 def setslice(space, shape, target, source):
     # note that unlike everything else, target and source here are
     # array implementations, not arrays
-    target_iter = target.create_iter(shape)
-    source_iter = source.create_iter(shape)
+    target_iter, target_state = target.create_iter(shape)
+    source_iter, source_state = source.create_iter(shape)
     dtype = target.dtype
     shapelen = len(shape)
-    while not target_iter.done():
+    while not target_iter.done(target_state):
         setslice_driver.jit_merge_point(shapelen=shapelen, dtype=dtype)
+        val = source_iter.getitem(source_state)
         if dtype.is_str_or_unicode():
-            target_iter.setitem(dtype.coerce(space, source_iter.getitem()))
+            val = dtype.coerce(space, val)
         else:
-            target_iter.setitem(source_iter.getitem().convert_to(space, dtype))
-        target_iter.next()
-        source_iter.next()
+            val = val.convert_to(space, dtype)
+        target_iter.setitem(target_state, val)
+        target_state = target_iter.next(target_state)
+        source_state = source_iter.next(source_state)
     return target
 
 reduce_driver = jit.JitDriver(name='numpy_reduce',
@@ -116,22 +110,22 @@
                               reds = 'auto')
 
 def compute_reduce(space, obj, calc_dtype, func, done_func, identity):
-    obj_iter = obj.create_iter()
+    obj_iter, obj_state = obj.create_iter()
     if identity is None:
-        cur_value = obj_iter.getitem().convert_to(space, calc_dtype)
-        obj_iter.next()
+        cur_value = obj_iter.getitem(obj_state).convert_to(space, calc_dtype)
+        obj_state = obj_iter.next(obj_state)
     else:
         cur_value = identity.convert_to(space, calc_dtype)
     shapelen = len(obj.get_shape())
-    while not obj_iter.done():
+    while not obj_iter.done(obj_state):
         reduce_driver.jit_merge_point(shapelen=shapelen, func=func,
                                       done_func=done_func,
                                       calc_dtype=calc_dtype)
-        rval = obj_iter.getitem().convert_to(space, calc_dtype)
+        rval = obj_iter.getitem(obj_state).convert_to(space, calc_dtype)
         if done_func is not None and done_func(calc_dtype, rval):
             return rval
         cur_value = func(calc_dtype, cur_value, rval)
-        obj_iter.next()
+        obj_state = obj_iter.next(obj_state)
     return cur_value
 
 reduce_cum_driver = jit.JitDriver(name='numpy_reduce_cum_driver',
@@ -139,69 +133,76 @@
                                   reds = 'auto')
 
 def compute_reduce_cumulative(space, obj, out, calc_dtype, func, identity):
-    obj_iter = obj.create_iter()
-    out_iter = out.create_iter()
+    obj_iter, obj_state = obj.create_iter()
+    out_iter, out_state = out.create_iter()
     if identity is None:
-        cur_value = obj_iter.getitem().convert_to(space, calc_dtype)
-        out_iter.setitem(cur_value)
-        out_iter.next()
-        obj_iter.next()
+        cur_value = obj_iter.getitem(obj_state).convert_to(space, calc_dtype)
+        out_iter.setitem(out_state, cur_value)
+        out_state = out_iter.next(out_state)
+        obj_state = obj_iter.next(obj_state)
     else:
         cur_value = identity.convert_to(space, calc_dtype)
     shapelen = len(obj.get_shape())
-    while not obj_iter.done():
+    while not obj_iter.done(obj_state):
         reduce_cum_driver.jit_merge_point(shapelen=shapelen, func=func,
                                           dtype=calc_dtype)
-        rval = obj_iter.getitem().convert_to(space, calc_dtype)
+        rval = obj_iter.getitem(obj_state).convert_to(space, calc_dtype)
         cur_value = func(calc_dtype, cur_value, rval)
-        out_iter.setitem(cur_value)
-        out_iter.next()
-        obj_iter.next()
+        out_iter.setitem(out_state, cur_value)
+        out_state = out_iter.next(out_state)
+        obj_state = obj_iter.next(obj_state)
 
 def fill(arr, box):
-    arr_iter = arr.create_iter()
-    while not arr_iter.done():
-        arr_iter.setitem(box)
-        arr_iter.next()
+    arr_iter, arr_state = arr.create_iter()
+    while not arr_iter.done(arr_state):
+        arr_iter.setitem(arr_state, box)
+        arr_state = arr_iter.next(arr_state)
 
 def assign(space, arr, seq):
-    arr_iter = arr.create_iter()
+    arr_iter, arr_state = arr.create_iter()
     arr_dtype = arr.get_dtype()
     for item in seq:
-        arr_iter.setitem(arr_dtype.coerce(space, item))
-        arr_iter.next()
+        arr_iter.setitem(arr_state, arr_dtype.coerce(space, item))
+        arr_state = arr_iter.next(arr_state)
 
 where_driver = jit.JitDriver(name='numpy_where',
                              greens = ['shapelen', 'dtype', 'arr_dtype'],
                              reds = 'auto')
 
 def where(space, out, shape, arr, x, y, dtype):
-    out_iter = out.create_iter(shape)
-    arr_iter = arr.create_iter(shape)
+    out_iter, out_state = out.create_iter(shape)
+    arr_iter, arr_state = arr.create_iter(shape)
     arr_dtype = arr.get_dtype()
-    x_iter = x.create_iter(shape)
-    y_iter = y.create_iter(shape)
+    x_iter, x_state = x.create_iter(shape)
+    y_iter, y_state = y.create_iter(shape)
     if x.is_scalar():
         if y.is_scalar():
-            iter = arr_iter
+            iter, state = arr_iter, arr_state
         else:
-            iter = y_iter
+            iter, state = y_iter, y_state
     else:
-        iter = x_iter
+        iter, state = x_iter, x_state
     shapelen = len(shape)
-    while not iter.done():
+    while not iter.done(state):
         where_driver.jit_merge_point(shapelen=shapelen, dtype=dtype,
                                         arr_dtype=arr_dtype)
-        w_cond = arr_iter.getitem()
+        w_cond = arr_iter.getitem(arr_state)
         if arr_dtype.itemtype.bool(w_cond):
-            w_val = x_iter.getitem().convert_to(space, dtype)
+            w_val = x_iter.getitem(x_state).convert_to(space, dtype)
         else:
-            w_val = y_iter.getitem().convert_to(space, dtype)
-        out_iter.setitem(w_val)
-        out_iter.next()
-        arr_iter.next()
-        x_iter.next()
-        y_iter.next()
+            w_val = y_iter.getitem(y_state).convert_to(space, dtype)
+        out_iter.setitem(out_state, w_val)
+        out_state = out_iter.next(out_state)
+        arr_state = arr_iter.next(arr_state)
+        x_state = x_iter.next(x_state)
+        y_state = y_iter.next(y_state)
+        if x.is_scalar():
+            if y.is_scalar():
+                state = arr_state
+            else:
+                state = y_state
+        else:
+            state = x_state
     return out
 
 axis_reduce__driver = jit.JitDriver(name='numpy_axis_reduce',
@@ -212,31 +213,36 @@
 def do_axis_reduce(space, shape, func, arr, dtype, axis, out, identity, cumulative,
                    temp):
     out_iter = AxisIter(out.implementation, arr.get_shape(), axis, cumulative)
+    out_state = out_iter.reset()
     if cumulative:
         temp_iter = AxisIter(temp.implementation, arr.get_shape(), axis, False)
+        temp_state = temp_iter.reset()
     else:
-        temp_iter = out_iter # hack
-    arr_iter = arr.create_iter()
+        temp_iter = out_iter  # hack
+        temp_state = out_state
+    arr_iter, arr_state = arr.create_iter()
     if identity is not None:
         identity = identity.convert_to(space, dtype)
     shapelen = len(shape)
-    while not out_iter.done():
+    while not out_iter.done(out_state):
         axis_reduce__driver.jit_merge_point(shapelen=shapelen, func=func,
                                             dtype=dtype)
-        assert not arr_iter.done()
-        w_val = arr_iter.getitem().convert_to(space, dtype)
-        if out_iter.indices[axis] == 0:
+        assert not arr_iter.done(arr_state)
+        w_val = arr_iter.getitem(arr_state).convert_to(space, dtype)
+        if out_state.indices[axis] == 0:
             if identity is not None:
                 w_val = func(dtype, identity, w_val)
         else:
-            cur = temp_iter.getitem()
+            cur = temp_iter.getitem(temp_state)
             w_val = func(dtype, cur, w_val)
-        out_iter.setitem(w_val)
+        out_iter.setitem(out_state, w_val)
+        out_state = out_iter.next(out_state)
         if cumulative:
-            temp_iter.setitem(w_val)
-            temp_iter.next()
-        arr_iter.next()
-        out_iter.next()
+            temp_iter.setitem(temp_state, w_val)
+            temp_state = temp_iter.next(temp_state)
+        else:
+            temp_state = out_state
+        arr_state = arr_iter.next(arr_state)
     return out
 
 
@@ -249,18 +255,18 @@
         result = 0
         idx = 1
         dtype = arr.get_dtype()
-        iter = arr.create_iter()
-        cur_best = iter.getitem()
-        iter.next()
+        iter, state = arr.create_iter()
+        cur_best = iter.getitem(state)
+        state = iter.next(state)
         shapelen = len(arr.get_shape())
-        while not iter.done():
+        while not iter.done(state):
             arg_driver.jit_merge_point(shapelen=shapelen, dtype=dtype)
-            w_val = iter.getitem()
+            w_val = iter.getitem(state)
             new_best = getattr(dtype.itemtype, op_name)(cur_best, w_val)
             if dtype.itemtype.ne(new_best, cur_best):
                 result = idx
                 cur_best = new_best
-            iter.next()
+            state = iter.next(state)
             idx += 1
         return result
     return argmin_argmax
@@ -291,17 +297,19 @@
     right_impl = right.implementation
     assert left_shape[-1] == right_shape[right_critical_dim]
     assert result.get_dtype() == dtype
-    outi = result.create_iter()
+    outi, outs = result.create_iter()
     lefti = AllButAxisIter(left_impl, len(left_shape) - 1)
     righti = AllButAxisIter(right_impl, right_critical_dim)
+    lefts = lefti.reset()
+    rights = righti.reset()
     n = left_impl.shape[-1]
     s1 = left_impl.strides[-1]
     s2 = right_impl.strides[right_critical_dim]
-    while not lefti.done():
-        while not righti.done():
-            oval = outi.getitem()
-            i1 = lefti.offset
-            i2 = righti.offset
+    while not lefti.done(lefts):
+        while not righti.done(rights):
+            oval = outi.getitem(outs)
+            i1 = lefts.offset
+            i2 = rights.offset
             i = 0
             while i < n:
                 i += 1
@@ -311,11 +319,11 @@
                 oval = dtype.itemtype.add(oval, dtype.itemtype.mul(lval, rval))
                 i1 += s1
                 i2 += s2
-            outi.setitem(oval)
-            outi.next()
-            righti.next()
-        righti.reset()
-        lefti.next()
+            outi.setitem(outs, oval)
+            outs = outi.next(outs)
+            rights = righti.next(rights)
+        rights = righti.reset()
+        lefts = lefti.next(lefts)
     return result
 
 count_all_true_driver = jit.JitDriver(name = 'numpy_count',
@@ -324,13 +332,13 @@
 
 def count_all_true_concrete(impl):
     s = 0
-    iter = impl.create_iter()
+    iter, state = impl.create_iter()
     shapelen = len(impl.shape)
     dtype = impl.dtype
-    while not iter.done():
+    while not iter.done(state):
         count_all_true_driver.jit_merge_point(shapelen=shapelen, dtype=dtype)
-        s += iter.getitem_bool()
-        iter.next()
+        s += iter.getitem_bool(state)
+        state = iter.next(state)
     return s
 
 def count_all_true(arr):
@@ -344,18 +352,18 @@
                                reds = 'auto')
 
 def nonzero(res, arr, box):
-    res_iter = res.create_iter()
-    arr_iter = arr.create_iter()
+    res_iter, res_state = res.create_iter()
+    arr_iter, arr_state = arr.create_iter()
     shapelen = len(arr.shape)
     dtype = arr.dtype
     dims = range(shapelen)
-    while not arr_iter.done():
+    while not arr_iter.done(arr_state):
         nonzero_driver.jit_merge_point(shapelen=shapelen, dims=dims, dtype=dtype)
-        if arr_iter.getitem_bool():
+        if arr_iter.getitem_bool(arr_state):
             for d in dims:
-                res_iter.setitem(box(arr_iter.indices[d]))
-                res_iter.next()
-        arr_iter.next()
+                res_iter.setitem(res_state, box(arr_state.indices[d]))
+                res_state = res_iter.next(res_state)
+        arr_state = arr_iter.next(arr_state)
     return res
 
 
@@ -365,26 +373,26 @@
                                       reds = 'auto')
 
 def getitem_filter(res, arr, index):
-    res_iter = res.create_iter()
+    res_iter, res_state = res.create_iter()
     shapelen = len(arr.get_shape())
     if shapelen > 1 and len(index.get_shape()) < 2:
-        index_iter = index.create_iter(arr.get_shape(), backward_broadcast=True)
+        index_iter, index_state = index.create_iter(arr.get_shape(), backward_broadcast=True)
     else:
-        index_iter = index.create_iter()
-    arr_iter = arr.create_iter()
+        index_iter, index_state = index.create_iter()
+    arr_iter, arr_state = arr.create_iter()
     arr_dtype = arr.get_dtype()
     index_dtype = index.get_dtype()
     # XXX length of shape of index as well?
-    while not index_iter.done():
+    while not index_iter.done(index_state):
         getitem_filter_driver.jit_merge_point(shapelen=shapelen,
                                               index_dtype=index_dtype,
                                               arr_dtype=arr_dtype,
                                               )
-        if index_iter.getitem_bool():
-            res_iter.setitem(arr_iter.getitem())
-            res_iter.next()
-        index_iter.next()
-        arr_iter.next()
+        if index_iter.getitem_bool(index_state):
+            res_iter.setitem(res_state, arr_iter.getitem(arr_state))
+            res_state = res_iter.next(res_state)
+        index_state = index_iter.next(index_state)
+        arr_state = arr_iter.next(arr_state)
     return res
 
 setitem_filter_driver = jit.JitDriver(name = 'numpy_setitem_bool',
@@ -393,41 +401,42 @@
                                       reds = 'auto')
 
 def setitem_filter(space, arr, index, value):
-    arr_iter = arr.create_iter()
+    arr_iter, arr_state = arr.create_iter()
     shapelen = len(arr.get_shape())
     if shapelen > 1 and len(index.get_shape()) < 2:
-        index_iter = index.create_iter(arr.get_shape(), backward_broadcast=True)
+        index_iter, index_state = index.create_iter(arr.get_shape(), backward_broadcast=True)
     else:
-        index_iter = index.create_iter()
+        index_iter, index_state = index.create_iter()
     if value.get_size() == 1:
-        value_iter = value.create_iter(arr.get_shape())
+        value_iter, value_state = value.create_iter(arr.get_shape())
     else:
-        value_iter = value.create_iter()
+        value_iter, value_state = value.create_iter()
     index_dtype = index.get_dtype()
     arr_dtype = arr.get_dtype()
-    while not index_iter.done():
+    while not index_iter.done(index_state):
         setitem_filter_driver.jit_merge_point(shapelen=shapelen,
                                               index_dtype=index_dtype,
                                               arr_dtype=arr_dtype,
                                              )
-        if index_iter.getitem_bool():
-            arr_iter.setitem(arr_dtype.coerce(space, value_iter.getitem()))
-            value_iter.next()
-        arr_iter.next()
-        index_iter.next()
+        if index_iter.getitem_bool(index_state):
+            val = arr_dtype.coerce(space, value_iter.getitem(value_state))
+            value_state = value_iter.next(value_state)
+            arr_iter.setitem(arr_state, val)
+        arr_state = arr_iter.next(arr_state)
+        index_state = index_iter.next(index_state)
 
 flatiter_getitem_driver = jit.JitDriver(name = 'numpy_flatiter_getitem',
                                         greens = ['dtype'],
                                         reds = 'auto')
 
-def flatiter_getitem(res, base_iter, step):
-    ri = res.create_iter()
+def flatiter_getitem(res, base_iter, base_state, step):
+    ri, rs = res.create_iter()
     dtype = res.get_dtype()
-    while not ri.done():
+    while not ri.done(rs):
         flatiter_getitem_driver.jit_merge_point(dtype=dtype)
-        ri.setitem(base_iter.getitem())
-        base_iter.next_skip_x(step)
-        ri.next()
+        ri.setitem(rs, base_iter.getitem(base_state))
+        base_state = base_iter.next_skip_x(base_state, step)
+        rs = ri.next(rs)
     return res
 
 flatiter_setitem_driver = jit.JitDriver(name = 'numpy_flatiter_setitem',
@@ -436,19 +445,21 @@
 
 def flatiter_setitem(space, arr, val, start, step, length):
     dtype = arr.get_dtype()
-    arr_iter = arr.create_iter()
-    val_iter = val.create_iter()
-    arr_iter.next_skip_x(start)
+    arr_iter, arr_state = arr.create_iter()
+    val_iter, val_state = val.create_iter()
+    arr_state = arr_iter.next_skip_x(arr_state, start)
     while length > 0:
         flatiter_setitem_driver.jit_merge_point(dtype=dtype)
+        val = val_iter.getitem(val_state)
         if dtype.is_str_or_unicode():
-            arr_iter.setitem(dtype.coerce(space, val_iter.getitem()))
+            val = dtype.coerce(space, val)
         else:
-            arr_iter.setitem(val_iter.getitem().convert_to(space, dtype))
+            val = val.convert_to(space, dtype)
+        arr_iter.setitem(arr_state, val)
         # need to repeat i_nput values until all assignments are done
-        arr_iter.next_skip_x(step)
+        arr_state = arr_iter.next_skip_x(arr_state, step)
+        val_state = val_iter.next(val_state)
         length -= 1
-        val_iter.next()
 
 fromstring_driver = jit.JitDriver(name = 'numpy_fromstring',
                                   greens = ['itemsize', 'dtype'],
@@ -456,30 +467,30 @@
 
 def fromstring_loop(space, a, dtype, itemsize, s):
     i = 0
-    ai = a.create_iter()
-    while not ai.done():
+    ai, state = a.create_iter()
+    while not ai.done(state):
         fromstring_driver.jit_merge_point(dtype=dtype, itemsize=itemsize)
         sub = s[i*itemsize:i*itemsize + itemsize]
         if dtype.is_str_or_unicode():
             val = dtype.coerce(space, space.wrap(sub))
         else:
             val = dtype.itemtype.runpack_str(space, sub)
-        ai.setitem(val)
-        ai.next()
+        ai.setitem(state, val)
+        state = ai.next(state)
         i += 1
 
 def tostring(space, arr):
     builder = StringBuilder()
-    iter = arr.create_iter()
+    iter, state = arr.create_iter()
     w_res_str = W_NDimArray.from_shape(space, [1], arr.get_dtype(), order='C')
     itemsize = arr.get_dtype().elsize
     res_str_casted = rffi.cast(rffi.CArrayPtr(lltype.Char),
                                w_res_str.implementation.get_storage_as_int(space))
-    while not iter.done():
-        w_res_str.implementation.setitem(0, iter.getitem())
+    while not iter.done(state):
+        w_res_str.implementation.setitem(0, iter.getitem(state))
         for i in range(itemsize):
             builder.append(res_str_casted[i])
-        iter.next()
+        state = iter.next(state)
     return builder.build()
 
 getitem_int_driver = jit.JitDriver(name = 'numpy_getitem_int',
@@ -500,8 +511,8 @@
         # prepare the index
         index_w = [None] * indexlen
         for i in range(indexlen):
-            if iter.idx_w[i] is not None:
-                index_w[i] = iter.idx_w[i].getitem()
+            if iter.idx_w_i[i] is not None:
+                index_w[i] = iter.idx_w_i[i].getitem(iter.idx_w_s[i])
             else:
                 index_w[i] = indexes_w[i]
         res.descr_setitem(space, space.newtuple(prefix_w[:prefixlen] +
@@ -528,8 +539,8 @@
         # prepare the index
         index_w = [None] * indexlen
         for i in range(indexlen):
-            if iter.idx_w[i] is not None:
-                index_w[i] = iter.idx_w[i].getitem()
+            if iter.idx_w_i[i] is not None:
+                index_w[i] = iter.idx_w_i[i].getitem(iter.idx_w_s[i])
             else:
                 index_w[i] = indexes_w[i]
         w_idx = space.newtuple(prefix_w[:prefixlen] + iter.get_index(space,
@@ -547,13 +558,14 @@
 
 def byteswap(from_, to):
     dtype = from_.dtype
-    from_iter = from_.create_iter()
-    to_iter = to.create_iter()
-    while not from_iter.done():
+    from_iter, from_state = from_.create_iter()
+    to_iter, to_state = to.create_iter()
+    while not from_iter.done(from_state):
         byteswap_driver.jit_merge_point(dtype=dtype)
-        to_iter.setitem(dtype.itemtype.byteswap(from_iter.getitem()))
-        to_iter.next()
-        from_iter.next()
+        val = dtype.itemtype.byteswap(from_iter.getitem(from_state))
+        to_iter.setitem(to_state, val)
+        to_state = to_iter.next(to_state)
+        from_state = from_iter.next(from_state)
 
 choose_driver = jit.JitDriver(name='numpy_choose_driver',
                               greens = ['shapelen', 'mode', 'dtype'],
@@ -561,13 +573,15 @@
 
 def choose(space, arr, choices, shape, dtype, out, mode):
     shapelen = len(shape)
-    iterators = [a.create_iter(shape) for a in choices]
-    arr_iter = arr.create_iter(shape)
-    out_iter = out.create_iter(shape)
-    while not arr_iter.done():
+    pairs = [a.create_iter(shape) for a in choices]
+    iterators = [i[0] for i in pairs]
+    states = [i[1] for i in pairs]
+    arr_iter, arr_state = arr.create_iter(shape)
+    out_iter, out_state = out.create_iter(shape)
+    while not arr_iter.done(arr_state):
         choose_driver.jit_merge_point(shapelen=shapelen, dtype=dtype,
                                       mode=mode)
-        index = support.index_w(space, arr_iter.getitem())
+        index = support.index_w(space, arr_iter.getitem(arr_state))
         if index < 0 or index >= len(iterators):
             if mode == NPY.RAISE:
                 raise OperationError(space.w_ValueError, space.wrap(
@@ -580,72 +594,73 @@
                     index = 0
                 else:
                     index = len(iterators) - 1
-        out_iter.setitem(iterators[index].getitem().convert_to(space, dtype))
-        for iter in iterators:
-            iter.next()
-        out_iter.next()
-        arr_iter.next()
+        val = iterators[index].getitem(states[index]).convert_to(space, dtype)
+        out_iter.setitem(out_state, val)
+        for i in range(len(iterators)):
+            states[i] = iterators[i].next(states[i])
+        out_state = out_iter.next(out_state)
+        arr_state = arr_iter.next(arr_state)
 
 clip_driver = jit.JitDriver(name='numpy_clip_driver',
                             greens = ['shapelen', 'dtype'],
                             reds = 'auto')
 
 def clip(space, arr, shape, min, max, out):
-    arr_iter = arr.create_iter(shape)
+    arr_iter, arr_state = arr.create_iter(shape)
     dtype = out.get_dtype()
     shapelen = len(shape)
-    min_iter = min.create_iter(shape)
-    max_iter = max.create_iter(shape)
-    out_iter = out.create_iter(shape)
-    while not arr_iter.done():
+    min_iter, min_state = min.create_iter(shape)
+    max_iter, max_state = max.create_iter(shape)
+    out_iter, out_state = out.create_iter(shape)
+    while not arr_iter.done(arr_state):
         clip_driver.jit_merge_point(shapelen=shapelen, dtype=dtype)
-        w_v = arr_iter.getitem().convert_to(space, dtype)
-        w_min = min_iter.getitem().convert_to(space, dtype)
-        w_max = max_iter.getitem().convert_to(space, dtype)
+        w_v = arr_iter.getitem(arr_state).convert_to(space, dtype)
+        w_min = min_iter.getitem(min_state).convert_to(space, dtype)
+        w_max = max_iter.getitem(max_state).convert_to(space, dtype)
         if dtype.itemtype.lt(w_v, w_min):
             w_v = w_min
         elif dtype.itemtype.gt(w_v, w_max):
             w_v = w_max
-        out_iter.setitem(w_v)
-        arr_iter.next()
-        max_iter.next()
-        out_iter.next()
-        min_iter.next()
+        out_iter.setitem(out_state, w_v)
+        arr_state = arr_iter.next(arr_state)
+        min_state = min_iter.next(min_state)
+        max_state = max_iter.next(max_state)
+        out_state = out_iter.next(out_state)
 
 round_driver = jit.JitDriver(name='numpy_round_driver',
                              greens = ['shapelen', 'dtype'],
                              reds = 'auto')
 
 def round(space, arr, dtype, shape, decimals, out):
-    arr_iter = arr.create_iter(shape)
+    arr_iter, arr_state = arr.create_iter(shape)
+    out_iter, out_state = out.create_iter(shape)
     shapelen = len(shape)
-    out_iter = out.create_iter(shape)
-    while not arr_iter.done():
+    while not arr_iter.done(arr_state):
         round_driver.jit_merge_point(shapelen=shapelen, dtype=dtype)
-        w_v = arr_iter.getitem().convert_to(space, dtype)
+        w_v = arr_iter.getitem(arr_state).convert_to(space, dtype)
         w_v = dtype.itemtype.round(w_v, decimals)
-        out_iter.setitem(w_v)
-        arr_iter.next()
-        out_iter.next()
+        out_iter.setitem(out_state, w_v)
+        arr_state = arr_iter.next(arr_state)
+        out_state = out_iter.next(out_state)
 
 diagonal_simple_driver = jit.JitDriver(name='numpy_diagonal_simple_driver',
                                        greens = ['axis1', 'axis2'],
                                        reds = 'auto')
 
 def diagonal_simple(space, arr, out, offset, axis1, axis2, size):
-    out_iter = out.create_iter()
+    out_iter, out_state = out.create_iter()
     i = 0
     index = [0] * 2
     while i < size:
         diagonal_simple_driver.jit_merge_point(axis1=axis1, axis2=axis2)
         index[axis1] = i
         index[axis2] = i + offset
-        out_iter.setitem(arr.getitem_index(space, index))
+        out_iter.setitem(out_state, arr.getitem_index(space, index))
         i += 1
-        out_iter.next()
+        out_state = out_iter.next(out_state)
 
 def diagonal_array(space, arr, out, offset, axis1, axis2, shape):
-    out_iter = out.create_iter()
+    out_iter, out_state = out.create_iter()
     iter = PureShapeIter(shape, [])
     shapelen_minus_1 = len(shape) - 1
     assert shapelen_minus_1 >= 0
@@ -667,6 +682,6 @@
             indexes = (iter.indexes[:a] + [last_index + offset] +
                        iter.indexes[a:b] + [last_index] +
                        iter.indexes[b:shapelen_minus_1])
-        out_iter.setitem(arr.getitem_index(space, indexes))
+        out_iter.setitem(out_state, arr.getitem_index(space, indexes))
         iter.next()
-        out_iter.next()
+        out_state = out_iter.next(out_state)
diff --git a/pypy/module/micronumpy/ndarray.py b/pypy/module/micronumpy/ndarray.py
--- a/pypy/module/micronumpy/ndarray.py
+++ b/pypy/module/micronumpy/ndarray.py
@@ -18,7 +18,7 @@
     multi_axis_converter
 from pypy.module.micronumpy.flagsobj import W_FlagsObject
 from pypy.module.micronumpy.flatiter import W_FlatIterator
-from pypy.module.micronumpy.strides import get_shape_from_iterable, to_coords, \
+from pypy.module.micronumpy.strides import get_shape_from_iterable, \
     shape_agreement, shape_agreement_multiple
 
 
@@ -260,24 +260,24 @@
         return space.call_function(cache.w_array_str, self)
 
     def dump_data(self, prefix='array(', separator=',', suffix=')'):
-        i = self.create_iter()
+        i, state = self.create_iter()
         first = True
         dtype = self.get_dtype()
         s = StringBuilder()
         s.append(prefix)
         if not self.is_scalar():
             s.append('[')
-        while not i.done():
+        while not i.done(state):
             if first:
                 first = False
             else:
                 s.append(separator)
                 s.append(' ')
             if self.is_scalar() and dtype.is_str():
-                s.append(dtype.itemtype.to_str(i.getitem()))
+                s.append(dtype.itemtype.to_str(i.getitem(state)))
             else:
-                s.append(dtype.itemtype.str_format(i.getitem()))
-            i.next()
+                s.append(dtype.itemtype.str_format(i.getitem(state)))
+            state = i.next(state)
         if not self.is_scalar():
             s.append(']')
         s.append(suffix)
@@ -469,29 +469,33 @@
     def descr_get_flatiter(self, space):
         return space.wrap(W_FlatIterator(self))
 
-    def to_coords(self, space, w_index):
-        coords, _, _ = to_coords(space, self.get_shape(),
-                                 self.get_size(), self.get_order(),
-                                 w_index)
-        return coords
-
-    def descr_item(self, space, w_arg=None):
-        if space.is_none(w_arg):
+    def descr_item(self, space, __args__):
+        args_w, kw_w = __args__.unpack()
+        if len(args_w) == 1 and space.isinstance_w(args_w[0], space.w_tuple):
+            args_w = space.fixedview(args_w[0])
+        shape = self.get_shape()
+        coords = [0] * len(shape)
+        if len(args_w) == 0:
             if self.get_size() == 1:
                 w_obj = self.get_scalar_value()
                 assert isinstance(w_obj, boxes.W_GenericBox)
                 return w_obj.item(space)
             raise oefmt(space.w_ValueError,
                 "can only convert an array of size 1 to a Python scalar")
-        if space.isinstance_w(w_arg, space.w_int):
-            if self.is_scalar():
-                raise oefmt(space.w_IndexError, "index out of bounds")
-            i = self.to_coords(space, w_arg)
-            item = self.getitem(space, i)
-            assert isinstance(item, boxes.W_GenericBox)
-            return item.item(space)
-        raise OperationError(space.w_NotImplementedError, space.wrap(
-            "non-int arg not supported"))
+        elif len(args_w) == 1 and len(shape) != 1:
+            value = support.index_w(space, args_w[0])
+            value = support.check_and_adjust_index(space, value, self.get_size(), -1)
+            for idim in range(len(shape) - 1, -1, -1):
+                coords[idim] = value % shape[idim]
+                value //= shape[idim]
+        elif len(args_w) == len(shape):
+            for idim in range(len(shape)):
+                coords[idim] = support.index_w(space, args_w[idim])
+        else:
+            raise oefmt(space.w_ValueError, "incorrect number of indices for array")
+        item = self.getitem(space, coords)
+        assert isinstance(item, boxes.W_GenericBox)
+        return item.item(space)
 
     def descr_itemset(self, space, args_w):
         if len(args_w) == 0:
@@ -841,8 +845,8 @@
         if self.get_size() > 1:
             raise OperationError(space.w_ValueError, space.wrap(
                 "The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()"))
-        iter = self.create_iter()
-        return space.wrap(space.is_true(iter.getitem()))
+        iter, state = self.create_iter()
+        return space.wrap(space.is_true(iter.getitem(state)))
 
     def _binop_impl(ufunc_name):
         def impl(self, space, w_other, w_out=None):
@@ -1019,7 +1023,8 @@
     descr_cumsum = _reduce_ufunc_impl('add', cumulative=True)
     descr_cumprod = _reduce_ufunc_impl('multiply', cumulative=True)
 
-    def _reduce_argmax_argmin_impl(op_name):
+    def _reduce_argmax_argmin_impl(raw_name):
+        op_name = "arg%s" % raw_name
         def impl(self, space, w_axis=None, w_out=None):
             if not space.is_none(w_axis):
                 raise oefmt(space.w_NotImplementedError,
@@ -1030,18 +1035,17 @@
             if self.get_size() == 0:
                 raise oefmt(space.w_ValueError,
                     "Can't call %s on zero-size arrays", op_name)
-            op = getattr(loop, op_name)
             try:
-                res = op(self)
+                getattr(self.get_dtype().itemtype, raw_name)
             except AttributeError:
                 raise oefmt(space.w_NotImplementedError,
                             '%s not implemented for %s',
                             op_name, self.get_dtype().get_name())
-            return space.wrap(res)
-        return func_with_new_name(impl, "reduce_arg%s_impl" % op_name)
+            return space.wrap(getattr(loop, op_name)(self))
+        return func_with_new_name(impl, "reduce_%s_impl" % op_name)
 
-    descr_argmax = _reduce_argmax_argmin_impl("argmax")
-    descr_argmin = _reduce_argmax_argmin_impl("argmin")
+    descr_argmax = _reduce_argmax_argmin_impl("max")
+    descr_argmin = _reduce_argmax_argmin_impl("min")
 
     def descr_int(self, space):
         if self.get_size() != 1:
@@ -1118,11 +1122,11 @@
 
         builder = StringBuilder()
         if isinstance(self.implementation, SliceArray):
-            iter = self.implementation.create_iter()
-            while not iter.done():
-                box = iter.getitem()
+            iter, state = self.implementation.create_iter()
+            while not iter.done(state):
+                box = iter.getitem(state)
                 builder.append(box.raw_str())
-                iter.next()
+                state = iter.next(state)
         else:
             builder.append_charpsize(self.implementation.get_storage(), self.implementation.get_storage_size())
 
diff --git a/pypy/module/micronumpy/nditer.py b/pypy/module/micronumpy/nditer.py
--- a/pypy/module/micronumpy/nditer.py
+++ b/pypy/module/micronumpy/nditer.py
@@ -1,99 +1,50 @@
 from pypy.interpreter.baseobjspace import W_Root
 from pypy.interpreter.typedef import TypeDef, GetSetProperty
 from pypy.interpreter.gateway import interp2app, unwrap_spec, WrappedDefault
-from pypy.interpreter.error import OperationError
+from pypy.interpreter.error import OperationError, oefmt
+from pypy.module.micronumpy import ufuncs, support, concrete
 from pypy.module.micronumpy.base import W_NDimArray, convert_to_array
+from pypy.module.micronumpy.descriptor import decode_w_dtype
+from pypy.module.micronumpy.iterators import ArrayIter
 from pypy.module.micronumpy.strides import (calculate_broadcast_strides,
-                                             shape_agreement, shape_agreement_multiple)
-from pypy.module.micronumpy.iterators import ArrayIter, SliceIterator
-from pypy.module.micronumpy.concrete import SliceArray
-from pypy.module.micronumpy.descriptor import decode_w_dtype
-from pypy.module.micronumpy import ufuncs, support
+                                            shape_agreement, shape_agreement_multiple)
 
 
-class AbstractIterator(object):
-    def done(self):
-        raise NotImplementedError("Abstract Class")
-
-    def next(self):
-        raise NotImplementedError("Abstract Class")
-
-    def getitem(self, space, array):
-        raise NotImplementedError("Abstract Class")
-
-class IteratorMixin(object):
-    _mixin_ = True
-    def __init__(self, it, op_flags):
-        self.it = it
-        self.op_flags = op_flags
-
-    def done(self):
-        return self.it.done()
-
-    def next(self):
-        self.it.next()
-
-    def getitem(self, space, array):
-        return self.op_flags.get_it_item[self.index](space, array, self.it)
-
-    def setitem(self, space, array, val):
-        xxx
-
-class BoxIterator(IteratorMixin, AbstractIterator):
-    index = 0
-
-class ExternalLoopIterator(IteratorMixin, AbstractIterator):
-    index = 1
-
 def parse_op_arg(space, name, w_op_flags, n, parse_one_arg):
+    if space.is_w(w_op_flags, space.w_None):
+        w_op_flags = space.newtuple([space.wrap('readonly')])
+    if not space.isinstance_w(w_op_flags, space.w_tuple) and not \
+            space.isinstance_w(w_op_flags, space.w_list):
+        raise oefmt(space.w_ValueError,
+                    '%s must be a tuple or array of per-op flag-tuples',
+                    name)
     ret = []
-    if space.is_w(w_op_flags, space.w_None):
+    w_lst = space.listview(w_op_flags)
+    if space.isinstance_w(w_lst[0], space.w_tuple) or \
+       space.isinstance_w(w_lst[0], space.w_list):
+        if len(w_lst) != n:
+            raise oefmt(space.w_ValueError,
+                        '%s must be a tuple or array of per-op flag-tuples',
+                        name)
+        for item in w_lst:
+            ret.append(parse_one_arg(space, space.listview(item)))
+    else:
+        op_flag = parse_one_arg(space, w_lst)
         for i in range(n):
-            ret.append(OpFlag())
-    elif not space.isinstance_w(w_op_flags, space.w_tuple) and not \
-             space.isinstance_w(w_op_flags, space.w_list):
-        raise OperationError(space.w_ValueError, space.wrap(
-                '%s must be a tuple or array of per-op flag-tuples' % name))
-    else:
-        w_lst = space.listview(w_op_flags)
-        if space.isinstance_w(w_lst[0], space.w_tuple) or \
-           space.isinstance_w(w_lst[0], space.w_list):
-            if len(w_lst) != n:
-                raise OperationError(space.w_ValueError, space.wrap(
-                   '%s must be a tuple or array of per-op flag-tuples' % name))
-            for item in w_lst:
-                ret.append(parse_one_arg(space, space.listview(item)))
-        else:
-            op_flag = parse_one_arg(space, w_lst)
-            for i in range(n):
-                ret.append(op_flag)
+            ret.append(op_flag)
     return ret
 
+
 class OpFlag(object):
     def __init__(self):
-        self.rw = 'r'
+        self.rw = ''
         self.broadcast = True
         self.force_contig = False
         self.force_align = False
         self.native_byte_order = False
         self.tmp_copy = ''
         self.allocate = False
-        self.get_it_item = (get_readonly_item, get_readonly_slice)
 
-def get_readonly_item(space, array, it):
-    return space.wrap(it.getitem())
-
-def get_readwrite_item(space, array, it):
-    #create a single-value view (since scalars are not views)
-    res = SliceArray(it.array.start + it.offset, [0], [0], [1,], it.array, array)
-    #it.dtype.setitem(res, 0, it.getitem())
-    return W_NDimArray(res)
-
-def get_readonly_slice(space, array, it):
-    return W_NDimArray(it.getslice().readonly())
-
-def get_readwrite_slice(space, array, it):
-    return W_NDimArray(it.getslice())
 
 def parse_op_flag(space, lst):
     op_flag = OpFlag()
@@ -121,39 +72,38 @@
             op_flag.allocate = True
         elif item == 'no_subtype':
             raise OperationError(space.w_NotImplementedError, space.wrap(
-                    '"no_subtype" op_flag not implemented yet'))
+                '"no_subtype" op_flag not implemented yet'))
         elif item == 'arraymask':
             raise OperationError(space.w_NotImplementedError, space.wrap(
-                    '"arraymask" op_flag not implemented yet'))
+                '"arraymask" op_flag not implemented yet'))
         elif item == 'writemask':
             raise OperationError(space.w_NotImplementedError, space.wrap(
-                    '"writemask" op_flag not implemented yet'))
+                '"writemask" op_flag not implemented yet'))
         else:
             raise OperationError(space.w_ValueError, space.wrap(
-                    'op_flags must be a tuple or array of per-op flag-tuples'))
-        if op_flag.rw == 'r':
-            op_flag.get_it_item = (get_readonly_item, get_readonly_slice)
-        elif op_flag.rw == 'rw':
-            op_flag.get_it_item = (get_readwrite_item, get_readwrite_slice)
-        elif op_flag.rw == 'w':
-            # XXX Extra logic needed to make sure writeonly
-            op_flag.get_it_item = (get_readwrite_item, get_readwrite_slice)
+                'op_flags must be a tuple or array of per-op flag-tuples'))
+    if op_flag.rw == '':
+        raise oefmt(space.w_ValueError,
+                    "None of the iterator flags READWRITE, READONLY, or "
+                    "WRITEONLY were specified for an operand")
     return op_flag
 
+
 def parse_func_flags(space, nditer, w_flags):
     if space.is_w(w_flags, space.w_None):
         return
     elif not space.isinstance_w(w_flags, space.w_tuple) and not \
-             space.isinstance_w(w_flags, space.w_list):
+            space.isinstance_w(w_flags, space.w_list):
         raise OperationError(space.w_ValueError, space.wrap(
-                'Iter global flags must be a list or tuple of strings'))
+            'Iter global flags must be a list or tuple of strings'))
     lst = space.listview(w_flags)
     for w_item in lst:
         if not space.isinstance_w(w_item, space.w_str) and not \
-               space.isinstance_w(w_item, space.w_unicode):
+                space.isinstance_w(w_item, space.w_unicode):
             typename = space.type(w_item).getname(space)
-            raise OperationError(space.w_TypeError, space.wrap(
-                    'expected string or Unicode object, %s found' % typename))
+            raise oefmt(space.w_TypeError,
+                        'expected string or Unicode object, %s found',
+                        typename)
         item = space.str_w(w_item)
         if item == 'external_loop':
             raise OperationError(space.w_NotImplementedError, space.wrap(
@@ -187,21 +137,24 @@
         elif item == 'zerosize_ok':
             nditer.zerosize_ok = True
         else:
-            raise OperationError(space.w_ValueError, space.wrap(
-                    'Unexpected iterator global flag "%s"' % item))
+            raise oefmt(space.w_ValueError,
+                        'Unexpected iterator global flag "%s"',
+                        item)
     if nditer.tracked_index and nditer.external_loop:
-            raise OperationError(space.w_ValueError, space.wrap(
-                'Iterator flag EXTERNAL_LOOP cannot be used if an index or '
-                'multi-index is being tracked'))
+        raise OperationError(space.w_ValueError, space.wrap(
+            'Iterator flag EXTERNAL_LOOP cannot be used if an index or '
+            'multi-index is being tracked'))
+
 
 def is_backward(imp, order):
     if order == 'K' or (order == 'C' and imp.order == 'C'):
         return False
-    elif order =='F' and imp.order == 'C':
+    elif order == 'F' and imp.order == 'C':
         return True
     else:
         raise NotImplementedError('not implemented yet')
 
+
 def get_iter(space, order, arr, shape, dtype):
     imp = arr.implementation
     backward = is_backward(imp, order)
@@ -223,19 +176,6 @@
                                     shape, backward)
     return ArrayIter(imp, imp.get_size(), shape, r[0], r[1])
 
-def get_external_loop_iter(space, order, arr, shape):
-    imp = arr.implementation
-    backward = is_backward(imp, order)
-    return SliceIterator(arr, imp.strides, imp.backstrides, shape, order=order, backward=backward)
-
-def convert_to_array_or_none(space, w_elem):
-    '''
-    None will be passed through, all others will be converted
-    '''
-    if space.is_none(w_elem):
-        return None
-    return convert_to_array(space, w_elem)
-
 
 class IndexIterator(object):
     def __init__(self, shape, backward=False):
@@ -263,10 +203,10 @@
                 ret += self.index[i] * self.shape[i - 1]
         return ret
 
+
 class W_NDIter(W_Root):
-
     def __init__(self, space, w_seq, w_flags, w_op_flags, w_op_dtypes, w_casting,
-            w_op_axes, w_itershape, w_buffersize, order):
+                 w_op_axes, w_itershape, w_buffersize, order):
         self.order = order
         self.external_loop = False
         self.buffered = False
@@ -286,9 +226,11 @@
         if space.isinstance_w(w_seq, space.w_tuple) or \
            space.isinstance_w(w_seq, space.w_list):
             w_seq_as_list = space.listview(w_seq)
-            self.seq = [convert_to_array_or_none(space, w_elem) for w_elem in w_seq_as_list]
+            self.seq = [convert_to_array(space, w_elem)
+                        if not space.is_none(w_elem) else None
+                        for w_elem in w_seq_as_list]
         else:
-            self.seq =[convert_to_array(space, w_seq)]
+            self.seq = [convert_to_array(space, w_seq)]
 
         parse_func_flags(space, self, w_flags)
         self.op_flags = parse_op_arg(space, 'op_flags', w_op_flags,
@@ -308,9 +250,9 @@
             self.dtypes = []
 
         # handle None or writable operands, calculate my shape
-        self.iters=[]
-        outargs = [i for i in range(len(self.seq)) \
-                        if self.seq[i] is None or self.op_flags[i].rw == 'w']
+        self.iters = []
+        outargs = [i for i in range(len(self.seq))
+                   if self.seq[i] is None or self.op_flags[i].rw == 'w']
         if len(outargs) > 0:
             out_shape = shape_agreement_multiple(space, [self.seq[i] for i in outargs])
         else:
@@ -325,14 +267,12 @@
                 out_dtype = None
                 for i in range(len(self.seq)):
                     if self.seq[i] is None:
-                        self.op_flags[i].get_it_item = (get_readwrite_item,
-                                                    get_readwrite_slice)
                         self.op_flags[i].allocate = True
                         continue
                     if self.op_flags[i].rw == 'w':
                         continue
-                    out_dtype = ufuncs.find_binop_result_dtype(space,
-                                                self.seq[i].get_dtype(), out_dtype)
+                    out_dtype = ufuncs.find_binop_result_dtype(
+                        space, self.seq[i].get_dtype(), out_dtype)
             for i in outargs:
                 if self.seq[i] is None:
                     # XXX can we postpone allocation to later?
@@ -360,8 +300,9 @@
                     self.dtypes[i] = seq_d
                 elif selfd != seq_d:
                     if not 'r' in self.op_flags[i].tmp_copy:
-                        raise OperationError(space.w_TypeError, space.wrap(
-                            "Iterator operand required copying or buffering for operand %d" % i))
+                        raise oefmt(space.w_TypeError,
+                                    "Iterator operand required copying or "
+                                    "buffering for operand %d", i)
                     impl = self.seq[i].implementation
                     new_impl = impl.astype(space, selfd)
                     self.seq[i] = W_NDimArray(new_impl)
@@ -370,18 +311,14 @@
             self.dtypes = [s.get_dtype() for s in self.seq]
 
         # create an iterator for each operand
-        if self.external_loop:
-            for i in range(len(self.seq)):
-                self.iters.append(ExternalLoopIterator(get_external_loop_iter(space, self.order,
-                                self.seq[i], iter_shape), self.op_flags[i]))
-        else:
-            for i in range(len(self.seq)):
-                self.iters.append(BoxIterator(get_iter(space, self.order,
-                                    self.seq[i], iter_shape, self.dtypes[i]),
-                                 self.op_flags[i]))
+        for i in range(len(self.seq)):
+            it = get_iter(space, self.order, self.seq[i], iter_shape, self.dtypes[i])
+            self.iters.append((it, it.reset()))
+
     def set_op_axes(self, space, w_op_axes):
         if space.len_w(w_op_axes) != len(self.seq):
-            raise OperationError(space.w_ValueError, space.wrap("op_axes must be a tuple/list matching the number of ops"))
+            raise oefmt(space.w_ValueError,
+                        "op_axes must be a tuple/list matching the number of ops")
         op_axes = space.listview(w_op_axes)
         l = -1
         for w_axis in op_axes:
@@ -390,10 +327,14 @@
                 if l == -1:
                     l = axis_len
                 elif axis_len != l:
-                    raise OperationError(space.w_ValueError, space.wrap("Each entry of op_axes must have the same size"))
-                self.op_axes.append([space.int_w(x) if not space.is_none(x) else -1 for x in space.listview(w_axis)])
+                    raise oefmt(space.w_ValueError,
+                                "Each entry of op_axes must have the same size")
+                self.op_axes.append([space.int_w(x) if not space.is_none(x) else -1
+                                     for x in space.listview(w_axis)])
         if l == -1:
-            raise OperationError(space.w_ValueError, space.wrap("If op_axes is provided, at least one list of axes must be contained within it"))
+            raise oefmt(space.w_ValueError,
+                        "If op_axes is provided, at least one list of axes "
+                        "must be contained within it")
         raise Exception('xxx TODO')
         # Check that values make sense:
         # - in bounds for each operand
@@ -404,24 +345,34 @@
     def descr_iter(self, space):
         return space.wrap(self)
 
+    def getitem(self, it, st, op_flags):
+        if op_flags.rw == 'r':
+            impl = concrete.ConcreteNonWritableArrayWithBase
+        else:
+            impl = concrete.ConcreteArrayWithBase
+        res = impl([], it.array.dtype, it.array.order, [], [],
+                   it.array.storage, self)
+        res.start = st.offset
+        return W_NDimArray(res)
+
     def descr_getitem(self, space, w_idx):
         idx = space.int_w(w_idx)
         try:
-            ret = space.wrap(self.iters[idx].getitem(space, self.seq[idx]))
+            it, st = self.iters[idx]
         except IndexError:
-            raise OperationError(space.w_IndexError, space.wrap("Iterator operand index %d is out of bounds" % idx))
-        return ret
+            raise oefmt(space.w_IndexError,
+                        "Iterator operand index %d is out of bounds", idx)
+        return self.getitem(it, st, self.op_flags[idx])
 
     def descr_setitem(self, space, w_idx, w_value):
-        raise OperationError(space.w_NotImplementedError, space.wrap(
-            'not implemented yet'))
+        raise oefmt(space.w_NotImplementedError, "not implemented yet")
 
     def descr_len(self, space):
         space.wrap(len(self.iters))
 
     def descr_next(self, space):
-        for it in self.iters:
-            if not it.done():
+        for it, st in self.iters:
+            if not it.done(st):
                 break
         else:
             self.done = True
@@ -432,20 +383,20 @@
                 self.index_iter.next()
             else:
                 self.first_next = False
-        for i in range(len(self.iters)):
-            res.append(self.iters[i].getitem(space, self.seq[i]))
-            self.iters[i].next()
-        if len(res) <2:
+        for i, (it, st) in enumerate(self.iters):
+            res.append(self.getitem(it, st, self.op_flags[i]))
+            self.iters[i] = (it, it.next(st))
+        if len(res) < 2:
             return res[0]
         return space.newtuple(res)
 
     def iternext(self):
         if self.index_iter:
             self.index_iter.next()
-        for i in range(len(self.iters)):
-            self.iters[i].next()
-        for it in self.iters:
-            if not it.done():
+        for i, (it, st) in enumerate(self.iters):
+            self.iters[i] = (it, it.next(st))
+        for it, st in self.iters:
+            if not it.done(st):
                 break
         else:
             self.done = True
@@ -456,29 +407,23 @@
         return space.wrap(self.iternext())
 
     def descr_copy(self, space):
-        raise OperationError(space.w_NotImplementedError, space.wrap(
-            'not implemented yet'))
+        raise oefmt(space.w_NotImplementedError, "not implemented yet")
 
     def descr_debug_print(self, space):
-        raise OperationError(space.w_NotImplementedError, space.wrap(
-            'not implemented yet'))
+        raise oefmt(space.w_NotImplementedError, "not implemented yet")
 
     def descr_enable_external_loop(self, space):
-        raise OperationError(space.w_NotImplementedError, space.wrap(
-            'not implemented yet'))
+        raise oefmt(space.w_NotImplementedError, "not implemented yet")
 
     @unwrap_spec(axis=int)
     def descr_remove_axis(self, space, axis):
-        raise OperationError(space.w_NotImplementedError, space.wrap(
-            'not implemented yet'))
+        raise oefmt(space.w_NotImplementedError, "not implemented yet")
 
     def descr_remove_multi_index(self, space, w_multi_index):
-        raise OperationError(space.w_NotImplementedError, space.wrap(
-            'not implemented yet'))
+        raise oefmt(space.w_NotImplementedError, "not implemented yet")
 
     def descr_reset(self, space):
-        raise OperationError(space.w_NotImplementedError, space.wrap(
-            'not implemented yet'))
+        raise oefmt(space.w_NotImplementedError, "not implemented yet")
 
     def descr_get_operands(self, space):
         l_w = []
@@ -496,17 +441,16 @@
         return space.wrap(self.done)
 
     def descr_get_has_delayed_bufalloc(self, space):
-        raise OperationError(space.w_NotImplementedError, space.wrap(
-            'not implemented yet'))
+        raise oefmt(space.w_NotImplementedError, "not implemented yet")
 
     def descr_get_has_index(self, space):
         return space.wrap(self.tracked_index in ["C", "F"])
 
     def descr_get_index(self, space):
         if not self.tracked_index in ["C", "F"]:
-            raise OperationError(space.w_ValueError, space.wrap("Iterator does not have an index"))
+            raise oefmt(space.w_ValueError, "Iterator does not have an index")
         if self.done:
-            raise OperationError(space.w_ValueError, space.wrap("Iterator is past the end"))
+            raise oefmt(space.w_ValueError, "Iterator is past the end")
         return space.wrap(self.index_iter.getvalue())
 
     def descr_get_has_multi_index(self, space):
@@ -514,51 +458,44 @@
 
     def descr_get_multi_index(self, space):
         if not self.tracked_index == "multi":
-            raise OperationError(space.w_ValueError, space.wrap("Iterator is not tracking a multi-index"))
+            raise oefmt(space.w_ValueError, "Iterator is not tracking a multi-index")
         if self.done:
-            raise OperationError(space.w_ValueError, space.wrap("Iterator is past the end"))
+            raise oefmt(space.w_ValueError, "Iterator is past the end")
         return space.newtuple([space.wrap(x) for x in self.index_iter.index])
 
     def descr_get_iterationneedsapi(self, space):
-        raise OperationError(space.w_NotImplementedError, space.wrap(
-            'not implemented yet'))
+        raise oefmt(space.w_NotImplementedError, "not implemented yet")
 
     def descr_get_iterindex(self, space):
-        raise OperationError(space.w_NotImplementedError, space.wrap(
-            'not implemented yet'))
+        raise oefmt(space.w_NotImplementedError, "not implemented yet")
 
     def descr_get_itersize(self, space):
         return space.wrap(support.product(self.shape))
 
     def descr_get_itviews(self, space):
-        raise OperationError(space.w_NotImplementedError, space.wrap(
-            'not implemented yet'))
+        raise oefmt(space.w_NotImplementedError, "not implemented yet")
 
     def descr_get_ndim(self, space):
-        raise OperationError(space.w_NotImplementedError, space.wrap(
-            'not implemented yet'))
+        raise oefmt(space.w_NotImplementedError, "not implemented yet")
 
     def descr_get_nop(self, space):
-        raise OperationError(space.w_NotImplementedError, space.wrap(
-            'not implemented yet'))
+        raise oefmt(space.w_NotImplementedError, "not implemented yet")
 
     def descr_get_shape(self, space):
-        raise OperationError(space.w_NotImplementedError, space.wrap(
-            'not implemented yet'))
+        raise oefmt(space.w_NotImplementedError, "not implemented yet")
 
     def descr_get_value(self, space):
-        raise OperationError(space.w_NotImplementedError, space.wrap(
-            'not implemented yet'))
+        raise oefmt(space.w_NotImplementedError, "not implemented yet")
 
 
- at unwrap_spec(w_flags = WrappedDefault(None), w_op_flags=WrappedDefault(None),
-             w_op_dtypes = WrappedDefault(None), order=str,
+ at unwrap_spec(w_flags=WrappedDefault(None), w_op_flags=WrappedDefault(None),
+             w_op_dtypes=WrappedDefault(None), order=str,
              w_casting=WrappedDefault(None), w_op_axes=WrappedDefault(None),
              w_itershape=WrappedDefault(None), w_buffersize=WrappedDefault(None))
 def nditer(space, w_seq, w_flags, w_op_flags, w_op_dtypes, w_casting, w_op_axes,
-             w_itershape, w_buffersize, order='K'):
+           w_itershape, w_buffersize, order='K'):
     return W_NDIter(space, w_seq, w_flags, w_op_flags, w_op_dtypes, w_casting, w_op_axes,
-            w_itershape, w_buffersize, order)
+                    w_itershape, w_buffersize, order)
 
 W_NDIter.typedef = TypeDef(
     'nditer',
diff --git a/pypy/module/micronumpy/sort.py b/pypy/module/micronumpy/sort.py
--- a/pypy/module/micronumpy/sort.py
+++ b/pypy/module/micronumpy/sort.py
@@ -148,20 +148,22 @@
             if axis < 0 or axis >= len(shape):
                 raise oefmt(space.w_IndexError, "Wrong axis %d", axis)
             arr_iter = AllButAxisIter(arr, axis)
+            arr_state = arr_iter.reset()
             index_impl = index_arr.implementation
             index_iter = AllButAxisIter(index_impl, axis)
+            index_state = index_iter.reset()
             stride_size = arr.strides[axis]
             index_stride_size = index_impl.strides[axis]
             axis_size = arr.shape[axis]
-            while not arr_iter.done():
+            while not arr_iter.done(arr_state):
                 for i in range(axis_size):
                     raw_storage_setitem(storage, i * index_stride_size +
-                                        index_iter.offset, i)
+                                        index_state.offset, i)
                 r = Repr(index_stride_size, stride_size, axis_size,
-                         arr.get_storage(), storage, index_iter.offset, arr_iter.offset)
+                         arr.get_storage(), storage, index_state.offset, arr_state.offset)
                 ArgSort(r).sort()
-                arr_iter.next()
-                index_iter.next()
+                arr_state = arr_iter.next(arr_state)
+                index_state = index_iter.next(index_state)
         return index_arr
 
     return argsort
@@ -292,12 +294,13 @@
             if axis < 0 or axis >= len(shape):
                 raise oefmt(space.w_IndexError, "Wrong axis %d", axis)
             arr_iter = AllButAxisIter(arr, axis)
+            arr_state = arr_iter.reset()
             stride_size = arr.strides[axis]
             axis_size = arr.shape[axis]
-            while not arr_iter.done():
-                r = Repr(stride_size, axis_size, arr.get_storage(), arr_iter.offset)
+            while not arr_iter.done(arr_state):
+                r = Repr(stride_size, axis_size, arr.get_storage(), arr_state.offset)
                 ArgSort(r).sort()
-                arr_iter.next()
+                arr_state = arr_iter.next(arr_state)
 
     return sort
 
diff --git a/pypy/module/micronumpy/strides.py b/pypy/module/micronumpy/strides.py
--- a/pypy/module/micronumpy/strides.py
+++ b/pypy/module/micronumpy/strides.py
@@ -233,30 +233,6 @@
     return dtype
 
 
-def to_coords(space, shape, size, order, w_item_or_slice):
-    '''Returns a start coord, step, and length.
-    '''
-    start = lngth = step = 0
-    if not (space.isinstance_w(w_item_or_slice, space.w_int) or
-            space.isinstance_w(w_item_or_slice, space.w_slice)):
-        raise OperationError(space.w_IndexError,
-                             space.wrap('unsupported iterator index'))
-
-    start, stop, step, lngth = space.decode_index4(w_item_or_slice, size)
-
-    coords = [0] * len(shape)
-    i = start
-    if order == 'C':
-        for s in range(len(shape) -1, -1, -1):
-            coords[s] = i % shape[s]
-            i //= shape[s]
-    else:
-        for s in range(len(shape)):
-            coords[s] = i % shape[s]
-            i //= shape[s]
-    return coords, step, lngth
-
-
 @jit.unroll_safe
 def shape_agreement(space, shape1, w_arr2, broadcast_down=True):
     if w_arr2 is None:
diff --git a/pypy/module/micronumpy/support.py b/pypy/module/micronumpy/support.py
--- a/pypy/module/micronumpy/support.py
+++ b/pypy/module/micronumpy/support.py
@@ -25,3 +25,18 @@
     for x in s:
         i *= x
     return i
+
+
+def check_and_adjust_index(space, index, size, axis):
+    if index < -size or index >= size:
+        if axis >= 0:
+            raise oefmt(space.w_IndexError,
+                        "index %d is out of bounds for axis %d with size %d",
+                        index, axis, size)
+        else:
+            raise oefmt(space.w_IndexError,
+                        "index %d is out of bounds for size %d",
+                        index, size)
+    if index < 0:
+        index += size
+    return index
diff --git a/pypy/module/micronumpy/test/test_iterators.py b/pypy/module/micronumpy/test/test_iterators.py
--- a/pypy/module/micronumpy/test/test_iterators.py
+++ b/pypy/module/micronumpy/test/test_iterators.py
@@ -16,17 +16,18 @@
         assert backstrides == [10, 4]
         i = ArrayIter(MockArray, support.product(shape), shape,
                       strides, backstrides)
-        i.next()
-        i.next()
-        i.next()
-        assert i.offset == 3
-        assert not i.done()
-        assert i.indices == [0,3]
+        s = i.reset()
+        s = i.next(s)
+        s = i.next(s)
+        s = i.next(s)
+        assert s.offset == 3
+        assert not i.done(s)
+        assert s.indices == [0,3]
         #cause a dimension overflow
-        i.next()
-        i.next()
-        assert i.offset == 5
-        assert i.indices == [1,0]
+        s = i.next(s)
+        s = i.next(s)
+        assert s.offset == 5
+        assert s.indices == [1,0]
 
         #Now what happens if the array is transposed? strides[-1] != 1
         # therefore layout is non-contiguous
@@ -35,17 +36,18 @@
         assert backstrides == [2, 12]
         i = ArrayIter(MockArray, support.product(shape), shape,
                       strides, backstrides)
-        i.next()
-        i.next()
-        i.next()
-        assert i.offset == 9
-        assert not i.done()
-        assert i.indices == [0,3]
+        s = i.reset()
+        s = i.next(s)
+        s = i.next(s)
+        s = i.next(s)
+        assert s.offset == 9
+        assert not i.done(s)
+        assert s.indices == [0,3]
         #cause a dimension overflow
-        i.next()
-        i.next()
-        assert i.offset == 1
-        assert i.indices == [1,0]
+        s = i.next(s)
+        s = i.next(s)
+        assert s.offset == 1
+        assert s.indices == [1,0]
 
     def test_iterator_step(self):
         #iteration in C order with #contiguous layout => strides[-1] is 1
@@ -56,22 +58,23 @@
         assert backstrides == [10, 4]
         i = ArrayIter(MockArray, support.product(shape), shape,
                       strides, backstrides)
-        i.next_skip_x(2)
-        i.next_skip_x(2)
-        i.next_skip_x(2)
-        assert i.offset == 6
-        assert not i.done()
-        assert i.indices == [1,1]
+        s = i.reset()
+        s = i.next_skip_x(s, 2)
+        s = i.next_skip_x(s, 2)
+        s = i.next_skip_x(s, 2)
+        assert s.offset == 6
+        assert not i.done(s)
+        assert s.indices == [1,1]
         #And for some big skips
-        i.next_skip_x(5)
-        assert i.offset == 11
-        assert i.indices == [2,1]
-        i.next_skip_x(5)
+        s = i.next_skip_x(s, 5)
+        assert s.offset == 11
+        assert s.indices == [2,1]
+        s = i.next_skip_x(s, 5)
         # Note: the offset does not overflow but recycles,
         # this is good for broadcast
-        assert i.offset == 1
-        assert i.indices == [0,1]
-        assert i.done()
+        assert s.offset == 1
+        assert s.indices == [0,1]
+        assert i.done(s)
 
         #Now what happens if the array is transposed? strides[-1] != 1
         # therefore layout is non-contiguous
@@ -80,17 +83,18 @@
         assert backstrides == [2, 12]
         i = ArrayIter(MockArray, support.product(shape), shape,
                       strides, backstrides)
-        i.next_skip_x(2)
-        i.next_skip_x(2)
-        i.next_skip_x(2)
-        assert i.offset == 4
-        assert i.indices == [1,1]
-        assert not i.done()
-        i.next_skip_x(5)
-        assert i.offset == 5
-        assert i.indices == [2,1]
-        assert not i.done()
-        i.next_skip_x(5)
-        assert i.indices == [0,1]
-        assert i.offset == 3
-        assert i.done()
+        s = i.reset()
+        s = i.next_skip_x(s, 2)
+        s = i.next_skip_x(s, 2)
+        s = i.next_skip_x(s, 2)
+        assert s.offset == 4
+        assert s.indices == [1,1]
+        assert not i.done(s)
+        s = i.next_skip_x(s, 5)
+        assert s.offset == 5
+        assert s.indices == [2,1]
+        assert not i.done(s)
+        s = i.next_skip_x(s, 5)
+        assert s.indices == [0,1]
+        assert s.offset == 3
+        assert i.done(s)
diff --git a/pypy/module/micronumpy/test/test_ndarray.py b/pypy/module/micronumpy/test/test_ndarray.py
--- a/pypy/module/micronumpy/test/test_ndarray.py
+++ b/pypy/module/micronumpy/test/test_ndarray.py
@@ -164,24 +164,6 @@
         assert calc_new_strides([1, 1, 105, 1, 1], [7, 15], [1, 7],'F') == \
                                     [1, 1, 1, 105, 105]
 
-    def test_to_coords(self):
-        from pypy.module.micronumpy.strides import to_coords
-


More information about the pypy-commit mailing list