[pypy-commit] pypy default: merge nditer-external_loop which implements numpy's nditer external_loop argument

mattip noreply at buildbot.pypy.org
Sat Nov 1 21:48:42 CET 2014


Author: mattip <matti.picus at gmail.com>
Branch: 
Changeset: r74323:55056cc539e0
Date: 2014-10-31 08:30 +0200
http://bitbucket.org/pypy/pypy/changeset/55056cc539e0/

Log:	merge nditer-external_loop which implements numpy's nditer
	external_loop argument

diff --git a/pypy/doc/whatsnew-head.rst b/pypy/doc/whatsnew-head.rst
--- a/pypy/doc/whatsnew-head.rst
+++ b/pypy/doc/whatsnew-head.rst
@@ -39,3 +39,7 @@
 .. branch: kill-multimethod
 
 Kill multimethod machinery, all multimethods were removed earlier.
+
+.. branch nditer-external_loop
+
+Implement `external_loop` arguement to numpy's nditer
diff --git a/pypy/module/micronumpy/concrete.py b/pypy/module/micronumpy/concrete.py
--- a/pypy/module/micronumpy/concrete.py
+++ b/pypy/module/micronumpy/concrete.py
@@ -449,7 +449,7 @@
                 strides.reverse()
                 backstrides.reverse()
                 new_shape.reverse()
