[pypy-commit] pypy numpy-speed: create iter state object to help jit in loops

bdkearns noreply at buildbot.pypy.org
Fri Apr 18 00:09:07 CEST 2014


Author: Brian Kearns <bdkearns at gmail.com>
Branch: numpy-speed
Changeset: r70723:04699e4b9dd0
Date: 2014-04-17 17:40 -0400
http://bitbucket.org/pypy/pypy/changeset/04699e4b9dd0/

Log:	create iter state object to help jit in loops

diff --git a/pypy/module/micronumpy/iterators.py b/pypy/module/micronumpy/iterators.py
--- a/pypy/module/micronumpy/iterators.py
+++ b/pypy/module/micronumpy/iterators.py
@@ -78,6 +78,13 @@
         return [space.wrap(self.indexes[i]) for i in range(shapelen)]
 
 
+class IterState(object):
+    def __init__(self, index, indices, offset):
+        self.index = index
+        self.indices = indices
+        self.offset = offset
+
+
 class ArrayIter(object):
     _immutable_fields_ = ['array', 'size', 'ndim_m1', 'shape_m1[*]',
                           'strides[*]', 'backstrides[*]']
@@ -91,61 +98,59 @@
         self.strides = strides
         self.backstrides = backstrides
 
-        self.index = 0
-        self.indices = [0] * len(shape)
-        self.offset = array.start
+    def reset(self):
+        return IterState(0, [0] * len(self.shape_m1), self.array.start)
 
     @jit.unroll_safe
-    def reset(self):
-        self.index = 0
+    def next(self, state):
+        index = state.index + 1
+        indices = state.indices
+        offset = state.offset
         for i in xrange(self.ndim_m1, -1, -1):
-            self.indices[i] = 0
-        self.offset = self.array.start
+            idx = indices[i]
+            if idx < self.shape_m1[i]:
+                indices[i] = idx + 1
+                offset += self.strides[i]
+                break
+            else:
+                indices[i] = 0
+                offset -= self.backstrides[i]
+        return IterState(index, indices, offset)
 
     @jit.unroll_safe
-    def next(self):
-        self.index += 1
+    def next_skip_x(self, state, step):
+        assert step >= 0
+        if step == 0:
+            return state
+        index = state.index + step
+        indices = state.indices
+        offset = state.offset
         for i in xrange(self.ndim_m1, -1, -1):
-            idx = self.indices[i]
-            if idx < self.shape_m1[i]:
-                self.indices[i] = idx + 1
-                self.offset += self.strides[i]
+            idx = indices[i]
+            if idx < (self.shape_m1[i] + 1) - step:
+                indices[i] = idx + step
+                offset += self.strides[i] * step
                 break
             else:
-                self.indices[i] = 0
-                self.offset -= self.backstrides[i]
-
-    @jit.unroll_safe
-    def next_skip_x(self, step):
-        assert step >= 0
-        if step == 0:
-            return
-        self.index += step
-        for i in xrange(self.ndim_m1, -1, -1):
-            idx = self.indices[i]
-            if idx < (self.shape_m1[i] + 1) - step:
-                self.indices[i] = idx + step
-                self.offset += self.strides[i] * step
-                break
-            else:
-                rem_step = (self.indices[i] + step) // (self.shape_m1[i] + 1)
+                rem_step = (idx + step) // (self.shape_m1[i] + 1)
                 cur_step = step - rem_step * (self.shape_m1[i] + 1)
-                self.indices[i] += cur_step
-                self.offset += self.strides[i] * cur_step
+                indices[i] = idx + cur_step
+                offset += self.strides[i] * cur_step
                 step = rem_step
                 assert step > 0
+        return IterState(index, indices, offset)
 
-    def done(self):
-        return self.index >= self.size
+    def done(self, state):
+        return state.index >= self.size
 
-    def getitem(self):
-        return self.array.getitem(self.offset)
+    def getitem(self, state):
+        return self.array.getitem(state.offset)
 
-    def getitem_bool(self):
-        return self.array.getitem_bool(self.offset)
+    def getitem_bool(self, state):
+        return self.array.getitem_bool(state.offset)
 
-    def setitem(self, elem):
-        self.array.setitem(self.offset, elem)
+    def setitem(self, state, elem):
+        self.array.setitem(state.offset, elem)
 
 
 def AxisIter(array, shape, axis, cumulative):
