[pypy-commit] pypy numpy-comparison: Initial implementation (tests pass, translation fails)

snus_mumrik noreply at buildbot.pypy.org
Fri Sep 2 14:45:55 CEST 2011


Author: Ilya Osadchiy <osadchiy.ilya at gmail.com>
Branch: numpy-comparison
Changeset: r47014:edb6c31894de
Date: 2011-09-02 10:37 +0300
http://bitbucket.org/pypy/pypy/changeset/edb6c31894de/

Log:	Initial implementation (tests pass, translation fails)

diff --git a/pypy/module/micronumpy/__init__.py b/pypy/module/micronumpy/__init__.py
--- a/pypy/module/micronumpy/__init__.py
+++ b/pypy/module/micronumpy/__init__.py
@@ -25,17 +25,18 @@
         'floor': 'interp_ufuncs.floor',
         'maximum': 'interp_ufuncs.maximum',
         'minimum': 'interp_ufuncs.minimum',
-        'multiply': 'interp_ufuncs.multiply',
-        'negative': 'interp_ufuncs.negative',
-        'reciprocal': 'interp_ufuncs.reciprocal',
-        'sign': 'interp_ufuncs.sign',
-        'subtract': 'interp_ufuncs.subtract',
-        'sin': 'interp_ufuncs.sin',
-        'cos': 'interp_ufuncs.cos',
-        'tan': 'interp_ufuncs.tan',
-        'arcsin': 'interp_ufuncs.arcsin',
-        'arccos': 'interp_ufuncs.arccos',
-        'arctan': 'interp_ufuncs.arctan',
+    'multiply': 'interp_ufuncs.multiply',
+    'negative': 'interp_ufuncs.negative',
+    'reciprocal': 'interp_ufuncs.reciprocal',
+    'sign': 'interp_ufuncs.sign',
+    'subtract': 'interp_ufuncs.subtract',
+    'sin': 'interp_ufuncs.sin',
+    'cos': 'interp_ufuncs.cos',
+    'tan': 'interp_ufuncs.tan',
+    'arcsin': 'interp_ufuncs.arcsin',
+    'arccos': 'interp_ufuncs.arccos',
+    'arctan': 'interp_ufuncs.arctan',
+    'equal': 'interp_ufuncs.equal',
     }
 
     appleveldefs = {
diff --git a/pypy/module/micronumpy/interp_dtype.py b/pypy/module/micronumpy/interp_dtype.py
--- a/pypy/module/micronumpy/interp_dtype.py
+++ b/pypy/module/micronumpy/interp_dtype.py
@@ -125,6 +125,15 @@
         ))
     return impl
 
+def bool_binop(func):
+    @functools.wraps(func)
+    def impl(self, v1, v2):
+        return self.box(func(self,
+            self.for_computation(self.unbox(v1)),
+            self.for_computation(self.unbox(v2)),
+        ))
+    return impl
+
 def unaryop(func):
     @functools.wraps(func)
     def impl(self, v):
@@ -147,6 +156,25 @@
     def div(self, v1, v2):
         return v1 / v2
 
+    @bool_binop
+    def eq(self, v1, v2):
+        return v1 == v2
+    @bool_binop
+    def ne(self, v1, v2):
+        return v1 != v2
+    @bool_binop
+    def lt(self, v1, v2):
+        return v1 < v2
+    @bool_binop
+    def le(self, v1, v2):
+        return v1 <= v2
+    @bool_binop
+    def gt(self, v1, v2):
+        return v1 > v2
+    @bool_binop
+    def ge(self, v1, v2):
+        return v1 >= v2
+
     @unaryop
     def pos(self, v):
         return +v
@@ -166,8 +194,8 @@
 
     def bool(self, v):
         return bool(self.for_computation(self.unbox(v)))
-    def ne(self, v1, v2):
-        return self.for_computation(self.unbox(v1)) != self.for_computation(self.unbox(v2))
+#    def ne(self, v1, v2):
+#        return self.for_computation(self.unbox(v1)) != self.for_computation(self.unbox(v2))
 
 
 class FloatArithmeticDtype(ArithmaticTypeMixin):
@@ -355,4 +383,4 @@
     num = interp_attrproperty("num", cls=W_Dtype),
     kind = interp_attrproperty("kind", cls=W_Dtype),
     shape = GetSetProperty(W_Dtype.descr_get_shape),
-)
\ No newline at end of file
+)
diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -74,6 +74,13 @@
     descr_pow = _binop_impl(interp_ufuncs.power)
     descr_mod = _binop_impl(interp_ufuncs.mod)
 
