[pypy-commit] pypy numpy-speed: change usages of iters to use state

bdkearns noreply at buildbot.pypy.org
Fri Apr 18 00:09:08 CEST 2014


Author: Brian Kearns <bdkearns at gmail.com>
Branch: numpy-speed
Changeset: r70724:71b86d3efc92
Date: 2014-04-17 16:31 -0400
http://bitbucket.org/pypy/pypy/changeset/71b86d3efc92/

Log:	change usages of iters to use state

diff --git a/pypy/module/micronumpy/concrete.py b/pypy/module/micronumpy/concrete.py
--- a/pypy/module/micronumpy/concrete.py
+++ b/pypy/module/micronumpy/concrete.py
@@ -284,9 +284,11 @@
                                             self.get_backstrides(),
                                             self.get_shape(), shape,
                                             backward_broadcast)
-            return ArrayIter(self, support.product(shape), shape, r[0], r[1])
-        return ArrayIter(self, self.get_size(), self.shape,
-                         self.strides, self.backstrides)
+            i = ArrayIter(self, support.product(shape), shape, r[0], r[1])
+        else:
+            i = ArrayIter(self, self.get_size(), self.shape,
+                          self.strides, self.backstrides)
+        return i, i.reset()
 
     def swapaxes(self, space, orig_arr, axis1, axis2):
         shape = self.get_shape()[:]
diff --git a/pypy/module/micronumpy/ctors.py b/pypy/module/micronumpy/ctors.py
--- a/pypy/module/micronumpy/ctors.py
+++ b/pypy/module/micronumpy/ctors.py
@@ -156,10 +156,10 @@
             "string is smaller than requested size"))
 
     a = W_NDimArray.from_shape(space, [num_items], dtype=dtype)
-    ai = a.create_iter()
+    ai, state = a.create_iter()
     for val in items:
-        ai.setitem(val)
-        ai.next()
+        ai.setitem(state, val)
+        state = ai.next(state)
 
     return space.wrap(a)
 
diff --git a/pypy/module/micronumpy/flatiter.py b/pypy/module/micronumpy/flatiter.py
--- a/pypy/module/micronumpy/flatiter.py
+++ b/pypy/module/micronumpy/flatiter.py
@@ -32,23 +32,23 @@
         self.reset()
 
     def reset(self):
-        self.iter = self.base.create_iter()
+        self.iter, self.state = self.base.create_iter()
 
     def descr_len(self, space):
         return space.wrap(self.base.get_size())
 
     def descr_next(self, space):
-        if self.iter.done():
+        if self.iter.done(self.state):
             raise OperationError(space.w_StopIteration, space.w_None)
-        w_res = self.iter.getitem()
-        self.iter.next()
+        w_res = self.iter.getitem(self.state)
+        self.state = self.iter.next(self.state)
         return w_res
 
     def descr_index(self, space):
-        return space.wrap(self.iter.index)
+        return space.wrap(self.state.index)
 
     def descr_coords(self, space):
-        coords = self.base.to_coords(space, space.wrap(self.iter.index))
+        coords = self.base.to_coords(space, space.wrap(self.state.index))
         return space.newtuple([space.wrap(c) for c in coords])
 
     def descr_getitem(self, space, w_idx):
@@ -58,13 +58,13 @@
         self.reset()
         base = self.base
         start, stop, step, length = space.decode_index4(w_idx, base.get_size())
-        base_iter = base.create_iter()
-        base_iter.next_skip_x(start)
+        base_iter, base_state = base.create_iter()
+        base_state = base_iter.next_skip_x(base_state, start)
         if length == 1:
-            return base_iter.getitem()
+            return base_iter.getitem(base_state)
         res = W_NDimArray.from_shape(space, [length], base.get_dtype(),
                                      base.get_order(), w_instance=base)
