[pypy-commit] pypy numpy-dtype-alt: Added dtype guessing, this also fixes the return type on things like numpy.maximum(1, 2), which used to be buggy and return a float.

alex_gaynor noreply at buildbot.pypy.org
Tue Aug 23 14:26:14 CEST 2011


Author: Alex Gaynor <alex.gaynor at gmail.com>
Branch: numpy-dtype-alt
Changeset: r46729:3a0ecdbc1565
Date: 2011-08-23 07:31 -0500
http://bitbucket.org/pypy/pypy/changeset/3a0ecdbc1565/

Log:	Added dtype guessing, this also fixes the return type on things like
	numpy.maximum(1, 2), which used to be buggy and return a float.

diff --git a/TODO.txt b/TODO.txt
--- a/TODO.txt
+++ b/TODO.txt
@@ -1,5 +1,4 @@
 TODO for mering numpy-dtype-alt
 ===============================
 
-* More operations on more dtypes
-* dtype guessing
+* More operations on more dtypes (including copy-paste reduction)
\ No newline at end of file
diff --git a/pypy/module/micronumpy/interp_dtype.py b/pypy/module/micronumpy/interp_dtype.py
--- a/pypy/module/micronumpy/interp_dtype.py
+++ b/pypy/module/micronumpy/interp_dtype.py
@@ -218,9 +218,29 @@
 class IntegerArithmeticDtype(object):
     _mixin_ = True
 
+    # XXX: reduce the copy paste
     @binop
     def add(self, v1, v2):
         return widen(v1) + widen(v2)
+    @binop
+    def sub(self, v1, v2):
+        return widen(v1) - widen(v2)
+    @binop
+    def mul(self, v1, v2):
+        return widen(v1) * widen(v2)
+    @binop
+    def div(self, v1, v2):
+        return widen(v1) / widen(v2)
+    @binop
+    def mod(self, v1, v2):
+        return widen(v1) % widen(v2)
+
+    @binop
+    def max(self, v1, v2):
+        return max(widen(v1), widen(v2))
+
+    def bool(self, v):
+        return bool(widen(self.unbox(v)))
 
     def str_format(self, item):
         return str(widen(self.unbox(item)))
diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -37,10 +37,19 @@
         self.invalidates.append(other)
 
     def descr__new__(space, w_subtype, w_size_or_iterable, w_dtype=None):
+        l = space.listview(w_size_or_iterable)
+        if space.is_w(w_dtype, space.w_None):
+            w_dtype = None
+            for w_item in l:
+                w_dtype = interp_ufuncs.find_dtype_for_scalar(space, w_item, w_dtype)
+                if w_dtype is space.fromcache(interp_dtype.W_Float64Dtype):
+                    break
+            if w_dtype is None:
+                w_dtype = space.w_None
+
         dtype = space.interp_w(interp_dtype.W_Dtype,
             space.call_function(space.gettypefor(interp_dtype.W_Dtype), w_dtype)
         )
-        l = space.listview(w_size_or_iterable)
         arr = SingleDimArray(len(l), dtype=dtype)
         i = 0
         for w_elem in l:
@@ -71,7 +80,10 @@
 
     def _binop_right_impl(w_ufunc):
         def impl(self, space, w_other):
-            w_other = scalar_w(space, interp_dtype.W_Float64Dtype, w_other)
+            w_other = scalar_w(space,
+                interp_ufuncs.find_dtype_for_scalar(space, w_other, self.find_dtype()),
+                w_other
+            )
             return w_ufunc(space, w_other, self)
         return func_with_new_name(impl, "binop_right_%s_impl" % w_ufunc.__name__)
 
@@ -295,11 +307,11 @@
             slice_driver.jit_merge_point(signature=source.signature, step=step,
                                          stop=stop, i=i, j=j, source=source,
                                          dest=dest)
-            dest.setitem(i, source.eval(j))
+            dest.setitem(i, source.eval(j).convert_to(dest.find_dtype()))
             j += 1
             i += step
 
-def convert_to_array (space, w_obj):
+def convert_to_array(space, w_obj):
     if isinstance(w_obj, BaseArray):
         return w_obj
     elif space.issequence_w(w_obj):
@@ -309,16 +321,11 @@
         return w_obj
     else:
         # If it's a scalar
-        return scalar_w(space, interp_dtype.W_Float64Dtype, w_obj)
+        dtype = interp_ufuncs.find_dtype_for_scalar(space, w_obj)
+        return scalar_w(space, dtype, w_obj)
 
- at specialize.arg(1)
 def scalar_w(space, dtype, w_obj):