diff --git a/pypy/module/micronumpy/test/test_iterators.py b/pypy/module/micronumpy/test/test_iterators.py
--- a/pypy/module/micronumpy/test/test_iterators.py
+++ b/pypy/module/micronumpy/test/test_iterators.py
@@ -16,17 +16,18 @@
         assert backstrides == [10, 4]
         i = ArrayIter(MockArray, support.product(shape), shape,
                       strides, backstrides)
-        i.next()
-        i.next()
-        i.next()
-        assert i.offset == 3
-        assert not i.done()
-        assert i.indices == [0,3]
+        s = i.reset()
+        s = i.next(s)
+        s = i.next(s)
+        s = i.next(s)
+        assert s.offset == 3
+        assert not i.done(s)
+        assert s.indices == [0,3]
         #cause a dimension overflow
-        i.next()
-        i.next()
-        assert i.offset == 5
-        assert i.indices == [1,0]
+        s = i.next(s)
+        s = i.next(s)
+        assert s.offset == 5
+        assert s.indices == [1,0]
 
         #Now what happens if the array is transposed? strides[-1] != 1
         # therefore layout is non-contiguous
@@ -35,17 +36,18 @@
         assert backstrides == [2, 12]
         i = ArrayIter(MockArray, support.product(shape), shape,
                       strides, backstrides)
-        i.next()
-        i.next()
-        i.next()
-        assert i.offset == 9
-        assert not i.done()
-        assert i.indices == [0,3]
+        s = i.reset()
+        s = i.next(s)
+        s = i.next(s)
+        s = i.next(s)
+        assert s.offset == 9
+        assert not i.done(s)
+        assert s.indices == [0,3]
         #cause a dimension overflow
-        i.next()
-        i.next()
-        assert i.offset == 1
-        assert i.indices == [1,0]
+        s = i.next(s)
+        s = i.next(s)
+        assert s.offset == 1
+        assert s.indices == [1,0]
 
     def test_iterator_step(self):
         #iteration in C order with #contiguous layout => strides[-1] is 1
@@ -56,22 +58,23 @@
         assert backstrides == [10, 4]
         i = ArrayIter(MockArray, support.product(shape), shape,
                       strides, backstrides)
-        i.next_skip_x(2)
-        i.next_skip_x(2)
-        i.next_skip_x(2)
-        assert i.offset == 6
-        assert not i.done()
-        assert i.indices == [1,1]
+        s = i.reset()
+        s = i.next_skip_x(s, 2)
+        s = i.next_skip_x(s, 2)
+        s = i.next_skip_x(s, 2)
+        assert s.offset == 6
+        assert not i.done(s)
+        assert s.indices == [1,1]
         #And for some big skips
-        i.next_skip_x(5)
-        assert i.offset == 11
-        assert i.indices == [2,1]
-        i.next_skip_x(5)
+        s = i.next_skip_x(s, 5)
+        assert s.offset == 11
+        assert s.indices == [2,1]
+        s = i.next_skip_x(s, 5)
         # Note: the offset does not overflow but recycles,
         # this is good for broadcast
-        assert i.offset == 1
-        assert i.indices == [0,1]
-        assert i.done()
+        assert s.offset == 1
+        assert s.indices == [0,1]
+        assert i.done(s)
 
         #Now what happens if the array is transposed? strides[-1] != 1
         # therefore layout is non-contiguous
@@ -80,17 +83,18 @@
         assert backstrides == [2, 12]
         i = ArrayIter(MockArray, support.product(shape), shape,
                       strides, backstrides)
-        i.next_skip_x(2)
-        i.next_skip_x(2)
-        i.next_skip_x(2)
-        assert i.offset == 4
-        assert i.indices == [1,1]
-        assert not i.done()
-        i.next_skip_x(5)
-        assert i.offset == 5
-        assert i.indices == [2,1]
-        assert not i.done()
-        i.next_skip_x(5)
-        assert i.indices == [0,1]
-        assert i.offset == 3
-        assert i.done()
+        s = i.reset()
+        s = i.next_skip_x(s, 2)
+        s = i.next_skip_x(s, 2)
+        s = i.next_skip_x(s, 2)
+        assert s.offset == 4
+        assert s.indices == [1,1]
+        assert not i.done(s)
+        s = i.next_skip_x(s, 5)
+        assert s.offset == 5
+        assert s.indices == [2,1]
+        assert not i.done(s)
+        s = i.next_skip_x(s, 5)
+        assert s.indices == [0,1]
+        assert s.offset == 3
+        assert i.done(s)


More information about the pypy-commit mailing list