[pypy-commit] pypy scalar-operations: fix performance of ufunc(scalar, scalar)

rlamy noreply at buildbot.pypy.org
Thu Jun 26 20:57:00 CEST 2014


Author: Ronan Lamy <ronan.lamy at gmail.com>
Branch: scalar-operations
Changeset: r72246:0620f0c12772
Date: 2014-06-26 18:10 +0100
http://bitbucket.org/pypy/pypy/changeset/0620f0c12772/

Log:	fix performance of ufunc(scalar, scalar)

diff --git a/pypy/module/micronumpy/base.py b/pypy/module/micronumpy/base.py
--- a/pypy/module/micronumpy/base.py
+++ b/pypy/module/micronumpy/base.py
@@ -18,7 +18,12 @@
     pass
 
 
-class W_NDimArray(W_Root):
+class W_NumpyObject(W_Root):
+    """Base class for ndarrays and scalars (aka boxes)."""
+    _attrs_ = []
+
+
+class W_NDimArray(W_NumpyObject):
     __metaclass__ = extendabletype
 
     def __init__(self, implementation):
diff --git a/pypy/module/micronumpy/boxes.py b/pypy/module/micronumpy/boxes.py
--- a/pypy/module/micronumpy/boxes.py
+++ b/pypy/module/micronumpy/boxes.py
@@ -1,4 +1,3 @@
-from pypy.interpreter.baseobjspace import W_Root
 from pypy.interpreter.error import OperationError, oefmt
 from pypy.interpreter.gateway import interp2app, unwrap_spec
 from pypy.interpreter.mixedmodule import MixedModule
@@ -14,7 +13,7 @@
 from rpython.rtyper.lltypesystem import lltype, rffi
 from rpython.tool.sourcetools import func_with_new_name
 from pypy.module.micronumpy import constants as NPY
-from pypy.module.micronumpy.base import W_NDimArray
+from pypy.module.micronumpy.base import W_NDimArray, W_NumpyObject
 from pypy.module.micronumpy.concrete import VoidBoxStorage
 from pypy.module.micronumpy.flagsobj import W_FlagsObject
 
@@ -126,7 +125,7 @@
         return ret
 
 
-class W_GenericBox(W_Root):
+class W_GenericBox(W_NumpyObject):
     _attrs_ = ['w_flags']
 
     def descr__new__(space, w_subtype, __args__):
@@ -136,6 +135,12 @@
     def get_dtype(self, space):
         return self._get_dtype(space)
 
+    def is_scalar(self):
+        return True
+
+    def get_scalar_value(self):
+        return self
+
     def item(self, space):
         return self.get_dtype(space).itemtype.to_builtin_type(space, self)
 
diff --git a/pypy/module/micronumpy/ufuncs.py b/pypy/module/micronumpy/ufuncs.py
--- a/pypy/module/micronumpy/ufuncs.py
+++ b/pypy/module/micronumpy/ufuncs.py
@@ -385,10 +385,15 @@
         else:
             [w_lhs, w_rhs] = args_w
             w_out = None
-        w_lhs = convert_to_array(space, w_lhs)
-        w_rhs = convert_to_array(space, w_rhs)
-        w_ldtype = w_lhs.get_dtype()
-        w_rdtype = w_rhs.get_dtype()
+        if (isinstance(w_lhs, boxes.W_GenericBox) and
+                isinstance(w_rhs, boxes.W_GenericBox)):
+            w_ldtype = w_lhs.get_dtype(space)
+            w_rdtype = w_rhs.get_dtype(space)
+        else:
+            w_lhs = convert_to_array(space, w_lhs)
+            w_rhs = convert_to_array(space, w_rhs)
+            w_ldtype = w_lhs.get_dtype()
+            w_rdtype = w_rhs.get_dtype()
         if w_ldtype.is_str() and w_rdtype.is_str() and \
                 self.comparison_func:
             pass
@@ -451,6 +456,8 @@
             else:
                 out = arr
             return out
+        assert isinstance(w_lhs, W_NDimArray)
+        assert isinstance(w_rhs, W_NDimArray)
         new_shape = shape_agreement(space, w_lhs.get_shape(), w_rhs)
         new_shape = shape_agreement(space, new_shape, out, broadcast_down=False)
         return loop.call2(space, new_shape, self.func, calc_dtype,


More information about the pypy-commit mailing list