-    return Scalar(space.fromcache(dtype), scalar(space, dtype, w_obj))
-
- at specialize.arg(1)
-def scalar(space, dtype, w_obj):
-    dtype = space.fromcache(dtype)
-    return dtype.unwrap(space, w_obj)
+    return Scalar(dtype, dtype.unwrap(space, w_obj))
 
 class Scalar(BaseArray):
     """
diff --git a/pypy/module/micronumpy/interp_ufuncs.py b/pypy/module/micronumpy/interp_ufuncs.py
--- a/pypy/module/micronumpy/interp_ufuncs.py
+++ b/pypy/module/micronumpy/interp_ufuncs.py
@@ -13,15 +13,14 @@
             convert_to_array, Scalar)
 
         w_obj = convert_to_array(space, w_obj)
-        if isinstance(w_obj, Scalar):
-            res_dtype = space.fromcache(interp_dtype.W_Float64Dtype)
-            return func(res_dtype, w_obj.value).wrap(space)
-
-        new_sig = signature.Signature.find_sig([call_sig, w_obj.signature])
         res_dtype = find_unaryop_result_dtype(space,
             w_obj.find_dtype(),
             promote_to_float=promote_to_float,
         )
+        if isinstance(w_obj, Scalar):
+            return func(res_dtype, w_obj.value.convert_to(res_dtype)).wrap(space)
+
+        new_sig = signature.Signature.find_sig([call_sig, w_obj.signature])
         w_res = Call1(new_sig, res_dtype, w_obj)
         w_obj.add_invalidates(w_res)
         return w_res
@@ -38,15 +37,16 @@
 
         w_lhs = convert_to_array(space, w_lhs)
         w_rhs = convert_to_array(space, w_rhs)
-        if isinstance(w_lhs, Scalar) and isinstance(w_rhs, Scalar):
-            res_dtype = space.fromcache(interp_dtype.W_Float64Dtype)
-            return func(res_dtype, w_lhs.value, w_rhs.value).wrap(space)
-
-        new_sig = signature.Signature.find_sig([call_sig, w_lhs.signature, w_rhs.signature])
         res_dtype = find_binop_result_dtype(space,
             w_lhs.find_dtype(), w_rhs.find_dtype(),
             promote_to_float=promote_to_float,
         )
+        if isinstance(w_lhs, Scalar) and isinstance(w_rhs, Scalar):
+            return func(res_dtype, w_lhs.value, w_rhs.value).wrap(space)
+
+        new_sig = signature.Signature.find_sig([
+            call_sig, w_lhs.signature, w_rhs.signature
+        ])
         w_res = Call2(new_sig, res_dtype, w_lhs, w_rhs)
         w_lhs.add_invalidates(w_res)
         w_rhs.add_invalidates(w_res)
@@ -87,6 +87,21 @@
             assert False
     return dt
 
+def find_dtype_for_scalar(space, w_obj, current_guess=None):
+    w_type = space.type(w_obj)
+
+    bool_dtype = space.fromcache(interp_dtype.W_BoolDtype)
+    int64_dtype = space.fromcache(interp_dtype.W_Int64Dtype)
+
+    if space.is_w(w_type, space.w_bool):
+        if current_guess is None:
+            return bool_dtype
+    elif space.is_w(w_type, space.w_int):
+        if (current_guess is None or current_guess is bool_dtype or
+            current_guess is int64_dtype):
+            return int64_dtype
+    return space.fromcache(interp_dtype.W_Float64Dtype)
+
 
 def ufunc_dtype_caller(ufunc_name, op_name, argcount, **kwargs):
     if argcount == 1:
diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -52,7 +52,7 @@
 
     def test_repr(self):
         from numpy import array, zeros
-        a = array(range(5))
+        a = array(range(5), float)
         assert repr(a) == "array([0.0, 1.0, 2.0, 3.0, 4.0])"
         a = zeros(1001)
         assert repr(a) == "array([0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0])"
@@ -63,7 +63,7 @@
 
     def test_repr_slice(self):
         from numpy import array, zeros
-        a = array(range(5))
+        a = array(range(5), float)
         b = a[1::2]
         assert repr(b) == "array([1.0, 3.0])"
         a = zeros(2002)
@@ -72,7 +72,7 @@
 
     def test_str(self):
         from numpy import array, zeros
-        a = array(range(5))
+        a = array(range(5), float)
         assert str(a) == "[0.0 1.0 2.0 3.0 4.0]"
         assert str((2*a)[:]) == "[0.0 2.0 4.0 6.0 8.0]"
         a = zeros(1001)
@@ -88,7 +88,7 @@
 
     def test_str_slice(self):
         from numpy import array, zeros
-        a = array(range(5))
+        a = array(range(5), float)
         b = a[1::2]
         assert str(b) == "[1.0 3.0]"
         a = zeros(2002)
@@ -144,7 +144,7 @@
 
     def test_setslice_list(self):
         from numpy import array
-        a = array(range(5))
+        a = array(range(5), float)
         b = [0., 1.]
         a[1:4:2] = b
         assert a[1] == 0.
@@ -152,7 +152,7 @@
 
     def test_setslice_constant(self):
         from numpy import array
-        a = array(range(5))
+        a = array(range(5), float)
         a[1:4:2] = 0.
         assert a[1] == 0.
         assert a[3] == 0.
@@ -261,7 +261,7 @@
     def test_div_other(self):
         from numpy import array
         a = array(range(5))
-        b = array([2, 2, 2, 2, 2])
+        b = array([2, 2, 2, 2, 2], float)
         c = a / b
         for i in range(5):
             assert c[i] == i / 2.0
@@ -275,7 +275,7 @@
 
     def test_pow(self):
         from numpy import array
-        a = array(range(5))
+        a = array(range(5), float)
         b = a ** a
         for i in range(5):
             print b[i], i**i
@@ -283,7 +283,7 @@
 
     def test_pow_other(self):
         from numpy import array
-        a = array(range(5))
+        a = array(range(5), float)
         b = array([2, 2, 2, 2, 2])
         c = a ** b
         for i in range(5):
@@ -291,7 +291,7 @@
 
     def test_pow_constant(self):
         from numpy import array
-        a = array(range(5))
+        a = array(range(5), float)
         b = a ** 2
         for i in range(5):
             assert b[i] == i ** 2
@@ -484,6 +484,16 @@
         for i in xrange(5):
             assert b[i] == 2.5 * a[i]
 
+    def test_dtype_guessing(self):
+        from numpy import array, dtype
+
+        assert array([True]).dtype is dtype(bool)
+        assert array([True, 1]).dtype is dtype(long)
+        assert array([1, 2, 3]).dtype is dtype(long)
+        assert array([1.2, True]).dtype is dtype(float)
+        assert array([1.2, 5]).dtype is dtype(float)
+        assert array([]).dtype is dtype(float)
+
 
 class AppTestSupport(object):
     def setup_class(cls):
@@ -496,5 +506,4 @@
         a = fromstring(self.data)
         for i in range(4):
             assert a[i] == i + 1
-        raises(ValueError, fromstring, "abc")
-
+        raises(ValueError, fromstring, "abc")
\ No newline at end of file
diff --git a/pypy/module/micronumpy/test/test_ufuncs.py b/pypy/module/micronumpy/test/test_ufuncs.py
--- a/pypy/module/micronumpy/test/test_ufuncs.py
+++ b/pypy/module/micronumpy/test/test_ufuncs.py
@@ -110,6 +110,10 @@
         for i in range(3):
             assert c[i] == max(a[i], b[i])
 
+        x = maximum(2, 3)
+        assert x == 3
+        assert type(x) is int
+
     def test_multiply(self):
         from numpy import array, multiply
 
diff --git a/pypy/module/micronumpy/test/test_zjit.py b/pypy/module/micronumpy/test/test_zjit.py
--- a/pypy/module/micronumpy/test/test_zjit.py
+++ b/pypy/module/micronumpy/test/test_zjit.py
@@ -34,7 +34,7 @@
             ar = SingleDimArray(i, dtype=self.float64_dtype)
             v = interp_ufuncs.add(self.space,
                 ar,
-                scalar_w(self.space, W_Float64Dtype, self.space.wrap(4.5))
+                scalar_w(self.space, self.float64_dtype, self.space.wrap(4.5))
             )
             assert isinstance(v, BaseArray)
             return v.get_concrete().eval(3).val
@@ -181,9 +181,9 @@
 
         def f(i):
             ar = SingleDimArray(i, dtype=self.float64_dtype)
-            v1 = interp_ufuncs.add(space, ar, scalar_w(space, W_Float64Dtype, space.wrap(4.5)))
+            v1 = interp_ufuncs.add(space, ar, scalar_w(space, self.float64_dtype, space.wrap(4.5)))
             assert isinstance(v1, BaseArray)
-            v2 = interp_ufuncs.multiply(space, v1, scalar_w(space, W_Float64Dtype, space.wrap(4.5)))
+            v2 = interp_ufuncs.multiply(space, v1, scalar_w(space, self.float64_dtype, space.wrap(4.5)))
             v1.force_if_needed()
             assert isinstance(v2, BaseArray)
             return v2.get_concrete().eval(3).val


More information about the pypy-commit mailing list