[pypy-commit] pypy fix-result-types: Use the same logic as cnumpy in W_Ufunc1.find_specialization()
rlamy
noreply at buildbot.pypy.org
Tue May 12 07:16:47 CEST 2015
Author: Ronan Lamy <ronan.lamy at gmail.com>
Branch: fix-result-types
Changeset: r77296:2fc8c1b68f07
Date: 2015-05-12 06:16 +0100
http://bitbucket.org/pypy/pypy/changeset/2fc8c1b68f07/
Log: Use the same logic as cnumpy in W_Ufunc1.find_specialization()
diff --git a/pypy/module/micronumpy/descriptor.py b/pypy/module/micronumpy/descriptor.py
--- a/pypy/module/micronumpy/descriptor.py
+++ b/pypy/module/micronumpy/descriptor.py
@@ -900,17 +900,20 @@
NPY.CDOUBLE: self.w_float64dtype,
NPY.CLONGDOUBLE: self.w_floatlongdtype,
}
- self.builtin_dtypes = [
- self.w_booldtype,
+ integer_dtypes = [
self.w_int8dtype, self.w_uint8dtype,
self.w_int16dtype, self.w_uint16dtype,
+ self.w_int32dtype, self.w_uint32dtype,
self.w_longdtype, self.w_ulongdtype,
- self.w_int32dtype, self.w_uint32dtype,
- self.w_int64dtype, self.w_uint64dtype,
- ] + float_dtypes + complex_dtypes + [
- self.w_stringdtype, self.w_unicodedtype, self.w_voiddtype,
- self.w_objectdtype,
- ]
+ self.w_int64dtype, self.w_uint64dtype]
+ self.builtin_dtypes = ([self.w_booldtype] + integer_dtypes +
+ float_dtypes + complex_dtypes + [
+ self.w_stringdtype, self.w_unicodedtype, self.w_voiddtype,
+ self.w_objectdtype,
+ ])
+ self.integer_dtypes = integer_dtypes
+ self.float_dtypes = float_dtypes
+ self.complex_dtypes = complex_dtypes
self.float_dtypes_by_num_bytes = sorted(
(dtype.elsize, dtype)
for dtype in float_dtypes
diff --git a/pypy/module/micronumpy/test/test_ufuncs.py b/pypy/module/micronumpy/test/test_ufuncs.py
--- a/pypy/module/micronumpy/test/test_ufuncs.py
+++ b/pypy/module/micronumpy/test/test_ufuncs.py
@@ -1,5 +1,5 @@
from pypy.module.micronumpy.test.test_base import BaseNumpyAppTest
-from pypy.module.micronumpy.ufuncs import W_UfuncGeneric
+from pypy.module.micronumpy.ufuncs import W_UfuncGeneric, W_Ufunc1
from pypy.module.micronumpy.support import _parse_signature
from pypy.module.micronumpy.descriptor import get_dtype_cache
from pypy.module.micronumpy.base import W_NDimArray
@@ -54,6 +54,20 @@
exc = raises(OperationError, ufunc.type_resolver, space, [f32_array], [None],
'i->i', ufunc.dtypes)
+ def test_allowed_types(self, space):
+ dt_bool = get_dtype_cache(space).w_booldtype
+ dt_float16 = get_dtype_cache(space).w_float16dtype
+ dt_int32 = get_dtype_cache(space).w_int32dtype
+ ufunc = W_Ufunc1(None, 'x', int_only=True)
+ assert ufunc._calc_dtype(space, dt_bool) == dt_bool
+ assert ufunc.allowed_types(space) # XXX: shouldn't contain too much stuff
+
+ ufunc = W_Ufunc1(None, 'x', promote_to_float=True)
+ assert ufunc._calc_dtype(space, dt_bool) == dt_float16
+
+ ufunc = W_Ufunc1(None, 'x')
+ assert ufunc._calc_dtype(space, dt_int32) == dt_int32
+
class AppTestUfuncs(BaseNumpyAppTest):
def test_constants(self):
import numpy as np
diff --git a/pypy/module/micronumpy/ufuncs.py b/pypy/module/micronumpy/ufuncs.py
--- a/pypy/module/micronumpy/ufuncs.py
+++ b/pypy/module/micronumpy/ufuncs.py
@@ -18,7 +18,8 @@
from pypy.module.micronumpy.nditer import W_NDIter, coalesce_iter
from pypy.module.micronumpy.strides import shape_agreement
from pypy.module.micronumpy.support import _parse_signature, product, get_storage_as_int
-from .casting import find_unaryop_result_dtype, find_binop_result_dtype
+from .casting import (
+ find_unaryop_result_dtype, find_binop_result_dtype, can_cast_type)
def done_if_true(dtype, val):
return dtype.itemtype.bool(val)
@@ -384,12 +385,36 @@
not self.allow_complex and dtype.is_complex()):
raise oefmt(space.w_TypeError,
"ufunc %s not supported for the input type", self.name)
- calc_dtype = find_unaryop_result_dtype(space,
- dtype,
- promote_to_float=self.promote_to_float,
- promote_bools=self.promote_bools)
+ calc_dtype = self._calc_dtype(space, dtype)
return calc_dtype, self.func
+ def _calc_dtype(self, space, arg_dtype):
+ use_min_scalar=False
+ if arg_dtype.is_object():
+ return arg_dtype
+ for dtype in self.allowed_types(space):
+ if use_min_scalar:
+ if can_cast_array(space, w_arg, dtype, casting='safe'):
+ return dtype
+ else:
+ if can_cast_type(space, arg_dtype, dtype, casting='safe'):
+ return dtype
+ else:
+ raise oefmt(space.w_TypeError,
+ "No loop matching the specified signature was found "
+ "for ufunc %s", self.name)
+
+ def allowed_types(self, space):
+ dtypes = []
+ cache = get_dtype_cache(space)
+ if not self.promote_bools and not self.promote_to_float:
+ dtypes.append(cache.w_booldtype)
+ if not self.promote_to_float:
+ dtypes.extend(cache.integer_dtypes)
+ dtypes.extend(cache.float_dtypes)
+ dtypes.extend(cache.complex_dtypes)
+ return dtypes
+
class W_Ufunc2(W_Ufunc):
_immutable_fields_ = ["func", "comparison_func", "done_func"]
More information about the pypy-commit
mailing list