[pypy-commit] pypy fix-result-types: first step towards computing the loop's output type in _calc_dtype()
rlamy
noreply at buildbot.pypy.org
Wed May 13 19:52:33 CEST 2015
Author: Ronan Lamy <ronan.lamy at gmail.com>
Branch: fix-result-types
Changeset: r77312:e46d1376e0d6
Date: 2015-05-13 18:33 +0100
http://bitbucket.org/pypy/pypy/changeset/e46d1376e0d6/
Log: first step towards computing the loop's output type in _calc_dtype()
diff --git a/pypy/module/micronumpy/test/test_ufuncs.py b/pypy/module/micronumpy/test/test_ufuncs.py
--- a/pypy/module/micronumpy/test/test_ufuncs.py
+++ b/pypy/module/micronumpy/test/test_ufuncs.py
@@ -59,14 +59,14 @@
dt_float16 = get_dtype_cache(space).w_float16dtype
dt_int32 = get_dtype_cache(space).w_int32dtype
ufunc = W_Ufunc1(None, 'x', int_only=True)
- assert ufunc._calc_dtype(space, dt_bool) == dt_bool
+ assert ufunc._calc_dtype(space, dt_bool, out=None) == (dt_bool, dt_bool)
assert ufunc.allowed_types(space) # XXX: shouldn't contain too much stuff
ufunc = W_Ufunc1(None, 'x', promote_to_float=True)
- assert ufunc._calc_dtype(space, dt_bool) == dt_float16
+ assert ufunc._calc_dtype(space, dt_bool, out=None) == (dt_float16, dt_float16)
ufunc = W_Ufunc1(None, 'x')
- assert ufunc._calc_dtype(space, dt_int32) == dt_int32
+ assert ufunc._calc_dtype(space, dt_int32, out=None) == (dt_int32, dt_int32)
class AppTestUfuncs(BaseNumpyAppTest):
def test_constants(self):
diff --git a/pypy/module/micronumpy/ufuncs.py b/pypy/module/micronumpy/ufuncs.py
--- a/pypy/module/micronumpy/ufuncs.py
+++ b/pypy/module/micronumpy/ufuncs.py
@@ -371,7 +371,7 @@
not self.allow_complex and dtype.is_complex()):
raise oefmt(space.w_TypeError,
"ufunc %s not supported for the input type", self.name)
- calc_dtype = self._calc_dtype(space, dtype)
+ dt_in, dt_out = self._calc_dtype(space, dtype, out)
if out is not None:
res_dtype = out.get_dtype()
@@ -381,25 +381,32 @@
elif self.bool_result:
res_dtype = get_dtype_cache(space).w_booldtype
else:
- res_dtype = calc_dtype
- if self.complex_to_float and calc_dtype.is_complex():
- if calc_dtype.num == NPY.CFLOAT:
+ res_dtype = dt_in
+ if self.complex_to_float and dt_in.is_complex():
+ if dt_in.num == NPY.CFLOAT:
res_dtype = get_dtype_cache(space).w_float32dtype
else:
res_dtype = get_dtype_cache(space).w_float64dtype
- return calc_dtype, res_dtype, self.func
+ return dt_in, res_dtype, self.func
- def _calc_dtype(self, space, arg_dtype):
- use_min_scalar=False
+ def _calc_dtype(self, space, arg_dtype, out):
+ use_min_scalar = False
if arg_dtype.is_object():
- return arg_dtype
+ return arg_dtype, arg_dtype
for dtype in self.allowed_types(space):
if use_min_scalar:
- if can_cast_array(space, w_arg, dtype, casting='safe'):
- return dtype
+ if not can_cast_array(space, w_arg, dtype, casting='safe'):
+ continue
else:
- if can_cast_type(space, arg_dtype, dtype, casting='safe'):
- return dtype
+ if not can_cast_type(space, arg_dtype, dtype, casting='safe'):
+ continue
+ dt_out = dtype
+ if out is not None:
+ res_dtype = out.get_dtype()
+ if not can_cast_type(space, dt_out, res_dtype, 'unsafe'):
+ continue
+ return dtype, dt_out
+
else:
raise oefmt(space.w_TypeError,
"No loop matching the specified signature was found "
More information about the pypy-commit
mailing list