[pypy-commit] pypy numpy-dtype-alt: Write dtype coerscion logic, and a failing test.

Thu Aug 18 01:39:17 CEST 2011

```Author: Alex Gaynor <alex.gaynor at gmail.com>
Branch: numpy-dtype-alt
Changeset: r46584:5aa559a16cbe
Date: 2011-08-17 18:43 -0500
http://bitbucket.org/pypy/pypy/changeset/5aa559a16cbe/

Log:	Write dtype coerscion logic, and a failing test.

diff --git a/pypy/module/micronumpy/interp_ufuncs.py b/pypy/module/micronumpy/interp_ufuncs.py
--- a/pypy/module/micronumpy/interp_ufuncs.py
+++ b/pypy/module/micronumpy/interp_ufuncs.py
@@ -41,6 +41,22 @@
).wrap(space)
return func_with_new_name(impl, "%s_dispatcher" % func.__name__)

+def find_binop_result_dtype(space, dt1, dt2, promote_bools=False):
+    # dt1.num should be <= dt2.num
+    if dt1.num > dt2.num:
+        dt1, dt2 = dt2, dt1
+    # Some operations promote op(bool, bool) to return int8, rather than bool
+    if promote_bools and (dt1.kind == dt2.kind == interp_dtype.BOOLLTR):
+        return space.fromcache(interp_dtype.W_Int8Dtype)
+    # If they're the same kind, choose the greater one.
+    if dt1.kind == dt2.kind:
+        return dt2
+
+    # Everything promotes to float, and bool promotes to everything.
+    if dt2.kind == interp_dtype.FLOATINGLTR or dt1.kind == interp_dtype.BOOLLTR:
+        return dt2
+
+
def ufunc_dtype_caller(ufunc_name, op_name, argcount):
if argcount == 1:
@ufunc
diff --git a/pypy/module/micronumpy/test/test_base.py b/pypy/module/micronumpy/test/test_base.py
--- a/pypy/module/micronumpy/test/test_base.py
+++ b/pypy/module/micronumpy/test/test_base.py
@@ -1,15 +1,16 @@
from pypy.conftest import gettestobjspace
-from pypy.module.micronumpy.interp_dtype import W_Float64Dtype
+from pypy.module.micronumpy import interp_dtype
from pypy.module.micronumpy.interp_numarray import SingleDimArray, Scalar
+from pypy.module.micronumpy.interp_ufuncs import find_binop_result_dtype

class BaseNumpyAppTest(object):
def setup_class(cls):
-        cls.space = gettestobjspace(usemodules=('micronumpy',))
+        cls.space = gettestobjspace(usemodules=['micronumpy'])

class TestSignature(object):
def test_binop_signature(self, space):
-        ar = SingleDimArray(10, dtype=space.fromcache(W_Float64Dtype))
+        ar = SingleDimArray(10, dtype=space.fromcache(interp_dtype.W_Float64Dtype))
assert v1.signature is not v2.signature
@@ -19,7 +20,7 @@
assert v1.signature is v4.signature

def test_slice_signature(self, space):
-        ar = SingleDimArray(10, dtype=space.fromcache(W_Float64Dtype))
+        ar = SingleDimArray(10, dtype=space.fromcache(interp_dtype.W_Float64Dtype))
v1 = ar.descr_getitem(space, space.wrap(slice(1, 5, 1)))
v2 = ar.descr_getitem(space, space.wrap(slice(4, 6, 1)))
assert v1.signature is v2.signature
@@ -27,3 +28,22 @@
assert v3.signature is v4.signature
+
+class TestUfuncCoerscion(object):
+    def test_binops(self, space):
+        bool_dtype = space.fromcache(interp_dtype.W_BoolDtype)
+        int8_dtype = space.fromcache(interp_dtype.W_Int8Dtype)
+        int32_dtype = space.fromcache(interp_dtype.W_Int32Dtype)
+        float64_dtype = space.fromcache(interp_dtype.W_Float64Dtype)
+
+        # Basic pairing
+        assert find_binop_result_dtype(space, bool_dtype, bool_dtype) is bool_dtype
+        assert find_binop_result_dtype(space, bool_dtype, float64_dtype) is float64_dtype
+        assert find_binop_result_dtype(space, float64_dtype, bool_dtype) is float64_dtype
+        assert find_binop_result_dtype(space, int32_dtype, int8_dtype) is int32_dtype
+        assert find_binop_result_dtype(space, int32_dtype, bool_dtype) is int32_dtype
+
+        # With promote bool (happens on div), the result is that the op should
+        # promote bools to int8
+        assert find_binop_result_dtype(space, bool_dtype, bool_dtype, promote_bools=True) is int8_dtype
+        assert find_binop_result_dtype(space, bool_dtype, float64_dtype, promote_bools=True) is float64_dtype
\ No newline at end of file
diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -175,6 +175,12 @@
for i in range(5):
assert b[i] == i + i

+        a = array([True, False, True, False], dtype="?")
+        b = array([True, True, False, False], dtype="?")
+        c = a + b
+        for i in range(4):
+            assert c[i] == bool(a[i] + b[i])
+
from numpy import array
a = array(range(5))
@@ -463,7 +469,7 @@
a = array(range(5))
b = a.dot(2.5)
for i in xrange(5):
-            assert b[i] == 2.5*a[i]
+            assert b[i] == 2.5 * a[i]

class AppTestSupport(object):
```