[pypy-commit] pypy numpy-share-iterators: in-progress work on sharing iterators. not really working

fijal noreply at buildbot.pypy.org
Sat Dec 3 07:55:44 CET 2011


Author: Maciej Fijalkowski <fijall at gmail.com>
Branch: numpy-share-iterators
Changeset: r50082:02ca7995bf12
Date: 2011-12-03 08:52 +0200
http://bitbucket.org/pypy/pypy/changeset/02ca7995bf12/

Log:	in-progress work on sharing iterators. not really working

diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -133,7 +133,7 @@
     )
     arr = NDimArray(size, shape[:], dtype=dtype, order=order)
     shapelen = len(shape)
-    arr_iter = arr.start_iter(arr.shape)
+    arr_iter = arr.start_iter([])
     for i in range(len(elems_w)):
         w_elem = elems_w[i]
         dtype.setitem_w(space, arr.storage, arr_iter.offset, w_elem)
@@ -153,6 +153,8 @@
 # in the original array, strides[i] == backstrides[i] == 0
 
 class BaseIterator(object):
+    _next = None
+    
     def next(self, shapelen):
         raise NotImplementedError
 
@@ -162,6 +164,19 @@
     def get_offset(self):
         raise NotImplementedError
 
+    def unique(self, all_iters):
+        for iter in all_iters:
+            if iter.compatible(self):
+                return ChildIterator(iter)
+        all_iters.append(self)
+        return self
+
+    def compatible(self, other):
+        return False
+
+    def clean_next(self):
+        self._next = None
+
 class ArrayIterator(BaseIterator):
     def __init__(self, size):
         self.offset = 0
@@ -171,6 +186,7 @@
         arr = instantiate(ArrayIterator)
         arr.size = self.size
         arr.offset = self.offset + 1
+        self._next = arr
         return arr
 
     def done(self):
@@ -179,6 +195,9 @@
     def get_offset(self):
         return self.offset
 
+    def compatible(self, other):
+        return isinstance(other, ArrayIterator) # there can be only one
+
 class OneDimIterator(BaseIterator):
     def __init__(self, start, step, stop):
         self.offset = start
@@ -190,6 +209,7 @@
         arr.size = self.size
         arr.step = self.step
         arr.offset = self.offset + self.step
+        self._next = arr
         return arr
 
     def done(self):
@@ -227,6 +247,7 @@
         res.indices = indices
         res.arr = self.arr
         res._done = done
+        self._next = res
         return res
 
     def done(self):
@@ -282,6 +303,7 @@
         res.strides = self.strides
         res.backstrides = self.backstrides
         res.res_shape = self.res_shape
+        self._next = res
         return res
 
     def done(self):
@@ -309,6 +331,10 @@
             return self.right.get_offset()
         return self.left.get_offset()
 
+    def clean_next(self):
+        self.left.clean_next()
+        self.right.clean_next()
+
 class Call1Iterator(BaseIterator):
     def __init__(self, child):
         self.child = child
@@ -322,6 +348,9 @@
     def get_offset(self):
         return self.child.get_offset()
 
+    def clean_next(self):
+        self.child.clean_next()
+
 class ConstantIterator(BaseIterator):
     def next(self, shapelen):
         return self
@@ -332,6 +361,20 @@
     def get_offset(self):
         return 0
 
+class ChildIterator(BaseIterator):
+    """ An iterator that just refers to some other iterator
+    """
+    def __init__(self, parent):
+        self.parent = parent
+
+    def next(self, shapelen):
+        return ChildIterator(self.parent._next)
+
+    def done(self):
+        return self.parent.done()
+
+    def get_offset(self):
+        return self.parent.get_offset()
 
 class BaseArray(Wrappable):
     _attrs_ = ["invalidates", "signature", "shape", "strides", "backstrides",
@@ -438,7 +481,7 @@
             reds=['result', 'idx', 'i', 'self', 'cur_best', 'dtype']
         )
         def loop(self):
-            i = self.start_iter()
+            i = self.start_iter([])
             cur_best = self.eval(i)
             shapelen = len(self.shape)
             i = i.next(shapelen)
