[pypy-commit] pypy fix-result-types: move more stuff inside find_specialization()

rlamy noreply at buildbot.pypy.org
Wed May 13 19:52:32 CEST 2015


Author: Ronan Lamy <ronan.lamy at gmail.com>
Branch: fix-result-types
Changeset: r77311:65d85120736d
Date: 2015-05-13 04:42 +0100
http://bitbucket.org/pypy/pypy/changeset/65d85120736d/

Log:	move more stuff inside find_specialization()

diff --git a/pypy/module/micronumpy/ufuncs.py b/pypy/module/micronumpy/ufuncs.py
--- a/pypy/module/micronumpy/ufuncs.py
+++ b/pypy/module/micronumpy/ufuncs.py
@@ -336,25 +336,11 @@
             out = args_w[1]
             if space.is_w(out, space.w_None):
                 out = None
+            elif out is not None and not isinstance(out, W_NDimArray):
+                raise oefmt(space.w_TypeError, 'output must be an array')
         w_obj = numpify(space, w_obj)
         dtype = w_obj.get_dtype(space)
-        calc_dtype, func = self.find_specialization(space, dtype)
-        if out is not None:
-            if not isinstance(out, W_NDimArray):
-                raise oefmt(space.w_TypeError, 'output must be an array')
-            res_dtype = out.get_dtype()
-            #if not w_obj.get_dtype().can_cast_to(res_dtype):
-            #    raise oefmt(space.w_TypeError,
-            #        "Cannot cast ufunc %s output from dtype('%s') to dtype('%s') with casting rule 'same_kind'", self.name, w_obj.get_dtype().name, res_dtype.name)
-        elif self.bool_result:
-            res_dtype = get_dtype_cache(space).w_booldtype
-        else:
-            res_dtype = calc_dtype
-            if self.complex_to_float and calc_dtype.is_complex():
-                if calc_dtype.num == NPY.CFLOAT:
-                    res_dtype = get_dtype_cache(space).w_float32dtype
-                else:
-                    res_dtype = get_dtype_cache(space).w_float64dtype
+        calc_dtype, res_dtype, func = self.find_specialization(space, dtype, out)
         if w_obj.is_scalar():
             return self.call_scalar(space, w_obj.get_scalar_value(),
                                     calc_dtype, res_dtype, out)
@@ -377,7 +363,7 @@
             out.fill(space, w_val)
         return out
 
-    def find_specialization(self, space, dtype):
+    def find_specialization(self, space, dtype, out):
         if dtype.is_flexible():
             raise oefmt(space.w_TypeError, 'Not implemented for this type')
         if (self.int_only and not (dtype.is_int() or dtype.is_object()) or
@@ -386,7 +372,22 @@
             raise oefmt(space.w_TypeError,
                 "ufunc %s not supported for the input type", self.name)
         calc_dtype = self._calc_dtype(space, dtype)
-        return calc_dtype, self.func
+
+        if out is not None:
+            res_dtype = out.get_dtype()
+            #if not w_obj.get_dtype().can_cast_to(res_dtype):
+            #    raise oefmt(space.w_TypeError,
+            #        "Cannot cast ufunc %s output from dtype('%s') to dtype('%s') with casting rule 'same_kind'", self.name, w_obj.get_dtype().name, res_dtype.name)
+        elif self.bool_result:
+            res_dtype = get_dtype_cache(space).w_booldtype
+        else:
+            res_dtype = calc_dtype
+            if self.complex_to_float and calc_dtype.is_complex():
+                if calc_dtype.num == NPY.CFLOAT:
+                    res_dtype = get_dtype_cache(space).w_float32dtype
+                else:
+                    res_dtype = get_dtype_cache(space).w_float64dtype
+        return calc_dtype, res_dtype, self.func
 
     def _calc_dtype(self, space, arg_dtype):
         use_min_scalar=False


More information about the pypy-commit mailing list