[pypy-commit] pypy numpy-back-to-applevel: implement keepdims=True

fijal noreply at buildbot.pypy.org
Sat Jan 21 18:35:22 CET 2012


Author: Maciej Fijalkowski <fijall at gmail.com>
Branch: numpy-back-to-applevel
Changeset: r51598:0fcad0cba011
Date: 2012-01-21 19:34 +0200
http://bitbucket.org/pypy/pypy/changeset/0fcad0cba011/

Log:	implement keepdims=True

diff --git a/pypy/module/micronumpy/interp_iter.py b/pypy/module/micronumpy/interp_iter.py
--- a/pypy/module/micronumpy/interp_iter.py
+++ b/pypy/module/micronumpy/interp_iter.py
@@ -153,8 +153,13 @@
 class AxisIterator(BaseIterator):
     def __init__(self, start, dim, shape, strides, backstrides):
         self.res_shape = shape[:]
-        self.strides = strides[:dim] + [0] + strides[dim:]
-        self.backstrides = backstrides[:dim] + [0] + backstrides[dim:]
+        if len(shape) == len(strides):
+            # keepdims = True
+            self.strides = strides[:dim] + [0] + strides[dim + 1:]
+            self.backstrides = backstrides[:dim] + [0] + backstrides[dim + 1:]
+        else:
+            self.strides = strides[:dim] + [0] + strides[dim:]
+            self.backstrides = backstrides[:dim] + [0] + backstrides[dim:]
         self.first_line = True
         self.indices = [0] * len(shape)
         self._done = False
diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -1077,6 +1077,8 @@
 def array(space, w_item_or_iterable, w_dtype=None, w_order=None,
           subok=True, copy=False, w_maskna=None, ownmaskna=False):
     # find scalar
+    if w_maskna is None:
+        w_maskna = space.w_None
     if (not subok or copy or not space.is_w(w_maskna, space.w_None) or
         ownmaskna):
         raise OperationError(space.w_NotImplementedError, space.wrap("Unsupported args"))
@@ -1088,7 +1090,7 @@
             space.call_function(space.gettypefor(interp_dtype.W_Dtype), w_dtype)
         )
         return scalar_w(space, dtype, w_item_or_iterable)
-    if space.is_w(w_order, space.w_None):
+    if space.is_w(w_order, space.w_None) or w_order is None:
         order = 'C'
     else:
         order = space.str_w(w_order)
diff --git a/pypy/module/micronumpy/interp_ufuncs.py b/pypy/module/micronumpy/interp_ufuncs.py
--- a/pypy/module/micronumpy/interp_ufuncs.py
+++ b/pypy/module/micronumpy/interp_ufuncs.py
@@ -46,7 +46,8 @@
         return self.identity
 
     def descr_call(self, space, __args__):
-        if __args__.keywords or len(__args__.arguments_w) < self.argcount:
+        # XXX do something with strange keywords
+        if len(__args__.arguments_w) < self.argcount:
             raise OperationError(space.w_ValueError,
                 space.wrap("invalid number of arguments")
             )
@@ -60,7 +61,7 @@
 
     @unwrap_spec(skipna=bool, keepdims=bool)
     def descr_reduce(self, space, w_obj, w_axis=None, w_dtype=None,
-                     skipna=False, keepdims=True, w_out=None):
+                     skipna=False, keepdims=False, w_out=None):
         """reduce(...)
         reduce(a, axis=0)
 
@@ -120,9 +121,9 @@
             axis = -1
         else:
             axis = space.int_w(w_axis)
-        return self.reduce(space, w_obj, False, False, axis)
+        return self.reduce(space, w_obj, False, False, axis, keepdims)
 
-    def reduce(self, space, w_obj, multidim, promote_to_largest, dim):
+    def reduce(self, space, w_obj, multidim, promote_to_largest, dim, keepdims):
         from pypy.module.micronumpy.interp_numarray import convert_to_array, \
                                                            Scalar
         if self.argcount != 2:
@@ -148,7 +149,7 @@
             raise operationerrfmt(space.w_ValueError, "zero-size array to "
                     "%s.reduce without identity", self.name)
         if shapelen > 1 and dim >= 0:
-            res = self.do_axis_reduce(obj, dtype, dim)
+            res = self.do_axis_reduce(obj, dtype, dim, keepdims)
             return space.wrap(res)
         scalarsig = ScalarSignature(dtype)
         sig = find_sig(ReduceSignature(self.func, self.name, dtype,
@@ -162,11 +163,14 @@
             value = self.identity.convert_to(dtype)
         return self.reduce_loop(shapelen, sig, frame, value, obj, dtype)
 
-    def do_axis_reduce(self, obj, dtype, dim):
+    def do_axis_reduce(self, obj, dtype, dim, keepdims):
         from pypy.module.micronumpy.interp_numarray import AxisReduce,\
              W_NDimArray
-        
-        shape = obj.shape[0:dim] + obj.shape[dim + 1:len(obj.shape)]
+
+        if keepdims:
+            shape = obj.shape[:dim] + [1] + obj.shape[dim + 1:]
+        else:
+            shape = obj.shape[:dim] + obj.shape[dim + 1:]
         size = 1
         for s in shape:
             size *= s
diff --git a/pypy/module/micronumpy/strides.py b/pypy/module/micronumpy/strides.py
--- a/pypy/module/micronumpy/strides.py
+++ b/pypy/module/micronumpy/strides.py
@@ -1,5 +1,5 @@
 from pypy.rlib import jit
-
+from pypy.interpreter.error import OperationError
 
 @jit.look_inside_iff(lambda shape, start, strides, backstrides, chunks:
     jit.isconstant(len(chunks))
diff --git a/pypy/module/micronumpy/test/test_ufuncs.py b/pypy/module/micronumpy/test/test_ufuncs.py
--- a/pypy/module/micronumpy/test/test_ufuncs.py
+++ b/pypy/module/micronumpy/test/test_ufuncs.py
@@ -344,7 +344,7 @@
         from _numpypy import sin, add
 
         raises(ValueError, sin.reduce, [1, 2, 3])
-        raises(ValueError, add.reduce, 1)
+        raises(TypeError, add.reduce, 1)
 
     def test_reduce_1d(self):
         from _numpypy import add, maximum
@@ -360,6 +360,14 @@
         assert (add.reduce(a, 0) == [12, 15, 18, 21]).all()
         assert (add.reduce(a, 1) == [6.0, 22.0, 38.0]).all()
 
+    def test_reduce_keepdims(self):
+        from _numpypy import add, arange
+        a = arange(12).reshape(3, 4)
+        b = add.reduce(a, 0, keepdims=True)
+        assert b.shape == (1, 4)
+        assert (add.reduce(a, 0, keepdims=True) == [12, 15, 18, 21]).all()
+        
+
     def test_bitwise(self):
         from _numpypy import bitwise_and, bitwise_or, arange, array
         a = arange(6).reshape(2, 3)


More information about the pypy-commit mailing list