[pypy-commit] pypy default: implement comparison funcs for record types

bdkearns noreply at buildbot.pypy.org
Sat Feb 22 22:15:04 CET 2014


Author: Brian Kearns <bdkearns at gmail.com>
Branch: 
Changeset: r69273:9ba1d3bb478e
Date: 2014-02-22 15:52 -0500
http://bitbucket.org/pypy/pypy/changeset/9ba1d3bb478e/

Log:	implement comparison funcs for record types

diff --git a/pypy/module/micronumpy/compile.py b/pypy/module/micronumpy/compile.py
--- a/pypy/module/micronumpy/compile.py
+++ b/pypy/module/micronumpy/compile.py
@@ -5,6 +5,7 @@
 
 import re
 
+from pypy.interpreter import special
 from pypy.interpreter.baseobjspace import InternalSpaceCache, W_Root
 from pypy.interpreter.error import OperationError
 from pypy.module.micronumpy import interp_boxes
@@ -74,6 +75,7 @@
     def __init__(self):
         """NOT_RPYTHON"""
         self.fromcache = InternalSpaceCache(self).getorbuild
+        self.w_NotImplemented = special.NotImplemented(self)
 
     def _freeze_(self):
         return True
@@ -194,6 +196,9 @@
     def is_w(self, w_obj, w_what):
         return w_obj is w_what
 
+    def eq_w(self, w_obj, w_what):
+        return w_obj == w_what
+
     def issubtype(self, w_type1, w_type2):
         return BoolObject(True)
 
diff --git a/pypy/module/micronumpy/interp_ufuncs.py b/pypy/module/micronumpy/interp_ufuncs.py
--- a/pypy/module/micronumpy/interp_ufuncs.py
+++ b/pypy/module/micronumpy/interp_ufuncs.py
@@ -371,17 +371,23 @@
         w_ldtype = w_lhs.get_dtype()
         w_rdtype = w_rhs.get_dtype()
         if w_ldtype.is_str_type() and w_rdtype.is_str_type() and \
-           self.comparison_func:
+                self.comparison_func:
             pass
         elif (w_ldtype.is_str_type() or w_rdtype.is_str_type()) and \
-            self.comparison_func and w_out is None:
+                self.comparison_func and w_out is None:
             return space.wrap(False)
-        elif (w_ldtype.is_flexible_type() or \
-                w_rdtype.is_flexible_type()):
-            raise OperationError(space.w_TypeError, space.wrap(
-                 'unsupported operand dtypes %s and %s for "%s"' % \
-                 (w_rdtype.get_name(), w_ldtype.get_name(),
-                  self.name)))
+        elif w_ldtype.is_flexible_type() or w_rdtype.is_flexible_type():
+            if self.comparison_func:
+                if self.name == 'equal' or self.name == 'not_equal':
+                    res = w_ldtype.eq(space, w_rdtype)
+                    if not res:
+                        return space.wrap(self.name == 'not_equal')
+                else:
+                    return space.w_NotImplemented
+            else:
+                raise oefmt(space.w_TypeError,
+                    'unsupported operand dtypes %s and %s for "%s"',
+                    w_rdtype.name, w_ldtype.name, self.name)
 
         if self.are_common_types(w_ldtype, w_rdtype):
             if not w_lhs.is_scalar() and w_rhs.is_scalar():
diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -3573,6 +3573,28 @@
         exc = raises(ValueError, "a.view(('float32', 2))")
         assert exc.value[0] == 'new type not compatible with array.'
 
+    def test_record_ufuncs(self):
+        import numpy as np
+        a = np.zeros(3, dtype=[('a', 'i8'), ('b', 'i8')])
+        b = np.zeros(3, dtype=[('a', 'i8'), ('b', 'i8')])
+        c = np.zeros(3, dtype=[('a', 'f8'), ('b', 'f8')])
+        d = np.ones(3, dtype=[('a', 'i8'), ('b', 'i8')])
+        e = np.ones(3, dtype=[('a', 'i8'), ('b', 'i8'), ('c', 'i8')])
+        exc = raises(TypeError, abs, a)
+        assert exc.value[0] == 'Not implemented for this type'
+        assert (a == a).all()
+        assert not (a != a).any()
+        assert (a == b).all()
+        assert not (a != b).any()
+        assert a != c
+        assert not a == c
+        assert (a != d).all()
+        assert not (a == d).any()
+        assert a != e
+        assert not a == e
+        assert np.greater(a, a) is NotImplemented
+        assert np.less_equal(a, a) is NotImplemented
+
 class AppTestPyPy(BaseNumpyAppTest):
     def setup_class(cls):
         if option.runappdirect and '__pypy__' not in sys.builtin_module_names:
diff --git a/pypy/module/micronumpy/types.py b/pypy/module/micronumpy/types.py
--- a/pypy/module/micronumpy/types.py
+++ b/pypy/module/micronumpy/types.py
@@ -1944,6 +1944,20 @@
         pieces.append(")")
         return "".join(pieces)
 
+    def eq(self, v1, v2):
+        assert isinstance(v1, interp_boxes.W_VoidBox)
+        assert isinstance(v2, interp_boxes.W_VoidBox)
+        s1 = v1.dtype.get_size()
+        s2 = v2.dtype.get_size()
+        assert s1 == s2
+        for i in range(s1):
+            if v1.arr.storage[v1.ofs + i] != v2.arr.storage[v2.ofs + i]:
+                return False
+        return True
+
+    def ne(self, v1, v2):
+        return not self.eq(v1, v2)
+
 for tp in [Int32, Int64]:
     if tp.T == lltype.Signed:
         IntP = tp


More information about the pypy-commit mailing list