+    descr_eq = _binop_impl(interp_ufuncs.equal)
+    descr_ne = _binop_impl(interp_ufuncs.not_equal)
+    descr_lt = _binop_impl(interp_ufuncs.less)
+    descr_le = _binop_impl(interp_ufuncs.less_equal)
+    descr_gt = _binop_impl(interp_ufuncs.greater)
+    descr_ge = _binop_impl(interp_ufuncs.greater_equal)
+
     def _binop_right_impl(w_ufunc):
         def impl(self, space, w_other):
             w_other = scalar_w(space,
@@ -152,7 +159,7 @@
                                               size=size, i=i, result=result,
                                               cur_best=cur_best)
                 new_best = getattr(dtype, op_name)(cur_best, self.eval(i))
-                if dtype.ne(new_best, cur_best):
+                if dtype.unbox(dtype.ne(new_best, cur_best)):
                     result = i
                     cur_best = new_best
                 i += 1
@@ -350,11 +357,12 @@
     """
     Class for representing virtual arrays, such as binary ops or ufuncs
     """
-    def __init__(self, signature, res_dtype):
+    def __init__(self, signature, res_dtype, calc_dtype):
         BaseArray.__init__(self)
         self.forced_result = None
         self.signature = signature
         self.res_dtype = res_dtype
+        self.calc_dtype = calc_dtype
 
     def _del_sources(self):
         # Function for deleting references to source arrays, to allow garbage-collecting them
@@ -402,7 +410,7 @@
 
 class Call1(VirtualArray):
     def __init__(self, signature, res_dtype, values):
-        VirtualArray.__init__(self, signature, res_dtype)
+        VirtualArray.__init__(self, signature, res_dtype, res_dtype)
         self.values = values
 
     def _del_sources(self):
@@ -427,8 +435,8 @@
     """
     Intermediate class for performing binary operations.
     """
-    def __init__(self, signature, res_dtype, left, right):
-        VirtualArray.__init__(self, signature, res_dtype)
+    def __init__(self, signature, res_dtype, calc_dtype, left, right):
+        VirtualArray.__init__(self, signature, res_dtype, calc_dtype)
         self.left = left
         self.right = right
 
@@ -444,14 +452,14 @@
         return self.right.find_size()
 
     def _eval(self, i):
-        lhs = self.left.eval(i).convert_to(self.res_dtype)
-        rhs = self.right.eval(i).convert_to(self.res_dtype)
+        lhs = self.left.eval(i).convert_to(self.calc_dtype)
+        rhs = self.right.eval(i).convert_to(self.calc_dtype)
 
         sig = jit.promote(self.signature)
         assert isinstance(sig, signature.Signature)
         call_sig = sig.components[0]
         assert isinstance(call_sig, signature.Call2)
