[pypy-commit] pypy numpypy-out: add more passing tests

mattip noreply at buildbot.pypy.org
Sat Feb 25 22:25:09 CET 2012


Author: mattip
Branch: numpypy-out
Changeset: r52904:0705b078c3c2
Date: 2012-02-25 23:15 +0200
http://bitbucket.org/pypy/pypy/changeset/0705b078c3c2/

Log:	add more passing tests

diff --git a/pypy/module/micronumpy/app_numpy.py b/pypy/module/micronumpy/app_numpy.py
--- a/pypy/module/micronumpy/app_numpy.py
+++ b/pypy/module/micronumpy/app_numpy.py
@@ -16,7 +16,7 @@
         a[i][i] = 1
     return a
 
-def sum(a,axis=None):
+def sum(a,axis=None, out=None):
     '''sum(a, axis=None)
     Sum of array elements over a given axis.
 
@@ -43,17 +43,17 @@
     # TODO: add to doc (once it's implemented): cumsum : Cumulative sum of array elements.
     if not hasattr(a, "sum"):
         a = _numpypy.array(a)
-    return a.sum(axis)
+    return a.sum(axis=axis, out=out)
 
-def min(a, axis=None):
+def min(a, axis=None, out=None):
     if not hasattr(a, "min"):
         a = _numpypy.array(a)
-    return a.min(axis)
+    return a.min(axis=axis, out=out)
 
-def max(a, axis=None):
+def max(a, axis=None, out=None):
     if not hasattr(a, "max"):
         a = _numpypy.array(a)
-    return a.max(axis)
+    return a.max(axis=axis, out=out)
 
 def arange(start, stop=None, step=1, dtype=None):
     '''arange([start], stop[, step], dtype=None)
diff --git a/pypy/module/micronumpy/interp_ufuncs.py b/pypy/module/micronumpy/interp_ufuncs.py
--- a/pypy/module/micronumpy/interp_ufuncs.py
+++ b/pypy/module/micronumpy/interp_ufuncs.py
@@ -321,10 +321,17 @@
         else:
             res_dtype = calc_dtype
         if isinstance(w_lhs, Scalar) and isinstance(w_rhs, Scalar):
-            return space.wrap(self.func(calc_dtype,
+            arr = self.func(calc_dtype,
                 w_lhs.value.convert_to(calc_dtype),
                 w_rhs.value.convert_to(calc_dtype)
-            ))
+            )
+            if isinstance(out,Scalar):
+                out.value=arr
+            elif isinstance(out, BaseArray):
+                out.fill(space, arr)
+            else:
+                out = arr
+            return space.wrap(out)
         new_shape = shape_agreement(space, w_lhs.shape, w_rhs.shape)
         # Test correctness of out.shape
         if out and out.shape != shape_agreement(space, new_shape, out.shape):
diff --git a/pypy/module/micronumpy/test/test_outarg.py b/pypy/module/micronumpy/test/test_outarg.py
--- a/pypy/module/micronumpy/test/test_outarg.py
+++ b/pypy/module/micronumpy/test/test_outarg.py
@@ -70,21 +70,44 @@
         assert c.dtype is a.dtype
         c[0,0] = 100
         assert out[0, 0] == 100
+        out[:] = 100
         raises(ValueError, 'c = add(a, a, out=out[1])')
         c = add(a[0], a[1], out=out[1])
         assert (c == out[1]).all()
         assert (c == [4, 6]).all()
+        assert (out[0] == 100).all()
         c = add(a[0], a[1], out=out)
         assert (c == out[1]).all()
         assert (c == out[0]).all()
+        out = array(16, dtype=int)
+        b = add(10, 10, out=out)
+        assert b==out
+        assert b.dtype == out.dtype
         
-
+    def test_applevel(self):
+        from _numpypy import array, sum, max, min
+        a = array([[1, 2], [3, 4]])
+        out = array([[0, 0], [0, 0]])
+        c = sum(a, axis=0, out=out[0])
+        assert (c == [4, 6]).all()
+        assert (c == out[0]).all()
+        assert (c != out[1]).all()
+        c = max(a, axis=1, out=out[0])
+        assert (c == [2, 4]).all()
+        assert (c == out[0]).all()
+        assert (c != out[1]).all()
+        
     def test_ufunc_cast(self):
-        from _numpypy import array, negative
+        from _numpypy import array, negative, add, sum
         a = array(16, dtype = int)
         c = array(0, dtype = float)
         b = negative(a, out=c)
         assert b == c
+        b = add(a, a, out=c)
+        assert b == c
+        d = array([16, 16], dtype=int)
+        b = sum(d, out=c)
+        assert b == c
         try:
             from _numpypy import version
             v = version.version.split('.')
@@ -93,6 +116,10 @@
         if v[0]<'2':
             b = negative(c, out=a)
             assert b == a
+            b = add(c, c, out=a)
+            assert b == a
+            b = sum(array([16, 16], dtype=float), out=a)
+            assert b == a
         else:
             cast_error = raises(TypeError, negative, c, a)
             assert str(cast_error.value) == \


More information about the pypy-commit mailing list