@@ -469,7 +512,7 @@
 
     def _all(self):
         dtype = self.find_dtype()
-        i = self.start_iter()
+        i = self.start_iter([])
         shapelen = len(self.shape)
         while not i.done():
             all_driver.jit_merge_point(signature=self.signature,
@@ -484,7 +527,7 @@
 
     def _any(self):
         dtype = self.find_dtype()
-        i = self.start_iter()
+        i = self.start_iter([])
         shapelen = len(self.shape)
         while not i.done():
             any_driver.jit_merge_point(signature=self.signature,
@@ -778,7 +821,7 @@
             raise OperationError(space.w_ValueError, space.wrap(
                 "The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()"))
         return space.wrap(space.is_true(self.get_concrete().eval(
-            self.start_iter(self.shape)).wrap(space)))
+            self.start_iter([])).wrap(space)))
 
     def descr_get_transpose(self, space):
         concrete = self.get_concrete()
@@ -803,7 +846,7 @@
     def getitem(self, item):
         raise NotImplementedError
 
-    def start_iter(self, res_shape=None):
+    def start_iter(self, all_iters, res_shape=None):
         raise NotImplementedError
 
     def descr_debug_repr(self, space):
@@ -854,8 +897,8 @@
     def eval(self, iter):
         return self.value
 
-    def start_iter(self, res_shape=None):
-        return ConstantIterator()
+    def start_iter(self, all_iters, res_shape=None):
+        return ConstantIterator().unique(all_iters)
 
     def to_str(self, space, comma, builder, indent=' ', use_ellipsis=False):
         builder.append(self.dtype.str_format(self.value))
@@ -886,16 +929,19 @@
         result_size = self.find_size()
         result = NDimArray(result_size, self.shape, self.find_dtype())
         shapelen = len(self.shape)
-        i = self.start_iter()
-        ri = result.start_iter()
+        all_iters = []
+        i = self.start_iter(all_iters)
+        ri = result.start_iter(all_iters)
         while not ri.done():
             numpy_driver.jit_merge_point(signature=signature,
                                          shapelen=shapelen,
                                          result_size=result_size, i=i, ri=ri,
                                          self=self, result=result)
-            result.dtype.setitem(result.storage, ri.offset, self.eval(i))
+            result.dtype.setitem(result.storage, ri.get_offset(), self.eval(i))
             i = i.next(shapelen)
             ri = ri.next(shapelen)
+            i.clean_next()
+            ri.clean_next()
         return result
 
     def force_if_needed(self):
@@ -952,10 +998,10 @@
         assert isinstance(call_sig, signature.Call1)
         return call_sig.func(self.res_dtype, val)
 
-    def start_iter(self, res_shape=None):
+    def start_iter(self, all_iters, res_shape=None):
         if self.forced_result is not None:
-            return self.forced_result.start_iter(res_shape)
-        return Call1Iterator(self.values.start_iter(res_shape))
+            return self.forced_result.start_iter(all_iters, res_shape)
+        return Call1Iterator(self.values.start_iter(all_iters, res_shape))
 
     def debug_repr(self):
         sig = self.signature
@@ -989,13 +1035,13 @@
     def _find_size(self):
         return self.size
 
-    def start_iter(self, res_shape=None):
+    def start_iter(self, all_iters, res_shape=None):
         if self.forced_result is not None:
-            return self.forced_result.start_iter(res_shape)
+            return self.forced_result.start_iter(all_iters, res_shape)
         if res_shape is None:
             res_shape = self.shape  # we still force the shape on children
-        return Call2Iterator(self.left.start_iter(res_shape),
-                             self.right.start_iter(res_shape))
+        return Call2Iterator(self.left.start_iter(all_iters, res_shape),
+                             self.right.start_iter(all_iters, res_shape))
 
     def _eval(self, iter):
         assert isinstance(iter, Call2Iterator)
@@ -1083,8 +1129,9 @@
         self._sliceloop(w_value, res_shape)
 
     def _sliceloop(self, source, res_shape):
