[pypy-commit] pypy fix-result-types: support 'casting' argument in unary ufuncs
rlamy
noreply at buildbot.pypy.org
Wed May 13 21:33:49 CEST 2015
Author: Ronan Lamy <ronan.lamy at gmail.com>
Branch: fix-result-types
Changeset: r77314:e7146ca785d0
Date: 2015-05-13 20:33 +0100
http://bitbucket.org/pypy/pypy/changeset/e7146ca785d0/
Log: support 'casting' argument in unary ufuncs
diff --git a/pypy/module/micronumpy/test/test_ufuncs.py b/pypy/module/micronumpy/test/test_ufuncs.py
--- a/pypy/module/micronumpy/test/test_ufuncs.py
+++ b/pypy/module/micronumpy/test/test_ufuncs.py
@@ -64,6 +64,8 @@
ufunc = W_Ufunc1(None, 'x', promote_to_float=True)
assert ufunc._calc_dtype(space, dt_bool, out=None) == (dt_float16, dt_float16)
+ assert ufunc._calc_dtype(space, dt_bool, casting='same_kind') == (dt_float16, dt_float16)
+ raises(OperationError, ufunc._calc_dtype, space, dt_bool, casting='no')
ufunc = W_Ufunc1(None, 'x')
assert ufunc._calc_dtype(space, dt_int32, out=None) == (dt_int32, dt_int32)
@@ -261,6 +263,14 @@
raises(TypeError, adder_ufunc, *args, extobj=True)
raises(RuntimeError, adder_ufunc, *args, sig='(d,d)->(d)', dtype=int)
+ def test_unary_ufunc_kwargs(self):
+ from numpy import array, sin, float16
+ bool_array = array([True])
+ raises(TypeError, sin, bool_array, casting='no')
+ assert sin(bool_array, casting='same_kind').dtype == float16
+ raises(TypeError, sin, bool_array, out=bool_array, casting='same_kind')
+ assert sin(bool_array).dtype == float16
+
def test_ufunc_attrs(self):
from numpy import add, multiply, sin
diff --git a/pypy/module/micronumpy/ufuncs.py b/pypy/module/micronumpy/ufuncs.py
--- a/pypy/module/micronumpy/ufuncs.py
+++ b/pypy/module/micronumpy/ufuncs.py
@@ -17,7 +17,7 @@
from pypy.module.micronumpy.ctors import numpify
from pypy.module.micronumpy.nditer import W_NDIter, coalesce_iter
from pypy.module.micronumpy.strides import shape_agreement
-from pypy.module.micronumpy.support import (_parse_signature, product,
+from pypy.module.micronumpy.support import (_parse_signature, product,
get_storage_as_int, is_rhs_priority_higher)
from .casting import (
find_unaryop_result_dtype, find_binop_result_dtype, can_cast_type)
@@ -35,11 +35,11 @@
If an output argument is provided, then it is wrapped
with its own __array_wrap__ not with the one determined by
the input arguments.
-
+
if the provided output argument is already an array,
the wrapping function is None (which means no wrapping will
be done --- not even PyArray_Return).
-
+
A NULL is placed in output_wrap for outputs that
should just have PyArray_Return called.
'''
@@ -78,7 +78,7 @@
def descr_call(self, space, __args__):
args_w, kwds_w = __args__.unpack()
# sig, extobj are used in generic ufuncs
- w_subok, w_out, sig, casting, extobj = self.parse_kwargs(space, kwds_w)
+ w_subok, w_out, sig, w_casting, extobj = self.parse_kwargs(space, kwds_w)
if space.is_w(w_out, space.w_None):
out = None
else:
@@ -107,6 +107,10 @@
if out is not None and not isinstance(out, W_NDimArray):
raise OperationError(space.w_TypeError, space.wrap(
'output must be an array'))
+ if w_casting is None:
+ casting = 'unsafe'
+ else:
+ casting = space.str_w(w_casting)
retval = self.call(space, args_w, sig, casting, extobj)
keepalive_until_here(args_w)
return retval
@@ -329,8 +333,7 @@
"outer product only supported for binary functions"))
def parse_kwargs(self, space, kwds_w):
- # we don't support casting, change it when we do
- casting = kwds_w.pop('casting', None)
+ w_casting = kwds_w.pop('casting', None)
w_subok = kwds_w.pop('subok', None)
w_out = kwds_w.pop('out', space.w_None)
sig = None
@@ -339,7 +342,7 @@
extobj_w = kwds_w.pop('extobj', get_extobj(space))
if not space.isinstance_w(extobj_w, space.w_list) or space.len_w(extobj_w) != 3:
raise oefmt(space.w_TypeError, "'extobj' must be a list of 3 values")
- return w_subok, w_out, sig, casting, extobj_w
+ return w_subok, w_out, sig, w_casting, extobj_w
def get_extobj(space):
extobj_w = space.newlist([space.wrap(8192), space.wrap(0), space.w_None])
@@ -371,6 +374,12 @@
return False
return space.getattr(w_obj, space.wrap('__' + refops[op] + '__')) is not None
+def safe_casting_mode(casting):
+ if casting in ('unsafe', 'same_kind'):
+ return 'safe'
+ else:
+ return casting
+
class W_Ufunc1(W_Ufunc):
_immutable_fields_ = ["func", "bool_result"]
nin = 1
@@ -397,7 +406,7 @@
raise oefmt(space.w_TypeError, 'output must be an array')
w_obj = numpify(space, w_obj)
dtype = w_obj.get_dtype(space)
- calc_dtype, res_dtype, func = self.find_specialization(space, dtype, out)
+ calc_dtype, res_dtype, func = self.find_specialization(space, dtype, out, casting)
if w_obj.is_scalar():
return self.call_scalar(space, w_obj.get_scalar_value(),
calc_dtype, res_dtype, out)
@@ -420,7 +429,7 @@
out.fill(space, w_val)
return out
- def find_specialization(self, space, dtype, out):
+ def find_specialization(self, space, dtype, out, casting):
if dtype.is_flexible():
raise oefmt(space.w_TypeError, 'Not implemented for this type')
if (self.int_only and not (dtype.is_int() or dtype.is_object()) or
@@ -428,7 +437,7 @@
not self.allow_complex and dtype.is_complex()):
raise oefmt(space.w_TypeError,
"ufunc %s not supported for the input type", self.name)
- dt_in, dt_out = self._calc_dtype(space, dtype, out)
+ dt_in, dt_out = self._calc_dtype(space, dtype, out, casting)
if out is not None:
res_dtype = out.get_dtype()
@@ -446,21 +455,22 @@
res_dtype = get_dtype_cache(space).w_float64dtype
return dt_in, res_dtype, self.func
- def _calc_dtype(self, space, arg_dtype, out):
+ def _calc_dtype(self, space, arg_dtype, out=None, casting='unsafe'):
use_min_scalar = False
if arg_dtype.is_object():
return arg_dtype, arg_dtype
+ in_casting = safe_casting_mode(casting)
for dtype in self.allowed_types(space):
if use_min_scalar:
- if not can_cast_array(space, w_arg, dtype, casting='safe'):
+ if not can_cast_array(space, w_arg, dtype, in_casting):
continue
else:
- if not can_cast_type(space, arg_dtype, dtype, casting='safe'):
+ if not can_cast_type(space, arg_dtype, dtype, in_casting):
continue
dt_out = dtype
if out is not None:
res_dtype = out.get_dtype()
- if not can_cast_type(space, dt_out, res_dtype, 'unsafe'):
+ if not can_cast_type(space, dt_out, res_dtype, casting):
continue
return dtype, dt_out
@@ -810,7 +820,7 @@
return outargs[0]
def parse_kwargs(self, space, kwargs_w):
- w_subok, w_out, casting, sig, extobj = \
+ w_subok, w_out, sig, w_casting, extobj = \
W_Ufunc.parse_kwargs(self, space, kwargs_w)
# do equivalent of get_ufunc_arguments in numpy's ufunc_object.c
dtype_w = kwargs_w.pop('dtype', None)
@@ -837,7 +847,7 @@
parsed_kw.append(kw)
for kw in parsed_kw:
kwargs_w.pop(kw)
- return w_subok, w_out, sig, casting, extobj
+ return w_subok, w_out, sig, w_casting, extobj
def type_resolver(self, space, inargs, outargs, type_tup, _dtypes):
# Find a match for the inargs.dtype in _dtypes, like
More information about the pypy-commit
mailing list