[pypy-commit] pypy fix-result-types: Use promote_types() for binary ufunc resolution in some cases

rlamy noreply at buildbot.pypy.org
Sat May 23 20:15:38 CEST 2015


Author: Ronan Lamy <ronan.lamy at gmail.com>
Branch: fix-result-types
Changeset: r77510:237700dbd639
Date: 2015-05-23 19:16 +0100
http://bitbucket.org/pypy/pypy/changeset/237700dbd639/

Log:	Use promote_types() for binary ufunc resolution in some cases

diff --git a/pypy/module/micronumpy/test/test_ndarray.py b/pypy/module/micronumpy/test/test_ndarray.py
--- a/pypy/module/micronumpy/test/test_ndarray.py
+++ b/pypy/module/micronumpy/test/test_ndarray.py
@@ -1084,6 +1084,7 @@
         b = a * a
         for i in range(5):
             assert b[i] == i * i
+        assert a.dtype.num == b.dtype.num
         assert b.dtype is a.dtype
 
         a = numpy.array(range(5), dtype=bool)
diff --git a/pypy/module/micronumpy/ufuncs.py b/pypy/module/micronumpy/ufuncs.py
--- a/pypy/module/micronumpy/ufuncs.py
+++ b/pypy/module/micronumpy/ufuncs.py
@@ -20,7 +20,7 @@
 from pypy.module.micronumpy.strides import shape_agreement
 from pypy.module.micronumpy.support import (_parse_signature, product,
         get_storage_as_int, is_rhs_priority_higher)
-from .casting import can_cast_type, find_result_type
+from .casting import can_cast_type, find_result_type, _promote_types
 from .boxes import W_GenericBox, W_ObjectBox
 
 def done_if_true(dtype, val):
@@ -538,7 +538,7 @@
 
 
 class W_Ufunc2(W_Ufunc):
-    _immutable_fields_ = ["func", "bool_result", "done_func"]
+    _immutable_fields_ = ["func", "bool_result", "done_func", "simple_binary"]
     nin = 2
     nout = 1
     nargs = 3
@@ -557,6 +557,10 @@
             self.done_func = done_if_true
         else:
             self.done_func = None
+        self.simple_binary = (
+            allow_complex and allow_bool and not bool_result and not int_only
+            and not complex_to_float and not promote_to_float
+            and not promote_bools)
 
     def are_common_types(self, dtype1, dtype2):
         if dtype1.is_bool() or dtype2.is_bool():
@@ -659,7 +663,7 @@
             return w_val.w_obj
         return w_val
 
-    def find_specialization(self, space, l_dtype, r_dtype, out, casting):
+    def _find_specialization(self, space, l_dtype, r_dtype, out, casting):
         if (not self.allow_bool and (l_dtype.is_bool() or
                                          r_dtype.is_bool()) or
                 not self.allow_complex and (l_dtype.is_complex() or
@@ -674,6 +678,13 @@
         dt_in, dt_out = self._calc_dtype(space, l_dtype, r_dtype, out, casting)
         return dt_in, dt_out, self.func
 
+    def find_specialization(self, space, l_dtype, r_dtype, out, casting):
+        if self.simple_binary:
+            if out is None and not (l_dtype.is_object() or r_dtype.is_object()):
+                dtype = _promote_types(space, l_dtype, r_dtype)
+                return dtype, dtype, self.func
+        return self._find_specialization(space, l_dtype, r_dtype, out, casting)
+
     def find_binop_type(self, space, dtype):
         """Find a valid dtype signature of the form xx->x"""
         if dtype.is_object():


More information about the pypy-commit mailing list