-        source_iter = source.start_iter(res_shape)
-        res_iter = self.start_iter(res_shape)
+        all_iters = []
+        source_iter = source.start_iter(all_iters, res_shape)
+        res_iter = self.start_iter(all_iters, res_shape)
         shapelen = len(res_shape)
         while not res_iter.done():
             slice_driver.jit_merge_point(signature=source.signature,
@@ -1097,11 +1144,11 @@
             source_iter = source_iter.next(shapelen)
             res_iter = res_iter.next(shapelen)
 
-    def start_iter(self, res_shape=None):
+    def start_iter(self, all_iters, res_shape=None):
         if res_shape is not None and res_shape != self.shape:
-            return BroadcastIterator(self, res_shape)
+            return BroadcastIterator(self, res_shape).unique(all_iters)
         if len(self.shape) == 1:
-            return OneDimIterator(self.start, self.strides[0], self.shape[0])
+            return OneDimIterator(self.start, self.strides[0], self.shape[0]).unique(all_iters)
         return ViewIterator(self)
 
     def setitem(self, item, value):
@@ -1112,7 +1159,7 @@
 
     def copy(self):
         array = NDimArray(self.size, self.shape[:], self.find_dtype())
-        iter = self.start_iter()
+        iter = self.start_iter([])
         while not iter.done():
             array.setitem(iter.offset, self.getitem(iter.offset))
             iter = iter.next(len(self.shape))
@@ -1167,11 +1214,11 @@
         self.invalidated()
         self.dtype.setitem(self.storage, item, value)
 
-    def start_iter(self, res_shape=None):
+    def start_iter(self, all_iters, res_shape=None):
         if self.order == 'C':
             if res_shape is not None and res_shape != self.shape:
-                return BroadcastIterator(self, res_shape)
-            return ArrayIterator(self.size)
+                return BroadcastIterator(self, res_shape).unique(all_iters)
+            return ArrayIterator(self.size).unique(all_iters)
         raise NotImplementedError  # use ViewIterator simply, test it
 
     def debug_repr(self):
@@ -1292,13 +1339,13 @@
                            [arr.backstrides[-1]], [size])
         self.shapelen = len(arr.shape)
         self.arr = arr
-        self.iter = self.start_iter()
+        self.iter = self.start_iter([])
 
-    def start_iter(self, res_shape=None):
+    def start_iter(self, all_iters, res_shape=None):
         if res_shape is not None and res_shape != self.shape:
-            return BroadcastIterator(self, res_shape)
+            return BroadcastIterator(self, res_shape).unique(all_iters)
         return OneDimIterator(self.arr.start, self.strides[0],
-                              self.shape[0])
+                              self.shape[0]).unique(all_iters)
 
     def find_dtype(self):
         return self.arr.find_dtype()
diff --git a/pypy/module/micronumpy/interp_ufuncs.py b/pypy/module/micronumpy/interp_ufuncs.py
--- a/pypy/module/micronumpy/interp_ufuncs.py
+++ b/pypy/module/micronumpy/interp_ufuncs.py
@@ -64,7 +64,7 @@
             space, obj.find_dtype(),
             promote_to_largest=True
         )
-        start = obj.start_iter(obj.shape)
+        start = obj.start_iter([])
         shapelen = len(obj.shape)
         if shapelen > 1 and not multidim:
             raise OperationError(space.w_NotImplementedError,
diff --git a/pypy/module/micronumpy/test/test_zjit.py b/pypy/module/micronumpy/test/test_zjit.py
--- a/pypy/module/micronumpy/test/test_zjit.py
+++ b/pypy/module/micronumpy/test/test_zjit.py
@@ -50,7 +50,7 @@
             interp.run(space)
             res = interp.results[-1]
             assert isinstance(res, BaseArray)
-            w_res = res.eval(res.start_iter()).wrap(interp.space)
+            w_res = res.eval(res.start_iter([])).wrap(interp.space)
             if isinstance(w_res, BoolObject):
                 return float(w_res.boolval)
             elif isinstance(w_res, FloatObject):


More information about the pypy-commit mailing list