[pypy-commit] pypy default: correct the dtype of scalar arrays

alex_gaynor noreply at buildbot.pypy.org
Mon Nov 28 18:20:48 CET 2011


Author: Alex Gaynor <alex.gaynor at gmail.com>
Branch: 
Changeset: r49922:74b0db67fd56
Date: 2011-11-28 12:11 -0500
http://bitbucket.org/pypy/pypy/changeset/74b0db67fd56/

Log:	correct the dtype of scalar arrays

diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -102,11 +102,12 @@
                     w_order=NoneNotWrapped):
     # find scalar
     if not space.issequence_w(w_item_or_iterable):
-        w_dtype = interp_ufuncs.find_dtype_for_scalar(space,
-                                                      w_item_or_iterable,
-                                                      w_dtype)
+        if space.is_w(w_dtype, space.w_None):
+            w_dtype = interp_ufuncs.find_dtype_for_scalar(space,
+                                                          w_item_or_iterable)
         dtype = space.interp_w(interp_dtype.W_Dtype,
-           space.call_function(space.gettypefor(interp_dtype.W_Dtype), w_dtype))
+            space.call_function(space.gettypefor(interp_dtype.W_Dtype), w_dtype)
+        )
         return scalar_w(space, dtype, w_item_or_iterable)
     if w_order is None:
         order = 'C'
diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -176,8 +176,7 @@
 
     def test_size(self):
         from numpypy import array
-        # XXX fixed on multidim branch
-        #assert array(3).size == 1
+        assert array(3).size == 1
         a = array([1, 2, 3])
         assert a.size == 3
         assert (a + a).size == 3
@@ -302,12 +301,13 @@
         assert a[3] == 0.
 
     def test_scalar(self):
-        from numpypy import array
+        from numpypy import array, dtype
         a = array(3)
         #assert a[0] == 3
         raises(IndexError, "a[0]")
         assert a.size == 1
         assert a.shape == ()
+        assert a.dtype is dtype(int)
 
     def test_len(self):
         from numpypy import array


More information about the pypy-commit mailing list