[pypy-commit] pypy numpy-refactor: argmin/argmax
fijal
noreply at buildbot.pypy.org
Wed Sep 5 18:00:01 CEST 2012
Author: Maciej Fijalkowski <fijall at gmail.com>
Branch: numpy-refactor
Changeset: r57152:d3558890855c
Date: 2012-09-05 17:59 +0200
http://bitbucket.org/pypy/pypy/changeset/d3558890855c/
Log: argmin/argmax
diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -8,6 +8,7 @@
get_shape_from_iterable
from pypy.module.micronumpy.interp_support import unwrap_axis_arg
from pypy.module.micronumpy.appbridge import get_appbridge_cache
+from pypy.module.micronumpy import loop
from pypy.tool.sourcetools import func_with_new_name
from pypy.rlib import jit
from pypy.rlib.rstring import StringBuilder
@@ -259,6 +260,17 @@
w_denom = space.wrap(self.get_shape()[axis])
return space.div(self.descr_sum_promote(space, w_axis, w_out), w_denom)
+ def _reduce_argmax_argmin_impl(op_name):
+ def impl(self, space):
+ if self.get_size() == 0:
+ raise OperationError(space.w_ValueError,
+ space.wrap("Can't call %s on zero-size arrays" % op_name))
+ return space.wrap(loop.argmin_argmax(op_name, self))
+ return func_with_new_name(impl, "reduce_arg%s_impl" % op_name)
+
+ descr_argmax = _reduce_argmax_argmin_impl("max")
+ descr_argmin = _reduce_argmax_argmin_impl("min")
+
@unwrap_spec(offset=int)
def descr_new_array(space, w_subtype, w_shape, w_dtype=None, w_buffer=None,
@@ -339,8 +351,8 @@
prod = interp2app(W_NDimArray.descr_prod),
max = interp2app(W_NDimArray.descr_max),
min = interp2app(W_NDimArray.descr_min),
- #argmax = interp2app(W_NDimArray.descr_argmax),
- #argmin = interp2app(W_NDimArray.descr_argmin),
+ argmax = interp2app(W_NDimArray.descr_argmax),
+ argmin = interp2app(W_NDimArray.descr_argmin),
all = interp2app(W_NDimArray.descr_all),
any = interp2app(W_NDimArray.descr_any),
#dot = interp2app(W_NDimArray.descr_dot),
diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py
--- a/pypy/module/micronumpy/loop.py
+++ b/pypy/module/micronumpy/loop.py
@@ -3,6 +3,7 @@
signatures
"""
+from pypy.rlib.objectmodel import specialize
from pypy.module.micronumpy.base import W_NDimArray
def call2(shape, func, name, calc_dtype, res_dtype, w_lhs, w_rhs, out):
@@ -100,3 +101,21 @@
arr_iter.next()
out_iter.next()
return out
+
+ at specialize.arg(0)
+def argmin_argmax(op_name, arr):
+ result = 0
+ idx = 1
+ dtype = arr.get_dtype()
+ iter = arr.create_iter(arr.get_shape())
+ cur_best = iter.getitem()
+ iter.next()
+ while not iter.done():
+ w_val = iter.getitem()
+ new_best = getattr(dtype.itemtype, op_name)(cur_best, w_val)
+ if dtype.itemtype.ne(new_best, cur_best):
+ result = idx
+ cur_best = new_best
+ iter.next()
+ idx += 1
+ return result
More information about the pypy-commit
mailing list