[pypy-commit] pypy fix-result-types: first step towards computing the loop's output type in _calc_dtype()

rlamy noreply at buildbot.pypy.org
Wed May 13 19:52:33 CEST 2015


Author: Ronan Lamy <ronan.lamy at gmail.com>
Branch: fix-result-types
Changeset: r77312:e46d1376e0d6
Date: 2015-05-13 18:33 +0100
http://bitbucket.org/pypy/pypy/changeset/e46d1376e0d6/

Log:	first step towards computing the loop's output type in _calc_dtype()

diff --git a/pypy/module/micronumpy/test/test_ufuncs.py b/pypy/module/micronumpy/test/test_ufuncs.py
--- a/pypy/module/micronumpy/test/test_ufuncs.py
+++ b/pypy/module/micronumpy/test/test_ufuncs.py
@@ -59,14 +59,14 @@
         dt_float16 = get_dtype_cache(space).w_float16dtype
         dt_int32 = get_dtype_cache(space).w_int32dtype
         ufunc = W_Ufunc1(None, 'x', int_only=True)
-        assert ufunc._calc_dtype(space, dt_bool) == dt_bool
+        assert ufunc._calc_dtype(space, dt_bool, out=None) == (dt_bool, dt_bool)
         assert ufunc.allowed_types(space)  # XXX: shouldn't contain too much stuff
 
         ufunc = W_Ufunc1(None, 'x', promote_to_float=True)
-        assert ufunc._calc_dtype(space, dt_bool) == dt_float16
+        assert ufunc._calc_dtype(space, dt_bool, out=None) == (dt_float16, dt_float16)
 
         ufunc = W_Ufunc1(None, 'x')
-        assert ufunc._calc_dtype(space, dt_int32) == dt_int32
+        assert ufunc._calc_dtype(space, dt_int32, out=None) == (dt_int32, dt_int32)
 
 class AppTestUfuncs(BaseNumpyAppTest):
     def test_constants(self):
diff --git a/pypy/module/micronumpy/ufuncs.py b/pypy/module/micronumpy/ufuncs.py
--- a/pypy/module/micronumpy/ufuncs.py
+++ b/pypy/module/micronumpy/ufuncs.py
@@ -371,7 +371,7 @@
                 not self.allow_complex and dtype.is_complex()):
             raise oefmt(space.w_TypeError,
                 "ufunc %s not supported for the input type", self.name)
-        calc_dtype = self._calc_dtype(space, dtype)
+        dt_in, dt_out = self._calc_dtype(space, dtype, out)
 
         if out is not None:
             res_dtype = out.get_dtype()
@@ -381,25 +381,32 @@
         elif self.bool_result:
             res_dtype = get_dtype_cache(space).w_booldtype
         else:
-            res_dtype = calc_dtype
-            if self.complex_to_float and calc_dtype.is_complex():
-                if calc_dtype.num == NPY.CFLOAT:
+            res_dtype = dt_in
+            if self.complex_to_float and dt_in.is_complex():
+                if dt_in.num == NPY.CFLOAT:
                     res_dtype = get_dtype_cache(space).w_float32dtype
                 else:
                     res_dtype = get_dtype_cache(space).w_float64dtype
-        return calc_dtype, res_dtype, self.func
+        return dt_in, res_dtype, self.func
 
-    def _calc_dtype(self, space, arg_dtype):
-        use_min_scalar=False
+    def _calc_dtype(self, space, arg_dtype, out):
+        use_min_scalar = False
         if arg_dtype.is_object():
-            return arg_dtype
+            return arg_dtype, arg_dtype
         for dtype in self.allowed_types(space):
             if use_min_scalar:
-                if can_cast_array(space, w_arg, dtype, casting='safe'):
-                    return dtype
+                if not can_cast_array(space, w_arg, dtype, casting='safe'):
+                    continue
             else:
-                if can_cast_type(space, arg_dtype, dtype, casting='safe'):
-                    return dtype
+                if not can_cast_type(space, arg_dtype, dtype, casting='safe'):
+                    continue
+            dt_out = dtype
+            if out is not None:
+                res_dtype = out.get_dtype()
+                if not can_cast_type(space, dt_out, res_dtype, 'unsafe'):
+                    continue
+            return dtype, dt_out
+
         else:
             raise oefmt(space.w_TypeError,
                 "No loop matching the specified signature was found "


More information about the pypy-commit mailing list