[pypy-commit] pypy numpy-dtype-alt: fix for sum/prod with various dtypes. breaks test_zjit.

alex_gaynor noreply at buildbot.pypy.org
Mon Aug 22 19:36:38 CEST 2011


Author: Alex Gaynor <alex.gaynor at gmail.com>
Branch: numpy-dtype-alt
Changeset: r46706:5377b6e0918b
Date: 2011-08-22 12:41 -0500
http://bitbucket.org/pypy/pypy/changeset/5377b6e0918b/

Log:	fix for sum/prod with various dtypes. breaks test_zjit.

diff --git a/pypy/module/micronumpy/interp_dtype.py b/pypy/module/micronumpy/interp_dtype.py
--- a/pypy/module/micronumpy/interp_dtype.py
+++ b/pypy/module/micronumpy/interp_dtype.py
@@ -218,6 +218,10 @@
 class IntegerArithmeticDtype(object):
     _mixin_ = True
 
+    @binop
+    def add(self, v1, v2):
+        return v1 + v2
+
     def str_format(self, item):
         return str(widen(self.unbox(item)))
 
diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -92,13 +92,19 @@
                 reduce_driver.jit_merge_point(signature=self.signature,
                                               self=self, res_dtype=res_dtype,
                                               size=size, i=i, result=result)
-                result = getattr(res_dtype, op_name)(result, self.eval(i))
+                result = getattr(res_dtype, op_name)(
+                    result,
+                    self.eval(i).convert_to(res_dtype)
+                )
                 i += 1
             return result
 
         def impl(self, space):
-            result = space.fromcache(interp_dtype.W_Float64Dtype).box(init).convert_to(self.find_dtype())
-            return loop(self, self.find_dtype(), result, self.find_size()).wrap(space)
+            dtype = interp_ufuncs.find_unaryop_result_dtype(
+                space, self.find_dtype(), promote_to_largest=True
+            )
+            result = dtype.adapt_val(init)
+            return loop(self, dtype, result, self.find_size()).wrap(space)
         return func_with_new_name(impl, "reduce_%s_impl" % op_name)
 
     def _reduce_max_min_impl(op_name):
@@ -178,8 +184,8 @@
     def descr_any(self, space):
         return space.wrap(self._any())
 
-    descr_sum = _reduce_sum_prod_impl("add", 0.0)
-    descr_prod = _reduce_sum_prod_impl("mul", 1.0)
+    descr_sum = _reduce_sum_prod_impl("add", 0)
+    descr_prod = _reduce_sum_prod_impl("mul", 1)
     descr_max = _reduce_max_min_impl("max")
     descr_min = _reduce_max_min_impl("min")
     descr_argmax = _reduce_argmax_argmin_impl("max")
diff --git a/pypy/module/micronumpy/interp_ufuncs.py b/pypy/module/micronumpy/interp_ufuncs.py
--- a/pypy/module/micronumpy/interp_ufuncs.py
+++ b/pypy/module/micronumpy/interp_ufuncs.py
@@ -73,11 +73,19 @@
 
     assert False
 
-def find_unaryop_result_dtype(space, dt, promote_to_float=False):
+def find_unaryop_result_dtype(space, dt, promote_to_float=False,
+    promote_to_largest=False):
     if promote_to_float:
         for bytes, dtype in interp_dtype.dtypes_by_num_bytes:
             if dtype.kind == interp_dtype.FLOATINGLTR and dtype.num_bytes >= dt.num_bytes:
                 return space.fromcache(dtype)
+    if promote_to_largest:
+        if dt.kind == interp_dtype.BOOLLTR or dt.kind == interp_dtype.SIGNEDLTR:
+            return space.fromcache(interp_dtype.W_Int64Dtype)
+        elif dt.kind == interp_dtype.FLOATINGLTR:
+            return space.fromcache(interp_dtype.W_Float64Dtype)
+        else:
+            assert False
     return dt
 
 
diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -412,6 +412,9 @@
         assert a.sum() == 10.0
         assert a[:4].sum() == 6.0
 
+        a = array([True] * 5, bool)
+        assert a.sum() == 5
+
     def test_prod(self):
         from numpy import array
         a = array(range(1,6))
diff --git a/pypy/module/micronumpy/test/test_zjit.py b/pypy/module/micronumpy/test/test_zjit.py
--- a/pypy/module/micronumpy/test/test_zjit.py
+++ b/pypy/module/micronumpy/test/test_zjit.py
@@ -14,7 +14,7 @@
 class TestNumpyJIt(LLJitMixin):
     def setup_class(cls):
         cls.space = FakeSpace()
-        cls.float64_dtype = W_Float64Dtype(cls.space)
+        cls.float64_dtype = cls.space.fromcache(W_Float64Dtype)
 
     def test_add(self):
         def f(i):


More information about the pypy-commit mailing list