-            return SliceArray(self.start, strides, backstrides, new_shape,
+            return self.__class__(self.start, strides, backstrides, new_shape,
                               self, orig_array)
         new_strides = calc_new_strides(new_shape, self.get_shape(),
                                        self.get_strides(),
@@ -460,10 +460,16 @@
         new_backstrides = [0] * len(new_shape)
         for nd in range(len(new_shape)):
             new_backstrides[nd] = (new_shape[nd] - 1) * new_strides[nd]
-        return SliceArray(self.start, new_strides, new_backstrides, new_shape,
+        return self.__class__(self.start, new_strides, new_backstrides, new_shape,
                           self, orig_array)
 
 
+class NonWritableSliceArray(SliceArray):
+    def descr_setitem(self, space, orig_array, w_index, w_value):
+        raise OperationError(space.w_ValueError, space.wrap(
+            "assignment destination is read-only"))
+
+
 class VoidBoxStorage(BaseConcreteArray):
     def __init__(self, size, dtype):
         self.storage = alloc_raw_storage(size)
diff --git a/pypy/module/micronumpy/iterators.py b/pypy/module/micronumpy/iterators.py
--- a/pypy/module/micronumpy/iterators.py
+++ b/pypy/module/micronumpy/iterators.py
@@ -8,8 +8,8 @@
 At which byte in x.data does the item x[3,4] begin?
 if x.strides==[1,5]:
     pData = x.pData + (x.start + 3*1 + 4*5)*sizeof(x.pData[0])
-    pData = x.pData + (x.start + 24) * sizeof(x.pData[0])
-so the offset of the element is 24 elements after the first
+    pData = x.pData + (x.start + 23) * sizeof(x.pData[0])
+so the offset of the element is 23 elements after the first
 
 What is the next element in x after coordinates [3,4]?
 if x.order =='C':
@@ -33,7 +33,7 @@
   which is x.strides[1] * (x.shape[1] - 1) + x.strides[0]
 so if we precalculate the overflow backstride as
 [x.strides[i] * (x.shape[i] - 1) for i in range(len(x.shape))]
-we can go faster.
+we can do only addition while iterating
 All the calculations happen in next()
 """
 from rpython.rlib import jit
@@ -41,6 +41,16 @@
 from pypy.module.micronumpy.base import W_NDimArray
 from pypy.module.micronumpy.flagsobj import _update_contiguous_flags
 
+class OpFlag(object):
+    def __init__(self):
+        self.rw = ''
+        self.broadcast = True
+        self.force_contig = False
+        self.force_align = False
+        self.native_byte_order = False
+        self.tmp_copy = ''
+        self.allocate = False
+
 
 class PureShapeIter(object):
     def __init__(self, shape, idx_w):
@@ -89,11 +99,13 @@
 class ArrayIter(object):
     _immutable_fields_ = ['contiguous', 'array', 'size', 'ndim_m1', 'shape_m1[*]',
                           'strides[*]', 'backstrides[*]', 'factors[*]',
-                          'track_index']
+                          'slice_shape', 'slice_stride', 'slice_backstride',
+                          'track_index', 'operand_type', 'slice_operand_type']
 
     track_index = True
 
-    def __init__(self, array, size, shape, strides, backstrides):
+    def __init__(self, array, size, shape, strides, backstrides, op_flags=OpFlag()):
+        from pypy.module.micronumpy import concrete
         assert len(shape) == len(strides) == len(backstrides)
         _update_contiguous_flags(array)
         self.contiguous = (array.flags & NPY.ARRAY_C_CONTIGUOUS and
@@ -105,6 +117,12 @@
         self.shape_m1 = [s - 1 for s in shape]
         self.strides = strides
         self.backstrides = backstrides
+        self.slice_shape = 1
+        self.slice_stride = -1
+        if strides:
+            self.slice_stride = strides[-1]
+        self.slice_backstride = 1
+        self.slice_operand_type = concrete.SliceArray
 
         ndim = len(shape)
         factors = [0] * ndim
@@ -114,6 +132,10 @@
             else:
                 factors[ndim-i-1] = factors[ndim-i] * shape[ndim-i]
         self.factors = factors
+        if op_flags.rw == 'r':
+            self.operand_type = concrete.ConcreteNonWritableArrayWithBase
+        else:
+            self.operand_type = concrete.ConcreteArrayWithBase
 
     @jit.unroll_safe
     def reset(self, state=None):
@@ -193,6 +215,12 @@
         assert state.iterator is self
         self.array.setitem(state.offset, elem)
 
+    def getoperand(self, st, base):
+        impl = self.operand_type
+        res = impl([], self.array.dtype, self.array.order, [], [],
+                   self.array.storage, base)
+        res.start = st.offset
+        return res
 
 def AxisIter(array, shape, axis, cumulative):
     strides = array.get_strides()
@@ -216,3 +244,42 @@
         size /= shape[axis]
     shape[axis] = backstrides[axis] = 0
     return ArrayIter(array, size, shape, array.strides, backstrides)
+
+class SliceIter(ArrayIter):
+    '''
+    used with external loops, getitem and setitem return a SliceArray
+    view into the original array
+    '''
+    _immutable_fields_ = ['base', 'slice_shape[*]', 'slice_stride[*]', 'slice_backstride[*]']
+
+    def __init__(self, array, size, shape, strides, backstrides, slice_shape,
+                 slice_stride, slice_backstride, op_flags, base):
+        from pypy.module.micronumpy import concrete
+        ArrayIter.__init__(self, array, size, shape, strides, backstrides, op_flags)
+        self.slice_shape = slice_shape
+        self.slice_stride = slice_stride
+        self.slice_backstride = slice_backstride
+        self.base = base
+        if op_flags.rw == 'r':
+            self.slice_operand_type = concrete.NonWritableSliceArray
+        else:
+            self.slice_operand_type = concrete.SliceArray
+
+    def getitem(self, state):
+        # XXX cannot be called - must return a boxed value
+        assert False
+
+    def getitem_bool(self, state):
+        # XXX cannot be called - must return a boxed value
+        assert False
+
+    def setitem(self, state, elem):
+        # XXX cannot be called - must return a boxed value
+        assert False
+
+    def getoperand(self, state, base):
+        assert state.iterator is self
+        impl = self.slice_operand_type
+        arr = impl(state.offset, [self.slice_stride], [self.slice_backstride],
+                   [self.slice_shape], self.array, self.base)
+        return arr
diff --git a/pypy/module/micronumpy/ndarray.py b/pypy/module/micronumpy/ndarray.py
--- a/pypy/module/micronumpy/ndarray.py
+++ b/pypy/module/micronumpy/ndarray.py
@@ -83,8 +83,12 @@
         raise OperationError(space.w_AttributeError, space.wrap(
             "Cannot delete array dtype"))
 
+    def ndims(self):
+        return len(self.get_shape())
+    ndims._always_inline_ = True
+
     def descr_get_ndim(self, space):
-        return space.wrap(len(self.get_shape()))
+        return space.wrap(self.ndims())
 
     def descr_get_itemsize(self, space):
         return space.wrap(self.get_dtype().elsize)
@@ -103,14 +107,14 @@
         return space.wrap(loop.tostring(space, self))
 
     def getitem_filter(self, space, arr):
-        if len(arr.get_shape()) > 1 and arr.get_shape() != self.get_shape():
+        if arr.ndims() > 1 and arr.get_shape() != self.get_shape():
             raise OperationError(space.w_ValueError, space.wrap(
                 "boolean index array should have 1 dimension"))
         if arr.get_size() > self.get_size():
             raise OperationError(space.w_ValueError, space.wrap(
                 "index out of range for array"))
         size = loop.count_all_true(arr)
-        if len(arr.get_shape()) == 1:
+        if arr.ndims() == 1:
             res_shape = [size] + self.get_shape()[1:]
         else:
             res_shape = [size]
@@ -119,7 +123,7 @@
         return loop.getitem_filter(w_res, self, arr)
 
     def setitem_filter(self, space, idx, val):
-        if len(idx.get_shape()) > 1 and idx.get_shape() != self.get_shape():
+        if idx.ndims() > 1 and idx.get_shape() != self.get_shape():
             raise OperationError(space.w_ValueError, space.wrap(
                 "boolean index array should have 1 dimension"))
         if idx.get_size() > self.get_size():
@@ -210,7 +214,7 @@
         if space.is_w(w_idx, space.w_Ellipsis):
             return self
         elif isinstance(w_idx, W_NDimArray) and w_idx.get_dtype().is_bool() \
-                and len(w_idx.get_shape()) > 0:
+                and w_idx.ndims() > 0:
             return self.getitem_filter(space, w_idx)
         try:
             return self.implementation.descr_getitem(space, self, w_idx)
@@ -228,7 +232,7 @@
             self.implementation.setslice(space, convert_to_array(space, w_value))
             return
         elif isinstance(w_idx, W_NDimArray) and w_idx.get_dtype().is_bool() \
-                and len(w_idx.get_shape()) > 0:
+                and w_idx.ndims() > 0:
             self.setitem_filter(space, w_idx, convert_to_array(space, w_value))
             return
         try:
@@ -289,7 +293,7 @@
             shape=shape, backward_broadcast=backward_broadcast)
 
     def is_scalar(self):
-        return len(self.get_shape()) == 0
+        return self.ndims() == 0
 
     def set_scalar_value(self, w_val):
         return self.implementation.setitem(self.implementation.start, w_val)
@@ -408,7 +412,7 @@
         """
         if axis1 == axis2:
             return self
-        n = len(self.get_shape())
+        n = self.ndims()
         if n <= 1:
             return self
         if axis1 < 0:
@@ -426,7 +430,7 @@
         return self.implementation.nonzero(space, index_type)
 
     def descr_tolist(self, space):
-        if len(self.get_shape()) == 0:
+        if self.ndims() == 0:
             return self.get_scalar_value().item(space)
         l_w = []
         for i in range(self.get_shape()[0]):
@@ -514,7 +518,7 @@
         if len(args_w) == 0:
             raise OperationError(space.w_ValueError, space.wrap(
                 "itemset must have at least one argument"))
-        if len(args_w) != len(self.get_shape()) + 1:
+        if len(args_w) != self.ndims() + 1:
             raise OperationError(space.w_ValueError, space.wrap(
                 "incorrect number of indices for array"))
         self.descr_setitem(space, space.newtuple(args_w[:-1]), args_w[-1])
@@ -647,14 +651,14 @@
 
     @unwrap_spec(offset=int, axis1=int, axis2=int)
     def descr_diagonal(self, space, offset=0, axis1=0, axis2=1):
-        if len(self.get_shape()) < 2:
+        if self.ndims() < 2:
             raise OperationError(space.w_ValueError, space.wrap(
                 "need at least 2 dimensions for diagonal"))
-        if (axis1 < 0 or axis2 < 0 or axis1 >= len(self.get_shape()) or
-                axis2 >= len(self.get_shape())):
+        if (axis1 < 0 or axis2 < 0 or axis1 >= self.ndims() or
+                axis2 >= self.ndims()):
             raise oefmt(space.w_ValueError,
                         "axis1(=%d) and axis2(=%d) must be withing range "
-                        "(ndim=%d)", axis1, axis2, len(self.get_shape()))
+                        "(ndim=%d)", axis1, axis2, self.ndims())
         if axis1 == axis2:
             raise OperationError(space.w_ValueError, space.wrap(
                 "axis1 and axis2 cannot be the same"))
@@ -733,7 +737,7 @@
             raise OperationError(space.w_NotImplementedError, space.wrap(
                 'sorter not supported in searchsort'))
         side = searchside_converter(space, w_side)
-        if len(self.get_shape()) != 1:
+        if self.ndims() != 1:
             raise oefmt(space.w_ValueError, "a must be a 1-d array")
         v = convert_to_array(space, w_v)
         ret = W_NDimArray.from_shape(
@@ -972,7 +976,7 @@
         if other.is_scalar():
             #Note: w_out is not modified, this is numpy compliant.
             return self.descr_mul(space, other)
-        elif len(self.get_shape()) < 2 and len(other.get_shape()) < 2:
+        elif self.ndims() < 2 and other.ndims() < 2:
             w_res = self.descr_mul(space, other)
             assert isinstance(w_res, W_NDimArray)
             return w_res.descr_sum(space, space.wrap(-1), out)
@@ -989,7 +993,7 @@
                 matches = False
             elif not out.implementation.order == "C":
                 matches = False
-            elif len(out.get_shape()) != len(out_shape):
+            elif out.ndims() != len(out_shape):
                 matches = False
             else:
                 for i in range(len(out_shape)):
diff --git a/pypy/module/micronumpy/nditer.py b/pypy/module/micronumpy/nditer.py
--- a/pypy/module/micronumpy/nditer.py
+++ b/pypy/module/micronumpy/nditer.py
@@ -5,7 +5,7 @@
 from pypy.module.micronumpy import ufuncs, support, concrete
 from pypy.module.micronumpy.base import W_NDimArray, convert_to_array
 from pypy.module.micronumpy.descriptor import decode_w_dtype
-from pypy.module.micronumpy.iterators import ArrayIter
+from pypy.module.micronumpy.iterators import ArrayIter, SliceIter, OpFlag
 from pypy.module.micronumpy.strides import (calculate_broadcast_strides,
                                             shape_agreement, shape_agreement_multiple)
 
@@ -35,17 +35,6 @@
     return ret
 
 
-class OpFlag(object):
-    def __init__(self):
-        self.rw = ''
-        self.broadcast = True
-        self.force_contig = False
-        self.force_align = False
-        self.native_byte_order = False
-        self.tmp_copy = ''
-        self.allocate = False
-
-
 def parse_op_flag(space, lst):
     op_flag = OpFlag()
     for w_item in lst:
@@ -71,17 +60,17 @@
         elif item == 'allocate':
             op_flag.allocate = True
         elif item == 'no_subtype':
-            raise OperationError(space.w_NotImplementedError, space.wrap(
-                '"no_subtype" op_flag not implemented yet'))
+            raise oefmt(space.w_NotImplementedError,
+                '"no_subtype" op_flag not implemented yet')
         elif item == 'arraymask':
-            raise OperationError(space.w_NotImplementedError, space.wrap(
-                '"arraymask" op_flag not implemented yet'))
+            raise oefmt(space.w_NotImplementedError,
+                '"arraymask" op_flag not implemented yet')
         elif item == 'writemask':
-            raise OperationError(space.w_NotImplementedError, space.wrap(
-                '"writemask" op_flag not implemented yet'))
+            raise oefmt(space.w_NotImplementedError,
+                '"writemask" op_flag not implemented yet')
         else:
-            raise OperationError(space.w_ValueError, space.wrap(
-                'op_flags must be a tuple or array of per-op flag-tuples'))
+            raise oefmt(space.w_ValueError,
+                'op_flags must be a tuple or array of per-op flag-tuples')
     if op_flag.rw == '':
         raise oefmt(space.w_ValueError,
                     "None of the iterator flags READWRITE, READONLY, or "
@@ -94,8 +83,8 @@
         return
     elif not space.isinstance_w(w_flags, space.w_tuple) and not \
             space.isinstance_w(w_flags, space.w_list):
-        raise OperationError(space.w_ValueError, space.wrap(
-            'Iter global flags must be a list or tuple of strings'))
+        raise oefmt(space.w_ValueError,
+            'Iter global flags must be a list or tuple of strings')
     lst = space.listview(w_flags)
     for w_item in lst:
         if not space.isinstance_w(w_item, space.w_str) and not \
@@ -106,12 +95,10 @@
                         typename)
         item = space.str_w(w_item)
         if item == 'external_loop':
-            raise OperationError(space.w_NotImplementedError, space.wrap(
-                'nditer external_loop not implemented yet'))
             nditer.external_loop = True
         elif item == 'buffered':
-            raise OperationError(space.w_NotImplementedError, space.wrap(
-                'nditer buffered not implemented yet'))
+            raise oefmt(space.w_NotImplementedError,
+                'nditer buffered not implemented yet')
             # For numpy compatability
             nditer.buffered = True
         elif item == 'c_index':
@@ -131,8 +118,8 @@
         elif item == 'refs_ok':
             nditer.refs_ok = True
         elif item == 'reduce_ok':
-            raise OperationError(space.w_NotImplementedError, space.wrap(
-                'nditer reduce_ok not implemented yet'))
+            raise oefmt(space.w_NotImplementedError,
+                'nditer reduce_ok not implemented yet')
             nditer.reduce_ok = True
         elif item == 'zerosize_ok':
             nditer.zerosize_ok = True
@@ -141,9 +128,9 @@
                         'Unexpected iterator global flag "%s"',
                         item)
     if nditer.tracked_index and nditer.external_loop:
-        raise OperationError(space.w_ValueError, space.wrap(
+        raise oefmt(space.w_ValueError,
             'Iterator flag EXTERNAL_LOOP cannot be used if an index or '
-            'multi-index is being tracked'))
+            'multi-index is being tracked')
 
 
 def is_backward(imp, order):
@@ -155,11 +142,11 @@
         raise NotImplementedError('not implemented yet')
 
 
-def get_iter(space, order, arr, shape, dtype):
+def get_iter(space, order, arr, shape, dtype, op_flags):
     imp = arr.implementation
     backward = is_backward(imp, order)
     if arr.is_scalar():
-        return ArrayIter(imp, 1, [], [], [])
+        return ArrayIter(imp, 1, [], [], [], op_flags=op_flags)
     if (imp.strides[0] < imp.strides[-1] and not backward) or \
        (imp.strides[0] > imp.strides[-1] and backward):
         # flip the strides. Is this always true for multidimension?
@@ -174,8 +161,103 @@
         backstrides = imp.backstrides
     r = calculate_broadcast_strides(strides, backstrides, imp.shape,
                                     shape, backward)
-    return ArrayIter(imp, imp.get_size(), shape, r[0], r[1])
+    return ArrayIter(imp, imp.get_size(), shape, r[0], r[1], op_flags=op_flags)
 
+def calculate_ndim(op_in, oa_ndim):
+    if oa_ndim >=0:
+        return oa_ndim
+    else:
+        ndim = 0
+        for op in op_in:
+            if op is None:
+                continue
+            assert isinstance(op, W_NDimArray)
+            ndim = max(ndim, op.ndims())
+    return ndim
+
+def coalesce_axes(it, space):
+    # Copy logic from npyiter_coalesce_axes, used in ufunc iterators
+    # and in nditer's with 'external_loop' flag
+    can_coalesce = True
+    if it.order == 'F':
+        fastest = 0
+    else:
+        fastest = -1
+    for idim in range(it.ndim - 1):
+        for op_it, _ in it.iters:
+            if op_it is None:
+                continue
+            assert isinstance(op_it, ArrayIter)
+            indx = len(op_it.strides)
+            if it.order == 'F':
+                indx = len(op_it.array.strides) - indx
+                assert indx >=0
+                astrides = op_it.array.strides[indx:]
+            else:
+                astrides = op_it.array.strides[:indx]
+            # does op_it iters over array "naturally"
+            if astrides != op_it.strides:
+                can_coalesce = False
+                break
+        if can_coalesce:
+            for i in range(len(it.iters)):
+                old_iter = it.iters[i][0]
+                shape = [s+1 for s in old_iter.shape_m1]
+                strides = old_iter.strides
+                backstrides = old_iter.backstrides
+                if it.order == 'F':
+                    new_shape = shape[1:]
+                    new_strides = strides[1:]
+                    new_backstrides = backstrides[1:]
+                    _stride = min(strides[0], old_iter.slice_stride)
+                else:
+                    new_shape = shape[:-1]
+                    new_strides = strides[:-1]
+                    new_backstrides = backstrides[:-1]
+                    _stride = old_iter.slice_stride
+                # We always want the "fastest" iterator in external loops
+                _shape = shape[fastest] * old_iter.slice_shape
+                _backstride = (_shape - 1) * _stride
+                new_iter = SliceIter(old_iter.array, old_iter.size / shape[fastest],
+                            new_shape, new_strides, new_backstrides,
+                            _shape, _stride, _backstride,
+                            it.op_flags[i], it)
+                it.iters[i] = (new_iter, new_iter.reset())
+            if len(it.shape) > 1:
+                if it.order == 'F':
+                    it.shape = it.shape[1:]
+                else:
+                    it.shape = it.shape[:-1]
+            else:
+                it.shape = [1]
+        else:
+            break
+    # Always coalesce at least one
+    for i in range(len(it.iters)):
+        old_iter = it.iters[i][0]
+        shape = [s+1 for s in old_iter.shape_m1]
+        strides = old_iter.strides
+        backstrides = old_iter.backstrides
+        new_shape = shape[:-1]
+        new_strides = strides[:-1]
+        new_backstrides = backstrides[:-1]
+        _shape = shape[-1] * old_iter.slice_shape
+        # use the operand's iterator's rightmost stride,
+        # even if it is not the fastest (for 'F' or swapped axis)
+        _stride = old_iter.slice_stride
+        _backstride = (_shape - 1) * _stride
+        new_iter = SliceIter(old_iter.array, old_iter.size / shape[-1],
+                    new_shape, new_strides, new_backstrides,
+                    _shape, _stride, _backstride,
+                    it.op_flags[i], it)
+        it.iters[i] = (new_iter, new_iter.reset())
+    if len(it.shape) > 1:
+        if it.order == 'F':
+            it.shape = it.shape[1:]
+        else:
+            it.shape = it.shape[:-1]
+    else:
+        it.shape = [1]
 
 class IndexIterator(object):
     def __init__(self, shape, backward=False):
@@ -205,6 +287,7 @@
 
 
 class W_NDIter(W_Root):
+    _immutable_fields_ = ['ndim', ]
     def __init__(self, space, w_seq, w_flags, w_op_flags, w_op_dtypes, w_casting,
                  w_op_axes, w_itershape, w_buffersize, order):
         self.order = order
@@ -236,28 +319,29 @@
         self.op_flags = parse_op_arg(space, 'op_flags', w_op_flags,
                                      len(self.seq), parse_op_flag)
         # handle w_op_axes
+        oa_ndim = -1
         if not space.is_none(w_op_axes):
-            self.set_op_axes(space, w_op_axes)
+            oa_ndim = self.set_op_axes(space, w_op_axes)
+        self.ndim = calculate_ndim(self.seq, oa_ndim)
 
         # handle w_op_dtypes part 1: creating self.dtypes list from input
         if not space.is_none(w_op_dtypes):
             w_seq_as_list = space.listview(w_op_dtypes)
             self.dtypes = [decode_w_dtype(space, w_elem) for w_elem in w_seq_as_list]
             if len(self.dtypes) != len(self.seq):
-                raise OperationError(space.w_ValueError, space.wrap(
-                    "op_dtypes must be a tuple/list matching the number of ops"))
+                raise oefmt(space.w_ValueError,
+                    "op_dtypes must be a tuple/list matching the number of ops")
         else:
             self.dtypes = []
 
         # handle None or writable operands, calculate my shape
-        self.iters = []
         outargs = [i for i in range(len(self.seq))
                    if self.seq[i] is None or self.op_flags[i].rw == 'w']
         if len(outargs) > 0:
             out_shape = shape_agreement_multiple(space, [self.seq[i] for i in outargs])
         else:
             out_shape = None
-        self.shape = iter_shape = shape_agreement_multiple(space, self.seq,
+        self.shape = shape_agreement_multiple(space, self.seq,
                                                            shape=out_shape)
         if len(outargs) > 0:
             # Make None operands writeonly and flagged for allocation
@@ -276,11 +360,11 @@
             for i in outargs:
                 if self.seq[i] is None:
                     # XXX can we postpone allocation to later?
-                    self.seq[i] = W_NDimArray.from_shape(space, iter_shape, out_dtype)
+                    self.seq[i] = W_NDimArray.from_shape(space, self.shape, out_dtype)
                 else:
                     if not self.op_flags[i].broadcast:
                         # Raises if ooutput cannot be broadcast
-                        shape_agreement(space, iter_shape, self.seq[i], False)
+                        shape_agreement(space, self.shape, self.seq[i], False)
 
         if self.tracked_index != "":
             if self.order == "K":
@@ -289,7 +373,7 @@
                 backward = False
             else:
                 backward = self.order != self.tracked_index
-            self.index_iter = IndexIterator(iter_shape, backward=backward)
+            self.index_iter = IndexIterator(self.shape, backward=backward)
 
         # handle w_op_dtypes part 2: copy where needed if possible
         if len(self.dtypes) > 0:
@@ -311,49 +395,49 @@
             self.dtypes = [s.get_dtype() for s in self.seq]
 
         # create an iterator for each operand
+        self.iters = []
         for i in range(len(self.seq)):
-            it = get_iter(space, self.order, self.seq[i], iter_shape, self.dtypes[i])
+            it = get_iter(space, self.order, self.seq[i], self.shape,
+                          self.dtypes[i], self.op_flags[i])
             it.contiguous = False
             self.iters.append((it, it.reset()))
 
+        if self.external_loop:
+            coalesce_axes(self, space)
+
     def set_op_axes(self, space, w_op_axes):
         if space.len_w(w_op_axes) != len(self.seq):
             raise oefmt(space.w_ValueError,
                         "op_axes must be a tuple/list matching the number of ops")
         op_axes = space.listview(w_op_axes)
-        l = -1
+        oa_ndim = -1
         for w_axis in op_axes:
             if not space.is_none(w_axis):
                 axis_len = space.len_w(w_axis)
-                if l == -1:
-                    l = axis_len
-                elif axis_len != l:
+                if oa_ndim == -1:
+                    oa_ndim = axis_len
+                elif axis_len != oa_ndim:
                     raise oefmt(space.w_ValueError,
                                 "Each entry of op_axes must have the same size")
                 self.op_axes.append([space.int_w(x) if not space.is_none(x) else -1
                                      for x in space.listview(w_axis)])
-        if l == -1:
+        if oa_ndim == -1:
             raise oefmt(space.w_ValueError,
                         "If op_axes is provided, at least one list of axes "
                         "must be contained within it")
-        raise Exception('xxx TODO')
+        raise oefmt(space.w_NotImplementedError, "op_axis not finished yet")
         # Check that values make sense:
         # - in bounds for each operand
         # ValueError: Iterator input op_axes[0][3] (==3) is not a valid axis of op[0], which has 2 dimensions
         # - no repeat axis
         # ValueError: The 'op_axes' provided to the iterator constructor for operand 1 contained duplicate value 0
+        return oa_ndim
 
     def descr_iter(self, space):
         return space.wrap(self)
 
-    def getitem(self, it, st, op_flags):
-        if op_flags.rw == 'r':
-            impl = concrete.ConcreteNonWritableArrayWithBase
-        else:
-            impl = concrete.ConcreteArrayWithBase
-        res = impl([], it.array.dtype, it.array.order, [], [],
-                   it.array.storage, self)
-        res.start = st.offset
+    def getitem(self, it, st):
+        res = it.getoperand(st, self)
         return W_NDimArray(res)
 
     def descr_getitem(self, space, w_idx):
@@ -363,7 +447,7 @@
         except IndexError:
             raise oefmt(space.w_IndexError,
                         "Iterator operand index %d is out of bounds", idx)
-        return self.getitem(it, st, self.op_flags[idx])
+        return self.getitem(it, st)
 
     def descr_setitem(self, space, w_idx, w_value):
         raise oefmt(space.w_NotImplementedError, "not implemented yet")
@@ -385,7 +469,7 @@
             else:
                 self.first_next = False
         for i, (it, st) in enumerate(self.iters):
-            res.append(self.getitem(it, st, self.op_flags[i]))
+            res.append(self.getitem(it, st))
             self.iters[i] = (it, it.next(st))
         if len(res) < 2:
             return res[0]
@@ -477,7 +561,7 @@
         raise oefmt(space.w_NotImplementedError, "not implemented yet")
 
     def descr_get_ndim(self, space):
-        raise oefmt(space.w_NotImplementedError, "not implemented yet")
+        return space.wrap(self.ndim)
 
     def descr_get_nop(self, space):
         raise oefmt(space.w_NotImplementedError, "not implemented yet")
diff --git a/pypy/module/micronumpy/test/test_nditer.py b/pypy/module/micronumpy/test/test_nditer.py
--- a/pypy/module/micronumpy/test/test_nditer.py
+++ b/pypy/module/micronumpy/test/test_nditer.py
@@ -63,9 +63,6 @@
         from numpy import arange, nditer, array
         a = arange(24).reshape(2, 3, 4)
         import sys
-        if '__pypy__' in sys.builtin_module_names:
-            raises(NotImplementedError, nditer, a, flags=['external_loop'])
-            skip('nditer external_loop not implmented')
         r = []
         n = 0
         for x in nditer(a, flags=['external_loop']):
@@ -79,7 +76,9 @@
             r.append(x)
             n += 1
         assert n == 12
-        assert (array(r) == [[ 0, 12], [ 4, 16], [ 8, 20], [ 1, 13], [ 5, 17], [ 9, 21], [ 2, 14], [ 6, 18], [10, 22], [ 3, 15], [ 7, 19], [11, 23]]).all()
+        assert (array(r) == [[ 0, 12], [ 4, 16], [ 8, 20], [ 1, 13], [ 5, 17], [ 9, 21],
+                             [ 2, 14], [ 6, 18], [10, 22], [ 3, 15], [ 7, 19], [11, 23],
+                            ]).all()
         e = raises(ValueError, 'r[0][0] = 0')
         assert str(e.value) == 'assignment destination is read-only'
         r = []
@@ -222,9 +221,6 @@
     def test_outarg(self):
         from numpy import nditer, zeros, arange
         import sys
-        if '__pypy__' in sys.builtin_module_names:
-            raises(NotImplementedError, nditer, [1, 2], flags=['external_loop'])
-            skip('nditer external_loop not implmented')
 
         def square1(a):
             it = nditer([a, None])
@@ -233,6 +229,9 @@
             return it.operands[1]
         assert (square1([1, 2, 3]) == [1, 4, 9]).all()
 
+        if '__pypy__' in sys.builtin_module_names:
+            raises(NotImplementedError, nditer, [1, 2], flags=['buffered'])
+            skip('nditer buffered not implmented')
         def square2(a, out=None):
             it = nditer([a, out], flags=['external_loop', 'buffered'],
                         op_flags=[['readonly'],
@@ -252,10 +251,11 @@
         from numpy import nditer, arange
         a = arange(3)
         import sys
+        b = arange(8).reshape(2,4)
         if '__pypy__' in sys.builtin_module_names:
-            raises(NotImplementedError, nditer, a, flags=['external_loop'])
-            skip('nditer external_loop not implmented')
-        b = arange(8).reshape(2,4)
+            raises(NotImplementedError, nditer, [a, b, None], flags=['external_loop'],
+                   op_axes=[[0, -1, -1], [-1, 0, 1], None])
+            skip('nditer op_axes not implemented yet')
         it = nditer([a, b, None], flags=['external_loop'],
                     op_axes=[[0, -1, -1], [-1, 0, 1], None])
         for x, y, z in it:


More information about the pypy-commit mailing list