[pypy-commit] pypy numpy-refactor: fix some scalar get/set cases

bdkearns noreply at buildbot.pypy.org
Thu Feb 27 01:59:49 CET 2014


Author: Brian Kearns <bdkearns at gmail.com>
Branch: numpy-refactor
Changeset: r69476:20666c9f42b8
Date: 2014-02-26 16:07 -0500
http://bitbucket.org/pypy/pypy/changeset/20666c9f42b8/

Log:	fix some scalar get/set cases

diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -306,7 +306,7 @@
         return len(self.get_shape()) == 0
 
     def set_scalar_value(self, w_val):
-        return self.implementation.setitem(0, w_val)
+        return self.implementation.setitem(self.implementation.start, w_val)
 
     def fill(self, space, box):
         self.implementation.fill(space, box)
@@ -318,8 +318,8 @@
         return self.implementation.get_size()
 
     def get_scalar_value(self):
-        assert len(self.get_shape()) == 0
-        return self.implementation.getitem(0)
+        assert self.get_size() == 1
+        return self.implementation.getitem(self.implementation.start)
 
     def descr_copy(self, space, w_order=None):
         order = order_converter(space, w_order, NPY.KEEPORDER)
@@ -490,19 +490,15 @@
 
     def descr_item(self, space, w_arg=None):
         if space.is_none(w_arg):
-            if self.is_scalar():
-                return self.get_scalar_value().item(space)
             if self.get_size() == 1:
-                w_obj = self.getitem(space,
-                                     [0] * len(self.get_shape()))
+                w_obj = self.get_scalar_value()
                 assert isinstance(w_obj, interp_boxes.W_GenericBox)
                 return w_obj.item(space)
-            raise OperationError(space.w_ValueError,
-                                 space.wrap("can only convert an array of size 1 to a Python scalar"))
+            raise oefmt(space.w_ValueError,
+                "can only convert an array of size 1 to a Python scalar")
         if space.isinstance_w(w_arg, space.w_int):
             if self.is_scalar():
-                raise OperationError(space.w_IndexError,
-                                     space.wrap("index out of bounds"))
+                raise oefmt(space.w_IndexError, "index out of bounds")
             i = self.to_coords(space, w_arg)
             item = self.getitem(space, i)
             assert isinstance(item, interp_boxes.W_GenericBox)
@@ -1041,7 +1037,7 @@
         if self.get_dtype().is_str_or_unicode():
             raise OperationError(space.w_TypeError, space.wrap(
                 "don't know how to convert scalar number to int"))
-        value = self.implementation.getitem(0)
+        value = self.get_scalar_value()
         return space.int(value)
 
     def descr_long(self, space):
@@ -1051,7 +1047,7 @@
         if self.get_dtype().is_str_or_unicode():
             raise OperationError(space.w_TypeError, space.wrap(
                 "don't know how to convert scalar number to long"))
-        value = self.implementation.getitem(0)
+        value = self.get_scalar_value()
         return space.long(value)
 
     def descr_float(self, space):
@@ -1061,7 +1057,7 @@
         if self.get_dtype().is_str_or_unicode():
             raise OperationError(space.w_TypeError, space.wrap(
                 "don't know how to convert scalar number to float"))
-        value = self.implementation.getitem(0)
+        value = self.get_scalar_value()
         return space.float(value)
 
     def descr_index(self, space):
@@ -1070,7 +1066,7 @@
             raise OperationError(space.w_TypeError, space.wrap(
                 "only integer arrays with one element "
                 "can be converted to an index"))
-        value = self.implementation.getitem(0)
+        value = self.get_scalar_value()
         assert isinstance(value, interp_boxes.W_GenericBox)
         return value.item(space)
 
diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -1415,6 +1415,12 @@
         b = a.sum(out=d)
         assert b == d
         assert b is d
+        c = array(1.5+2.5j)
+        assert c.real == 1.5
+        assert c.imag == 2.5
+        a.sum(out=c.imag)
+        assert c.real == 1.5
+        assert c.imag == 5
 
         assert list(zeros((0, 2)).sum(axis=1)) == []
 


More information about the pypy-commit mailing list