[pypy-commit] pypy ufunc-reduce: Refactor do_axis_reduce() so that it takes the axis_flags list instead of a single axis

rlamy noreply at buildbot.pypy.org
Tue Jul 28 19:37:18 CEST 2015


Author: Ronan Lamy <ronan.lamy at gmail.com>
Branch: ufunc-reduce
Changeset: r78697:77850d7d684c
Date: 2015-07-28 18:37 +0100
http://bitbucket.org/pypy/pypy/changeset/77850d7d684c/

Log:	Refactor do_axis_reduce() so that it takes the axis_flags list
	instead of a single axis

diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py
--- a/pypy/module/micronumpy/loop.py
+++ b/pypy/module/micronumpy/loop.py
@@ -339,30 +339,60 @@
                                    greens=['shapelen', 'func', 'dtype'],
                                    reds='auto')
 
-def do_axis_reduce(space, func, arr, dtype, axis, out, identity):
-    out_iter = AxisIter(out.implementation, arr.get_shape(), axis)
-    out_state = out_iter.reset()
-    arr_iter, arr_state = arr.create_iter()
-    arr_iter.track_index = False
+def do_axis_reduce(space, func, arr, dtype, axis_flags, out, identity):
+    out_iter, out_state = out.create_iter()
+    out_iter.track_index = False
+    shape = arr.get_shape()
+    strides = arr.implementation.get_strides()
+    backstrides = arr.implementation.get_backstrides()
+    shapelen = len(shape)
+    inner_shape = [-1] * shapelen
+    inner_strides = [-1] * shapelen
+    inner_backstrides = [-1] * shapelen
+    outer_shape = [-1] * shapelen
+    outer_strides = [-1] * shapelen
+    outer_backstrides = [-1] * shapelen
+    for i in range(len(shape)):
+        if axis_flags[i]:
+            inner_shape[i] = shape[i]
+            inner_strides[i] = strides[i]
+            inner_backstrides[i] = backstrides[i]
+            outer_shape[i] = 1
+            outer_strides[i] = 0
+            outer_backstrides[i] = 0
+        else:
+            outer_shape[i] = shape[i]
+            outer_strides[i] = strides[i]
+            outer_backstrides[i] = backstrides[i]
+            inner_shape[i] = 1
+            inner_strides[i] = 0
+            inner_backstrides[i] = 0
+    inner_iter = ArrayIter(arr.implementation, support.product(inner_shape),
+                           inner_shape, inner_strides, inner_backstrides)
+    outer_iter = ArrayIter(arr.implementation, support.product(outer_shape),
+                           outer_shape, outer_strides, outer_backstrides)
+    assert outer_iter.size == out_iter.size
+
     if identity is not None:
         identity = identity.convert_to(space, dtype)
-    shapelen = len(out.get_shape())
-    while not out_iter.done(out_state):
-        axis_reduce_driver.jit_merge_point(shapelen=shapelen, func=func,
-                                           dtype=dtype)
-        w_val = arr_iter.getitem(arr_state).convert_to(space, dtype)
-        arr_state = arr_iter.next(arr_state)
-
-        out_indices = out_iter.indices(out_state)
-        if out_indices[axis] == 0:
-            if identity is not None:
-                w_val = func(dtype, identity, w_val)
+    outer_state = outer_iter.reset()
+    while not outer_iter.done(outer_state):
+        inner_state = inner_iter.reset()
+        inner_state.offset = outer_state.offset
+        if identity is not None:
+            w_val = identity
         else:
-            cur = out_iter.getitem(out_state)
-            w_val = func(dtype, cur, w_val)
-
+            w_val = inner_iter.getitem(inner_state).convert_to(space, dtype)
+            inner_state = inner_iter.next(inner_state)
+        while not inner_iter.done(inner_state):
+            axis_reduce_driver.jit_merge_point(shapelen=shapelen, func=func,
+                                            dtype=dtype)
+            w_item = inner_iter.getitem(inner_state).convert_to(space, dtype)
+            w_val = func(dtype, w_item, w_val)
+            inner_state = inner_iter.next(inner_state)
         out_iter.setitem(out_state, w_val)
         out_state = out_iter.next(out_state)
+        outer_state = outer_iter.next(outer_state)
     return out
 
 
diff --git a/pypy/module/micronumpy/ufuncs.py b/pypy/module/micronumpy/ufuncs.py
--- a/pypy/module/micronumpy/ufuncs.py
+++ b/pypy/module/micronumpy/ufuncs.py
@@ -410,7 +410,7 @@
                 if self.identity is not None:
                     out.fill(space, self.identity.convert_to(space, dtype))
                 return out
-            loop.do_axis_reduce(space, self.func, obj, dtype, axis,
+            loop.do_axis_reduce(space, self.func, obj, dtype, axis_flags,
                                 out, self.identity)
             if call__array_wrap__:
                 out = space.call_method(obj, '__array_wrap__', out)


More information about the pypy-commit mailing list