[pypy-commit] pypy numpy-dtype-alt: fix for type coersion on unary ufuncs

alex_gaynor noreply at buildbot.pypy.org
Sat Aug 20 23:15:04 CEST 2011


Author: Alex Gaynor <alex.gaynor at gmail.com>
Branch: numpy-dtype-alt
Changeset: r46674:612e65b4b4cb
Date: 2011-08-20 16:19 -0500
http://bitbucket.org/pypy/pypy/changeset/612e65b4b4cb/

Log:	fix for type coersion on unary ufuncs

diff --git a/TODO.txt b/TODO.txt
--- a/TODO.txt
+++ b/TODO.txt
@@ -1,9 +1,9 @@
 TODO for mering numpy-dtype-alt
 ===============================
 
-* Fix sin(<bool array>)
 * Fix for raw memory with the JIT under llgraph.
 * More operations on more dtypes
+* dtype guessing
 * Any more attributes that need to be exposed at app-level
 
 For later
diff --git a/pypy/module/micronumpy/interp_dtype.py b/pypy/module/micronumpy/interp_dtype.py
--- a/pypy/module/micronumpy/interp_dtype.py
+++ b/pypy/module/micronumpy/interp_dtype.py
@@ -69,7 +69,6 @@
         def unerase(self, storage):
             return rffi.cast(TP, storage)
 
-        @specialize.argtype(1)
         @enforceargs(None, valtype)
         def box(self, value):
             return Box(value)
@@ -286,21 +285,10 @@
     def str_format(self, item):
         return float2string(self.unbox(item), 'g', rfloat.DTSF_STR_PRECISION)
 
