[pypy-commit] pypy fix-result-types: support 'casting' argument in unary ufuncs

rlamy noreply at buildbot.pypy.org
Wed May 13 21:33:49 CEST 2015


Author: Ronan Lamy <ronan.lamy at gmail.com>
Branch: fix-result-types
Changeset: r77314:e7146ca785d0
Date: 2015-05-13 20:33 +0100
http://bitbucket.org/pypy/pypy/changeset/e7146ca785d0/

Log:	support 'casting' argument in unary ufuncs

diff --git a/pypy/module/micronumpy/test/test_ufuncs.py b/pypy/module/micronumpy/test/test_ufuncs.py
--- a/pypy/module/micronumpy/test/test_ufuncs.py
+++ b/pypy/module/micronumpy/test/test_ufuncs.py
@@ -64,6 +64,8 @@
 
         ufunc = W_Ufunc1(None, 'x', promote_to_float=True)
         assert ufunc._calc_dtype(space, dt_bool, out=None) == (dt_float16, dt_float16)
+        assert ufunc._calc_dtype(space, dt_bool, casting='same_kind') == (dt_float16, dt_float16)
+        raises(OperationError, ufunc._calc_dtype, space, dt_bool, casting='no')
 
         ufunc = W_Ufunc1(None, 'x')
         assert ufunc._calc_dtype(space, dt_int32, out=None) == (dt_int32, dt_int32)
@@ -261,6 +263,14 @@
         raises(TypeError, adder_ufunc, *args, extobj=True)
         raises(RuntimeError, adder_ufunc, *args, sig='(d,d)->(d)', dtype=int)
 
+    def test_unary_ufunc_kwargs(self):
+        from numpy import array, sin, float16
+        bool_array = array([True])
+        raises(TypeError, sin, bool_array, casting='no')
+        assert sin(bool_array, casting='same_kind').dtype == float16
+        raises(TypeError, sin, bool_array, out=bool_array, casting='same_kind')
+        assert sin(bool_array).dtype == float16
+
     def test_ufunc_attrs(self):
         from numpy import add, multiply, sin
 
diff --git a/pypy/module/micronumpy/ufuncs.py b/pypy/module/micronumpy/ufuncs.py
--- a/pypy/module/micronumpy/ufuncs.py
+++ b/pypy/module/micronumpy/ufuncs.py
@@ -17,7 +17,7 @@
 from pypy.module.micronumpy.ctors import numpify
 from pypy.module.micronumpy.nditer import W_NDIter, coalesce_iter
 from pypy.module.micronumpy.strides import shape_agreement
