[pypy-commit] pypy scalar-operations: fix performance of ufunc(scalar, scalar)
rlamy
noreply at buildbot.pypy.org
Thu Jun 26 20:57:00 CEST 2014
Author: Ronan Lamy <ronan.lamy at gmail.com>
Branch: scalar-operations
Changeset: r72246:0620f0c12772
Date: 2014-06-26 18:10 +0100
http://bitbucket.org/pypy/pypy/changeset/0620f0c12772/
Log: fix performance of ufunc(scalar, scalar)
diff --git a/pypy/module/micronumpy/base.py b/pypy/module/micronumpy/base.py
--- a/pypy/module/micronumpy/base.py
+++ b/pypy/module/micronumpy/base.py
@@ -18,7 +18,12 @@
pass
-class W_NDimArray(W_Root):
+class W_NumpyObject(W_Root):
+ """Base class for ndarrays and scalars (aka boxes)."""
+ _attrs_ = []
+
+
+class W_NDimArray(W_NumpyObject):
__metaclass__ = extendabletype
def __init__(self, implementation):
diff --git a/pypy/module/micronumpy/boxes.py b/pypy/module/micronumpy/boxes.py
--- a/pypy/module/micronumpy/boxes.py
+++ b/pypy/module/micronumpy/boxes.py
@@ -1,4 +1,3 @@
-from pypy.interpreter.baseobjspace import W_Root
from pypy.interpreter.error import OperationError, oefmt
from pypy.interpreter.gateway import interp2app, unwrap_spec
from pypy.interpreter.mixedmodule import MixedModule
@@ -14,7 +13,7 @@
from rpython.rtyper.lltypesystem import lltype, rffi
from rpython.tool.sourcetools import func_with_new_name
from pypy.module.micronumpy import constants as NPY
-from pypy.module.micronumpy.base import W_NDimArray
+from pypy.module.micronumpy.base import W_NDimArray, W_NumpyObject
from pypy.module.micronumpy.concrete import VoidBoxStorage
from pypy.module.micronumpy.flagsobj import W_FlagsObject
@@ -126,7 +125,7 @@
return ret
-class W_GenericBox(W_Root):
+class W_GenericBox(W_NumpyObject):
_attrs_ = ['w_flags']
def descr__new__(space, w_subtype, __args__):
@@ -136,6 +135,12 @@
def get_dtype(self, space):
return self._get_dtype(space)
+ def is_scalar(self):
+ return True
+
+ def get_scalar_value(self):
+ return self
+
def item(self, space):
return self.get_dtype(space).itemtype.to_builtin_type(space, self)
diff --git a/pypy/module/micronumpy/ufuncs.py b/pypy/module/micronumpy/ufuncs.py
--- a/pypy/module/micronumpy/ufuncs.py
+++ b/pypy/module/micronumpy/ufuncs.py
@@ -385,10 +385,15 @@
else:
[w_lhs, w_rhs] = args_w
w_out = None
- w_lhs = convert_to_array(space, w_lhs)
- w_rhs = convert_to_array(space, w_rhs)
- w_ldtype = w_lhs.get_dtype()
- w_rdtype = w_rhs.get_dtype()
+ if (isinstance(w_lhs, boxes.W_GenericBox) and
+ isinstance(w_rhs, boxes.W_GenericBox)):
+ w_ldtype = w_lhs.get_dtype(space)
+ w_rdtype = w_rhs.get_dtype(space)
+ else:
+ w_lhs = convert_to_array(space, w_lhs)
+ w_rhs = convert_to_array(space, w_rhs)
+ w_ldtype = w_lhs.get_dtype()
+ w_rdtype = w_rhs.get_dtype()
if w_ldtype.is_str() and w_rdtype.is_str() and \
self.comparison_func:
pass
@@ -451,6 +456,8 @@
else:
out = arr
return out
+ assert isinstance(w_lhs, W_NDimArray)
+ assert isinstance(w_rhs, W_NDimArray)
new_shape = shape_agreement(space, w_lhs.get_shape(), w_rhs)
new_shape = shape_agreement(space, new_shape, out, broadcast_down=False)
return loop.call2(space, new_shape, self.func, calc_dtype,
More information about the pypy-commit
mailing list