[pypy-commit] pypy fix-result-types: Use promote_types() for binary ufunc resolution in some cases
rlamy
noreply at buildbot.pypy.org
Sat May 23 20:15:38 CEST 2015
Author: Ronan Lamy <ronan.lamy at gmail.com>
Branch: fix-result-types
Changeset: r77510:237700dbd639
Date: 2015-05-23 19:16 +0100
http://bitbucket.org/pypy/pypy/changeset/237700dbd639/
Log: Use promote_types() for binary ufunc resolution in some cases
diff --git a/pypy/module/micronumpy/test/test_ndarray.py b/pypy/module/micronumpy/test/test_ndarray.py
--- a/pypy/module/micronumpy/test/test_ndarray.py
+++ b/pypy/module/micronumpy/test/test_ndarray.py
@@ -1084,6 +1084,7 @@
b = a * a
for i in range(5):
assert b[i] == i * i
+ assert a.dtype.num == b.dtype.num
assert b.dtype is a.dtype
a = numpy.array(range(5), dtype=bool)
diff --git a/pypy/module/micronumpy/ufuncs.py b/pypy/module/micronumpy/ufuncs.py
--- a/pypy/module/micronumpy/ufuncs.py
+++ b/pypy/module/micronumpy/ufuncs.py
@@ -20,7 +20,7 @@
from pypy.module.micronumpy.strides import shape_agreement
from pypy.module.micronumpy.support import (_parse_signature, product,
get_storage_as_int, is_rhs_priority_higher)
-from .casting import can_cast_type, find_result_type
+from .casting import can_cast_type, find_result_type, _promote_types
from .boxes import W_GenericBox, W_ObjectBox
def done_if_true(dtype, val):
@@ -538,7 +538,7 @@
class W_Ufunc2(W_Ufunc):
- _immutable_fields_ = ["func", "bool_result", "done_func"]
+ _immutable_fields_ = ["func", "bool_result", "done_func", "simple_binary"]
nin = 2
nout = 1
nargs = 3
@@ -557,6 +557,10 @@
self.done_func = done_if_true
else:
self.done_func = None
+ self.simple_binary = (
+ allow_complex and allow_bool and not bool_result and not int_only
+ and not complex_to_float and not promote_to_float
+ and not promote_bools)
def are_common_types(self, dtype1, dtype2):
if dtype1.is_bool() or dtype2.is_bool():
@@ -659,7 +663,7 @@
return w_val.w_obj
return w_val
- def find_specialization(self, space, l_dtype, r_dtype, out, casting):
+ def _find_specialization(self, space, l_dtype, r_dtype, out, casting):
if (not self.allow_bool and (l_dtype.is_bool() or
r_dtype.is_bool()) or
not self.allow_complex and (l_dtype.is_complex() or
@@ -674,6 +678,13 @@
dt_in, dt_out = self._calc_dtype(space, l_dtype, r_dtype, out, casting)
return dt_in, dt_out, self.func
+ def find_specialization(self, space, l_dtype, r_dtype, out, casting):
+ if self.simple_binary:
+ if out is None and not (l_dtype.is_object() or r_dtype.is_object()):
+ dtype = _promote_types(space, l_dtype, r_dtype)
+ return dtype, dtype, self.func
+ return self._find_specialization(space, l_dtype, r_dtype, out, casting)
+
def find_binop_type(self, space, dtype):
"""Find a valid dtype signature of the form xx->x"""
if dtype.is_object():
More information about the pypy-commit
mailing list