-        return call_sig.func(self.res_dtype, lhs, rhs)
+        return call_sig.func(self.calc_dtype, lhs, rhs).convert_to(self.res_dtype)
 
 class ViewArray(BaseArray):
     """
@@ -610,6 +618,13 @@
     __repr__ = interp2app(BaseArray.descr_repr),
     __str__ = interp2app(BaseArray.descr_str),
 
+    __eq__ = interp2app(BaseArray.descr_eq),
+    __ne__ = interp2app(BaseArray.descr_ne),
+    __lt__ = interp2app(BaseArray.descr_lt),
+    __le__ = interp2app(BaseArray.descr_le),
+    __gt__ = interp2app(BaseArray.descr_gt),
+    __ge__ = interp2app(BaseArray.descr_ge),
+
     dtype = GetSetProperty(BaseArray.descr_get_dtype),
     shape = GetSetProperty(BaseArray.descr_get_shape),
 
diff --git a/pypy/module/micronumpy/interp_ufuncs.py b/pypy/module/micronumpy/interp_ufuncs.py
--- a/pypy/module/micronumpy/interp_ufuncs.py
+++ b/pypy/module/micronumpy/interp_ufuncs.py
@@ -24,9 +24,9 @@
         return w_res
     return func_with_new_name(impl, "%s_dispatcher" % func.__name__)
 
-def ufunc2(func=None, promote_to_float=False):
+def ufunc2(func=None, promote_to_float=False, bool_result=False):
     if func is None:
-        return lambda func: ufunc2(func, promote_to_float)
+        return lambda func: ufunc2(func, promote_to_float, bool_result)
 
     call_sig = signature.Call2(func)
     def impl(space, w_lhs, w_rhs):
@@ -35,17 +35,25 @@
 
         w_lhs = convert_to_array(space, w_lhs)
         w_rhs = convert_to_array(space, w_rhs)
-        res_dtype = find_binop_result_dtype(space,
+        calc_dtype = find_binop_result_dtype(space,
             w_lhs.find_dtype(), w_rhs.find_dtype(),
             promote_to_float=promote_to_float,
         )
+        # Some operations return bool regardless of input type
+        if bool_result:
+            res_dtype = space.fromcache(interp_dtype.W_BoolDtype)
+        else:
+            res_dtype = calc_dtype
         if isinstance(w_lhs, Scalar) and isinstance(w_rhs, Scalar):
-            return func(res_dtype, w_lhs.value, w_rhs.value).wrap(space)
+            lhs = w_lhs.value.convert_to(calc_dtype)
+            rhs = w_rhs.value.convert_to(calc_dtype)
+            interm_res = func(calc_dtype, lhs, rhs)
+            return interm_res.convert_to(res_dtype).wrap(space)
 
         new_sig = signature.Signature.find_sig([
             call_sig, w_lhs.signature, w_rhs.signature
         ])
-        w_res = Call2(new_sig, res_dtype, w_lhs, w_rhs)
+        w_res = Call2(new_sig, res_dtype, calc_dtype, w_lhs, w_rhs)
         w_lhs.add_invalidates(w_res)
         w_rhs.add_invalidates(w_res)
         return w_res
@@ -123,6 +131,13 @@
     ("maximum", "max", 2),
     ("minimum", "min", 2),
 
+    ("equal", "eq", 2, {"bool_result": True}),
+    ("not_equal", "ne", 2, {"bool_result": True}),
+    ("less", "lt", 2, {"bool_result": True}),
+    ("less_equal", "le", 2, {"bool_result": True}),
+    ("greater", "gt", 2, {"bool_result": True}),
+    ("greater_equal", "ge", 2, {"bool_result": True}),
+
     ("copysign", "copysign", 2, {"promote_to_float": True}),
 
     ("positive", "pos", 1),
diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -510,6 +510,34 @@
         assert array([1.2, 5]).dtype is dtype(float)
         assert array([]).dtype is dtype(float)
 
+    def test_comparison(self):
+        from numpy import array, dtype
+        a = array(range(5))
+        b = array(range(5), dtype=float)
+        for func in [
+                lambda x, y: x == y,
+                lambda x, y: x != y,
+                lambda x, y: x <  y,
+                lambda x, y: x <= y,
+                lambda x, y: x >  y,
+                lambda x, y: x >= y,
+                ]:
+            _a3 = func (a, 3)
+            assert _a3.dtype is dtype(bool)
+            for i in xrange(5):
+                assert _a3[i] == (True if func(a[i], 3) else False)
+            _b3 = func (b, 3)
+            assert _b3.dtype is dtype(bool)
+            for i in xrange(5):
+                assert _b3[i] == (True if func(b[i], 3) else False)
+            _3a = func (3, a)
+            assert _3a.dtype is dtype(bool)
+            for i in xrange(5):
+                assert _3a[i] == (True if func(3, a[i]) else False)
+            _3b = func (3, b)
+            assert _3b.dtype is dtype(bool)
+            for i in xrange(5):
+                assert _3b[i] == (True if func(3, b[i]) else False)
 
 class AppTestSupport(object):
     def setup_class(cls):
@@ -522,4 +550,4 @@
         a = fromstring(self.data)
         for i in range(4):
             assert a[i] == i + 1
-        raises(ValueError, fromstring, "abc")
\ No newline at end of file
+        raises(ValueError, fromstring, "abc")
diff --git a/pypy/module/micronumpy/test/test_ufuncs.py b/pypy/module/micronumpy/test/test_ufuncs.py
--- a/pypy/module/micronumpy/test/test_ufuncs.py
+++ b/pypy/module/micronumpy/test/test_ufuncs.py
@@ -267,3 +267,11 @@
         b = arctan(a)
         assert math.isnan(b[0])
 
+    def test_comparison(self):
+        from numpy import array, dtype, equal
+        assert equal(3, 3) is True
+        assert equal(3, 4) is False
+        assert equal(3.0, 3.0) is True
+        assert equal(3.0, 3.5) is False
+        assert equal(3.0, 3) is True
+        assert equal(3.0, 4) is False


More information about the pypy-commit mailing list