[pypy-commit] pypy numpy-refactor: argmin/argmax

fijal noreply at buildbot.pypy.org
Wed Sep 5 18:00:01 CEST 2012


Author: Maciej Fijalkowski <fijall at gmail.com>
Branch: numpy-refactor
Changeset: r57152:d3558890855c
Date: 2012-09-05 17:59 +0200
http://bitbucket.org/pypy/pypy/changeset/d3558890855c/

Log:	argmin/argmax

diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -8,6 +8,7 @@
      get_shape_from_iterable
 from pypy.module.micronumpy.interp_support import unwrap_axis_arg
 from pypy.module.micronumpy.appbridge import get_appbridge_cache
+from pypy.module.micronumpy import loop
 from pypy.tool.sourcetools import func_with_new_name
 from pypy.rlib import jit
 from pypy.rlib.rstring import StringBuilder
@@ -259,6 +260,17 @@
             w_denom = space.wrap(self.get_shape()[axis])
         return space.div(self.descr_sum_promote(space, w_axis, w_out), w_denom)
 
+    def _reduce_argmax_argmin_impl(op_name):
+        def impl(self, space):
+            if self.get_size() == 0:
+                raise OperationError(space.w_ValueError,
+                    space.wrap("Can't call %s on zero-size arrays" % op_name))
+            return space.wrap(loop.argmin_argmax(op_name, self))
+        return func_with_new_name(impl, "reduce_arg%s_impl" % op_name)
+
+    descr_argmax = _reduce_argmax_argmin_impl("max")
+    descr_argmin = _reduce_argmax_argmin_impl("min")
+
 
 @unwrap_spec(offset=int)
 def descr_new_array(space, w_subtype, w_shape, w_dtype=None, w_buffer=None,
@@ -339,8 +351,8 @@
     prod = interp2app(W_NDimArray.descr_prod),
     max = interp2app(W_NDimArray.descr_max),
     min = interp2app(W_NDimArray.descr_min),
-    #argmax = interp2app(W_NDimArray.descr_argmax),
-    #argmin = interp2app(W_NDimArray.descr_argmin),
+    argmax = interp2app(W_NDimArray.descr_argmax),
+    argmin = interp2app(W_NDimArray.descr_argmin),
     all = interp2app(W_NDimArray.descr_all),
     any = interp2app(W_NDimArray.descr_any),
     #dot = interp2app(W_NDimArray.descr_dot),
diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py
--- a/pypy/module/micronumpy/loop.py
+++ b/pypy/module/micronumpy/loop.py
@@ -3,6 +3,7 @@
 signatures
 """
 
+from pypy.rlib.objectmodel import specialize
 from pypy.module.micronumpy.base import W_NDimArray
 
 def call2(shape, func, name, calc_dtype, res_dtype, w_lhs, w_rhs, out):
@@ -100,3 +101,21 @@
         arr_iter.next()
         out_iter.next()
     return out
+
+ at specialize.arg(0)
+def argmin_argmax(op_name, arr):
+    result = 0
+    idx = 1
+    dtype = arr.get_dtype()
+    iter = arr.create_iter(arr.get_shape())
+    cur_best = iter.getitem()
+    iter.next()
+    while not iter.done():
+        w_val = iter.getitem()
+        new_best = getattr(dtype.itemtype, op_name)(cur_best, w_val)
+        if dtype.itemtype.ne(new_best, cur_best):
+            result = idx
+            cur_best = new_best
+        iter.next()
+        idx += 1
+    return result


More information about the pypy-commit mailing list