[pypy-commit] pypy numpy-searchsorted: merge default into branch
mattip
noreply at buildbot.pypy.org
Sat Apr 19 23:53:33 CEST 2014
Author: Matti Picus <matti.picus at gmail.com>
Branch: numpy-searchsorted
Changeset: r70788:e0560bcc6840
Date: 2014-04-19 23:56 +0300
http://bitbucket.org/pypy/pypy/changeset/e0560bcc6840/
Log: merge default into branch
diff too long, truncating to 2000 out of 2685 lines
diff --git a/pypy/doc/stm.rst b/pypy/doc/stm.rst
--- a/pypy/doc/stm.rst
+++ b/pypy/doc/stm.rst
@@ -40,7 +40,7 @@
``pypy-stm`` project is to improve what is so far the state-of-the-art
for using multiple CPUs, which for cases where separate processes don't
work is done by writing explicitly multi-threaded programs. Instead,
-``pypy-stm`` is flushing forward an approach to *hide* the threads, as
+``pypy-stm`` is pushing forward an approach to *hide* the threads, as
described below in `atomic sections`_.
diff --git a/pypy/doc/whatsnew-head.rst b/pypy/doc/whatsnew-head.rst
--- a/pypy/doc/whatsnew-head.rst
+++ b/pypy/doc/whatsnew-head.rst
@@ -140,3 +140,6 @@
.. branch: numpypy-nditer
Implement the core of nditer, without many of the fancy flags (external_loop, buffered)
+
+.. branch: numpy-speed
+Separate iterator from its state so jit can optimize better
diff --git a/pypy/module/micronumpy/concrete.py b/pypy/module/micronumpy/concrete.py
--- a/pypy/module/micronumpy/concrete.py
+++ b/pypy/module/micronumpy/concrete.py
@@ -284,9 +284,11 @@
self.get_backstrides(),
self.get_shape(), shape,
backward_broadcast)
- return ArrayIter(self, support.product(shape), shape, r[0], r[1])
- return ArrayIter(self, self.get_size(), self.shape,
- self.strides, self.backstrides)
+ i = ArrayIter(self, support.product(shape), shape, r[0], r[1])
+ else:
+ i = ArrayIter(self, self.get_size(), self.shape,
+ self.strides, self.backstrides)
+ return i, i.reset()
def swapaxes(self, space, orig_arr, axis1, axis2):
shape = self.get_shape()[:]
diff --git a/pypy/module/micronumpy/ctors.py b/pypy/module/micronumpy/ctors.py
--- a/pypy/module/micronumpy/ctors.py
+++ b/pypy/module/micronumpy/ctors.py
@@ -2,7 +2,7 @@
from pypy.interpreter.gateway import unwrap_spec, WrappedDefault
from rpython.rlib.rstring import strip_spaces
from rpython.rtyper.lltypesystem import lltype, rffi
-from pypy.module.micronumpy import descriptor, loop, ufuncs
+from pypy.module.micronumpy import descriptor, loop
from pypy.module.micronumpy.base import W_NDimArray, convert_to_array
from pypy.module.micronumpy.converters import shape_converter
@@ -156,10 +156,10 @@
"string is smaller than requested size"))
a = W_NDimArray.from_shape(space, [num_items], dtype=dtype)
- ai = a.create_iter()
+ ai, state = a.create_iter()
for val in items:
- ai.setitem(val)
- ai.next()
+ ai.setitem(state, val)
+ state = ai.next(state)
return space.wrap(a)
diff --git a/pypy/module/micronumpy/flatiter.py b/pypy/module/micronumpy/flatiter.py
--- a/pypy/module/micronumpy/flatiter.py
+++ b/pypy/module/micronumpy/flatiter.py
@@ -32,24 +32,23 @@
self.reset()
def reset(self):
- self.iter = self.base.create_iter()
+ self.iter, self.state = self.base.create_iter()
def descr_len(self, space):
- return space.wrap(self.base.get_size())
+ return space.wrap(self.iter.size)
def descr_next(self, space):
- if self.iter.done():
+ if self.iter.done(self.state):
raise OperationError(space.w_StopIteration, space.w_None)
- w_res = self.iter.getitem()
- self.iter.next()
+ w_res = self.iter.getitem(self.state)
+ self.state = self.iter.next(self.state)
return w_res
def descr_index(self, space):
- return space.wrap(self.iter.index)
+ return space.wrap(self.state.index)
def descr_coords(self, space):
- coords = self.base.to_coords(space, space.wrap(self.iter.index))
- return space.newtuple([space.wrap(c) for c in coords])
+ return space.newtuple([space.wrap(c) for c in self.state.indices])
def descr_getitem(self, space, w_idx):
if not (space.isinstance_w(w_idx, space.w_int) or
@@ -58,13 +57,13 @@
self.reset()
base = self.base
start, stop, step, length = space.decode_index4(w_idx, base.get_size())
- base_iter = base.create_iter()
- base_iter.next_skip_x(start)
+ base_iter, base_state = base.create_iter()
+ base_state = base_iter.next_skip_x(base_state, start)
if length == 1:
- return base_iter.getitem()
+ return base_iter.getitem(base_state)
res = W_NDimArray.from_shape(space, [length], base.get_dtype(),
base.get_order(), w_instance=base)
- return loop.flatiter_getitem(res, base_iter, step)
+ return loop.flatiter_getitem(res, base_iter, base_state, step)
def descr_setitem(self, space, w_idx, w_value):
if not (space.isinstance_w(w_idx, space.w_int) or
diff --git a/pypy/module/micronumpy/iterators.py b/pypy/module/micronumpy/iterators.py
--- a/pypy/module/micronumpy/iterators.py
+++ b/pypy/module/micronumpy/iterators.py
@@ -42,7 +42,6 @@
"""
from rpython.rlib import jit
from pypy.module.micronumpy import support
-from pypy.module.micronumpy.strides import calc_strides
from pypy.module.micronumpy.base import W_NDimArray
@@ -52,19 +51,20 @@
self.shapelen = len(shape)
self.indexes = [0] * len(shape)
self._done = False
- self.idx_w = [None] * len(idx_w)
+ self.idx_w_i = [None] * len(idx_w)
+ self.idx_w_s = [None] * len(idx_w)
for i, w_idx in enumerate(idx_w):
if isinstance(w_idx, W_NDimArray):
- self.idx_w[i] = w_idx.create_iter(shape)
+ self.idx_w_i[i], self.idx_w_s[i] = w_idx.create_iter(shape)
def done(self):
return self._done
@jit.unroll_safe
def next(self):
- for w_idx in self.idx_w:
- if w_idx is not None:
- w_idx.next()
+ for i, idx_w_i in enumerate(self.idx_w_i):
+ if idx_w_i is not None:
+ self.idx_w_s[i] = idx_w_i.next(self.idx_w_s[i])
for i in range(self.shapelen - 1, -1, -1):
if self.indexes[i] < self.shape[i] - 1:
self.indexes[i] += 1
@@ -79,6 +79,16 @@
return [space.wrap(self.indexes[i]) for i in range(shapelen)]
+class IterState(object):
+ _immutable_fields_ = ['iterator', 'index', 'indices[*]', 'offset']
+
+ def __init__(self, iterator, index, indices, offset):
+ self.iterator = iterator
+ self.index = index
+ self.indices = indices
+ self.offset = offset
+
+
class ArrayIter(object):
_immutable_fields_ = ['array', 'size', 'ndim_m1', 'shape_m1[*]',
'strides[*]', 'backstrides[*]']
@@ -91,90 +101,66 @@
self.shape_m1 = [s - 1 for s in shape]
self.strides = strides
self.backstrides = backstrides
- self.reset()
def reset(self):
- self.index = 0
- self.indices = [0] * len(self.shape_m1)
- self.offset = self.array.start
+ return IterState(self, 0, [0] * len(self.shape_m1), self.array.start)
@jit.unroll_safe
- def next(self):
- self.index += 1
+ def next(self, state):
+ assert state.iterator is self
+ index = state.index + 1
+ indices = state.indices
+ offset = state.offset
for i in xrange(self.ndim_m1, -1, -1):
- idx = self.indices[i]
+ idx = indices[i]
if idx < self.shape_m1[i]:
- self.indices[i] = idx + 1
- self.offset += self.strides[i]
+ indices[i] = idx + 1
+ offset += self.strides[i]
break
else:
- self.indices[i] = 0
- self.offset -= self.backstrides[i]
+ indices[i] = 0
+ offset -= self.backstrides[i]
+ return IterState(self, index, indices, offset)
@jit.unroll_safe
- def next_skip_x(self, step):
+ def next_skip_x(self, state, step):
+ assert state.iterator is self
assert step >= 0
if step == 0:
- return
- self.index += step
+ return state
+ index = state.index + step
+ indices = state.indices
+ offset = state.offset
for i in xrange(self.ndim_m1, -1, -1):
- idx = self.indices[i]
+ idx = indices[i]
if idx < (self.shape_m1[i] + 1) - step:
- self.indices[i] = idx + step
- self.offset += self.strides[i] * step
+ indices[i] = idx + step
+ offset += self.strides[i] * step
break
else:
- rem_step = (self.indices[i] + step) // (self.shape_m1[i] + 1)
+ rem_step = (idx + step) // (self.shape_m1[i] + 1)
cur_step = step - rem_step * (self.shape_m1[i] + 1)
- self.indices[i] += cur_step
- self.offset += self.strides[i] * cur_step
+ indices[i] = idx + cur_step
+ offset += self.strides[i] * cur_step
step = rem_step
assert step > 0
+ return IterState(self, index, indices, offset)
- def done(self):
- return self.index >= self.size
+ def done(self, state):
+ assert state.iterator is self
+ return state.index >= self.size
- def getitem(self):
- return self.array.getitem(self.offset)
+ def getitem(self, state):
+ assert state.iterator is self
+ return self.array.getitem(state.offset)
- def getitem_bool(self):
- return self.array.getitem_bool(self.offset)
+ def getitem_bool(self, state):
+ assert state.iterator is self
+ return self.array.getitem_bool(state.offset)
- def setitem(self, elem):
- self.array.setitem(self.offset, elem)
-
-
-class SliceIterator(ArrayIter):
- def __init__(self, arr, strides, backstrides, shape, order="C",
- backward=False, dtype=None):
- if dtype is None:
- dtype = arr.implementation.dtype
- self.dtype = dtype
- self.arr = arr
- if backward:
- self.slicesize = shape[0]
- self.gap = [support.product(shape[1:]) * dtype.elsize]
- strides = strides[1:]
- backstrides = backstrides[1:]
- shape = shape[1:]
- strides.reverse()
- backstrides.reverse()
- shape.reverse()
- size = support.product(shape)
- else:
- shape = [support.product(shape)]
- strides, backstrides = calc_strides(shape, dtype, order)
- size = 1
- self.slicesize = support.product(shape)
- self.gap = strides
-
- ArrayIter.__init__(self, arr.implementation, size, shape, strides, backstrides)
-
- def getslice(self):
- from pypy.module.micronumpy.concrete import SliceArray
- retVal = SliceArray(self.offset, self.gap, self.backstrides,
- [self.slicesize], self.arr.implementation, self.arr, self.dtype)
- return retVal
+ def setitem(self, state, elem):
+ assert state.iterator is self
+ self.array.setitem(state.offset, elem)
def AxisIter(array, shape, axis, cumulative):
diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py
--- a/pypy/module/micronumpy/loop.py
+++ b/pypy/module/micronumpy/loop.py
@@ -12,11 +12,10 @@
AllButAxisIter
-call2_driver = jit.JitDriver(name='numpy_call2',
- greens = ['shapelen', 'func', 'calc_dtype',
- 'res_dtype'],
- reds = ['shape', 'w_lhs', 'w_rhs', 'out',
- 'left_iter', 'right_iter', 'out_iter'])
+call2_driver = jit.JitDriver(
+ name='numpy_call2',
+ greens=['shapelen', 'func', 'calc_dtype', 'res_dtype'],
+ reds='auto')
def call2(space, shape, func, calc_dtype, res_dtype, w_lhs, w_rhs, out):
# handle array_priority
@@ -46,47 +45,40 @@
if out is None:
out = W_NDimArray.from_shape(space, shape, res_dtype,
w_instance=lhs_for_subtype)
- left_iter = w_lhs.create_iter(shape)
- right_iter = w_rhs.create_iter(shape)
- out_iter = out.create_iter(shape)
+ left_iter, left_state = w_lhs.create_iter(shape)
+ right_iter, right_state = w_rhs.create_iter(shape)
+ out_iter, out_state = out.create_iter(shape)
shapelen = len(shape)
- while not out_iter.done():
+ while not out_iter.done(out_state):
call2_driver.jit_merge_point(shapelen=shapelen, func=func,
- calc_dtype=calc_dtype, res_dtype=res_dtype,
- shape=shape, w_lhs=w_lhs, w_rhs=w_rhs,
- out=out,
- left_iter=left_iter, right_iter=right_iter,
- out_iter=out_iter)
- w_left = left_iter.getitem().convert_to(space, calc_dtype)
- w_right = right_iter.getitem().convert_to(space, calc_dtype)
- out_iter.setitem(func(calc_dtype, w_left, w_right).convert_to(
+ calc_dtype=calc_dtype, res_dtype=res_dtype)
+ w_left = left_iter.getitem(left_state).convert_to(space, calc_dtype)
+ w_right = right_iter.getitem(right_state).convert_to(space, calc_dtype)
+ out_iter.setitem(out_state, func(calc_dtype, w_left, w_right).convert_to(
space, res_dtype))
- left_iter.next()
- right_iter.next()
- out_iter.next()
+ left_state = left_iter.next(left_state)
+ right_state = right_iter.next(right_state)
+ out_state = out_iter.next(out_state)
return out
-call1_driver = jit.JitDriver(name='numpy_call1',
- greens = ['shapelen', 'func', 'calc_dtype',
- 'res_dtype'],
- reds = ['shape', 'w_obj', 'out', 'obj_iter',
- 'out_iter'])
+call1_driver = jit.JitDriver(
+ name='numpy_call1',
+ greens=['shapelen', 'func', 'calc_dtype', 'res_dtype'],
+ reds='auto')
def call1(space, shape, func, calc_dtype, res_dtype, w_obj, out):
if out is None:
out = W_NDimArray.from_shape(space, shape, res_dtype, w_instance=w_obj)
- obj_iter = w_obj.create_iter(shape)
- out_iter = out.create_iter(shape)
+ obj_iter, obj_state = w_obj.create_iter(shape)
+ out_iter, out_state = out.create_iter(shape)
shapelen = len(shape)
- while not out_iter.done():
+ while not out_iter.done(out_state):
call1_driver.jit_merge_point(shapelen=shapelen, func=func,
- calc_dtype=calc_dtype, res_dtype=res_dtype,
- shape=shape, w_obj=w_obj, out=out,
- obj_iter=obj_iter, out_iter=out_iter)
- elem = obj_iter.getitem().convert_to(space, calc_dtype)
- out_iter.setitem(func(calc_dtype, elem).convert_to(space, res_dtype))
- out_iter.next()
- obj_iter.next()
+ calc_dtype=calc_dtype, res_dtype=res_dtype)
+ elem = obj_iter.getitem(obj_state).convert_to(space, calc_dtype)
+ out_iter.setitem(out_state, func(calc_dtype, elem).convert_to(space, res_dtype))
+ out_state = out_iter.next(out_state)
+ obj_state = obj_iter.next(obj_state)
return out
setslice_driver = jit.JitDriver(name='numpy_setslice',
@@ -96,18 +88,20 @@
def setslice(space, shape, target, source):
# note that unlike everything else, target and source here are
# array implementations, not arrays
- target_iter = target.create_iter(shape)
- source_iter = source.create_iter(shape)
+ target_iter, target_state = target.create_iter(shape)
+ source_iter, source_state = source.create_iter(shape)
dtype = target.dtype
shapelen = len(shape)
- while not target_iter.done():
+ while not target_iter.done(target_state):
setslice_driver.jit_merge_point(shapelen=shapelen, dtype=dtype)
+ val = source_iter.getitem(source_state)
if dtype.is_str_or_unicode():
- target_iter.setitem(dtype.coerce(space, source_iter.getitem()))
+ val = dtype.coerce(space, val)
else:
- target_iter.setitem(source_iter.getitem().convert_to(space, dtype))
- target_iter.next()
- source_iter.next()
+ val = val.convert_to(space, dtype)
+ target_iter.setitem(target_state, val)
+ target_state = target_iter.next(target_state)
+ source_state = source_iter.next(source_state)
return target
reduce_driver = jit.JitDriver(name='numpy_reduce',
@@ -116,22 +110,22 @@
reds = 'auto')
def compute_reduce(space, obj, calc_dtype, func, done_func, identity):
- obj_iter = obj.create_iter()
+ obj_iter, obj_state = obj.create_iter()
if identity is None:
- cur_value = obj_iter.getitem().convert_to(space, calc_dtype)
- obj_iter.next()
+ cur_value = obj_iter.getitem(obj_state).convert_to(space, calc_dtype)
+ obj_state = obj_iter.next(obj_state)
else:
cur_value = identity.convert_to(space, calc_dtype)
shapelen = len(obj.get_shape())
- while not obj_iter.done():
+ while not obj_iter.done(obj_state):
reduce_driver.jit_merge_point(shapelen=shapelen, func=func,
done_func=done_func,
calc_dtype=calc_dtype)
- rval = obj_iter.getitem().convert_to(space, calc_dtype)
+ rval = obj_iter.getitem(obj_state).convert_to(space, calc_dtype)
if done_func is not None and done_func(calc_dtype, rval):
return rval
cur_value = func(calc_dtype, cur_value, rval)
- obj_iter.next()
+ obj_state = obj_iter.next(obj_state)
return cur_value
reduce_cum_driver = jit.JitDriver(name='numpy_reduce_cum_driver',
@@ -139,69 +133,76 @@
reds = 'auto')
def compute_reduce_cumulative(space, obj, out, calc_dtype, func, identity):
- obj_iter = obj.create_iter()
- out_iter = out.create_iter()
+ obj_iter, obj_state = obj.create_iter()
+ out_iter, out_state = out.create_iter()
if identity is None:
- cur_value = obj_iter.getitem().convert_to(space, calc_dtype)
- out_iter.setitem(cur_value)
- out_iter.next()
- obj_iter.next()
+ cur_value = obj_iter.getitem(obj_state).convert_to(space, calc_dtype)
+ out_iter.setitem(out_state, cur_value)
+ out_state = out_iter.next(out_state)
+ obj_state = obj_iter.next(obj_state)
else:
cur_value = identity.convert_to(space, calc_dtype)
shapelen = len(obj.get_shape())
- while not obj_iter.done():
+ while not obj_iter.done(obj_state):
reduce_cum_driver.jit_merge_point(shapelen=shapelen, func=func,
dtype=calc_dtype)
- rval = obj_iter.getitem().convert_to(space, calc_dtype)
+ rval = obj_iter.getitem(obj_state).convert_to(space, calc_dtype)
cur_value = func(calc_dtype, cur_value, rval)
- out_iter.setitem(cur_value)
- out_iter.next()
- obj_iter.next()
+ out_iter.setitem(out_state, cur_value)
+ out_state = out_iter.next(out_state)
+ obj_state = obj_iter.next(obj_state)
def fill(arr, box):
- arr_iter = arr.create_iter()
- while not arr_iter.done():
- arr_iter.setitem(box)
- arr_iter.next()
+ arr_iter, arr_state = arr.create_iter()
+ while not arr_iter.done(arr_state):
+ arr_iter.setitem(arr_state, box)
+ arr_state = arr_iter.next(arr_state)
def assign(space, arr, seq):
- arr_iter = arr.create_iter()
+ arr_iter, arr_state = arr.create_iter()
arr_dtype = arr.get_dtype()
for item in seq:
- arr_iter.setitem(arr_dtype.coerce(space, item))
- arr_iter.next()
+ arr_iter.setitem(arr_state, arr_dtype.coerce(space, item))
+ arr_state = arr_iter.next(arr_state)
where_driver = jit.JitDriver(name='numpy_where',
greens = ['shapelen', 'dtype', 'arr_dtype'],
reds = 'auto')
def where(space, out, shape, arr, x, y, dtype):
- out_iter = out.create_iter(shape)
- arr_iter = arr.create_iter(shape)
+ out_iter, out_state = out.create_iter(shape)
+ arr_iter, arr_state = arr.create_iter(shape)
arr_dtype = arr.get_dtype()
- x_iter = x.create_iter(shape)
- y_iter = y.create_iter(shape)
+ x_iter, x_state = x.create_iter(shape)
+ y_iter, y_state = y.create_iter(shape)
if x.is_scalar():
if y.is_scalar():
- iter = arr_iter
+ iter, state = arr_iter, arr_state
else:
- iter = y_iter
+ iter, state = y_iter, y_state
else:
- iter = x_iter
+ iter, state = x_iter, x_state
shapelen = len(shape)
- while not iter.done():
+ while not iter.done(state):
where_driver.jit_merge_point(shapelen=shapelen, dtype=dtype,
arr_dtype=arr_dtype)
- w_cond = arr_iter.getitem()
+ w_cond = arr_iter.getitem(arr_state)
if arr_dtype.itemtype.bool(w_cond):
- w_val = x_iter.getitem().convert_to(space, dtype)
+ w_val = x_iter.getitem(x_state).convert_to(space, dtype)
else:
- w_val = y_iter.getitem().convert_to(space, dtype)
- out_iter.setitem(w_val)
- out_iter.next()
- arr_iter.next()
- x_iter.next()
- y_iter.next()
+ w_val = y_iter.getitem(y_state).convert_to(space, dtype)
+ out_iter.setitem(out_state, w_val)
+ out_state = out_iter.next(out_state)
+ arr_state = arr_iter.next(arr_state)
+ x_state = x_iter.next(x_state)
+ y_state = y_iter.next(y_state)
+ if x.is_scalar():
+ if y.is_scalar():
+ state = arr_state
+ else:
+ state = y_state
+ else:
+ state = x_state
return out
axis_reduce__driver = jit.JitDriver(name='numpy_axis_reduce',
@@ -212,31 +213,36 @@
def do_axis_reduce(space, shape, func, arr, dtype, axis, out, identity, cumulative,
temp):
out_iter = AxisIter(out.implementation, arr.get_shape(), axis, cumulative)
+ out_state = out_iter.reset()
if cumulative:
temp_iter = AxisIter(temp.implementation, arr.get_shape(), axis, False)
+ temp_state = temp_iter.reset()
else:
- temp_iter = out_iter # hack
- arr_iter = arr.create_iter()
+ temp_iter = out_iter # hack
+ temp_state = out_state
+ arr_iter, arr_state = arr.create_iter()
if identity is not None:
identity = identity.convert_to(space, dtype)
shapelen = len(shape)
- while not out_iter.done():
+ while not out_iter.done(out_state):
axis_reduce__driver.jit_merge_point(shapelen=shapelen, func=func,
dtype=dtype)
- assert not arr_iter.done()
- w_val = arr_iter.getitem().convert_to(space, dtype)
- if out_iter.indices[axis] == 0:
+ assert not arr_iter.done(arr_state)
+ w_val = arr_iter.getitem(arr_state).convert_to(space, dtype)
+ if out_state.indices[axis] == 0:
if identity is not None:
w_val = func(dtype, identity, w_val)
else:
- cur = temp_iter.getitem()
+ cur = temp_iter.getitem(temp_state)
w_val = func(dtype, cur, w_val)
- out_iter.setitem(w_val)
+ out_iter.setitem(out_state, w_val)
+ out_state = out_iter.next(out_state)
if cumulative:
- temp_iter.setitem(w_val)
- temp_iter.next()
- arr_iter.next()
- out_iter.next()
+ temp_iter.setitem(temp_state, w_val)
+ temp_state = temp_iter.next(temp_state)
+ else:
+ temp_state = out_state
+ arr_state = arr_iter.next(arr_state)
return out
@@ -249,18 +255,18 @@
result = 0
idx = 1
dtype = arr.get_dtype()
- iter = arr.create_iter()
- cur_best = iter.getitem()
- iter.next()
+ iter, state = arr.create_iter()
+ cur_best = iter.getitem(state)
+ state = iter.next(state)
shapelen = len(arr.get_shape())
- while not iter.done():
+ while not iter.done(state):
arg_driver.jit_merge_point(shapelen=shapelen, dtype=dtype)
- w_val = iter.getitem()
+ w_val = iter.getitem(state)
new_best = getattr(dtype.itemtype, op_name)(cur_best, w_val)
if dtype.itemtype.ne(new_best, cur_best):
result = idx
cur_best = new_best
- iter.next()
+ state = iter.next(state)
idx += 1
return result
return argmin_argmax
@@ -291,17 +297,19 @@
right_impl = right.implementation
assert left_shape[-1] == right_shape[right_critical_dim]
assert result.get_dtype() == dtype
- outi = result.create_iter()
+ outi, outs = result.create_iter()
lefti = AllButAxisIter(left_impl, len(left_shape) - 1)
righti = AllButAxisIter(right_impl, right_critical_dim)
+ lefts = lefti.reset()
+ rights = righti.reset()
n = left_impl.shape[-1]
s1 = left_impl.strides[-1]
s2 = right_impl.strides[right_critical_dim]
- while not lefti.done():
- while not righti.done():
- oval = outi.getitem()
- i1 = lefti.offset
- i2 = righti.offset
+ while not lefti.done(lefts):
+ while not righti.done(rights):
+ oval = outi.getitem(outs)
+ i1 = lefts.offset
+ i2 = rights.offset
i = 0
while i < n:
i += 1
@@ -311,11 +319,11 @@
oval = dtype.itemtype.add(oval, dtype.itemtype.mul(lval, rval))
i1 += s1
i2 += s2
- outi.setitem(oval)
- outi.next()
- righti.next()
- righti.reset()
- lefti.next()
+ outi.setitem(outs, oval)
+ outs = outi.next(outs)
+ rights = righti.next(rights)
+ rights = righti.reset()
+ lefts = lefti.next(lefts)
return result
count_all_true_driver = jit.JitDriver(name = 'numpy_count',
@@ -324,13 +332,13 @@
def count_all_true_concrete(impl):
s = 0
- iter = impl.create_iter()
+ iter, state = impl.create_iter()
shapelen = len(impl.shape)
dtype = impl.dtype
- while not iter.done():
+ while not iter.done(state):
count_all_true_driver.jit_merge_point(shapelen=shapelen, dtype=dtype)
- s += iter.getitem_bool()
- iter.next()
+ s += iter.getitem_bool(state)
+ state = iter.next(state)
return s
def count_all_true(arr):
@@ -344,18 +352,18 @@
reds = 'auto')
def nonzero(res, arr, box):
- res_iter = res.create_iter()
- arr_iter = arr.create_iter()
+ res_iter, res_state = res.create_iter()
+ arr_iter, arr_state = arr.create_iter()
shapelen = len(arr.shape)
dtype = arr.dtype
dims = range(shapelen)
- while not arr_iter.done():
+ while not arr_iter.done(arr_state):
nonzero_driver.jit_merge_point(shapelen=shapelen, dims=dims, dtype=dtype)
- if arr_iter.getitem_bool():
+ if arr_iter.getitem_bool(arr_state):
for d in dims:
- res_iter.setitem(box(arr_iter.indices[d]))
- res_iter.next()
- arr_iter.next()
+ res_iter.setitem(res_state, box(arr_state.indices[d]))
+ res_state = res_iter.next(res_state)
+ arr_state = arr_iter.next(arr_state)
return res
@@ -365,26 +373,26 @@
reds = 'auto')
def getitem_filter(res, arr, index):
- res_iter = res.create_iter()
+ res_iter, res_state = res.create_iter()
shapelen = len(arr.get_shape())
if shapelen > 1 and len(index.get_shape()) < 2:
- index_iter = index.create_iter(arr.get_shape(), backward_broadcast=True)
+ index_iter, index_state = index.create_iter(arr.get_shape(), backward_broadcast=True)
else:
- index_iter = index.create_iter()
- arr_iter = arr.create_iter()
+ index_iter, index_state = index.create_iter()
+ arr_iter, arr_state = arr.create_iter()
arr_dtype = arr.get_dtype()
index_dtype = index.get_dtype()
# XXX length of shape of index as well?
- while not index_iter.done():
+ while not index_iter.done(index_state):
getitem_filter_driver.jit_merge_point(shapelen=shapelen,
index_dtype=index_dtype,
arr_dtype=arr_dtype,
)
- if index_iter.getitem_bool():
- res_iter.setitem(arr_iter.getitem())
- res_iter.next()
- index_iter.next()
- arr_iter.next()
+ if index_iter.getitem_bool(index_state):
+ res_iter.setitem(res_state, arr_iter.getitem(arr_state))
+ res_state = res_iter.next(res_state)
+ index_state = index_iter.next(index_state)
+ arr_state = arr_iter.next(arr_state)
return res
setitem_filter_driver = jit.JitDriver(name = 'numpy_setitem_bool',
@@ -393,41 +401,42 @@
reds = 'auto')
def setitem_filter(space, arr, index, value):
- arr_iter = arr.create_iter()
+ arr_iter, arr_state = arr.create_iter()
shapelen = len(arr.get_shape())
if shapelen > 1 and len(index.get_shape()) < 2:
- index_iter = index.create_iter(arr.get_shape(), backward_broadcast=True)
+ index_iter, index_state = index.create_iter(arr.get_shape(), backward_broadcast=True)
else:
- index_iter = index.create_iter()
+ index_iter, index_state = index.create_iter()
if value.get_size() == 1:
- value_iter = value.create_iter(arr.get_shape())
+ value_iter, value_state = value.create_iter(arr.get_shape())
else:
- value_iter = value.create_iter()
+ value_iter, value_state = value.create_iter()
index_dtype = index.get_dtype()
arr_dtype = arr.get_dtype()
- while not index_iter.done():
+ while not index_iter.done(index_state):
setitem_filter_driver.jit_merge_point(shapelen=shapelen,
index_dtype=index_dtype,
arr_dtype=arr_dtype,
)
- if index_iter.getitem_bool():
- arr_iter.setitem(arr_dtype.coerce(space, value_iter.getitem()))
- value_iter.next()
- arr_iter.next()
- index_iter.next()
+ if index_iter.getitem_bool(index_state):
+ val = arr_dtype.coerce(space, value_iter.getitem(value_state))
+ value_state = value_iter.next(value_state)
+ arr_iter.setitem(arr_state, val)
+ arr_state = arr_iter.next(arr_state)
+ index_state = index_iter.next(index_state)
flatiter_getitem_driver = jit.JitDriver(name = 'numpy_flatiter_getitem',
greens = ['dtype'],
reds = 'auto')
-def flatiter_getitem(res, base_iter, step):
- ri = res.create_iter()
+def flatiter_getitem(res, base_iter, base_state, step):
+ ri, rs = res.create_iter()
dtype = res.get_dtype()
- while not ri.done():
+ while not ri.done(rs):
flatiter_getitem_driver.jit_merge_point(dtype=dtype)
- ri.setitem(base_iter.getitem())
- base_iter.next_skip_x(step)
- ri.next()
+ ri.setitem(rs, base_iter.getitem(base_state))
+ base_state = base_iter.next_skip_x(base_state, step)
+ rs = ri.next(rs)
return res
flatiter_setitem_driver = jit.JitDriver(name = 'numpy_flatiter_setitem',
@@ -436,19 +445,21 @@
def flatiter_setitem(space, arr, val, start, step, length):
dtype = arr.get_dtype()
- arr_iter = arr.create_iter()
- val_iter = val.create_iter()
- arr_iter.next_skip_x(start)
+ arr_iter, arr_state = arr.create_iter()
+ val_iter, val_state = val.create_iter()
+ arr_state = arr_iter.next_skip_x(arr_state, start)
while length > 0:
flatiter_setitem_driver.jit_merge_point(dtype=dtype)
+ val = val_iter.getitem(val_state)
if dtype.is_str_or_unicode():
- arr_iter.setitem(dtype.coerce(space, val_iter.getitem()))
+ val = dtype.coerce(space, val)
else:
- arr_iter.setitem(val_iter.getitem().convert_to(space, dtype))
+ val = val.convert_to(space, dtype)
+ arr_iter.setitem(arr_state, val)
# need to repeat i_nput values until all assignments are done
- arr_iter.next_skip_x(step)
+ arr_state = arr_iter.next_skip_x(arr_state, step)
+ val_state = val_iter.next(val_state)
length -= 1
- val_iter.next()
fromstring_driver = jit.JitDriver(name = 'numpy_fromstring',
greens = ['itemsize', 'dtype'],
@@ -456,30 +467,30 @@
def fromstring_loop(space, a, dtype, itemsize, s):
i = 0
- ai = a.create_iter()
- while not ai.done():
+ ai, state = a.create_iter()
+ while not ai.done(state):
fromstring_driver.jit_merge_point(dtype=dtype, itemsize=itemsize)
sub = s[i*itemsize:i*itemsize + itemsize]
if dtype.is_str_or_unicode():
val = dtype.coerce(space, space.wrap(sub))
else:
val = dtype.itemtype.runpack_str(space, sub)
- ai.setitem(val)
- ai.next()
+ ai.setitem(state, val)
+ state = ai.next(state)
i += 1
def tostring(space, arr):
builder = StringBuilder()
- iter = arr.create_iter()
+ iter, state = arr.create_iter()
w_res_str = W_NDimArray.from_shape(space, [1], arr.get_dtype(), order='C')
itemsize = arr.get_dtype().elsize
res_str_casted = rffi.cast(rffi.CArrayPtr(lltype.Char),
w_res_str.implementation.get_storage_as_int(space))
- while not iter.done():
- w_res_str.implementation.setitem(0, iter.getitem())
+ while not iter.done(state):
+ w_res_str.implementation.setitem(0, iter.getitem(state))
for i in range(itemsize):
builder.append(res_str_casted[i])
- iter.next()
+ state = iter.next(state)
return builder.build()
getitem_int_driver = jit.JitDriver(name = 'numpy_getitem_int',
@@ -500,8 +511,8 @@
# prepare the index
index_w = [None] * indexlen
for i in range(indexlen):
- if iter.idx_w[i] is not None:
- index_w[i] = iter.idx_w[i].getitem()
+ if iter.idx_w_i[i] is not None:
+ index_w[i] = iter.idx_w_i[i].getitem(iter.idx_w_s[i])
else:
index_w[i] = indexes_w[i]
res.descr_setitem(space, space.newtuple(prefix_w[:prefixlen] +
@@ -528,8 +539,8 @@
# prepare the index
index_w = [None] * indexlen
for i in range(indexlen):
- if iter.idx_w[i] is not None:
- index_w[i] = iter.idx_w[i].getitem()
+ if iter.idx_w_i[i] is not None:
+ index_w[i] = iter.idx_w_i[i].getitem(iter.idx_w_s[i])
else:
index_w[i] = indexes_w[i]
w_idx = space.newtuple(prefix_w[:prefixlen] + iter.get_index(space,
@@ -547,13 +558,14 @@
def byteswap(from_, to):
dtype = from_.dtype
- from_iter = from_.create_iter()
- to_iter = to.create_iter()
- while not from_iter.done():
+ from_iter, from_state = from_.create_iter()
+ to_iter, to_state = to.create_iter()
+ while not from_iter.done(from_state):
byteswap_driver.jit_merge_point(dtype=dtype)
- to_iter.setitem(dtype.itemtype.byteswap(from_iter.getitem()))
- to_iter.next()
- from_iter.next()
+ val = dtype.itemtype.byteswap(from_iter.getitem(from_state))
+ to_iter.setitem(to_state, val)
+ to_state = to_iter.next(to_state)
+ from_state = from_iter.next(from_state)
choose_driver = jit.JitDriver(name='numpy_choose_driver',
greens = ['shapelen', 'mode', 'dtype'],
@@ -561,13 +573,15 @@
def choose(space, arr, choices, shape, dtype, out, mode):
shapelen = len(shape)
- iterators = [a.create_iter(shape) for a in choices]
- arr_iter = arr.create_iter(shape)
- out_iter = out.create_iter(shape)
- while not arr_iter.done():
+ pairs = [a.create_iter(shape) for a in choices]
+ iterators = [i[0] for i in pairs]
+ states = [i[1] for i in pairs]
+ arr_iter, arr_state = arr.create_iter(shape)
+ out_iter, out_state = out.create_iter(shape)
+ while not arr_iter.done(arr_state):
choose_driver.jit_merge_point(shapelen=shapelen, dtype=dtype,
mode=mode)
- index = support.index_w(space, arr_iter.getitem())
+ index = support.index_w(space, arr_iter.getitem(arr_state))
if index < 0 or index >= len(iterators):
if mode == NPY.RAISE:
raise OperationError(space.w_ValueError, space.wrap(
@@ -580,72 +594,73 @@
index = 0
else:
index = len(iterators) - 1
- out_iter.setitem(iterators[index].getitem().convert_to(space, dtype))
- for iter in iterators:
- iter.next()
- out_iter.next()
- arr_iter.next()
+ val = iterators[index].getitem(states[index]).convert_to(space, dtype)
+ out_iter.setitem(out_state, val)
+ for i in range(len(iterators)):
+ states[i] = iterators[i].next(states[i])
+ out_state = out_iter.next(out_state)
+ arr_state = arr_iter.next(arr_state)
clip_driver = jit.JitDriver(name='numpy_clip_driver',
greens = ['shapelen', 'dtype'],
reds = 'auto')
def clip(space, arr, shape, min, max, out):
- arr_iter = arr.create_iter(shape)
+ arr_iter, arr_state = arr.create_iter(shape)
dtype = out.get_dtype()
shapelen = len(shape)
- min_iter = min.create_iter(shape)
- max_iter = max.create_iter(shape)
- out_iter = out.create_iter(shape)
- while not arr_iter.done():
+ min_iter, min_state = min.create_iter(shape)
+ max_iter, max_state = max.create_iter(shape)
+ out_iter, out_state = out.create_iter(shape)
+ while not arr_iter.done(arr_state):
clip_driver.jit_merge_point(shapelen=shapelen, dtype=dtype)
- w_v = arr_iter.getitem().convert_to(space, dtype)
- w_min = min_iter.getitem().convert_to(space, dtype)
- w_max = max_iter.getitem().convert_to(space, dtype)
+ w_v = arr_iter.getitem(arr_state).convert_to(space, dtype)
+ w_min = min_iter.getitem(min_state).convert_to(space, dtype)
+ w_max = max_iter.getitem(max_state).convert_to(space, dtype)
if dtype.itemtype.lt(w_v, w_min):
w_v = w_min
elif dtype.itemtype.gt(w_v, w_max):
w_v = w_max
- out_iter.setitem(w_v)
- arr_iter.next()
- max_iter.next()
- out_iter.next()
- min_iter.next()
+ out_iter.setitem(out_state, w_v)
+ arr_state = arr_iter.next(arr_state)
+ min_state = min_iter.next(min_state)
+ max_state = max_iter.next(max_state)
+ out_state = out_iter.next(out_state)
round_driver = jit.JitDriver(name='numpy_round_driver',
greens = ['shapelen', 'dtype'],
reds = 'auto')
def round(space, arr, dtype, shape, decimals, out):
- arr_iter = arr.create_iter(shape)
+ arr_iter, arr_state = arr.create_iter(shape)
+ out_iter, out_state = out.create_iter(shape)
shapelen = len(shape)
- out_iter = out.create_iter(shape)
- while not arr_iter.done():
+ while not arr_iter.done(arr_state):
round_driver.jit_merge_point(shapelen=shapelen, dtype=dtype)
- w_v = arr_iter.getitem().convert_to(space, dtype)
+ w_v = arr_iter.getitem(arr_state).convert_to(space, dtype)
w_v = dtype.itemtype.round(w_v, decimals)
- out_iter.setitem(w_v)
- arr_iter.next()
- out_iter.next()
+ out_iter.setitem(out_state, w_v)
+ arr_state = arr_iter.next(arr_state)
+ out_state = out_iter.next(out_state)
diagonal_simple_driver = jit.JitDriver(name='numpy_diagonal_simple_driver',
greens = ['axis1', 'axis2'],
reds = 'auto')
def diagonal_simple(space, arr, out, offset, axis1, axis2, size):
- out_iter = out.create_iter()
+ out_iter, out_state = out.create_iter()
i = 0
index = [0] * 2
while i < size:
diagonal_simple_driver.jit_merge_point(axis1=axis1, axis2=axis2)
index[axis1] = i
index[axis2] = i + offset
- out_iter.setitem(arr.getitem_index(space, index))
+ out_iter.setitem(out_state, arr.getitem_index(space, index))
i += 1
- out_iter.next()
+ out_state = out_iter.next(out_state)
def diagonal_array(space, arr, out, offset, axis1, axis2, shape):
- out_iter = out.create_iter()
+ out_iter, out_state = out.create_iter()
iter = PureShapeIter(shape, [])
shapelen_minus_1 = len(shape) - 1
assert shapelen_minus_1 >= 0
@@ -667,6 +682,6 @@
indexes = (iter.indexes[:a] + [last_index + offset] +
iter.indexes[a:b] + [last_index] +
iter.indexes[b:shapelen_minus_1])
- out_iter.setitem(arr.getitem_index(space, indexes))
+ out_iter.setitem(out_state, arr.getitem_index(space, indexes))
iter.next()
- out_iter.next()
+ out_state = out_iter.next(out_state)
diff --git a/pypy/module/micronumpy/ndarray.py b/pypy/module/micronumpy/ndarray.py
--- a/pypy/module/micronumpy/ndarray.py
+++ b/pypy/module/micronumpy/ndarray.py
@@ -18,7 +18,7 @@
multi_axis_converter
from pypy.module.micronumpy.flagsobj import W_FlagsObject
from pypy.module.micronumpy.flatiter import W_FlatIterator
-from pypy.module.micronumpy.strides import get_shape_from_iterable, to_coords, \
+from pypy.module.micronumpy.strides import get_shape_from_iterable, \
shape_agreement, shape_agreement_multiple
@@ -260,24 +260,24 @@
return space.call_function(cache.w_array_str, self)
def dump_data(self, prefix='array(', separator=',', suffix=')'):
- i = self.create_iter()
+ i, state = self.create_iter()
first = True
dtype = self.get_dtype()
s = StringBuilder()
s.append(prefix)
if not self.is_scalar():
s.append('[')
- while not i.done():
+ while not i.done(state):
if first:
first = False
else:
s.append(separator)
s.append(' ')
if self.is_scalar() and dtype.is_str():
- s.append(dtype.itemtype.to_str(i.getitem()))
+ s.append(dtype.itemtype.to_str(i.getitem(state)))
else:
- s.append(dtype.itemtype.str_format(i.getitem()))
- i.next()
+ s.append(dtype.itemtype.str_format(i.getitem(state)))
+ state = i.next(state)
if not self.is_scalar():
s.append(']')
s.append(suffix)
@@ -469,29 +469,33 @@
def descr_get_flatiter(self, space):
return space.wrap(W_FlatIterator(self))
- def to_coords(self, space, w_index):
- coords, _, _ = to_coords(space, self.get_shape(),
- self.get_size(), self.get_order(),
- w_index)
- return coords
-
- def descr_item(self, space, w_arg=None):
- if space.is_none(w_arg):
+ def descr_item(self, space, __args__):
+ args_w, kw_w = __args__.unpack()
+ if len(args_w) == 1 and space.isinstance_w(args_w[0], space.w_tuple):
+ args_w = space.fixedview(args_w[0])
+ shape = self.get_shape()
+ coords = [0] * len(shape)
+ if len(args_w) == 0:
if self.get_size() == 1:
w_obj = self.get_scalar_value()
assert isinstance(w_obj, boxes.W_GenericBox)
return w_obj.item(space)
raise oefmt(space.w_ValueError,
"can only convert an array of size 1 to a Python scalar")
- if space.isinstance_w(w_arg, space.w_int):
- if self.is_scalar():
- raise oefmt(space.w_IndexError, "index out of bounds")
- i = self.to_coords(space, w_arg)
- item = self.getitem(space, i)
- assert isinstance(item, boxes.W_GenericBox)
- return item.item(space)
- raise OperationError(space.w_NotImplementedError, space.wrap(
- "non-int arg not supported"))
+ elif len(args_w) == 1 and len(shape) != 1:
+ value = support.index_w(space, args_w[0])
+ value = support.check_and_adjust_index(space, value, self.get_size(), -1)
+ for idim in range(len(shape) - 1, -1, -1):
+ coords[idim] = value % shape[idim]
+ value //= shape[idim]
+ elif len(args_w) == len(shape):
+ for idim in range(len(shape)):
+ coords[idim] = support.index_w(space, args_w[idim])
+ else:
+ raise oefmt(space.w_ValueError, "incorrect number of indices for array")
+ item = self.getitem(space, coords)
+ assert isinstance(item, boxes.W_GenericBox)
+ return item.item(space)
def descr_itemset(self, space, args_w):
if len(args_w) == 0:
@@ -841,8 +845,8 @@
if self.get_size() > 1:
raise OperationError(space.w_ValueError, space.wrap(
"The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()"))
- iter = self.create_iter()
- return space.wrap(space.is_true(iter.getitem()))
+ iter, state = self.create_iter()
+ return space.wrap(space.is_true(iter.getitem(state)))
def _binop_impl(ufunc_name):
def impl(self, space, w_other, w_out=None):
@@ -1019,7 +1023,8 @@
descr_cumsum = _reduce_ufunc_impl('add', cumulative=True)
descr_cumprod = _reduce_ufunc_impl('multiply', cumulative=True)
- def _reduce_argmax_argmin_impl(op_name):
+ def _reduce_argmax_argmin_impl(raw_name):
+ op_name = "arg%s" % raw_name
def impl(self, space, w_axis=None, w_out=None):
if not space.is_none(w_axis):
raise oefmt(space.w_NotImplementedError,
@@ -1030,18 +1035,17 @@
if self.get_size() == 0:
raise oefmt(space.w_ValueError,
"Can't call %s on zero-size arrays", op_name)
- op = getattr(loop, op_name)
try:
- res = op(self)
+ getattr(self.get_dtype().itemtype, raw_name)
except AttributeError:
raise oefmt(space.w_NotImplementedError,
'%s not implemented for %s',
op_name, self.get_dtype().get_name())
- return space.wrap(res)
- return func_with_new_name(impl, "reduce_arg%s_impl" % op_name)
+ return space.wrap(getattr(loop, op_name)(self))
+ return func_with_new_name(impl, "reduce_%s_impl" % op_name)
- descr_argmax = _reduce_argmax_argmin_impl("argmax")
- descr_argmin = _reduce_argmax_argmin_impl("argmin")
+ descr_argmax = _reduce_argmax_argmin_impl("max")
+ descr_argmin = _reduce_argmax_argmin_impl("min")
def descr_int(self, space):
if self.get_size() != 1:
@@ -1118,11 +1122,11 @@
builder = StringBuilder()
if isinstance(self.implementation, SliceArray):
- iter = self.implementation.create_iter()
- while not iter.done():
- box = iter.getitem()
+ iter, state = self.implementation.create_iter()
+ while not iter.done(state):
+ box = iter.getitem(state)
builder.append(box.raw_str())
- iter.next()
+ state = iter.next(state)
else:
builder.append_charpsize(self.implementation.get_storage(), self.implementation.get_storage_size())
diff --git a/pypy/module/micronumpy/nditer.py b/pypy/module/micronumpy/nditer.py
--- a/pypy/module/micronumpy/nditer.py
+++ b/pypy/module/micronumpy/nditer.py
@@ -1,99 +1,50 @@
from pypy.interpreter.baseobjspace import W_Root
from pypy.interpreter.typedef import TypeDef, GetSetProperty
from pypy.interpreter.gateway import interp2app, unwrap_spec, WrappedDefault
-from pypy.interpreter.error import OperationError
+from pypy.interpreter.error import OperationError, oefmt
+from pypy.module.micronumpy import ufuncs, support, concrete
from pypy.module.micronumpy.base import W_NDimArray, convert_to_array
+from pypy.module.micronumpy.descriptor import decode_w_dtype
+from pypy.module.micronumpy.iterators import ArrayIter
from pypy.module.micronumpy.strides import (calculate_broadcast_strides,
- shape_agreement, shape_agreement_multiple)
-from pypy.module.micronumpy.iterators import ArrayIter, SliceIterator
-from pypy.module.micronumpy.concrete import SliceArray
-from pypy.module.micronumpy.descriptor import decode_w_dtype
-from pypy.module.micronumpy import ufuncs, support
+ shape_agreement, shape_agreement_multiple)
-class AbstractIterator(object):
- def done(self):
- raise NotImplementedError("Abstract Class")
-
- def next(self):
- raise NotImplementedError("Abstract Class")
-
- def getitem(self, space, array):
- raise NotImplementedError("Abstract Class")
-
-class IteratorMixin(object):
- _mixin_ = True
- def __init__(self, it, op_flags):
- self.it = it
- self.op_flags = op_flags
-
- def done(self):
- return self.it.done()
-
- def next(self):
- self.it.next()
-
- def getitem(self, space, array):
- return self.op_flags.get_it_item[self.index](space, array, self.it)
-
- def setitem(self, space, array, val):
- xxx
-
-class BoxIterator(IteratorMixin, AbstractIterator):
- index = 0
-
-class ExternalLoopIterator(IteratorMixin, AbstractIterator):
- index = 1
-
def parse_op_arg(space, name, w_op_flags, n, parse_one_arg):
+ if space.is_w(w_op_flags, space.w_None):
+ w_op_flags = space.newtuple([space.wrap('readonly')])
+ if not space.isinstance_w(w_op_flags, space.w_tuple) and not \
+ space.isinstance_w(w_op_flags, space.w_list):
+ raise oefmt(space.w_ValueError,
+ '%s must be a tuple or array of per-op flag-tuples',
+ name)
ret = []
- if space.is_w(w_op_flags, space.w_None):
+ w_lst = space.listview(w_op_flags)
+ if space.isinstance_w(w_lst[0], space.w_tuple) or \
+ space.isinstance_w(w_lst[0], space.w_list):
+ if len(w_lst) != n:
+ raise oefmt(space.w_ValueError,
+ '%s must be a tuple or array of per-op flag-tuples',
+ name)
+ for item in w_lst:
+ ret.append(parse_one_arg(space, space.listview(item)))
+ else:
+ op_flag = parse_one_arg(space, w_lst)
for i in range(n):
- ret.append(OpFlag())
- elif not space.isinstance_w(w_op_flags, space.w_tuple) and not \
- space.isinstance_w(w_op_flags, space.w_list):
- raise OperationError(space.w_ValueError, space.wrap(
- '%s must be a tuple or array of per-op flag-tuples' % name))
- else:
- w_lst = space.listview(w_op_flags)
- if space.isinstance_w(w_lst[0], space.w_tuple) or \
- space.isinstance_w(w_lst[0], space.w_list):
- if len(w_lst) != n:
- raise OperationError(space.w_ValueError, space.wrap(
- '%s must be a tuple or array of per-op flag-tuples' % name))
- for item in w_lst:
- ret.append(parse_one_arg(space, space.listview(item)))
- else:
- op_flag = parse_one_arg(space, w_lst)
- for i in range(n):
- ret.append(op_flag)
+ ret.append(op_flag)
return ret
+
class OpFlag(object):
def __init__(self):
- self.rw = 'r'
+ self.rw = ''
self.broadcast = True
self.force_contig = False
self.force_align = False
self.native_byte_order = False
self.tmp_copy = ''
self.allocate = False
- self.get_it_item = (get_readonly_item, get_readonly_slice)
-def get_readonly_item(space, array, it):
- return space.wrap(it.getitem())
-
-def get_readwrite_item(space, array, it):
- #create a single-value view (since scalars are not views)
- res = SliceArray(it.array.start + it.offset, [0], [0], [1,], it.array, array)
- #it.dtype.setitem(res, 0, it.getitem())
- return W_NDimArray(res)
-
-def get_readonly_slice(space, array, it):
- return W_NDimArray(it.getslice().readonly())
-
-def get_readwrite_slice(space, array, it):
- return W_NDimArray(it.getslice())
def parse_op_flag(space, lst):
op_flag = OpFlag()
@@ -121,39 +72,38 @@
op_flag.allocate = True
elif item == 'no_subtype':
raise OperationError(space.w_NotImplementedError, space.wrap(
- '"no_subtype" op_flag not implemented yet'))
+ '"no_subtype" op_flag not implemented yet'))
elif item == 'arraymask':
raise OperationError(space.w_NotImplementedError, space.wrap(
- '"arraymask" op_flag not implemented yet'))
+ '"arraymask" op_flag not implemented yet'))
elif item == 'writemask':
raise OperationError(space.w_NotImplementedError, space.wrap(
- '"writemask" op_flag not implemented yet'))
+ '"writemask" op_flag not implemented yet'))
else:
raise OperationError(space.w_ValueError, space.wrap(
- 'op_flags must be a tuple or array of per-op flag-tuples'))
- if op_flag.rw == 'r':
- op_flag.get_it_item = (get_readonly_item, get_readonly_slice)
- elif op_flag.rw == 'rw':
- op_flag.get_it_item = (get_readwrite_item, get_readwrite_slice)
- elif op_flag.rw == 'w':
- # XXX Extra logic needed to make sure writeonly
- op_flag.get_it_item = (get_readwrite_item, get_readwrite_slice)
+ 'op_flags must be a tuple or array of per-op flag-tuples'))
+ if op_flag.rw == '':
+ raise oefmt(space.w_ValueError,
+ "None of the iterator flags READWRITE, READONLY, or "
+ "WRITEONLY were specified for an operand")
return op_flag
+
def parse_func_flags(space, nditer, w_flags):
if space.is_w(w_flags, space.w_None):
return
elif not space.isinstance_w(w_flags, space.w_tuple) and not \
- space.isinstance_w(w_flags, space.w_list):
+ space.isinstance_w(w_flags, space.w_list):
raise OperationError(space.w_ValueError, space.wrap(
- 'Iter global flags must be a list or tuple of strings'))
+ 'Iter global flags must be a list or tuple of strings'))
lst = space.listview(w_flags)
for w_item in lst:
if not space.isinstance_w(w_item, space.w_str) and not \
- space.isinstance_w(w_item, space.w_unicode):
+ space.isinstance_w(w_item, space.w_unicode):
typename = space.type(w_item).getname(space)
- raise OperationError(space.w_TypeError, space.wrap(
- 'expected string or Unicode object, %s found' % typename))
+ raise oefmt(space.w_TypeError,
+ 'expected string or Unicode object, %s found',
+ typename)
item = space.str_w(w_item)
if item == 'external_loop':
raise OperationError(space.w_NotImplementedError, space.wrap(
@@ -187,21 +137,24 @@
elif item == 'zerosize_ok':
nditer.zerosize_ok = True
else:
- raise OperationError(space.w_ValueError, space.wrap(
- 'Unexpected iterator global flag "%s"' % item))
+ raise oefmt(space.w_ValueError,
+ 'Unexpected iterator global flag "%s"',
+ item)
if nditer.tracked_index and nditer.external_loop:
- raise OperationError(space.w_ValueError, space.wrap(
- 'Iterator flag EXTERNAL_LOOP cannot be used if an index or '
- 'multi-index is being tracked'))
+ raise OperationError(space.w_ValueError, space.wrap(
+ 'Iterator flag EXTERNAL_LOOP cannot be used if an index or '
+ 'multi-index is being tracked'))
+
def is_backward(imp, order):
if order == 'K' or (order == 'C' and imp.order == 'C'):
return False
- elif order =='F' and imp.order == 'C':
+ elif order == 'F' and imp.order == 'C':
return True
else:
raise NotImplementedError('not implemented yet')
+
def get_iter(space, order, arr, shape, dtype):
imp = arr.implementation
backward = is_backward(imp, order)
@@ -223,19 +176,6 @@
shape, backward)
return ArrayIter(imp, imp.get_size(), shape, r[0], r[1])
-def get_external_loop_iter(space, order, arr, shape):
- imp = arr.implementation
- backward = is_backward(imp, order)
- return SliceIterator(arr, imp.strides, imp.backstrides, shape, order=order, backward=backward)
-
-def convert_to_array_or_none(space, w_elem):
- '''
- None will be passed through, all others will be converted
- '''
- if space.is_none(w_elem):
- return None
- return convert_to_array(space, w_elem)
-
class IndexIterator(object):
def __init__(self, shape, backward=False):
@@ -263,10 +203,10 @@
ret += self.index[i] * self.shape[i - 1]
return ret
+
class W_NDIter(W_Root):
-
def __init__(self, space, w_seq, w_flags, w_op_flags, w_op_dtypes, w_casting,
- w_op_axes, w_itershape, w_buffersize, order):
+ w_op_axes, w_itershape, w_buffersize, order):
self.order = order
self.external_loop = False
self.buffered = False
@@ -286,9 +226,11 @@
if space.isinstance_w(w_seq, space.w_tuple) or \
space.isinstance_w(w_seq, space.w_list):
w_seq_as_list = space.listview(w_seq)
- self.seq = [convert_to_array_or_none(space, w_elem) for w_elem in w_seq_as_list]
+ self.seq = [convert_to_array(space, w_elem)
+ if not space.is_none(w_elem) else None
+ for w_elem in w_seq_as_list]
else:
- self.seq =[convert_to_array(space, w_seq)]
+ self.seq = [convert_to_array(space, w_seq)]
parse_func_flags(space, self, w_flags)
self.op_flags = parse_op_arg(space, 'op_flags', w_op_flags,
@@ -308,9 +250,9 @@
self.dtypes = []
# handle None or writable operands, calculate my shape
- self.iters=[]
- outargs = [i for i in range(len(self.seq)) \
- if self.seq[i] is None or self.op_flags[i].rw == 'w']
+ self.iters = []
+ outargs = [i for i in range(len(self.seq))
+ if self.seq[i] is None or self.op_flags[i].rw == 'w']
if len(outargs) > 0:
out_shape = shape_agreement_multiple(space, [self.seq[i] for i in outargs])
else:
@@ -325,14 +267,12 @@
out_dtype = None
for i in range(len(self.seq)):
if self.seq[i] is None:
- self.op_flags[i].get_it_item = (get_readwrite_item,
- get_readwrite_slice)
self.op_flags[i].allocate = True
continue
if self.op_flags[i].rw == 'w':
continue
- out_dtype = ufuncs.find_binop_result_dtype(space,
- self.seq[i].get_dtype(), out_dtype)
+ out_dtype = ufuncs.find_binop_result_dtype(
+ space, self.seq[i].get_dtype(), out_dtype)
for i in outargs:
if self.seq[i] is None:
# XXX can we postpone allocation to later?
@@ -360,8 +300,9 @@
self.dtypes[i] = seq_d
elif selfd != seq_d:
if not 'r' in self.op_flags[i].tmp_copy:
- raise OperationError(space.w_TypeError, space.wrap(
- "Iterator operand required copying or buffering for operand %d" % i))
+ raise oefmt(space.w_TypeError,
+ "Iterator operand required copying or "
+ "buffering for operand %d", i)
impl = self.seq[i].implementation
new_impl = impl.astype(space, selfd)
self.seq[i] = W_NDimArray(new_impl)
@@ -370,18 +311,14 @@
self.dtypes = [s.get_dtype() for s in self.seq]
# create an iterator for each operand
- if self.external_loop:
- for i in range(len(self.seq)):
- self.iters.append(ExternalLoopIterator(get_external_loop_iter(space, self.order,
- self.seq[i], iter_shape), self.op_flags[i]))
- else:
- for i in range(len(self.seq)):
- self.iters.append(BoxIterator(get_iter(space, self.order,
- self.seq[i], iter_shape, self.dtypes[i]),
- self.op_flags[i]))
+ for i in range(len(self.seq)):
+ it = get_iter(space, self.order, self.seq[i], iter_shape, self.dtypes[i])
+ self.iters.append((it, it.reset()))
+
def set_op_axes(self, space, w_op_axes):
if space.len_w(w_op_axes) != len(self.seq):
- raise OperationError(space.w_ValueError, space.wrap("op_axes must be a tuple/list matching the number of ops"))
+ raise oefmt(space.w_ValueError,
+ "op_axes must be a tuple/list matching the number of ops")
op_axes = space.listview(w_op_axes)
l = -1
for w_axis in op_axes:
@@ -390,10 +327,14 @@
if l == -1:
l = axis_len
elif axis_len != l:
- raise OperationError(space.w_ValueError, space.wrap("Each entry of op_axes must have the same size"))
- self.op_axes.append([space.int_w(x) if not space.is_none(x) else -1 for x in space.listview(w_axis)])
+ raise oefmt(space.w_ValueError,
+ "Each entry of op_axes must have the same size")
+ self.op_axes.append([space.int_w(x) if not space.is_none(x) else -1
+ for x in space.listview(w_axis)])
if l == -1:
- raise OperationError(space.w_ValueError, space.wrap("If op_axes is provided, at least one list of axes must be contained within it"))
+ raise oefmt(space.w_ValueError,
+ "If op_axes is provided, at least one list of axes "
+ "must be contained within it")
raise Exception('xxx TODO')
# Check that values make sense:
# - in bounds for each operand
@@ -404,24 +345,34 @@
def descr_iter(self, space):
return space.wrap(self)
+ def getitem(self, it, st, op_flags):
+ if op_flags.rw == 'r':
+ impl = concrete.ConcreteNonWritableArrayWithBase
+ else:
+ impl = concrete.ConcreteArrayWithBase
+ res = impl([], it.array.dtype, it.array.order, [], [],
+ it.array.storage, self)
+ res.start = st.offset
+ return W_NDimArray(res)
+
def descr_getitem(self, space, w_idx):
idx = space.int_w(w_idx)
try:
- ret = space.wrap(self.iters[idx].getitem(space, self.seq[idx]))
+ it, st = self.iters[idx]
except IndexError:
- raise OperationError(space.w_IndexError, space.wrap("Iterator operand index %d is out of bounds" % idx))
- return ret
+ raise oefmt(space.w_IndexError,
+ "Iterator operand index %d is out of bounds", idx)
+ return self.getitem(it, st, self.op_flags[idx])
def descr_setitem(self, space, w_idx, w_value):
- raise OperationError(space.w_NotImplementedError, space.wrap(
- 'not implemented yet'))
+ raise oefmt(space.w_NotImplementedError, "not implemented yet")
def descr_len(self, space):
space.wrap(len(self.iters))
def descr_next(self, space):
- for it in self.iters:
- if not it.done():
+ for it, st in self.iters:
+ if not it.done(st):
break
else:
self.done = True
@@ -432,20 +383,20 @@
self.index_iter.next()
else:
self.first_next = False
- for i in range(len(self.iters)):
- res.append(self.iters[i].getitem(space, self.seq[i]))
- self.iters[i].next()
- if len(res) <2:
+ for i, (it, st) in enumerate(self.iters):
+ res.append(self.getitem(it, st, self.op_flags[i]))
+ self.iters[i] = (it, it.next(st))
+ if len(res) < 2:
return res[0]
return space.newtuple(res)
def iternext(self):
if self.index_iter:
self.index_iter.next()
- for i in range(len(self.iters)):
- self.iters[i].next()
- for it in self.iters:
- if not it.done():
+ for i, (it, st) in enumerate(self.iters):
+ self.iters[i] = (it, it.next(st))
+ for it, st in self.iters:
+ if not it.done(st):
break
else:
self.done = True
@@ -456,29 +407,23 @@
return space.wrap(self.iternext())
def descr_copy(self, space):
- raise OperationError(space.w_NotImplementedError, space.wrap(
- 'not implemented yet'))
+ raise oefmt(space.w_NotImplementedError, "not implemented yet")
def descr_debug_print(self, space):
- raise OperationError(space.w_NotImplementedError, space.wrap(
- 'not implemented yet'))
+ raise oefmt(space.w_NotImplementedError, "not implemented yet")
def descr_enable_external_loop(self, space):
- raise OperationError(space.w_NotImplementedError, space.wrap(
- 'not implemented yet'))
+ raise oefmt(space.w_NotImplementedError, "not implemented yet")
@unwrap_spec(axis=int)
def descr_remove_axis(self, space, axis):
- raise OperationError(space.w_NotImplementedError, space.wrap(
- 'not implemented yet'))
+ raise oefmt(space.w_NotImplementedError, "not implemented yet")
def descr_remove_multi_index(self, space, w_multi_index):
- raise OperationError(space.w_NotImplementedError, space.wrap(
- 'not implemented yet'))
+ raise oefmt(space.w_NotImplementedError, "not implemented yet")
def descr_reset(self, space):
- raise OperationError(space.w_NotImplementedError, space.wrap(
- 'not implemented yet'))
+ raise oefmt(space.w_NotImplementedError, "not implemented yet")
def descr_get_operands(self, space):
l_w = []
@@ -496,17 +441,16 @@
return space.wrap(self.done)
def descr_get_has_delayed_bufalloc(self, space):
- raise OperationError(space.w_NotImplementedError, space.wrap(
- 'not implemented yet'))
+ raise oefmt(space.w_NotImplementedError, "not implemented yet")
def descr_get_has_index(self, space):
return space.wrap(self.tracked_index in ["C", "F"])
def descr_get_index(self, space):
if not self.tracked_index in ["C", "F"]:
- raise OperationError(space.w_ValueError, space.wrap("Iterator does not have an index"))
+ raise oefmt(space.w_ValueError, "Iterator does not have an index")
if self.done:
- raise OperationError(space.w_ValueError, space.wrap("Iterator is past the end"))
+ raise oefmt(space.w_ValueError, "Iterator is past the end")
return space.wrap(self.index_iter.getvalue())
def descr_get_has_multi_index(self, space):
@@ -514,51 +458,44 @@
def descr_get_multi_index(self, space):
if not self.tracked_index == "multi":
- raise OperationError(space.w_ValueError, space.wrap("Iterator is not tracking a multi-index"))
+ raise oefmt(space.w_ValueError, "Iterator is not tracking a multi-index")
if self.done:
- raise OperationError(space.w_ValueError, space.wrap("Iterator is past the end"))
+ raise oefmt(space.w_ValueError, "Iterator is past the end")
return space.newtuple([space.wrap(x) for x in self.index_iter.index])
def descr_get_iterationneedsapi(self, space):
- raise OperationError(space.w_NotImplementedError, space.wrap(
- 'not implemented yet'))
+ raise oefmt(space.w_NotImplementedError, "not implemented yet")
def descr_get_iterindex(self, space):
- raise OperationError(space.w_NotImplementedError, space.wrap(
- 'not implemented yet'))
+ raise oefmt(space.w_NotImplementedError, "not implemented yet")
def descr_get_itersize(self, space):
return space.wrap(support.product(self.shape))
def descr_get_itviews(self, space):
- raise OperationError(space.w_NotImplementedError, space.wrap(
- 'not implemented yet'))
+ raise oefmt(space.w_NotImplementedError, "not implemented yet")
def descr_get_ndim(self, space):
- raise OperationError(space.w_NotImplementedError, space.wrap(
- 'not implemented yet'))
+ raise oefmt(space.w_NotImplementedError, "not implemented yet")
def descr_get_nop(self, space):
- raise OperationError(space.w_NotImplementedError, space.wrap(
- 'not implemented yet'))
+ raise oefmt(space.w_NotImplementedError, "not implemented yet")
def descr_get_shape(self, space):
- raise OperationError(space.w_NotImplementedError, space.wrap(
- 'not implemented yet'))
+ raise oefmt(space.w_NotImplementedError, "not implemented yet")
def descr_get_value(self, space):
- raise OperationError(space.w_NotImplementedError, space.wrap(
- 'not implemented yet'))
+ raise oefmt(space.w_NotImplementedError, "not implemented yet")
- at unwrap_spec(w_flags = WrappedDefault(None), w_op_flags=WrappedDefault(None),
- w_op_dtypes = WrappedDefault(None), order=str,
+ at unwrap_spec(w_flags=WrappedDefault(None), w_op_flags=WrappedDefault(None),
+ w_op_dtypes=WrappedDefault(None), order=str,
w_casting=WrappedDefault(None), w_op_axes=WrappedDefault(None),
w_itershape=WrappedDefault(None), w_buffersize=WrappedDefault(None))
def nditer(space, w_seq, w_flags, w_op_flags, w_op_dtypes, w_casting, w_op_axes,
- w_itershape, w_buffersize, order='K'):
+ w_itershape, w_buffersize, order='K'):
return W_NDIter(space, w_seq, w_flags, w_op_flags, w_op_dtypes, w_casting, w_op_axes,
- w_itershape, w_buffersize, order)
+ w_itershape, w_buffersize, order)
W_NDIter.typedef = TypeDef(
'nditer',
diff --git a/pypy/module/micronumpy/sort.py b/pypy/module/micronumpy/sort.py
--- a/pypy/module/micronumpy/sort.py
+++ b/pypy/module/micronumpy/sort.py
@@ -148,20 +148,22 @@
if axis < 0 or axis >= len(shape):
raise oefmt(space.w_IndexError, "Wrong axis %d", axis)
arr_iter = AllButAxisIter(arr, axis)
+ arr_state = arr_iter.reset()
index_impl = index_arr.implementation
index_iter = AllButAxisIter(index_impl, axis)
+ index_state = index_iter.reset()
stride_size = arr.strides[axis]
index_stride_size = index_impl.strides[axis]
axis_size = arr.shape[axis]
- while not arr_iter.done():
+ while not arr_iter.done(arr_state):
for i in range(axis_size):
raw_storage_setitem(storage, i * index_stride_size +
- index_iter.offset, i)
+ index_state.offset, i)
r = Repr(index_stride_size, stride_size, axis_size,
- arr.get_storage(), storage, index_iter.offset, arr_iter.offset)
+ arr.get_storage(), storage, index_state.offset, arr_state.offset)
ArgSort(r).sort()
- arr_iter.next()
- index_iter.next()
+ arr_state = arr_iter.next(arr_state)
+ index_state = index_iter.next(index_state)
return index_arr
return argsort
@@ -292,12 +294,13 @@
if axis < 0 or axis >= len(shape):
raise oefmt(space.w_IndexError, "Wrong axis %d", axis)
arr_iter = AllButAxisIter(arr, axis)
+ arr_state = arr_iter.reset()
stride_size = arr.strides[axis]
axis_size = arr.shape[axis]
- while not arr_iter.done():
- r = Repr(stride_size, axis_size, arr.get_storage(), arr_iter.offset)
+ while not arr_iter.done(arr_state):
+ r = Repr(stride_size, axis_size, arr.get_storage(), arr_state.offset)
ArgSort(r).sort()
- arr_iter.next()
+ arr_state = arr_iter.next(arr_state)
return sort
diff --git a/pypy/module/micronumpy/strides.py b/pypy/module/micronumpy/strides.py
--- a/pypy/module/micronumpy/strides.py
+++ b/pypy/module/micronumpy/strides.py
@@ -233,30 +233,6 @@
return dtype
-def to_coords(space, shape, size, order, w_item_or_slice):
- '''Returns a start coord, step, and length.
- '''
- start = lngth = step = 0
- if not (space.isinstance_w(w_item_or_slice, space.w_int) or
- space.isinstance_w(w_item_or_slice, space.w_slice)):
- raise OperationError(space.w_IndexError,
- space.wrap('unsupported iterator index'))
-
- start, stop, step, lngth = space.decode_index4(w_item_or_slice, size)
-
- coords = [0] * len(shape)
- i = start
- if order == 'C':
- for s in range(len(shape) -1, -1, -1):
- coords[s] = i % shape[s]
- i //= shape[s]
- else:
- for s in range(len(shape)):
- coords[s] = i % shape[s]
- i //= shape[s]
- return coords, step, lngth
-
-
@jit.unroll_safe
def shape_agreement(space, shape1, w_arr2, broadcast_down=True):
if w_arr2 is None:
diff --git a/pypy/module/micronumpy/support.py b/pypy/module/micronumpy/support.py
--- a/pypy/module/micronumpy/support.py
+++ b/pypy/module/micronumpy/support.py
@@ -25,3 +25,18 @@
for x in s:
i *= x
return i
+
+
+def check_and_adjust_index(space, index, size, axis):
+ if index < -size or index >= size:
+ if axis >= 0:
+ raise oefmt(space.w_IndexError,
+ "index %d is out of bounds for axis %d with size %d",
+ index, axis, size)
+ else:
+ raise oefmt(space.w_IndexError,
+ "index %d is out of bounds for size %d",
+ index, size)
+ if index < 0:
+ index += size
+ return index
diff --git a/pypy/module/micronumpy/test/test_iterators.py b/pypy/module/micronumpy/test/test_iterators.py
--- a/pypy/module/micronumpy/test/test_iterators.py
+++ b/pypy/module/micronumpy/test/test_iterators.py
@@ -16,17 +16,18 @@
assert backstrides == [10, 4]
i = ArrayIter(MockArray, support.product(shape), shape,
strides, backstrides)
- i.next()
- i.next()
- i.next()
- assert i.offset == 3
- assert not i.done()
- assert i.indices == [0,3]
+ s = i.reset()
+ s = i.next(s)
+ s = i.next(s)
+ s = i.next(s)
+ assert s.offset == 3
+ assert not i.done(s)
+ assert s.indices == [0,3]
#cause a dimension overflow
- i.next()
- i.next()
- assert i.offset == 5
- assert i.indices == [1,0]
+ s = i.next(s)
+ s = i.next(s)
+ assert s.offset == 5
+ assert s.indices == [1,0]
#Now what happens if the array is transposed? strides[-1] != 1
# therefore layout is non-contiguous
@@ -35,17 +36,18 @@
assert backstrides == [2, 12]
i = ArrayIter(MockArray, support.product(shape), shape,
strides, backstrides)
- i.next()
- i.next()
- i.next()
- assert i.offset == 9
- assert not i.done()
- assert i.indices == [0,3]
+ s = i.reset()
+ s = i.next(s)
+ s = i.next(s)
+ s = i.next(s)
+ assert s.offset == 9
+ assert not i.done(s)
+ assert s.indices == [0,3]
#cause a dimension overflow
- i.next()
- i.next()
- assert i.offset == 1
- assert i.indices == [1,0]
+ s = i.next(s)
+ s = i.next(s)
+ assert s.offset == 1
+ assert s.indices == [1,0]
def test_iterator_step(self):
#iteration in C order with #contiguous layout => strides[-1] is 1
@@ -56,22 +58,23 @@
assert backstrides == [10, 4]
i = ArrayIter(MockArray, support.product(shape), shape,
strides, backstrides)
- i.next_skip_x(2)
- i.next_skip_x(2)
- i.next_skip_x(2)
- assert i.offset == 6
- assert not i.done()
- assert i.indices == [1,1]
+ s = i.reset()
+ s = i.next_skip_x(s, 2)
+ s = i.next_skip_x(s, 2)
+ s = i.next_skip_x(s, 2)
+ assert s.offset == 6
+ assert not i.done(s)
+ assert s.indices == [1,1]
#And for some big skips
- i.next_skip_x(5)
- assert i.offset == 11
- assert i.indices == [2,1]
- i.next_skip_x(5)
+ s = i.next_skip_x(s, 5)
+ assert s.offset == 11
+ assert s.indices == [2,1]
+ s = i.next_skip_x(s, 5)
# Note: the offset does not overflow but recycles,
# this is good for broadcast
- assert i.offset == 1
- assert i.indices == [0,1]
- assert i.done()
+ assert s.offset == 1
+ assert s.indices == [0,1]
+ assert i.done(s)
#Now what happens if the array is transposed? strides[-1] != 1
# therefore layout is non-contiguous
@@ -80,17 +83,18 @@
assert backstrides == [2, 12]
i = ArrayIter(MockArray, support.product(shape), shape,
strides, backstrides)
- i.next_skip_x(2)
- i.next_skip_x(2)
- i.next_skip_x(2)
- assert i.offset == 4
- assert i.indices == [1,1]
- assert not i.done()
- i.next_skip_x(5)
- assert i.offset == 5
- assert i.indices == [2,1]
- assert not i.done()
- i.next_skip_x(5)
- assert i.indices == [0,1]
- assert i.offset == 3
- assert i.done()
+ s = i.reset()
+ s = i.next_skip_x(s, 2)
+ s = i.next_skip_x(s, 2)
+ s = i.next_skip_x(s, 2)
+ assert s.offset == 4
+ assert s.indices == [1,1]
+ assert not i.done(s)
+ s = i.next_skip_x(s, 5)
+ assert s.offset == 5
+ assert s.indices == [2,1]
+ assert not i.done(s)
+ s = i.next_skip_x(s, 5)
+ assert s.indices == [0,1]
+ assert s.offset == 3
+ assert i.done(s)
diff --git a/pypy/module/micronumpy/test/test_ndarray.py b/pypy/module/micronumpy/test/test_ndarray.py
--- a/pypy/module/micronumpy/test/test_ndarray.py
+++ b/pypy/module/micronumpy/test/test_ndarray.py
@@ -164,24 +164,6 @@
assert calc_new_strides([1, 1, 105, 1, 1], [7, 15], [1, 7],'F') == \
[1, 1, 1, 105, 105]
- def test_to_coords(self):
- from pypy.module.micronumpy.strides import to_coords
-
More information about the pypy-commit
mailing list