[pypy-commit] pypy numpypy-argminmax: make translation work

mattip noreply at buildbot.pypy.org
Wed Jul 4 23:05:18 CEST 2012


Author: mattip <matti.picus at gmail.com>
Branch: numpypy-argminmax
Changeset: r55922:5f31264d67ef
Date: 2012-07-05 00:00 +0300
http://bitbucket.org/pypy/pypy/changeset/5f31264d67ef/

Log:	make translation work

diff --git a/pypy/module/micronumpy/interp_iter.py b/pypy/module/micronumpy/interp_iter.py
--- a/pypy/module/micronumpy/interp_iter.py
+++ b/pypy/module/micronumpy/interp_iter.py
@@ -364,7 +364,10 @@
         self.indices = [0] * len(arr.shape)
         self.done = False
         self.offset = arr.start
-        self.dimorder = [dim] +range(len(arr.shape)-1, dim, -1) + range(dim-1, -1, -1)
+        # range is an iterator, make its result concrete
+        second_piece = [i for i in range(len(arr.shape)-1, dim, -1)]
+        third_piece = [i for i in range(dim-1, -1, -1)]
+        self.dimorder = [dim] + second_piece + third_piece
 
     def next(self):
         for i in self.dimorder:
diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -186,11 +186,11 @@
             name=name
         )
         def do_argminmax(self, space, axis, out):
+            res_dtype = interp_dtype.get_dtype_cache(space).w_int32dtype
             if isinstance(self, Scalar):
-                return 0
+                return Scalar(res_dtype, res_dtype.box(0))
             dtype = self.find_dtype()
             # numpy compatability demands int32 not uint32
-            res_dtype = interp_dtype.get_dtype_cache(space).w_int32dtype
             assert axis>=0
             if axis < len(self.shape):
                 if out:
@@ -227,17 +227,18 @@
             # Use a AxisFirstIterator to walk along self, with dimensions
             # reordered to move along 'axis' fastest. Every time 'axis' 's
             # index is 0, move to the next value of out.
-            dtype = self.find_dtype()
-            source = AxisFirstIterator(self, axis)
-            dest = ViewIterator(out.start, out.strides, out.backstrides, 
-                                out.shape)
+            concr = self.get_concrete()
+            dtype = concr.find_dtype()
+            source = AxisFirstIterator(concr, axis)
+            dest = out.create_iter()
             firsttime = True
+            cur_best = concr.getitem(source.offset)
             while not source.done:
-                cur_val = self.getitem(source.offset)
+                cur_val = concr.getitem(source.offset)
                 cur_index = source.get_dim_index()
                 if cur_index == 0:
                     if not firsttime:
-                        dest = dest.next(len(self.shape))
+                        dest = dest.next(len(concr.shape))
                     firsttime = False    
                     cur_best = cur_val
                     out.setitem(dest.offset, dtype.box(0))
@@ -268,14 +269,14 @@
                     shape = [1]
                 #Test for shape agreement
                 if len(out.shape) > len(shape):
-                    raise OperationError(space.w_TypesError,
+                    raise OperationError(space.w_TypeError,
                         space.wrap('invalid shape for output array'))
                 elif len(out.shape) < len(shape):
-                    raise OperationError(space.w_TypesError,
-                        space.wrape('invalid shape for output array'))
+                    raise OperationError(space.w_TypeError,
+                        space.wrap('invalid shape for output array'))
                 elif out.shape != shape:
-                    raise OperationError(space.w_TypesError,
-                        space.wrape('invalid shape for output array'))
+                    raise OperationError(space.w_TypeError,
+                        space.wrap('invalid shape for output array'))
                 #Test for dtype agreement, perhaps create an itermediate
                 #if out.dtype != self.dtype:
                 #    raise OperationError(space.w_TypeError, space.wrap(


More information about the pypy-commit mailing list