-W_Float16Dtype = create_low_level_dtype(
-    num = 23, kind = FLOATINGLTR, name = "float16",
-    aliases = [],
-    applevel_types =[],
-    T = rffi.USHORT,
-    valtype = rffi.USHORT._type,
-)
-class W_Float16Dtype(W_Float16Dtype):
-    def unwrap(self, space, w_item):
-        return self.adapt_val(space.float_w(space.float(w_item)))
-
 ALL_DTYPES = [
     W_BoolDtype,
     W_Int8Dtype, W_Int32Dtype, W_Int64Dtype,
-    W_Float64Dtype, W_Float16Dtype
+    W_Float64Dtype
 ]
 
 dtypes_by_alias = unrolling_iterable([
diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -310,7 +310,7 @@
 
 @specialize.arg(1)
 def scalar_w(space, dtype, w_obj):
-    return Scalar(scalar(space, dtype, w_obj))
+    return Scalar(space.fromcache(dtype), scalar(space, dtype, w_obj))
 
 @specialize.arg(1)
 def scalar(space, dtype, w_obj):
@@ -323,15 +323,16 @@
     """
     signature = signature.BaseSignature()
 
-    def __init__(self, value):
+    def __init__(self, dtype, value):
         BaseArray.__init__(self)
+        self.dtype = dtype
         self.value = value
 
     def find_size(self):
         raise ValueError
 
     def find_dtype(self):
-        raise ValueError
+        return self.dtype
 
     def eval(self, i):
         return self.value
@@ -340,10 +341,11 @@
     """
     Class for representing virtual arrays, such as binary ops or ufuncs
     """
-    def __init__(self, signature):
+    def __init__(self, signature, res_dtype):
         BaseArray.__init__(self)
         self.forced_result = None
         self.signature = signature
+        self.res_dtype = res_dtype
 
     def _del_sources(self):
         # Function for deleting references to source arrays, to allow garbage-collecting them
@@ -383,14 +385,12 @@
         return self._find_size()
 
     def find_dtype(self):
-        if self.forced_result is not None:
-            return self.forced_result.find_dtype()
-        return self._find_dtype()
+        return self.res_dtype
 
 
 class Call1(VirtualArray):
-    def __init__(self, signature, values):
-        VirtualArray.__init__(self, signature)
+    def __init__(self, signature, res_dtype, values):
+        VirtualArray.__init__(self, signature, res_dtype)
         self.values = values
 
     def _del_sources(self):
@@ -400,42 +400,23 @@
         return self.values.find_size()
 
     def _find_dtype(self):
-        return self.values.find_dtype()
+        return self.res_dtype
 
     def _eval(self, i):
         call_sig = self.signature.components[0]
         assert isinstance(call_sig, signature.Call1)
-        return call_sig.func(self.find_dtype(), self.values.eval(i))
+        val = self.values.eval(i).convert_to(self.res_dtype)
+        return call_sig.func(self.res_dtype, val)
 
 class Call2(VirtualArray):
     """
     Intermediate class for performing binary operations.
     """
-    def __init__(self, space, signature, left, right):
-        VirtualArray.__init__(self, signature)
+    def __init__(self, signature, res_dtype, left, right):
+        VirtualArray.__init__(self, signature, res_dtype)
         self.left = left
         self.right = right
 
-        lhs_dtype = rhs_dtype = None
-        try:
-            lhs_dtype = self.left.find_dtype()
-        except ValueError:
-            pass
-        try:
-            rhs_dtype = self.right.find_dtype()
-        except ValueError:
-            pass
-        if lhs_dtype is not None and rhs_dtype is not None:
-            self.res_dtype = interp_ufuncs.find_binop_result_dtype(space,
-                lhs_dtype, rhs_dtype
-            )
-        elif lhs_dtype is not None:
-            self.res_dtype = lhs_dtype
-        elif rhs_dtype is not None:
-            self.res_dtype = rhs_dtype
-        else:
-            self.res_dtype = None
-
     def _del_sources(self):
         self.left = None
         self.right = None
@@ -448,17 +429,11 @@
         return self.right.find_size()
 
     def _eval(self, i):
-        dtype = self.find_dtype()
-        lhs, rhs = self.left.eval(i), self.right.eval(i)
-        lhs, rhs = lhs.convert_to(dtype), rhs.convert_to(dtype)
+        lhs = self.left.eval(i).convert_to(self.res_dtype)
+        rhs = self.right.eval(i).convert_to(self.res_dtype)
         call_sig = self.signature.components[0]
         assert isinstance(call_sig, signature.Call2)
-        return call_sig.func(dtype, lhs, rhs)
-
-    def _find_dtype(self):
-        if self.res_dtype is not None:
-            return self.res_dtype
-        raise ValueError
+        return call_sig.func(self.res_dtype, lhs, rhs)
 
 class ViewArray(BaseArray):
     """
diff --git a/pypy/module/micronumpy/interp_ufuncs.py b/pypy/module/micronumpy/interp_ufuncs.py
--- a/pypy/module/micronumpy/interp_ufuncs.py
+++ b/pypy/module/micronumpy/interp_ufuncs.py
@@ -19,7 +19,11 @@
             return func(res_dtype, w_obj.value).wrap(space)
 
         new_sig = signature.Signature.find_sig([call_sig, w_obj.signature])
-        w_res = Call1(new_sig, w_obj)
+        res_dtype = find_unaryop_result_dtype(space,
+            w_obj.find_dtype(),
+            promote_to_float=promote_to_float,
+        )
+        w_res = Call1(new_sig, res_dtype, w_obj)
         w_obj.add_invalidates(w_res)
         return w_res
     return func_with_new_name(impl, "%s_dispatcher" % func.__name__)
@@ -40,7 +44,11 @@
             return func(res_dtype, w_lhs.value, w_rhs.value).wrap(space)
 
         new_sig = signature.Signature.find_sig([call_sig, w_lhs.signature, w_rhs.signature])
-        w_res = Call2(space, new_sig, w_lhs, w_rhs)
+        res_dtype = find_binop_result_dtype(space,
+            w_lhs.find_dtype(), w_rhs.find_dtype(),
+            promote_to_float=promote_to_float,
+        )
+        w_res = Call2(new_sig, res_dtype, w_lhs, w_rhs)
         w_lhs.add_invalidates(w_res)
         w_rhs.add_invalidates(w_res)
         return w_res
diff --git a/pypy/module/micronumpy/test/test_base.py b/pypy/module/micronumpy/test/test_base.py
--- a/pypy/module/micronumpy/test/test_base.py
+++ b/pypy/module/micronumpy/test/test_base.py
@@ -11,11 +11,13 @@
 
 class TestSignature(object):
     def test_binop_signature(self, space):
-        ar = SingleDimArray(10, dtype=space.fromcache(interp_dtype.W_Float64Dtype))
+        float64_dtype = space.fromcache(interp_dtype.W_Float64Dtype)
+
+        ar = SingleDimArray(10, dtype=float64_dtype)
         v1 = ar.descr_add(space, ar)
-        v2 = ar.descr_add(space, Scalar(2.0))
+        v2 = ar.descr_add(space, Scalar(float64_dtype, 2.0))
         assert v1.signature is not v2.signature
-        v3 = ar.descr_add(space, Scalar(1.0))
+        v3 = ar.descr_add(space, Scalar(float64_dtype, 1.0))
         assert v2.signature is v3.signature
         v4 = ar.descr_add(space, ar)
         assert v1.signature is v4.signature
@@ -63,7 +65,6 @@
         bool_dtype = space.fromcache(interp_dtype.W_BoolDtype)
         int8_dtype = space.fromcache(interp_dtype.W_Int8Dtype)
         int32_dtype = space.fromcache(interp_dtype.W_Int32Dtype)
-        float16_dtype = space.fromcache(interp_dtype.W_Float16Dtype)
         float64_dtype = space.fromcache(interp_dtype.W_Float64Dtype)
 
         # Normal rules, everythign returns itself
@@ -72,8 +73,9 @@
         assert find_unaryop_result_dtype(space, int32_dtype) is int32_dtype
         assert find_unaryop_result_dtype(space, float64_dtype) is float64_dtype
 
-        # Coerce to floats
-        assert find_unaryop_result_dtype(space, bool_dtype, promote_to_float=True) is float16_dtype
-        assert find_unaryop_result_dtype(space, int8_dtype, promote_to_float=True) is float16_dtype
+        # Coerce to floats, some of these will eventually be float16, or
+        # whatever our smallest float type is.
+        assert find_unaryop_result_dtype(space, bool_dtype, promote_to_float=True) is float64_dtype
+        assert find_unaryop_result_dtype(space, int8_dtype, promote_to_float=True) is float64_dtype
         assert find_unaryop_result_dtype(space, int32_dtype, promote_to_float=True) is float64_dtype
         assert find_unaryop_result_dtype(space, float64_dtype, promote_to_float=True) is float64_dtype
\ No newline at end of file
diff --git a/pypy/module/micronumpy/test/test_ufuncs.py b/pypy/module/micronumpy/test/test_ufuncs.py
--- a/pypy/module/micronumpy/test/test_ufuncs.py
+++ b/pypy/module/micronumpy/test/test_ufuncs.py
@@ -193,6 +193,10 @@
         for i in range(len(a)):
             assert b[i] == math.sin(a[i])
 
+        a = sin(array([True, False], dtype=bool))
+        assert a[0] == sin(1)
+        assert a[1] == 0.0
+
     def test_cos(self):
         import math
         from numpy import array, cos


More information about the pypy-commit mailing list