-from pypy.module.micronumpy.support import (_parse_signature, product, 
+from pypy.module.micronumpy.support import (_parse_signature, product,
         get_storage_as_int, is_rhs_priority_higher)
 from .casting import (
     find_unaryop_result_dtype, find_binop_result_dtype, can_cast_type)
@@ -35,11 +35,11 @@
       If an output argument is provided, then it is wrapped
       with its own __array_wrap__ not with the one determined by
       the input arguments.
-     
+
       if the provided output argument is already an array,
       the wrapping function is None (which means no wrapping will
       be done --- not even PyArray_Return).
-     
+
       A NULL is placed in output_wrap for outputs that
       should just have PyArray_Return called.
     '''
@@ -78,7 +78,7 @@
     def descr_call(self, space, __args__):
         args_w, kwds_w = __args__.unpack()
         # sig, extobj are used in generic ufuncs
-        w_subok, w_out, sig, casting, extobj = self.parse_kwargs(space, kwds_w)
+        w_subok, w_out, sig, w_casting, extobj = self.parse_kwargs(space, kwds_w)
         if space.is_w(w_out, space.w_None):
             out = None
         else:
@@ -107,6 +107,10 @@
         if out is not None and not isinstance(out, W_NDimArray):
             raise OperationError(space.w_TypeError, space.wrap(
                                             'output must be an array'))
+        if w_casting is None:
+            casting = 'unsafe'
+        else:
+            casting = space.str_w(w_casting)
         retval = self.call(space, args_w, sig, casting, extobj)
         keepalive_until_here(args_w)
         return retval
@@ -329,8 +333,7 @@
             "outer product only supported for binary functions"))
 
     def parse_kwargs(self, space, kwds_w):
-        # we don't support casting, change it when we do
-        casting = kwds_w.pop('casting', None)
+        w_casting = kwds_w.pop('casting', None)
         w_subok = kwds_w.pop('subok', None)
         w_out = kwds_w.pop('out', space.w_None)
         sig = None
@@ -339,7 +342,7 @@
         extobj_w = kwds_w.pop('extobj', get_extobj(space))
         if not space.isinstance_w(extobj_w, space.w_list) or space.len_w(extobj_w) != 3:
             raise oefmt(space.w_TypeError, "'extobj' must be a list of 3 values")
-        return w_subok, w_out, sig, casting, extobj_w
+        return w_subok, w_out, sig, w_casting, extobj_w
 
 def get_extobj(space):
         extobj_w = space.newlist([space.wrap(8192), space.wrap(0), space.w_None])
@@ -371,6 +374,12 @@
         return False
     return space.getattr(w_obj, space.wrap('__' + refops[op] + '__')) is not None
 
+def safe_casting_mode(casting):
+    if casting in ('unsafe', 'same_kind'):
+        return 'safe'
+    else:
+        return casting
+
 class W_Ufunc1(W_Ufunc):
     _immutable_fields_ = ["func", "bool_result"]
     nin = 1
@@ -397,7 +406,7 @@
                 raise oefmt(space.w_TypeError, 'output must be an array')
         w_obj = numpify(space, w_obj)
         dtype = w_obj.get_dtype(space)
-        calc_dtype, res_dtype, func = self.find_specialization(space, dtype, out)
+        calc_dtype, res_dtype, func = self.find_specialization(space, dtype, out, casting)
         if w_obj.is_scalar():
             return self.call_scalar(space, w_obj.get_scalar_value(),
                                     calc_dtype, res_dtype, out)
@@ -420,7 +429,7 @@
             out.fill(space, w_val)
         return out
 
-    def find_specialization(self, space, dtype, out):
+    def find_specialization(self, space, dtype, out, casting):
         if dtype.is_flexible():
             raise oefmt(space.w_TypeError, 'Not implemented for this type')
         if (self.int_only and not (dtype.is_int() or dtype.is_object()) or
@@ -428,7 +437,7 @@
                 not self.allow_complex and dtype.is_complex()):
             raise oefmt(space.w_TypeError,
                 "ufunc %s not supported for the input type", self.name)
-        dt_in, dt_out = self._calc_dtype(space, dtype, out)
+        dt_in, dt_out = self._calc_dtype(space, dtype, out, casting)
 
         if out is not None:
             res_dtype = out.get_dtype()
@@ -446,21 +455,22 @@
                     res_dtype = get_dtype_cache(space).w_float64dtype
         return dt_in, res_dtype, self.func
 
-    def _calc_dtype(self, space, arg_dtype, out):
+    def _calc_dtype(self, space, arg_dtype, out=None, casting='unsafe'):
         use_min_scalar = False
         if arg_dtype.is_object():
             return arg_dtype, arg_dtype
+        in_casting = safe_casting_mode(casting)
         for dtype in self.allowed_types(space):
             if use_min_scalar:
-                if not can_cast_array(space, w_arg, dtype, casting='safe'):
+                if not can_cast_array(space, w_arg, dtype, in_casting):
                     continue
             else:
-                if not can_cast_type(space, arg_dtype, dtype, casting='safe'):
+                if not can_cast_type(space, arg_dtype, dtype, in_casting):
                     continue
             dt_out = dtype
             if out is not None:
                 res_dtype = out.get_dtype()
-                if not can_cast_type(space, dt_out, res_dtype, 'unsafe'):
+                if not can_cast_type(space, dt_out, res_dtype, casting):
                     continue
             return dtype, dt_out
 
@@ -810,7 +820,7 @@
         return outargs[0]
 
     def parse_kwargs(self, space, kwargs_w):
-        w_subok, w_out, casting, sig, extobj = \
+        w_subok, w_out, sig, w_casting, extobj = \
                     W_Ufunc.parse_kwargs(self, space, kwargs_w)
         # do equivalent of get_ufunc_arguments in numpy's ufunc_object.c
         dtype_w = kwargs_w.pop('dtype', None)
@@ -837,7 +847,7 @@
                 parsed_kw.append(kw)
         for kw in parsed_kw:
             kwargs_w.pop(kw)
-        return w_subok, w_out, sig, casting, extobj
+        return w_subok, w_out, sig, w_casting, extobj
 
     def type_resolver(self, space, inargs, outargs, type_tup, _dtypes):
         # Find a match for the inargs.dtype in _dtypes, like


More information about the pypy-commit mailing list