[pypy-commit] pypy numpy-speed: change usages of iters to use state
bdkearns
noreply at buildbot.pypy.org
Fri Apr 18 00:09:08 CEST 2014
Author: Brian Kearns <bdkearns at gmail.com>
Branch: numpy-speed
Changeset: r70724:71b86d3efc92
Date: 2014-04-17 16:31 -0400
http://bitbucket.org/pypy/pypy/changeset/71b86d3efc92/
Log: change usages of iters to use state
diff --git a/pypy/module/micronumpy/concrete.py b/pypy/module/micronumpy/concrete.py
--- a/pypy/module/micronumpy/concrete.py
+++ b/pypy/module/micronumpy/concrete.py
@@ -284,9 +284,11 @@
self.get_backstrides(),
self.get_shape(), shape,
backward_broadcast)
- return ArrayIter(self, support.product(shape), shape, r[0], r[1])
- return ArrayIter(self, self.get_size(), self.shape,
- self.strides, self.backstrides)
+ i = ArrayIter(self, support.product(shape), shape, r[0], r[1])
+ else:
+ i = ArrayIter(self, self.get_size(), self.shape,
+ self.strides, self.backstrides)
+ return i, i.reset()
def swapaxes(self, space, orig_arr, axis1, axis2):
shape = self.get_shape()[:]
diff --git a/pypy/module/micronumpy/ctors.py b/pypy/module/micronumpy/ctors.py
--- a/pypy/module/micronumpy/ctors.py
+++ b/pypy/module/micronumpy/ctors.py
@@ -156,10 +156,10 @@
"string is smaller than requested size"))
a = W_NDimArray.from_shape(space, [num_items], dtype=dtype)
- ai = a.create_iter()
+ ai, state = a.create_iter()
for val in items:
- ai.setitem(val)
- ai.next()
+ ai.setitem(state, val)
+ state = ai.next(state)
return space.wrap(a)
diff --git a/pypy/module/micronumpy/flatiter.py b/pypy/module/micronumpy/flatiter.py
--- a/pypy/module/micronumpy/flatiter.py
+++ b/pypy/module/micronumpy/flatiter.py
@@ -32,23 +32,23 @@
self.reset()
def reset(self):
- self.iter = self.base.create_iter()
+ self.iter, self.state = self.base.create_iter()
def descr_len(self, space):
return space.wrap(self.base.get_size())
def descr_next(self, space):
- if self.iter.done():
+ if self.iter.done(self.state):
raise OperationError(space.w_StopIteration, space.w_None)
- w_res = self.iter.getitem()
- self.iter.next()
+ w_res = self.iter.getitem(self.state)
+ self.state = self.iter.next(self.state)
return w_res
def descr_index(self, space):
- return space.wrap(self.iter.index)
+ return space.wrap(self.state.index)
def descr_coords(self, space):
- coords = self.base.to_coords(space, space.wrap(self.iter.index))
+ coords = self.base.to_coords(space, space.wrap(self.state.index))
return space.newtuple([space.wrap(c) for c in coords])
def descr_getitem(self, space, w_idx):
@@ -58,13 +58,13 @@
self.reset()
base = self.base
start, stop, step, length = space.decode_index4(w_idx, base.get_size())
- base_iter = base.create_iter()
- base_iter.next_skip_x(start)
+ base_iter, base_state = base.create_iter()
+ base_state = base_iter.next_skip_x(base_state, start)
if length == 1:
- return base_iter.getitem()
+ return base_iter.getitem(base_state)
res = W_NDimArray.from_shape(space, [length], base.get_dtype(),
base.get_order(), w_instance=base)
- return loop.flatiter_getitem(res, base_iter, step)
+ return loop.flatiter_getitem(res, base_iter, base_state, step)
def descr_setitem(self, space, w_idx, w_value):
if not (space.isinstance_w(w_idx, space.w_int) or
diff --git a/pypy/module/micronumpy/iterators.py b/pypy/module/micronumpy/iterators.py
--- a/pypy/module/micronumpy/iterators.py
+++ b/pypy/module/micronumpy/iterators.py
@@ -51,19 +51,20 @@
self.shapelen = len(shape)
self.indexes = [0] * len(shape)
self._done = False
- self.idx_w = [None] * len(idx_w)
+ self.idx_w_i = [None] * len(idx_w)
+ self.idx_w_s = [None] * len(idx_w)
for i, w_idx in enumerate(idx_w):
if isinstance(w_idx, W_NDimArray):
- self.idx_w[i] = w_idx.create_iter(shape)
+ self.idx_w_i[i], self.idx_w_s[i] = w_idx.create_iter(shape)
def done(self):
return self._done
@jit.unroll_safe
def next(self):
- for w_idx in self.idx_w:
- if w_idx is not None:
- w_idx.next()
+ for i, idx_w_i in enumerate(self.idx_w_i):
+ if idx_w_i is not None:
+ self.idx_w_s[i] = idx_w_i.next(self.idx_w_s[i])
for i in range(self.shapelen - 1, -1, -1):
if self.indexes[i] < self.shape[i] - 1:
self.indexes[i] += 1
diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py
--- a/pypy/module/micronumpy/loop.py
+++ b/pypy/module/micronumpy/loop.py
@@ -12,11 +12,10 @@
AllButAxisIter
-call2_driver = jit.JitDriver(name='numpy_call2',
- greens = ['shapelen', 'func', 'calc_dtype',
- 'res_dtype'],
- reds = ['shape', 'w_lhs', 'w_rhs', 'out',
- 'left_iter', 'right_iter', 'out_iter'])
+call2_driver = jit.JitDriver(
+ name='numpy_call2',
+ greens=['shapelen', 'func', 'calc_dtype', 'res_dtype'],
+ reds='auto')
def call2(space, shape, func, calc_dtype, res_dtype, w_lhs, w_rhs, out):
# handle array_priority
@@ -46,47 +45,40 @@
if out is None:
out = W_NDimArray.from_shape(space, shape, res_dtype,
w_instance=lhs_for_subtype)
- left_iter = w_lhs.create_iter(shape)
- right_iter = w_rhs.create_iter(shape)
- out_iter = out.create_iter(shape)
+ left_iter, left_state = w_lhs.create_iter(shape)
+ right_iter, right_state = w_rhs.create_iter(shape)
+ out_iter, out_state = out.create_iter(shape)
shapelen = len(shape)
- while not out_iter.done():
+ while not out_iter.done(out_state):
call2_driver.jit_merge_point(shapelen=shapelen, func=func,
- calc_dtype=calc_dtype, res_dtype=res_dtype,
- shape=shape, w_lhs=w_lhs, w_rhs=w_rhs,
- out=out,
- left_iter=left_iter, right_iter=right_iter,
- out_iter=out_iter)
- w_left = left_iter.getitem().convert_to(space, calc_dtype)
- w_right = right_iter.getitem().convert_to(space, calc_dtype)
- out_iter.setitem(func(calc_dtype, w_left, w_right).convert_to(
+ calc_dtype=calc_dtype, res_dtype=res_dtype)
+ w_left = left_iter.getitem(left_state).convert_to(space, calc_dtype)
+ w_right = right_iter.getitem(right_state).convert_to(space, calc_dtype)
+ out_iter.setitem(out_state, func(calc_dtype, w_left, w_right).convert_to(
space, res_dtype))
- left_iter.next()
- right_iter.next()
- out_iter.next()
+ left_state = left_iter.next(left_state)
+ right_state = right_iter.next(right_state)
+ out_state = out_iter.next(out_state)
return out
-call1_driver = jit.JitDriver(name='numpy_call1',
- greens = ['shapelen', 'func', 'calc_dtype',
- 'res_dtype'],
- reds = ['shape', 'w_obj', 'out', 'obj_iter',
- 'out_iter'])
+call1_driver = jit.JitDriver(
+ name='numpy_call1',
+ greens=['shapelen', 'func', 'calc_dtype', 'res_dtype'],
+ reds='auto')
def call1(space, shape, func, calc_dtype, res_dtype, w_obj, out):
if out is None:
out = W_NDimArray.from_shape(space, shape, res_dtype, w_instance=w_obj)
- obj_iter = w_obj.create_iter(shape)
- out_iter = out.create_iter(shape)
+ obj_iter, obj_state = w_obj.create_iter(shape)
+ out_iter, out_state = out.create_iter(shape)
shapelen = len(shape)
- while not out_iter.done():
+ while not out_iter.done(out_state):
call1_driver.jit_merge_point(shapelen=shapelen, func=func,
- calc_dtype=calc_dtype, res_dtype=res_dtype,
- shape=shape, w_obj=w_obj, out=out,
- obj_iter=obj_iter, out_iter=out_iter)
- elem = obj_iter.getitem().convert_to(space, calc_dtype)
- out_iter.setitem(func(calc_dtype, elem).convert_to(space, res_dtype))
- out_iter.next()
- obj_iter.next()
+ calc_dtype=calc_dtype, res_dtype=res_dtype)
+ elem = obj_iter.getitem(obj_state).convert_to(space, calc_dtype)
+ out_iter.setitem(out_state, func(calc_dtype, elem).convert_to(space, res_dtype))
+ out_state = out_iter.next(out_state)
+ obj_state = obj_iter.next(obj_state)
return out
setslice_driver = jit.JitDriver(name='numpy_setslice',
@@ -96,18 +88,20 @@
def setslice(space, shape, target, source):
# note that unlike everything else, target and source here are
# array implementations, not arrays
- target_iter = target.create_iter(shape)
- source_iter = source.create_iter(shape)
+ target_iter, target_state = target.create_iter(shape)
+ source_iter, source_state = source.create_iter(shape)
dtype = target.dtype
shapelen = len(shape)
- while not target_iter.done():
+ while not target_iter.done(target_state):
setslice_driver.jit_merge_point(shapelen=shapelen, dtype=dtype)
+ val = source_iter.getitem(source_state)
if dtype.is_str_or_unicode():
- target_iter.setitem(dtype.coerce(space, source_iter.getitem()))
+ val = dtype.coerce(space, val)
else:
- target_iter.setitem(source_iter.getitem().convert_to(space, dtype))
- target_iter.next()
- source_iter.next()
+ val = val.convert_to(space, dtype)
+ target_iter.setitem(target_state, val)
+ target_state = target_iter.next(target_state)
+ source_state = source_iter.next(source_state)
return target
reduce_driver = jit.JitDriver(name='numpy_reduce',
@@ -116,22 +110,22 @@
reds = 'auto')
def compute_reduce(space, obj, calc_dtype, func, done_func, identity):
- obj_iter = obj.create_iter()
+ obj_iter, obj_state = obj.create_iter()
if identity is None:
- cur_value = obj_iter.getitem().convert_to(space, calc_dtype)
- obj_iter.next()
+ cur_value = obj_iter.getitem(obj_state).convert_to(space, calc_dtype)
+ obj_state = obj_iter.next(obj_state)
else:
cur_value = identity.convert_to(space, calc_dtype)
shapelen = len(obj.get_shape())
- while not obj_iter.done():
+ while not obj_iter.done(obj_state):
reduce_driver.jit_merge_point(shapelen=shapelen, func=func,
done_func=done_func,
calc_dtype=calc_dtype)
- rval = obj_iter.getitem().convert_to(space, calc_dtype)
+ rval = obj_iter.getitem(obj_state).convert_to(space, calc_dtype)
if done_func is not None and done_func(calc_dtype, rval):
return rval
cur_value = func(calc_dtype, cur_value, rval)
- obj_iter.next()
+ obj_state = obj_iter.next(obj_state)
return cur_value
reduce_cum_driver = jit.JitDriver(name='numpy_reduce_cum_driver',
@@ -139,69 +133,76 @@
reds = 'auto')
def compute_reduce_cumulative(space, obj, out, calc_dtype, func, identity):
- obj_iter = obj.create_iter()
- out_iter = out.create_iter()
+ obj_iter, obj_state = obj.create_iter()
+ out_iter, out_state = out.create_iter()
if identity is None:
- cur_value = obj_iter.getitem().convert_to(space, calc_dtype)
- out_iter.setitem(cur_value)
- out_iter.next()
- obj_iter.next()
+ cur_value = obj_iter.getitem(obj_state).convert_to(space, calc_dtype)
+ out_iter.setitem(out_state, cur_value)
+ out_state = out_iter.next(out_state)
+ obj_state = obj_iter.next(obj_state)
else:
cur_value = identity.convert_to(space, calc_dtype)
shapelen = len(obj.get_shape())
- while not obj_iter.done():
+ while not obj_iter.done(obj_state):
reduce_cum_driver.jit_merge_point(shapelen=shapelen, func=func,
dtype=calc_dtype)
- rval = obj_iter.getitem().convert_to(space, calc_dtype)
+ rval = obj_iter.getitem(obj_state).convert_to(space, calc_dtype)
cur_value = func(calc_dtype, cur_value, rval)
- out_iter.setitem(cur_value)
- out_iter.next()
- obj_iter.next()
+ out_iter.setitem(out_state, cur_value)
+ out_state = out_iter.next(out_state)
+ obj_state = obj_iter.next(obj_state)
def fill(arr, box):
- arr_iter = arr.create_iter()
- while not arr_iter.done():
- arr_iter.setitem(box)
- arr_iter.next()
+ arr_iter, arr_state = arr.create_iter()
+ while not arr_iter.done(arr_state):
+ arr_iter.setitem(arr_state, box)
+ arr_state = arr_iter.next(arr_state)
def assign(space, arr, seq):
- arr_iter = arr.create_iter()
+ arr_iter, arr_state = arr.create_iter()
arr_dtype = arr.get_dtype()
for item in seq:
- arr_iter.setitem(arr_dtype.coerce(space, item))
- arr_iter.next()
+ arr_iter.setitem(arr_state, arr_dtype.coerce(space, item))
+ arr_state = arr_iter.next(arr_state)
where_driver = jit.JitDriver(name='numpy_where',
greens = ['shapelen', 'dtype', 'arr_dtype'],
reds = 'auto')
def where(space, out, shape, arr, x, y, dtype):
- out_iter = out.create_iter(shape)
- arr_iter = arr.create_iter(shape)
+ out_iter, out_state = out.create_iter(shape)
+ arr_iter, arr_state = arr.create_iter(shape)
arr_dtype = arr.get_dtype()
- x_iter = x.create_iter(shape)
- y_iter = y.create_iter(shape)
+ x_iter, x_state = x.create_iter(shape)
+ y_iter, y_state = y.create_iter(shape)
if x.is_scalar():
if y.is_scalar():
- iter = arr_iter
+ iter, state = arr_iter, arr_state
else:
- iter = y_iter
+ iter, state = y_iter, y_state
else:
- iter = x_iter
+ iter, state = x_iter, x_state
shapelen = len(shape)
- while not iter.done():
+ while not iter.done(state):
where_driver.jit_merge_point(shapelen=shapelen, dtype=dtype,
arr_dtype=arr_dtype)
- w_cond = arr_iter.getitem()
+ w_cond = arr_iter.getitem(arr_state)
if arr_dtype.itemtype.bool(w_cond):
- w_val = x_iter.getitem().convert_to(space, dtype)
+ w_val = x_iter.getitem(x_state).convert_to(space, dtype)
else:
- w_val = y_iter.getitem().convert_to(space, dtype)
- out_iter.setitem(w_val)
- out_iter.next()
- arr_iter.next()
- x_iter.next()
- y_iter.next()
+ w_val = y_iter.getitem(y_state).convert_to(space, dtype)
+ out_iter.setitem(out_state, w_val)
+ out_state = out_iter.next(out_state)
+ arr_state = arr_iter.next(arr_state)
+ x_state = x_iter.next(x_state)
+ y_state = y_iter.next(y_state)
+ if x.is_scalar():
+ if y.is_scalar():
+ state = arr_state
+ else:
+ state = y_state
+ else:
+ state = x_state
return out
axis_reduce__driver = jit.JitDriver(name='numpy_axis_reduce',
@@ -212,31 +213,36 @@
def do_axis_reduce(space, shape, func, arr, dtype, axis, out, identity, cumulative,
temp):
out_iter = AxisIter(out.implementation, arr.get_shape(), axis, cumulative)
+ out_state = out_iter.reset()
if cumulative:
temp_iter = AxisIter(temp.implementation, arr.get_shape(), axis, False)
+ temp_state = temp_iter.reset()
else:
- temp_iter = out_iter # hack
- arr_iter = arr.create_iter()
+ temp_iter = out_iter # hack
+ temp_state = out_state
+ arr_iter, arr_state = arr.create_iter()
if identity is not None:
identity = identity.convert_to(space, dtype)
shapelen = len(shape)
- while not out_iter.done():
+ while not out_iter.done(out_state):
axis_reduce__driver.jit_merge_point(shapelen=shapelen, func=func,
dtype=dtype)
- assert not arr_iter.done()
- w_val = arr_iter.getitem().convert_to(space, dtype)
- if out_iter.indices[axis] == 0:
+ assert not arr_iter.done(arr_state)
+ w_val = arr_iter.getitem(arr_state).convert_to(space, dtype)
+ if out_state.indices[axis] == 0:
if identity is not None:
w_val = func(dtype, identity, w_val)
else:
- cur = temp_iter.getitem()
+ cur = temp_iter.getitem(temp_state)
w_val = func(dtype, cur, w_val)
- out_iter.setitem(w_val)
+ out_iter.setitem(out_state, w_val)
+ out_state = out_iter.next(out_state)
if cumulative:
- temp_iter.setitem(w_val)
- temp_iter.next()
- arr_iter.next()
- out_iter.next()
+ temp_iter.setitem(temp_state, w_val)
+ temp_state = temp_iter.next(temp_state)
+ else:
+ temp_state = out_state
+ arr_state = arr_iter.next(arr_state)
return out
@@ -249,18 +255,18 @@
result = 0
idx = 1
dtype = arr.get_dtype()
- iter = arr.create_iter()
- cur_best = iter.getitem()
- iter.next()
+ iter, state = arr.create_iter()
+ cur_best = iter.getitem(state)
+ state = iter.next(state)
shapelen = len(arr.get_shape())
- while not iter.done():
+ while not iter.done(state):
arg_driver.jit_merge_point(shapelen=shapelen, dtype=dtype)
- w_val = iter.getitem()
+ w_val = iter.getitem(state)
new_best = getattr(dtype.itemtype, op_name)(cur_best, w_val)
if dtype.itemtype.ne(new_best, cur_best):
result = idx
cur_best = new_best
- iter.next()
+ state = iter.next(state)
idx += 1
return result
return argmin_argmax
@@ -291,17 +297,19 @@
right_impl = right.implementation
assert left_shape[-1] == right_shape[right_critical_dim]
assert result.get_dtype() == dtype
- outi = result.create_iter()
+ outi, outs = result.create_iter()
lefti = AllButAxisIter(left_impl, len(left_shape) - 1)
righti = AllButAxisIter(right_impl, right_critical_dim)
+ lefts = lefti.reset()
+ rights = righti.reset()
n = left_impl.shape[-1]
s1 = left_impl.strides[-1]
s2 = right_impl.strides[right_critical_dim]
- while not lefti.done():
- while not righti.done():
- oval = outi.getitem()
- i1 = lefti.offset
- i2 = righti.offset
+ while not lefti.done(lefts):
+ while not righti.done(rights):
+ oval = outi.getitem(outs)
+ i1 = lefts.offset
+ i2 = rights.offset
i = 0
while i < n:
i += 1
@@ -311,11 +319,11 @@
oval = dtype.itemtype.add(oval, dtype.itemtype.mul(lval, rval))
i1 += s1
i2 += s2
- outi.setitem(oval)
- outi.next()
- righti.next()
- righti.reset()
- lefti.next()
+ outi.setitem(outs, oval)
+ outs = outi.next(outs)
+ rights = righti.next(rights)
+ rights = righti.reset()
+ lefts = lefti.next(lefts)
return result
count_all_true_driver = jit.JitDriver(name = 'numpy_count',
@@ -324,13 +332,13 @@
def count_all_true_concrete(impl):
s = 0
- iter = impl.create_iter()
+ iter, state = impl.create_iter()
shapelen = len(impl.shape)
dtype = impl.dtype
- while not iter.done():
+ while not iter.done(state):
count_all_true_driver.jit_merge_point(shapelen=shapelen, dtype=dtype)
- s += iter.getitem_bool()
- iter.next()
+ s += iter.getitem_bool(state)
+ state = iter.next(state)
return s
def count_all_true(arr):
@@ -344,18 +352,18 @@
reds = 'auto')
def nonzero(res, arr, box):
- res_iter = res.create_iter()
- arr_iter = arr.create_iter()
+ res_iter, res_state = res.create_iter()
+ arr_iter, arr_state = arr.create_iter()
shapelen = len(arr.shape)
dtype = arr.dtype
dims = range(shapelen)
- while not arr_iter.done():
+ while not arr_iter.done(arr_state):
nonzero_driver.jit_merge_point(shapelen=shapelen, dims=dims, dtype=dtype)
- if arr_iter.getitem_bool():
+ if arr_iter.getitem_bool(arr_state):
for d in dims:
- res_iter.setitem(box(arr_iter.indices[d]))
- res_iter.next()
- arr_iter.next()
+ res_iter.setitem(res_state, box(arr_state.indices[d]))
+ res_state = res_iter.next(res_state)
+ arr_state = arr_iter.next(arr_state)
return res
@@ -365,26 +373,26 @@
reds = 'auto')
def getitem_filter(res, arr, index):
- res_iter = res.create_iter()
+ res_iter, res_state = res.create_iter()
shapelen = len(arr.get_shape())
if shapelen > 1 and len(index.get_shape()) < 2:
- index_iter = index.create_iter(arr.get_shape(), backward_broadcast=True)
+ index_iter, index_state = index.create_iter(arr.get_shape(), backward_broadcast=True)
else:
- index_iter = index.create_iter()
- arr_iter = arr.create_iter()
+ index_iter, index_state = index.create_iter()
+ arr_iter, arr_state = arr.create_iter()
arr_dtype = arr.get_dtype()
index_dtype = index.get_dtype()
# XXX length of shape of index as well?
- while not index_iter.done():
+ while not index_iter.done(index_state):
getitem_filter_driver.jit_merge_point(shapelen=shapelen,
index_dtype=index_dtype,
arr_dtype=arr_dtype,
)
- if index_iter.getitem_bool():
- res_iter.setitem(arr_iter.getitem())
- res_iter.next()
- index_iter.next()
- arr_iter.next()
+ if index_iter.getitem_bool(index_state):
+ res_iter.setitem(res_state, arr_iter.getitem(arr_state))
+ res_state = res_iter.next(res_state)
+ index_state = index_iter.next(index_state)
+ arr_state = arr_iter.next(arr_state)
return res
setitem_filter_driver = jit.JitDriver(name = 'numpy_setitem_bool',
@@ -393,41 +401,42 @@
reds = 'auto')
def setitem_filter(space, arr, index, value):
- arr_iter = arr.create_iter()
+ arr_iter, arr_state = arr.create_iter()
shapelen = len(arr.get_shape())
if shapelen > 1 and len(index.get_shape()) < 2:
- index_iter = index.create_iter(arr.get_shape(), backward_broadcast=True)
+ index_iter, index_state = index.create_iter(arr.get_shape(), backward_broadcast=True)
else:
- index_iter = index.create_iter()
+ index_iter, index_state = index.create_iter()
if value.get_size() == 1:
- value_iter = value.create_iter(arr.get_shape())
+ value_iter, value_state = value.create_iter(arr.get_shape())
else:
- value_iter = value.create_iter()
+ value_iter, value_state = value.create_iter()
index_dtype = index.get_dtype()
arr_dtype = arr.get_dtype()
- while not index_iter.done():
+ while not index_iter.done(index_state):
setitem_filter_driver.jit_merge_point(shapelen=shapelen,
index_dtype=index_dtype,
arr_dtype=arr_dtype,
)
- if index_iter.getitem_bool():
- arr_iter.setitem(arr_dtype.coerce(space, value_iter.getitem()))
- value_iter.next()
- arr_iter.next()
- index_iter.next()
+ if index_iter.getitem_bool(index_state):
+ val = arr_dtype.coerce(space, value_iter.getitem(value_state))
+ value_state = value_iter.next(value_state)
+ arr_iter.setitem(arr_state, val)
+ arr_state = arr_iter.next(arr_state)
+ index_state = index_iter.next(index_state)
flatiter_getitem_driver = jit.JitDriver(name = 'numpy_flatiter_getitem',
greens = ['dtype'],
reds = 'auto')
-def flatiter_getitem(res, base_iter, step):
- ri = res.create_iter()
+def flatiter_getitem(res, base_iter, base_state, step):
+ ri, rs = res.create_iter()
dtype = res.get_dtype()
- while not ri.done():
+ while not ri.done(rs):
flatiter_getitem_driver.jit_merge_point(dtype=dtype)
- ri.setitem(base_iter.getitem())
- base_iter.next_skip_x(step)
- ri.next()
+ ri.setitem(rs, base_iter.getitem(base_state))
+ base_state = base_iter.next_skip_x(base_state, step)
+ rs = ri.next(rs)
return res
flatiter_setitem_driver = jit.JitDriver(name = 'numpy_flatiter_setitem',
@@ -436,19 +445,21 @@
def flatiter_setitem(space, arr, val, start, step, length):
dtype = arr.get_dtype()
- arr_iter = arr.create_iter()
- val_iter = val.create_iter()
- arr_iter.next_skip_x(start)
+ arr_iter, arr_state = arr.create_iter()
+ val_iter, val_state = val.create_iter()
+ arr_state = arr_iter.next_skip_x(arr_state, start)
while length > 0:
flatiter_setitem_driver.jit_merge_point(dtype=dtype)
+ val = val_iter.getitem(val_state)
if dtype.is_str_or_unicode():
- arr_iter.setitem(dtype.coerce(space, val_iter.getitem()))
+ val = dtype.coerce(space, val)
else:
- arr_iter.setitem(val_iter.getitem().convert_to(space, dtype))
+ val = val.convert_to(space, dtype)
+ arr_iter.setitem(arr_state, val)
# need to repeat i_nput values until all assignments are done
- arr_iter.next_skip_x(step)
+ arr_state = arr_iter.next_skip_x(arr_state, step)
+ val_state = val_iter.next(val_state)
length -= 1
- val_iter.next()
fromstring_driver = jit.JitDriver(name = 'numpy_fromstring',
greens = ['itemsize', 'dtype'],
@@ -456,30 +467,30 @@
def fromstring_loop(space, a, dtype, itemsize, s):
i = 0
- ai = a.create_iter()
- while not ai.done():
+ ai, state = a.create_iter()
+ while not ai.done(state):
fromstring_driver.jit_merge_point(dtype=dtype, itemsize=itemsize)
sub = s[i*itemsize:i*itemsize + itemsize]
if dtype.is_str_or_unicode():
val = dtype.coerce(space, space.wrap(sub))
else:
val = dtype.itemtype.runpack_str(space, sub)
- ai.setitem(val)
- ai.next()
+ ai.setitem(state, val)
+ state = ai.next(state)
i += 1
def tostring(space, arr):
builder = StringBuilder()
- iter = arr.create_iter()
+ iter, state = arr.create_iter()
w_res_str = W_NDimArray.from_shape(space, [1], arr.get_dtype(), order='C')
itemsize = arr.get_dtype().elsize
res_str_casted = rffi.cast(rffi.CArrayPtr(lltype.Char),
w_res_str.implementation.get_storage_as_int(space))
- while not iter.done():
- w_res_str.implementation.setitem(0, iter.getitem())
+ while not iter.done(state):
+ w_res_str.implementation.setitem(0, iter.getitem(state))
for i in range(itemsize):
builder.append(res_str_casted[i])
- iter.next()
+ state = iter.next(state)
return builder.build()
getitem_int_driver = jit.JitDriver(name = 'numpy_getitem_int',
@@ -500,8 +511,8 @@
# prepare the index
index_w = [None] * indexlen
for i in range(indexlen):
- if iter.idx_w[i] is not None:
- index_w[i] = iter.idx_w[i].getitem()
+ if iter.idx_w_i[i] is not None:
+ index_w[i] = iter.idx_w_i[i].getitem(iter.idx_w_s[i])
else:
index_w[i] = indexes_w[i]
res.descr_setitem(space, space.newtuple(prefix_w[:prefixlen] +
@@ -528,8 +539,8 @@
# prepare the index
index_w = [None] * indexlen
for i in range(indexlen):
- if iter.idx_w[i] is not None:
- index_w[i] = iter.idx_w[i].getitem()
+ if iter.idx_w_i[i] is not None:
+ index_w[i] = iter.idx_w_i[i].getitem(iter.idx_w_s[i])
else:
index_w[i] = indexes_w[i]
w_idx = space.newtuple(prefix_w[:prefixlen] + iter.get_index(space,
@@ -547,13 +558,14 @@
def byteswap(from_, to):
dtype = from_.dtype
- from_iter = from_.create_iter()
- to_iter = to.create_iter()
- while not from_iter.done():
+ from_iter, from_state = from_.create_iter()
+ to_iter, to_state = to.create_iter()
+ while not from_iter.done(from_state):
byteswap_driver.jit_merge_point(dtype=dtype)
- to_iter.setitem(dtype.itemtype.byteswap(from_iter.getitem()))
- to_iter.next()
- from_iter.next()
+ val = dtype.itemtype.byteswap(from_iter.getitem(from_state))
+ to_iter.setitem(to_state, val)
+ to_state = to_iter.next(to_state)
+ from_state = from_iter.next(from_state)
choose_driver = jit.JitDriver(name='numpy_choose_driver',
greens = ['shapelen', 'mode', 'dtype'],
@@ -561,13 +573,15 @@
def choose(space, arr, choices, shape, dtype, out, mode):
shapelen = len(shape)
- iterators = [a.create_iter(shape) for a in choices]
- arr_iter = arr.create_iter(shape)
- out_iter = out.create_iter(shape)
- while not arr_iter.done():
+ pairs = [a.create_iter(shape) for a in choices]
+ iterators = [i[0] for i in pairs]
+ states = [i[1] for i in pairs]
+ arr_iter, arr_state = arr.create_iter(shape)
+ out_iter, out_state = out.create_iter(shape)
+ while not arr_iter.done(arr_state):
choose_driver.jit_merge_point(shapelen=shapelen, dtype=dtype,
mode=mode)
- index = support.index_w(space, arr_iter.getitem())
+ index = support.index_w(space, arr_iter.getitem(arr_state))
if index < 0 or index >= len(iterators):
if mode == NPY.RAISE:
raise OperationError(space.w_ValueError, space.wrap(
@@ -580,72 +594,73 @@
index = 0
else:
index = len(iterators) - 1
- out_iter.setitem(iterators[index].getitem().convert_to(space, dtype))
- for iter in iterators:
- iter.next()
- out_iter.next()
- arr_iter.next()
+ val = iterators[index].getitem(states[index]).convert_to(space, dtype)
+ out_iter.setitem(out_state, val)
+ for i in range(len(iterators)):
+ states[i] = iterators[i].next(states[i])
+ out_state = out_iter.next(out_state)
+ arr_state = arr_iter.next(arr_state)
clip_driver = jit.JitDriver(name='numpy_clip_driver',
greens = ['shapelen', 'dtype'],
reds = 'auto')
def clip(space, arr, shape, min, max, out):
- arr_iter = arr.create_iter(shape)
+ arr_iter, arr_state = arr.create_iter(shape)
dtype = out.get_dtype()
shapelen = len(shape)
- min_iter = min.create_iter(shape)
- max_iter = max.create_iter(shape)
- out_iter = out.create_iter(shape)
- while not arr_iter.done():
+ min_iter, min_state = min.create_iter(shape)
+ max_iter, max_state = max.create_iter(shape)
+ out_iter, out_state = out.create_iter(shape)
+ while not arr_iter.done(arr_state):
clip_driver.jit_merge_point(shapelen=shapelen, dtype=dtype)
- w_v = arr_iter.getitem().convert_to(space, dtype)
- w_min = min_iter.getitem().convert_to(space, dtype)
- w_max = max_iter.getitem().convert_to(space, dtype)
+ w_v = arr_iter.getitem(arr_state).convert_to(space, dtype)
+ w_min = min_iter.getitem(min_state).convert_to(space, dtype)
+ w_max = max_iter.getitem(max_state).convert_to(space, dtype)
if dtype.itemtype.lt(w_v, w_min):
w_v = w_min
elif dtype.itemtype.gt(w_v, w_max):
w_v = w_max
- out_iter.setitem(w_v)
- arr_iter.next()
- max_iter.next()
- out_iter.next()
- min_iter.next()
+ out_iter.setitem(out_state, w_v)
+ arr_state = arr_iter.next(arr_state)
+ min_state = min_iter.next(min_state)
+ max_state = max_iter.next(max_state)
+ out_state = out_iter.next(out_state)
round_driver = jit.JitDriver(name='numpy_round_driver',
greens = ['shapelen', 'dtype'],
reds = 'auto')
def round(space, arr, dtype, shape, decimals, out):
- arr_iter = arr.create_iter(shape)
+ arr_iter, arr_state = arr.create_iter(shape)
+ out_iter, out_state = out.create_iter(shape)
shapelen = len(shape)
- out_iter = out.create_iter(shape)
- while not arr_iter.done():
+ while not arr_iter.done(arr_state):
round_driver.jit_merge_point(shapelen=shapelen, dtype=dtype)
- w_v = arr_iter.getitem().convert_to(space, dtype)
+ w_v = arr_iter.getitem(arr_state).convert_to(space, dtype)
w_v = dtype.itemtype.round(w_v, decimals)
- out_iter.setitem(w_v)
- arr_iter.next()
- out_iter.next()
+ out_iter.setitem(out_state, w_v)
+ arr_state = arr_iter.next(arr_state)
+ out_state = out_iter.next(out_state)
diagonal_simple_driver = jit.JitDriver(name='numpy_diagonal_simple_driver',
greens = ['axis1', 'axis2'],
reds = 'auto')
def diagonal_simple(space, arr, out, offset, axis1, axis2, size):
- out_iter = out.create_iter()
+ out_iter, out_state = out.create_iter()
i = 0
index = [0] * 2
while i < size:
diagonal_simple_driver.jit_merge_point(axis1=axis1, axis2=axis2)
index[axis1] = i
index[axis2] = i + offset
- out_iter.setitem(arr.getitem_index(space, index))
+ out_iter.setitem(out_state, arr.getitem_index(space, index))
i += 1
- out_iter.next()
+ out_state = out_iter.next(out_state)
def diagonal_array(space, arr, out, offset, axis1, axis2, shape):
- out_iter = out.create_iter()
+ out_iter, out_state = out.create_iter()
iter = PureShapeIter(shape, [])
shapelen_minus_1 = len(shape) - 1
assert shapelen_minus_1 >= 0
@@ -667,6 +682,6 @@
indexes = (iter.indexes[:a] + [last_index + offset] +
iter.indexes[a:b] + [last_index] +
iter.indexes[b:shapelen_minus_1])
- out_iter.setitem(arr.getitem_index(space, indexes))
+ out_iter.setitem(out_state, arr.getitem_index(space, indexes))
iter.next()
- out_iter.next()
+ out_state = out_iter.next(out_state)
diff --git a/pypy/module/micronumpy/ndarray.py b/pypy/module/micronumpy/ndarray.py
--- a/pypy/module/micronumpy/ndarray.py
+++ b/pypy/module/micronumpy/ndarray.py
@@ -260,24 +260,24 @@
return space.call_function(cache.w_array_str, self)
def dump_data(self, prefix='array(', separator=',', suffix=')'):
- i = self.create_iter()
+ i, state = self.create_iter()
first = True
dtype = self.get_dtype()
s = StringBuilder()
s.append(prefix)
if not self.is_scalar():
s.append('[')
- while not i.done():
+ while not i.done(state):
if first:
first = False
else:
s.append(separator)
s.append(' ')
if self.is_scalar() and dtype.is_str():
- s.append(dtype.itemtype.to_str(i.getitem()))
+ s.append(dtype.itemtype.to_str(i.getitem(state)))
else:
- s.append(dtype.itemtype.str_format(i.getitem()))
- i.next()
+ s.append(dtype.itemtype.str_format(i.getitem(state)))
+ state = i.next(state)
if not self.is_scalar():
s.append(']')
s.append(suffix)
@@ -818,8 +818,8 @@
if self.get_size() > 1:
raise OperationError(space.w_ValueError, space.wrap(
"The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()"))
- iter = self.create_iter()
- return space.wrap(space.is_true(iter.getitem()))
+ iter, state = self.create_iter()
+ return space.wrap(space.is_true(iter.getitem(state)))
def _binop_impl(ufunc_name):
def impl(self, space, w_other, w_out=None):
@@ -1095,11 +1095,11 @@
builder = StringBuilder()
if isinstance(self.implementation, SliceArray):
- iter = self.implementation.create_iter()
- while not iter.done():
- box = iter.getitem()
+ iter, state = self.implementation.create_iter()
+ while not iter.done(state):
+ box = iter.getitem(state)
builder.append(box.raw_str())
- iter.next()
+ state = iter.next(state)
else:
builder.append_charpsize(self.implementation.get_storage(), self.implementation.get_storage_size())
diff --git a/pypy/module/micronumpy/sort.py b/pypy/module/micronumpy/sort.py
--- a/pypy/module/micronumpy/sort.py
+++ b/pypy/module/micronumpy/sort.py
@@ -148,20 +148,22 @@
if axis < 0 or axis >= len(shape):
raise oefmt(space.w_IndexError, "Wrong axis %d", axis)
arr_iter = AllButAxisIter(arr, axis)
+ arr_state = arr_iter.reset()
index_impl = index_arr.implementation
index_iter = AllButAxisIter(index_impl, axis)
+ index_state = index_iter.reset()
stride_size = arr.strides[axis]
index_stride_size = index_impl.strides[axis]
axis_size = arr.shape[axis]
- while not arr_iter.done():
+ while not arr_iter.done(arr_state):
for i in range(axis_size):
raw_storage_setitem(storage, i * index_stride_size +
- index_iter.offset, i)
+ index_state.offset, i)
r = Repr(index_stride_size, stride_size, axis_size,
- arr.get_storage(), storage, index_iter.offset, arr_iter.offset)
+ arr.get_storage(), storage, index_state.offset, arr_state.offset)
ArgSort(r).sort()
- arr_iter.next()
- index_iter.next()
+ arr_state = arr_iter.next(arr_state)
+ index_state = index_iter.next(index_state)
return index_arr
return argsort
@@ -292,12 +294,13 @@
if axis < 0 or axis >= len(shape):
raise oefmt(space.w_IndexError, "Wrong axis %d", axis)
arr_iter = AllButAxisIter(arr, axis)
+ arr_state = arr_iter.reset()
stride_size = arr.strides[axis]
axis_size = arr.shape[axis]
- while not arr_iter.done():
- r = Repr(stride_size, axis_size, arr.get_storage(), arr_iter.offset)
+ while not arr_iter.done(arr_state):
+ r = Repr(stride_size, axis_size, arr.get_storage(), arr_state.offset)
ArgSort(r).sort()
- arr_iter.next()
+ arr_state = arr_iter.next(arr_state)
return sort
diff --git a/pypy/module/micronumpy/test/test_zjit.py b/pypy/module/micronumpy/test/test_zjit.py
--- a/pypy/module/micronumpy/test/test_zjit.py
+++ b/pypy/module/micronumpy/test/test_zjit.py
@@ -47,7 +47,8 @@
raise Exception("need results")
w_res = interp.results[-1]
if isinstance(w_res, W_NDimArray):
- w_res = w_res.create_iter().getitem()
+ i, s = w_res.create_iter()
+ w_res = i.getitem(s)
if isinstance(w_res, boxes.W_Float64Box):
return w_res.value
if isinstance(w_res, boxes.W_Int64Box):
More information about the pypy-commit
mailing list