-        return loop.flatiter_getitem(res, base_iter, step)
+        return loop.flatiter_getitem(res, base_iter, base_state, step)
 
     def descr_setitem(self, space, w_idx, w_value):
         if not (space.isinstance_w(w_idx, space.w_int) or
diff --git a/pypy/module/micronumpy/iterators.py b/pypy/module/micronumpy/iterators.py
--- a/pypy/module/micronumpy/iterators.py
+++ b/pypy/module/micronumpy/iterators.py
@@ -51,19 +51,20 @@
         self.shapelen = len(shape)
         self.indexes = [0] * len(shape)
         self._done = False
-        self.idx_w = [None] * len(idx_w)
+        self.idx_w_i = [None] * len(idx_w)
+        self.idx_w_s = [None] * len(idx_w)
         for i, w_idx in enumerate(idx_w):
             if isinstance(w_idx, W_NDimArray):
-                self.idx_w[i] = w_idx.create_iter(shape)
+                self.idx_w_i[i], self.idx_w_s[i] = w_idx.create_iter(shape)
 
     def done(self):
         return self._done
 
     @jit.unroll_safe
     def next(self):
-        for w_idx in self.idx_w:
-            if w_idx is not None:
-                w_idx.next()
+        for i, idx_w_i in enumerate(self.idx_w_i):
+            if idx_w_i is not None:
+                self.idx_w_s[i] = idx_w_i.next(self.idx_w_s[i])
         for i in range(self.shapelen - 1, -1, -1):
             if self.indexes[i] < self.shape[i] - 1:
                 self.indexes[i] += 1
diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py
--- a/pypy/module/micronumpy/loop.py
+++ b/pypy/module/micronumpy/loop.py
@@ -12,11 +12,10 @@
     AllButAxisIter
 
 
-call2_driver = jit.JitDriver(name='numpy_call2',
-                             greens = ['shapelen', 'func', 'calc_dtype',
-                                       'res_dtype'],
-                             reds = ['shape', 'w_lhs', 'w_rhs', 'out',
-                                     'left_iter', 'right_iter', 'out_iter'])
+call2_driver = jit.JitDriver(
+    name='numpy_call2',
+    greens=['shapelen', 'func', 'calc_dtype', 'res_dtype'],
+    reds='auto')
 
 def call2(space, shape, func, calc_dtype, res_dtype, w_lhs, w_rhs, out):
     # handle array_priority
@@ -46,47 +45,40 @@
     if out is None:
         out = W_NDimArray.from_shape(space, shape, res_dtype,
                                      w_instance=lhs_for_subtype)
-    left_iter = w_lhs.create_iter(shape)
-    right_iter = w_rhs.create_iter(shape)
-    out_iter = out.create_iter(shape)
+    left_iter, left_state = w_lhs.create_iter(shape)
+    right_iter, right_state = w_rhs.create_iter(shape)
+    out_iter, out_state = out.create_iter(shape)
     shapelen = len(shape)
-    while not out_iter.done():
+    while not out_iter.done(out_state):
         call2_driver.jit_merge_point(shapelen=shapelen, func=func,
-                                     calc_dtype=calc_dtype, res_dtype=res_dtype,
-                                     shape=shape, w_lhs=w_lhs, w_rhs=w_rhs,
-                                     out=out,
-                                     left_iter=left_iter, right_iter=right_iter,
-                                     out_iter=out_iter)
-        w_left = left_iter.getitem().convert_to(space, calc_dtype)
-        w_right = right_iter.getitem().convert_to(space, calc_dtype)
-        out_iter.setitem(func(calc_dtype, w_left, w_right).convert_to(
+                                     calc_dtype=calc_dtype, res_dtype=res_dtype)
+        w_left = left_iter.getitem(left_state).convert_to(space, calc_dtype)
+        w_right = right_iter.getitem(right_state).convert_to(space, calc_dtype)
+        out_iter.setitem(out_state, func(calc_dtype, w_left, w_right).convert_to(
             space, res_dtype))
-        left_iter.next()
-        right_iter.next()
-        out_iter.next()
+        left_state = left_iter.next(left_state)
+        right_state = right_iter.next(right_state)
+        out_state = out_iter.next(out_state)
     return out
 
-call1_driver = jit.JitDriver(name='numpy_call1',
-                             greens = ['shapelen', 'func', 'calc_dtype',
-                                       'res_dtype'],
-                             reds = ['shape', 'w_obj', 'out', 'obj_iter',
-                                     'out_iter'])
+call1_driver = jit.JitDriver(
+    name='numpy_call1',
+    greens=['shapelen', 'func', 'calc_dtype', 'res_dtype'],
+    reds='auto')
 
 def call1(space, shape, func, calc_dtype, res_dtype, w_obj, out):
     if out is None:
         out = W_NDimArray.from_shape(space, shape, res_dtype, w_instance=w_obj)
-    obj_iter = w_obj.create_iter(shape)
-    out_iter = out.create_iter(shape)
+    obj_iter, obj_state = w_obj.create_iter(shape)
+    out_iter, out_state = out.create_iter(shape)
     shapelen = len(shape)
-    while not out_iter.done():
+    while not out_iter.done(out_state):
         call1_driver.jit_merge_point(shapelen=shapelen, func=func,
-                                     calc_dtype=calc_dtype, res_dtype=res_dtype,
-                                     shape=shape, w_obj=w_obj, out=out,
-                                     obj_iter=obj_iter, out_iter=out_iter)
-        elem = obj_iter.getitem().convert_to(space, calc_dtype)
-        out_iter.setitem(func(calc_dtype, elem).convert_to(space, res_dtype))
-        out_iter.next()
-        obj_iter.next()
+                                     calc_dtype=calc_dtype, res_dtype=res_dtype)
+        elem = obj_iter.getitem(obj_state).convert_to(space, calc_dtype)
+        out_iter.setitem(out_state, func(calc_dtype, elem).convert_to(space, res_dtype))
+        out_state = out_iter.next(out_state)
+        obj_state = obj_iter.next(obj_state)
     return out
 
 setslice_driver = jit.JitDriver(name='numpy_setslice',
@@ -96,18 +88,20 @@
 def setslice(space, shape, target, source):
     # note that unlike everything else, target and source here are
     # array implementations, not arrays
-    target_iter = target.create_iter(shape)
-    source_iter = source.create_iter(shape)
+    target_iter, target_state = target.create_iter(shape)
+    source_iter, source_state = source.create_iter(shape)
     dtype = target.dtype
     shapelen = len(shape)
-    while not target_iter.done():
+    while not target_iter.done(target_state):
         setslice_driver.jit_merge_point(shapelen=shapelen, dtype=dtype)
+        val = source_iter.getitem(source_state)
         if dtype.is_str_or_unicode():
-            target_iter.setitem(dtype.coerce(space, source_iter.getitem()))
+            val = dtype.coerce(space, val)
         else:
-            target_iter.setitem(source_iter.getitem().convert_to(space, dtype))
-        target_iter.next()
-        source_iter.next()
+            val = val.convert_to(space, dtype)
+        target_iter.setitem(target_state, val)
+        target_state = target_iter.next(target_state)
+        source_state = source_iter.next(source_state)
     return target
 
 reduce_driver = jit.JitDriver(name='numpy_reduce',
@@ -116,22 +110,22 @@
                               reds = 'auto')
 
 def compute_reduce(space, obj, calc_dtype, func, done_func, identity):
-    obj_iter = obj.create_iter()
+    obj_iter, obj_state = obj.create_iter()
     if identity is None:
-        cur_value = obj_iter.getitem().convert_to(space, calc_dtype)
-        obj_iter.next()
+        cur_value = obj_iter.getitem(obj_state).convert_to(space, calc_dtype)
+        obj_state = obj_iter.next(obj_state)
     else:
         cur_value = identity.convert_to(space, calc_dtype)
     shapelen = len(obj.get_shape())
-    while not obj_iter.done():
+    while not obj_iter.done(obj_state):
         reduce_driver.jit_merge_point(shapelen=shapelen, func=func,
                                       done_func=done_func,
                                       calc_dtype=calc_dtype)
-        rval = obj_iter.getitem().convert_to(space, calc_dtype)
+        rval = obj_iter.getitem(obj_state).convert_to(space, calc_dtype)
         if done_func is not None and done_func(calc_dtype, rval):
             return rval
         cur_value = func(calc_dtype, cur_value, rval)
-        obj_iter.next()
+        obj_state = obj_iter.next(obj_state)
     return cur_value
 
 reduce_cum_driver = jit.JitDriver(name='numpy_reduce_cum_driver',
@@ -139,69 +133,76 @@
                                   reds = 'auto')
 
 def compute_reduce_cumulative(space, obj, out, calc_dtype, func, identity):
-    obj_iter = obj.create_iter()
-    out_iter = out.create_iter()
+    obj_iter, obj_state = obj.create_iter()
+    out_iter, out_state = out.create_iter()
     if identity is None:
-        cur_value = obj_iter.getitem().convert_to(space, calc_dtype)
-        out_iter.setitem(cur_value)
-        out_iter.next()
-        obj_iter.next()
+        cur_value = obj_iter.getitem(obj_state).convert_to(space, calc_dtype)
+        out_iter.setitem(out_state, cur_value)
+        out_state = out_iter.next(out_state)
+        obj_state = obj_iter.next(obj_state)
     else:
         cur_value = identity.convert_to(space, calc_dtype)
     shapelen = len(obj.get_shape())
-    while not obj_iter.done():
+    while not obj_iter.done(obj_state):
         reduce_cum_driver.jit_merge_point(shapelen=shapelen, func=func,
                                           dtype=calc_dtype)
-        rval = obj_iter.getitem().convert_to(space, calc_dtype)
+        rval = obj_iter.getitem(obj_state).convert_to(space, calc_dtype)
         cur_value = func(calc_dtype, cur_value, rval)
-        out_iter.setitem(cur_value)
-        out_iter.next()
-        obj_iter.next()
+        out_iter.setitem(out_state, cur_value)
+        out_state = out_iter.next(out_state)
+        obj_state = obj_iter.next(obj_state)
 
 def fill(arr, box):
-    arr_iter = arr.create_iter()
-    while not arr_iter.done():
-        arr_iter.setitem(box)
-        arr_iter.next()
+    arr_iter, arr_state = arr.create_iter()
+    while not arr_iter.done(arr_state):
+        arr_iter.setitem(arr_state, box)
+        arr_state = arr_iter.next(arr_state)
 
 def assign(space, arr, seq):
-    arr_iter = arr.create_iter()
+    arr_iter, arr_state = arr.create_iter()
     arr_dtype = arr.get_dtype()
     for item in seq:
-        arr_iter.setitem(arr_dtype.coerce(space, item))
-        arr_iter.next()
+        arr_iter.setitem(arr_state, arr_dtype.coerce(space, item))
+        arr_state = arr_iter.next(arr_state)
 
 where_driver = jit.JitDriver(name='numpy_where',
                              greens = ['shapelen', 'dtype', 'arr_dtype'],
                              reds = 'auto')
 
 def where(space, out, shape, arr, x, y, dtype):
-    out_iter = out.create_iter(shape)
-    arr_iter = arr.create_iter(shape)
+    out_iter, out_state = out.create_iter(shape)
+    arr_iter, arr_state = arr.create_iter(shape)
     arr_dtype = arr.get_dtype()
-    x_iter = x.create_iter(shape)
-    y_iter = y.create_iter(shape)
+    x_iter, x_state = x.create_iter(shape)
+    y_iter, y_state = y.create_iter(shape)
     if x.is_scalar():
         if y.is_scalar():
-            iter = arr_iter
+            iter, state = arr_iter, arr_state
         else:
-            iter = y_iter
+            iter, state = y_iter, y_state
     else:
-        iter = x_iter
+        iter, state = x_iter, x_state
     shapelen = len(shape)
-    while not iter.done():
+    while not iter.done(state):
         where_driver.jit_merge_point(shapelen=shapelen, dtype=dtype,
                                         arr_dtype=arr_dtype)
-        w_cond = arr_iter.getitem()
+        w_cond = arr_iter.getitem(arr_state)
         if arr_dtype.itemtype.bool(w_cond):
-            w_val = x_iter.getitem().convert_to(space, dtype)
+            w_val = x_iter.getitem(x_state).convert_to(space, dtype)
         else:
-            w_val = y_iter.getitem().convert_to(space, dtype)
-        out_iter.setitem(w_val)
-        out_iter.next()
-        arr_iter.next()
-        x_iter.next()
-        y_iter.next()
+            w_val = y_iter.getitem(y_state).convert_to(space, dtype)
+        out_iter.setitem(out_state, w_val)
+        out_state = out_iter.next(out_state)
+        arr_state = arr_iter.next(arr_state)
+        x_state = x_iter.next(x_state)
+        y_state = y_iter.next(y_state)
+        if x.is_scalar():
+            if y.is_scalar():
+                state = arr_state
+            else:
+                state = y_state
+        else:
+            state = x_state
     return out
 
 axis_reduce__driver = jit.JitDriver(name='numpy_axis_reduce',
@@ -212,31 +213,36 @@
 def do_axis_reduce(space, shape, func, arr, dtype, axis, out, identity, cumulative,
                    temp):
     out_iter = AxisIter(out.implementation, arr.get_shape(), axis, cumulative)
+    out_state = out_iter.reset()
     if cumulative:
         temp_iter = AxisIter(temp.implementation, arr.get_shape(), axis, False)
+        temp_state = temp_iter.reset()
     else:
-        temp_iter = out_iter # hack
-    arr_iter = arr.create_iter()
+        temp_iter = out_iter  # hack
+        temp_state = out_state
+    arr_iter, arr_state = arr.create_iter()
     if identity is not None:
         identity = identity.convert_to(space, dtype)
     shapelen = len(shape)
-    while not out_iter.done():
+    while not out_iter.done(out_state):
         axis_reduce__driver.jit_merge_point(shapelen=shapelen, func=func,
                                             dtype=dtype)
-        assert not arr_iter.done()
-        w_val = arr_iter.getitem().convert_to(space, dtype)
-        if out_iter.indices[axis] == 0:
+        assert not arr_iter.done(arr_state)
+        w_val = arr_iter.getitem(arr_state).convert_to(space, dtype)
+        if out_state.indices[axis] == 0:
             if identity is not None:
                 w_val = func(dtype, identity, w_val)
         else:
-            cur = temp_iter.getitem()
+            cur = temp_iter.getitem(temp_state)
             w_val = func(dtype, cur, w_val)
-        out_iter.setitem(w_val)
+        out_iter.setitem(out_state, w_val)
+        out_state = out_iter.next(out_state)
         if cumulative:
-            temp_iter.setitem(w_val)
-            temp_iter.next()
-        arr_iter.next()
-        out_iter.next()
+            temp_iter.setitem(temp_state, w_val)
+            temp_state = temp_iter.next(temp_state)
+        else:
+            temp_state = out_state
+        arr_state = arr_iter.next(arr_state)
     return out
 
 
@@ -249,18 +255,18 @@
         result = 0
         idx = 1
         dtype = arr.get_dtype()
-        iter = arr.create_iter()
-        cur_best = iter.getitem()
-        iter.next()
+        iter, state = arr.create_iter()
+        cur_best = iter.getitem(state)
+        state = iter.next(state)
         shapelen = len(arr.get_shape())
-        while not iter.done():
+        while not iter.done(state):
             arg_driver.jit_merge_point(shapelen=shapelen, dtype=dtype)
-            w_val = iter.getitem()
+            w_val = iter.getitem(state)
             new_best = getattr(dtype.itemtype, op_name)(cur_best, w_val)
             if dtype.itemtype.ne(new_best, cur_best):
                 result = idx
                 cur_best = new_best
-            iter.next()
+            state = iter.next(state)
             idx += 1
         return result
     return argmin_argmax
@@ -291,17 +297,19 @@
     right_impl = right.implementation
     assert left_shape[-1] == right_shape[right_critical_dim]
     assert result.get_dtype() == dtype
-    outi = result.create_iter()
+    outi, outs = result.create_iter()
     lefti = AllButAxisIter(left_impl, len(left_shape) - 1)
     righti = AllButAxisIter(right_impl, right_critical_dim)
+    lefts = lefti.reset()
+    rights = righti.reset()
     n = left_impl.shape[-1]
     s1 = left_impl.strides[-1]
     s2 = right_impl.strides[right_critical_dim]
-    while not lefti.done():
-        while not righti.done():
-            oval = outi.getitem()
-            i1 = lefti.offset
-            i2 = righti.offset
+    while not lefti.done(lefts):
+        while not righti.done(rights):
+            oval = outi.getitem(outs)
+            i1 = lefts.offset
+            i2 = rights.offset
             i = 0
             while i < n:
                 i += 1
@@ -311,11 +319,11 @@
                 oval = dtype.itemtype.add(oval, dtype.itemtype.mul(lval, rval))
                 i1 += s1
                 i2 += s2
-            outi.setitem(oval)
-            outi.next()
-            righti.next()
-        righti.reset()
-        lefti.next()
+            outi.setitem(outs, oval)
+            outs = outi.next(outs)
+            rights = righti.next(rights)
+        rights = righti.reset()
+        lefts = lefti.next(lefts)
     return result
 
 count_all_true_driver = jit.JitDriver(name = 'numpy_count',
@@ -324,13 +332,13 @@
 
 def count_all_true_concrete(impl):
     s = 0
-    iter = impl.create_iter()
+    iter, state = impl.create_iter()
     shapelen = len(impl.shape)
     dtype = impl.dtype
-    while not iter.done():
+    while not iter.done(state):
         count_all_true_driver.jit_merge_point(shapelen=shapelen, dtype=dtype)
-        s += iter.getitem_bool()
-        iter.next()
+        s += iter.getitem_bool(state)
+        state = iter.next(state)
     return s
 
 def count_all_true(arr):
@@ -344,18 +352,18 @@
                                reds = 'auto')
 
 def nonzero(res, arr, box):
-    res_iter = res.create_iter()
-    arr_iter = arr.create_iter()
+    res_iter, res_state = res.create_iter()
+    arr_iter, arr_state = arr.create_iter()
     shapelen = len(arr.shape)
     dtype = arr.dtype
     dims = range(shapelen)
-    while not arr_iter.done():
+    while not arr_iter.done(arr_state):
         nonzero_driver.jit_merge_point(shapelen=shapelen, dims=dims, dtype=dtype)
-        if arr_iter.getitem_bool():
+        if arr_iter.getitem_bool(arr_state):
             for d in dims:
-                res_iter.setitem(box(arr_iter.indices[d]))
-                res_iter.next()
-        arr_iter.next()
+                res_iter.setitem(res_state, box(arr_state.indices[d]))
+                res_state = res_iter.next(res_state)
+        arr_state = arr_iter.next(arr_state)
     return res
 
 
@@ -365,26 +373,26 @@
                                       reds = 'auto')
 
 def getitem_filter(res, arr, index):
-    res_iter = res.create_iter()
+    res_iter, res_state = res.create_iter()
     shapelen = len(arr.get_shape())
     if shapelen > 1 and len(index.get_shape()) < 2:
-        index_iter = index.create_iter(arr.get_shape(), backward_broadcast=True)
+        index_iter, index_state = index.create_iter(arr.get_shape(), backward_broadcast=True)
     else:
-        index_iter = index.create_iter()
-    arr_iter = arr.create_iter()
+        index_iter, index_state = index.create_iter()
+    arr_iter, arr_state = arr.create_iter()
     arr_dtype = arr.get_dtype()
     index_dtype = index.get_dtype()
     # XXX length of shape of index as well?
-    while not index_iter.done():
+    while not index_iter.done(index_state):
         getitem_filter_driver.jit_merge_point(shapelen=shapelen,
                                               index_dtype=index_dtype,
                                               arr_dtype=arr_dtype,
                                               )
-        if index_iter.getitem_bool():
-            res_iter.setitem(arr_iter.getitem())
-            res_iter.next()
-        index_iter.next()
-        arr_iter.next()
+        if index_iter.getitem_bool(index_state):
+            res_iter.setitem(res_state, arr_iter.getitem(arr_state))
+            res_state = res_iter.next(res_state)
+        index_state = index_iter.next(index_state)
+        arr_state = arr_iter.next(arr_state)
     return res
 
 setitem_filter_driver = jit.JitDriver(name = 'numpy_setitem_bool',
@@ -393,41 +401,42 @@
                                       reds = 'auto')
 
 def setitem_filter(space, arr, index, value):
-    arr_iter = arr.create_iter()
+    arr_iter, arr_state = arr.create_iter()
     shapelen = len(arr.get_shape())
     if shapelen > 1 and len(index.get_shape()) < 2:
-        index_iter = index.create_iter(arr.get_shape(), backward_broadcast=True)
+        index_iter, index_state = index.create_iter(arr.get_shape(), backward_broadcast=True)
     else:
-        index_iter = index.create_iter()
+        index_iter, index_state = index.create_iter()
     if value.get_size() == 1:
-        value_iter = value.create_iter(arr.get_shape())
+        value_iter, value_state = value.create_iter(arr.get_shape())
     else:
-        value_iter = value.create_iter()
+        value_iter, value_state = value.create_iter()
     index_dtype = index.get_dtype()
     arr_dtype = arr.get_dtype()
-    while not index_iter.done():
+    while not index_iter.done(index_state):
         setitem_filter_driver.jit_merge_point(shapelen=shapelen,
                                               index_dtype=index_dtype,
                                               arr_dtype=arr_dtype,
                                              )
-        if index_iter.getitem_bool():
-            arr_iter.setitem(arr_dtype.coerce(space, value_iter.getitem()))
-            value_iter.next()
-        arr_iter.next()
-        index_iter.next()
+        if index_iter.getitem_bool(index_state):
+            val = arr_dtype.coerce(space, value_iter.getitem(value_state))
+            value_state = value_iter.next(value_state)
+            arr_iter.setitem(arr_state, val)
+        arr_state = arr_iter.next(arr_state)
+        index_state = index_iter.next(index_state)
 
 flatiter_getitem_driver = jit.JitDriver(name = 'numpy_flatiter_getitem',
                                         greens = ['dtype'],
                                         reds = 'auto')
 
-def flatiter_getitem(res, base_iter, step):
-    ri = res.create_iter()
+def flatiter_getitem(res, base_iter, base_state, step):
+    ri, rs = res.create_iter()
     dtype = res.get_dtype()
-    while not ri.done():
+    while not ri.done(rs):
         flatiter_getitem_driver.jit_merge_point(dtype=dtype)
-        ri.setitem(base_iter.getitem())
-        base_iter.next_skip_x(step)
-        ri.next()
+        ri.setitem(rs, base_iter.getitem(base_state))
+        base_state = base_iter.next_skip_x(base_state, step)
+        rs = ri.next(rs)
     return res
 
 flatiter_setitem_driver = jit.JitDriver(name = 'numpy_flatiter_setitem',
@@ -436,19 +445,21 @@
 
 def flatiter_setitem(space, arr, val, start, step, length):
     dtype = arr.get_dtype()
-    arr_iter = arr.create_iter()
-    val_iter = val.create_iter()
-    arr_iter.next_skip_x(start)
+    arr_iter, arr_state = arr.create_iter()
+    val_iter, val_state = val.create_iter()
+    arr_state = arr_iter.next_skip_x(arr_state, start)
     while length > 0:
         flatiter_setitem_driver.jit_merge_point(dtype=dtype)
+        val = val_iter.getitem(val_state)
         if dtype.is_str_or_unicode():
-            arr_iter.setitem(dtype.coerce(space, val_iter.getitem()))
+            val = dtype.coerce(space, val)
         else:
-            arr_iter.setitem(val_iter.getitem().convert_to(space, dtype))
+            val = val.convert_to(space, dtype)
+        arr_iter.setitem(arr_state, val)
         # need to repeat i_nput values until all assignments are done
-        arr_iter.next_skip_x(step)
+        arr_state = arr_iter.next_skip_x(arr_state, step)
+        val_state = val_iter.next(val_state)
         length -= 1
-        val_iter.next()
 
 fromstring_driver = jit.JitDriver(name = 'numpy_fromstring',
                                   greens = ['itemsize', 'dtype'],
@@ -456,30 +467,30 @@
 
 def fromstring_loop(space, a, dtype, itemsize, s):
     i = 0
-    ai = a.create_iter()
-    while not ai.done():
+    ai, state = a.create_iter()
+    while not ai.done(state):
         fromstring_driver.jit_merge_point(dtype=dtype, itemsize=itemsize)
         sub = s[i*itemsize:i*itemsize + itemsize]
         if dtype.is_str_or_unicode():
             val = dtype.coerce(space, space.wrap(sub))
         else:
             val = dtype.itemtype.runpack_str(space, sub)
-        ai.setitem(val)
-        ai.next()
+        ai.setitem(state, val)
+        state = ai.next(state)
         i += 1
 
 def tostring(space, arr):
     builder = StringBuilder()
-    iter = arr.create_iter()
+    iter, state = arr.create_iter()
     w_res_str = W_NDimArray.from_shape(space, [1], arr.get_dtype(), order='C')
     itemsize = arr.get_dtype().elsize
     res_str_casted = rffi.cast(rffi.CArrayPtr(lltype.Char),
                                w_res_str.implementation.get_storage_as_int(space))
-    while not iter.done():
-        w_res_str.implementation.setitem(0, iter.getitem())
+    while not iter.done(state):
+        w_res_str.implementation.setitem(0, iter.getitem(state))
         for i in range(itemsize):
             builder.append(res_str_casted[i])
-        iter.next()
+        state = iter.next(state)
     return builder.build()
 
 getitem_int_driver = jit.JitDriver(name = 'numpy_getitem_int',
@@ -500,8 +511,8 @@
         # prepare the index
         index_w = [None] * indexlen
         for i in range(indexlen):
-            if iter.idx_w[i] is not None:
-                index_w[i] = iter.idx_w[i].getitem()
+            if iter.idx_w_i[i] is not None:
+                index_w[i] = iter.idx_w_i[i].getitem(iter.idx_w_s[i])
             else:
                 index_w[i] = indexes_w[i]
         res.descr_setitem(space, space.newtuple(prefix_w[:prefixlen] +
@@ -528,8 +539,8 @@
         # prepare the index
         index_w = [None] * indexlen
         for i in range(indexlen):
-            if iter.idx_w[i] is not None:
-                index_w[i] = iter.idx_w[i].getitem()
+            if iter.idx_w_i[i] is not None:
+                index_w[i] = iter.idx_w_i[i].getitem(iter.idx_w_s[i])
             else:
                 index_w[i] = indexes_w[i]
         w_idx = space.newtuple(prefix_w[:prefixlen] + iter.get_index(space,
@@ -547,13 +558,14 @@
 
 def byteswap(from_, to):
     dtype = from_.dtype
-    from_iter = from_.create_iter()
-    to_iter = to.create_iter()
-    while not from_iter.done():
+    from_iter, from_state = from_.create_iter()
+    to_iter, to_state = to.create_iter()
+    while not from_iter.done(from_state):
         byteswap_driver.jit_merge_point(dtype=dtype)
-        to_iter.setitem(dtype.itemtype.byteswap(from_iter.getitem()))
-        to_iter.next()
-        from_iter.next()
+        val = dtype.itemtype.byteswap(from_iter.getitem(from_state))
+        to_iter.setitem(to_state, val)
+        to_state = to_iter.next(to_state)
+        from_state = from_iter.next(from_state)
 
 choose_driver = jit.JitDriver(name='numpy_choose_driver',
                               greens = ['shapelen', 'mode', 'dtype'],
@@ -561,13 +573,15 @@
 
 def choose(space, arr, choices, shape, dtype, out, mode):
     shapelen = len(shape)
-    iterators = [a.create_iter(shape) for a in choices]
-    arr_iter = arr.create_iter(shape)
-    out_iter = out.create_iter(shape)
-    while not arr_iter.done():
+    pairs = [a.create_iter(shape) for a in choices]
+    iterators = [i[0] for i in pairs]
+    states = [i[1] for i in pairs]
+    arr_iter, arr_state = arr.create_iter(shape)
+    out_iter, out_state = out.create_iter(shape)
+    while not arr_iter.done(arr_state):
         choose_driver.jit_merge_point(shapelen=shapelen, dtype=dtype,
                                       mode=mode)
-        index = support.index_w(space, arr_iter.getitem())
+        index = support.index_w(space, arr_iter.getitem(arr_state))
         if index < 0 or index >= len(iterators):
             if mode == NPY.RAISE:
                 raise OperationError(space.w_ValueError, space.wrap(
@@ -580,72 +594,73 @@
                     index = 0
                 else:
                     index = len(iterators) - 1
-        out_iter.setitem(iterators[index].getitem().convert_to(space, dtype))
-        for iter in iterators:
-            iter.next()
-        out_iter.next()
-        arr_iter.next()
+        val = iterators[index].getitem(states[index]).convert_to(space, dtype)
+        out_iter.setitem(out_state, val)
+        for i in range(len(iterators)):
+            states[i] = iterators[i].next(states[i])
+        out_state = out_iter.next(out_state)
+        arr_state = arr_iter.next(arr_state)
 
 clip_driver = jit.JitDriver(name='numpy_clip_driver',
                             greens = ['shapelen', 'dtype'],
                             reds = 'auto')
 
 def clip(space, arr, shape, min, max, out):
-    arr_iter = arr.create_iter(shape)
+    arr_iter, arr_state = arr.create_iter(shape)
     dtype = out.get_dtype()
     shapelen = len(shape)
-    min_iter = min.create_iter(shape)
-    max_iter = max.create_iter(shape)
-    out_iter = out.create_iter(shape)
-    while not arr_iter.done():
+    min_iter, min_state = min.create_iter(shape)
+    max_iter, max_state = max.create_iter(shape)
+    out_iter, out_state = out.create_iter(shape)
+    while not arr_iter.done(arr_state):
         clip_driver.jit_merge_point(shapelen=shapelen, dtype=dtype)
-        w_v = arr_iter.getitem().convert_to(space, dtype)
-        w_min = min_iter.getitem().convert_to(space, dtype)
-        w_max = max_iter.getitem().convert_to(space, dtype)
+        w_v = arr_iter.getitem(arr_state).convert_to(space, dtype)
+        w_min = min_iter.getitem(min_state).convert_to(space, dtype)
+        w_max = max_iter.getitem(max_state).convert_to(space, dtype)
         if dtype.itemtype.lt(w_v, w_min):
             w_v = w_min
         elif dtype.itemtype.gt(w_v, w_max):
             w_v = w_max
-        out_iter.setitem(w_v)
-        arr_iter.next()
-        max_iter.next()
-        out_iter.next()
-        min_iter.next()
+        out_iter.setitem(out_state, w_v)
+        arr_state = arr_iter.next(arr_state)
+        min_state = min_iter.next(min_state)
+        max_state = max_iter.next(max_state)
+        out_state = out_iter.next(out_state)
 
 round_driver = jit.JitDriver(name='numpy_round_driver',
                              greens = ['shapelen', 'dtype'],
                              reds = 'auto')
 
 def round(space, arr, dtype, shape, decimals, out):
-    arr_iter = arr.create_iter(shape)
+    arr_iter, arr_state = arr.create_iter(shape)
+    out_iter, out_state = out.create_iter(shape)
     shapelen = len(shape)
-    out_iter = out.create_iter(shape)
-    while not arr_iter.done():
+    while not arr_iter.done(arr_state):
         round_driver.jit_merge_point(shapelen=shapelen, dtype=dtype)
-        w_v = arr_iter.getitem().convert_to(space, dtype)
+        w_v = arr_iter.getitem(arr_state).convert_to(space, dtype)
         w_v = dtype.itemtype.round(w_v, decimals)
-        out_iter.setitem(w_v)
-        arr_iter.next()
-        out_iter.next()
+        out_iter.setitem(out_state, w_v)
+        arr_state = arr_iter.next(arr_state)
+        out_state = out_iter.next(out_state)
 
 diagonal_simple_driver = jit.JitDriver(name='numpy_diagonal_simple_driver',
                                        greens = ['axis1', 'axis2'],
                                        reds = 'auto')
 
 def diagonal_simple(space, arr, out, offset, axis1, axis2, size):
-    out_iter = out.create_iter()
+    out_iter, out_state = out.create_iter()
     i = 0
     index = [0] * 2
     while i < size:
         diagonal_simple_driver.jit_merge_point(axis1=axis1, axis2=axis2)
         index[axis1] = i
         index[axis2] = i + offset
-        out_iter.setitem(arr.getitem_index(space, index))
+        out_iter.setitem(out_state, arr.getitem_index(space, index))
         i += 1
-        out_iter.next()
+        out_state = out_iter.next(out_state)
 
 def diagonal_array(space, arr, out, offset, axis1, axis2, shape):
-    out_iter = out.create_iter()
+    out_iter, out_state = out.create_iter()
     iter = PureShapeIter(shape, [])
     shapelen_minus_1 = len(shape) - 1
     assert shapelen_minus_1 >= 0
@@ -667,6 +682,6 @@
             indexes = (iter.indexes[:a] + [last_index + offset] +
                        iter.indexes[a:b] + [last_index] +
                        iter.indexes[b:shapelen_minus_1])
-        out_iter.setitem(arr.getitem_index(space, indexes))
+        out_iter.setitem(out_state, arr.getitem_index(space, indexes))
         iter.next()
-        out_iter.next()
+        out_state = out_iter.next(out_state)
diff --git a/pypy/module/micronumpy/ndarray.py b/pypy/module/micronumpy/ndarray.py
--- a/pypy/module/micronumpy/ndarray.py
+++ b/pypy/module/micronumpy/ndarray.py
@@ -260,24 +260,24 @@
         return space.call_function(cache.w_array_str, self)
 
     def dump_data(self, prefix='array(', separator=',', suffix=')'):
-        i = self.create_iter()
+        i, state = self.create_iter()
         first = True
         dtype = self.get_dtype()
         s = StringBuilder()
         s.append(prefix)
         if not self.is_scalar():
             s.append('[')
-        while not i.done():
+        while not i.done(state):
             if first:
                 first = False
             else:
                 s.append(separator)
                 s.append(' ')
             if self.is_scalar() and dtype.is_str():
-                s.append(dtype.itemtype.to_str(i.getitem()))
+                s.append(dtype.itemtype.to_str(i.getitem(state)))
             else:
-                s.append(dtype.itemtype.str_format(i.getitem()))
-            i.next()
+                s.append(dtype.itemtype.str_format(i.getitem(state)))
+            state = i.next(state)
         if not self.is_scalar():
             s.append(']')
         s.append(suffix)
@@ -818,8 +818,8 @@
         if self.get_size() > 1:
             raise OperationError(space.w_ValueError, space.wrap(
                 "The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()"))
-        iter = self.create_iter()
-        return space.wrap(space.is_true(iter.getitem()))
+        iter, state = self.create_iter()
+        return space.wrap(space.is_true(iter.getitem(state)))
 
     def _binop_impl(ufunc_name):
         def impl(self, space, w_other, w_out=None):
@@ -1095,11 +1095,11 @@
 
         builder = StringBuilder()
         if isinstance(self.implementation, SliceArray):
-            iter = self.implementation.create_iter()
-            while not iter.done():
-                box = iter.getitem()
+            iter, state = self.implementation.create_iter()
+            while not iter.done(state):
+                box = iter.getitem(state)
                 builder.append(box.raw_str())
-                iter.next()
+                state = iter.next(state)
         else:
             builder.append_charpsize(self.implementation.get_storage(), self.implementation.get_storage_size())
 
diff --git a/pypy/module/micronumpy/sort.py b/pypy/module/micronumpy/sort.py
--- a/pypy/module/micronumpy/sort.py
+++ b/pypy/module/micronumpy/sort.py
@@ -148,20 +148,22 @@
             if axis < 0 or axis >= len(shape):
                 raise oefmt(space.w_IndexError, "Wrong axis %d", axis)
             arr_iter = AllButAxisIter(arr, axis)
+            arr_state = arr_iter.reset()
             index_impl = index_arr.implementation
             index_iter = AllButAxisIter(index_impl, axis)
+            index_state = index_iter.reset()
             stride_size = arr.strides[axis]
             index_stride_size = index_impl.strides[axis]
             axis_size = arr.shape[axis]
-            while not arr_iter.done():
+            while not arr_iter.done(arr_state):
                 for i in range(axis_size):
                     raw_storage_setitem(storage, i * index_stride_size +
-                                        index_iter.offset, i)
+                                        index_state.offset, i)
                 r = Repr(index_stride_size, stride_size, axis_size,
-                         arr.get_storage(), storage, index_iter.offset, arr_iter.offset)
+                         arr.get_storage(), storage, index_state.offset, arr_state.offset)
                 ArgSort(r).sort()
-                arr_iter.next()
-                index_iter.next()
+                arr_state = arr_iter.next(arr_state)
+                index_state = index_iter.next(index_state)
         return index_arr
 
     return argsort
@@ -292,12 +294,13 @@
             if axis < 0 or axis >= len(shape):
                 raise oefmt(space.w_IndexError, "Wrong axis %d", axis)
             arr_iter = AllButAxisIter(arr, axis)
+            arr_state = arr_iter.reset()
             stride_size = arr.strides[axis]
             axis_size = arr.shape[axis]
-            while not arr_iter.done():
-                r = Repr(stride_size, axis_size, arr.get_storage(), arr_iter.offset)
+            while not arr_iter.done(arr_state):
+                r = Repr(stride_size, axis_size, arr.get_storage(), arr_state.offset)
                 ArgSort(r).sort()
-                arr_iter.next()
+                arr_state = arr_iter.next(arr_state)
 
     return sort
 
diff --git a/pypy/module/micronumpy/test/test_zjit.py b/pypy/module/micronumpy/test/test_zjit.py
--- a/pypy/module/micronumpy/test/test_zjit.py
+++ b/pypy/module/micronumpy/test/test_zjit.py
@@ -47,7 +47,8 @@
                 raise Exception("need results")
             w_res = interp.results[-1]
             if isinstance(w_res, W_NDimArray):
-                w_res = w_res.create_iter().getitem()
+                i, s = w_res.create_iter()
+                w_res = i.getitem(s)
             if isinstance(w_res, boxes.W_Float64Box):
                 return w_res.value
             if isinstance(w_res, boxes.W_Int64Box):


More information about the pypy-commit mailing list