[pypy-commit] pypy utf8-unicode2: Add support for __eq__, __ne__, __add__ and __mul__ to RPython

waedt noreply at buildbot.pypy.org
Tue Jul 29 16:17:02 CEST 2014


Author: Tyler Wade <wayedt at gmail.com>
Branch: utf8-unicode2
Changeset: r72607:92302cdd34ec
Date: 2014-07-28 13:02 -0500
http://bitbucket.org/pypy/pypy/changeset/92302cdd34ec/

Log:	Add support for __eq__, __ne__, __add__ and __mul__ to RPython

diff --git a/rpython/annotator/binaryop.py b/rpython/annotator/binaryop.py
--- a/rpython/annotator/binaryop.py
+++ b/rpython/annotator/binaryop.py
@@ -718,6 +718,19 @@
             thistype = pairtype(SomeInstance, SomeInstance)
             return super(thistype, pair(ins1, ins2)).improve()
 
+    def eq((s_obj1, s_obj2)):
+        if s_obj1.classdef.classdesc.lookup('__eq__'):
+            return s_obj1._emulate_call("__eq__", s_obj2)
+        elif s_obj2.classdef.classdesc.lookup('__eq__'):
+            return s_obj2._emulate_call("__eq__", s_obj1)
+        return s_Bool
+
+    def ne((s_obj1, s_obj2)):
+        if s_obj1.classdef.classdesc.lookup('__ne__'):
+            return s_obj1._emulate_call("__ne__", s_obj2)
+        elif s_obj2.classdef.classdesc.lookup('__ne__'):
+            return s_obj2._emulate_call("__ne__", s_obj1)
+        return s_Bool
 
 class __extend__(pairtype(SomeInstance, SomeObject)):
     def getitem((s_ins, s_idx)):
@@ -726,6 +739,33 @@
     def setitem((s_ins, s_idx), s_value):
         return s_ins._emulate_call("__setitem__", s_idx, s_value)
 
+    def add((s_ins, s_other)):
+        return s_ins._emulate_call("__add__", s_other)
+
+    def mul((s_ins, s_other)):
+        return s_ins._emulate_call("__mul__", s_other)
+
+    def eq((s_ins, s_obj)):
+        if s_ins.classdef.classdesc.lookup('__eq__'):
+            return s_ins._emulate_call("__eq__", s_obj)
+        return super(pairtype(SomeInstance, SomeObject), pair(s_ins, s_obj)).eq()
+
+    def ne((s_ins, s_obj)):
+        if s_ins.classdef.classdesc.lookup('__ne__'):
+            return s_ins._emulate_call("__ne__", s_obj)
+        return super(pairtype(SomeInstance, SomeObject), pair(s_ins, s_obj)).ne()
+
+class __extend__(pairtype(SomeObject, SomeInstance)):
+    def eq((s_obj, s_ins)):
+        if s_ins.classdef.classdesc.lookup('__eq__'):
+            return s_ins._emulate_call("__eq__", s_obj)
+        return super(pairtype(SomeObject, SomeInstance), pair(s_obj, s_ins)).eq()
+
+    def ne((s_obj, s_ins)):
+        if s_ins.classdef.classdesc.lookup('__ne__'):
+            return s_ins._emulate_call("__ne__", s_obj)
+        return super(pairtype(SomeObject, SomeInstance), pair(s_obj, s_ins)).ne()
+
 
 class __extend__(pairtype(SomeIterator, SomeIterator)):
 
diff --git a/rpython/annotator/description.py b/rpython/annotator/description.py
--- a/rpython/annotator/description.py
+++ b/rpython/annotator/description.py
@@ -476,6 +476,17 @@
                 if self.pyobj not in classdef.FORCE_ATTRIBUTES_INTO_CLASSES:
                     self.all_enforced_attrs = []    # no attribute allowed
 
+        if (self.lookup('__eq__') and
+            not all(b.lookup('__eq__') for b in self.getallbases())):
+            raise AnnotatorError("A class may only define a __eq__ method if "
+                                 "the class at the base of its heirarchy also "
+                                  "has a __eq__ method.")
+        if (self.lookup('__ne__') and
+            not all(b.lookup('__ne__') for b in self.getallbases())):
+            raise AnnotatorError("A class may only define a __ne__ method if "
+                                 "the class at the base of its heirarchy also "
+                                  "has a __ne__ method.")
+
     def add_source_attribute(self, name, value, mixin=False):
         if isinstance(value, types.FunctionType):
             # for debugging
diff --git a/rpython/annotator/test/test_annrpython.py b/rpython/annotator/test/test_annrpython.py
--- a/rpython/annotator/test/test_annrpython.py
+++ b/rpython/annotator/test/test_annrpython.py
@@ -2780,6 +2780,42 @@
         s = a.build_types(f, [])
         assert s.knowntype == int
 
+    def test__eq__in_sub_class(self):
+        class Base(object):
+            pass
+        class A(Base):
+            def __eq__(self, other):
+                return True
+
+        def f(a):
+            if a:
+                o = Base()
+            else:
+                o = A()
+
+            return o == Base()
+
+        a = self.RPythonAnnotator()
+        py.test.raises(annmodel.AnnotatorError, a.build_types, f,  [int])
+
+    def test__ne__in_sub_class(self):
+        class Base(object):
+            pass
+        class A(Base):
+            def __ne__(self, other):
+                return True
+
+        def f(a):
+            if a:
+                o = Base()
+            else:
+                o = A()
+
+            return o != Base()
+
+        a = self.RPythonAnnotator()
+        py.test.raises(annmodel.AnnotatorError, a.build_types, f,  [int])
+
     def test_chr_out_of_bounds(self):
         def g(n, max):
             if n < max:
diff --git a/rpython/rtyper/lltypesystem/rclass.py b/rpython/rtyper/lltypesystem/rclass.py
--- a/rpython/rtyper/lltypesystem/rclass.py
+++ b/rpython/rtyper/lltypesystem/rclass.py
@@ -657,10 +657,8 @@
         r_ins = getinstancerepr(r_ins1.rtyper, basedef, r_ins1.gcflavor)
         return pairtype(Repr, Repr).rtype_is_(pair(r_ins, r_ins), hop)
 
-    rtype_eq = rtype_is_
-
-    def rtype_ne(rpair, hop):
-        v = rpair.rtype_eq(hop)
+    def _rtype_ne(rpair, hop):
+        v = rpair.rtype_is_(hop)
         return hop.genop("bool_not", [v], resulttype=Bool)
 
 # ____________________________________________________________
diff --git a/rpython/rtyper/rclass.py b/rpython/rtyper/rclass.py
--- a/rpython/rtyper/rclass.py
+++ b/rpython/rtyper/rclass.py
@@ -7,7 +7,7 @@
 from rpython.rtyper.lltypesystem.lltype import Void
 from rpython.rtyper.rmodel import Repr, getgcflavor, inputconst
 from rpython.rlib.objectmodel import UnboxedValue
-from rpython.tool.pairtype import pairtype
+from rpython.tool.pairtype import pair, pairtype
 
 
 class FieldListAccessor(object):
@@ -471,14 +471,77 @@
                 break
 
 
+def create_forwarding_func(name):
+    def f((r_ins, r_obj), hop):
+        return r_ins._emulate_call(hop, name)
+    return f
+
 class __extend__(pairtype(AbstractInstanceRepr, Repr)):
-    def rtype_getitem((r_ins, r_obj), hop):
-        return r_ins._emulate_call(hop, "__getitem__")
+    rtype_getitem = create_forwarding_func('__getitem__')
+    rtype_setitem = create_forwarding_func('__setitem__')
+    rtype_add = create_forwarding_func('__add__')
+    rtype_mul = create_forwarding_func('__mul__')
 
-    def rtype_setitem((r_ins, r_obj), hop):
-        return r_ins._emulate_call(hop, "__setitem__")
+    rtype_inplace_add = rtype_add
+    rtype_inplace_mul = rtype_mul
 
+    def rtype_eq((r_ins, r_other), hop):
+        if r_ins.classdef.classdesc.lookup('__eq__'):
+            return r_ins._emulate_call(hop, '__eq__')
+        return super(pairtype(AbstractInstanceRepr, Repr),
+                     pair(r_ins, r_other)).rtype_eq(hop)
 
+    def rtype_ne((r_ins, r_other), hop):
+        if r_ins.classdef.classdesc.lookup('__ne__'):
+            return r_ins._emulate_call(hop, '__ne__')
+        return super(pairtype(AbstractInstanceRepr, Repr),
+                     pair(r_ins, r_other)).rtype_ne(hop)
+
+class __extend__(pairtype(AbstractInstanceRepr, AbstractInstanceRepr)):
+    def rtype_eq((r_ins, r_other), hop):
+        if r_ins.classdef.classdesc.lookup('__eq__'):
+            return r_ins._emulate_call(hop, '__eq__')
+        elif r_other.classdef.classdesc.lookup('__eq__'):
+            # Reverse the order of the arguments before the call to __eq__
+            hop2 = hop.copy()
+            hop2.args_r = hop.args_r[::-1]
+            hop2.args_s = hop.args_s[::-1]
+            hop2.args_v = hop.args_v[::-1]
+            return r_other._emulate_call(hop2, '__eq__')
+        return pair(r_ins, r_other).rtype_is_(hop)
+
+    def rtype_ne((r_ins, r_other), hop):
+        if r_ins.classdef.classdesc.lookup('__ne__'):
+            return r_ins._emulate_call(hop, '__ne__')
+        elif r_other.classdef.classdesc.lookup('__ne__'):
+            # Reverse the order of the arguments before the call to __ne__
+            hop2 = hop.copy()
+            hop2.args_r = hop.args_r[::-1]
+            hop2.args_s = hop.args_s[::-1]
+            hop2.args_v = hop.args_v[::-1]
+            return r_other._emulate_call(hop2, '__ne__')
+        return pair(r_ins, r_other)._rtype_ne(hop)
+
+class __extend__(pairtype(Repr, AbstractInstanceRepr)):
+    def rtype_eq((r_other, r_ins), hop):
+        if r_ins.classdef.classdesc.lookup('__eq__'):
+            hop2 = hop.copy()
+            hop2.args_r = hop.args_r[::-1]
+            hop2.args_s = hop.args_s[::-1]
+            hop2.args_v = hop.args_v[::-1]
+            return r_ins._emulate_call(hop2, '__eq__')
+        return super(pairtype(Repr, AbstractInstanceRepr),
+                     pair(r_other, r_ins)).rtype_eq(hop)
+
+    def rtype_ne((r_other, r_ins), hop):
+        if r_ins.classdef.classdesc.lookup('__ne__'):
+            hop2 = hop.copy()
+            hop2.args_r = hop.args_r[::-1]
+            hop2.args_s = hop.args_s[::-1]
+            hop2.args_v = hop.args_v[::-1]
+            return r_ins._emulate_call(hop2, '__ne__')
+        return super(pairtype(Repr, AbstractInstanceRepr),
+                     pair(r_other, r_ins)).rtype_ne(hop)
 
 # ____________________________________________________________
 
diff --git a/rpython/rtyper/test/test_rclass.py b/rpython/rtyper/test/test_rclass.py
--- a/rpython/rtyper/test/test_rclass.py
+++ b/rpython/rtyper/test/test_rclass.py
@@ -1271,3 +1271,87 @@
             return cls[k](a, b).b
 
         assert self.interpret(f, [1, 4, 7]) == 7
+
+    def test_overriding_eq(self):
+        class Base(object):
+            def __eq__(self, other):
+                return self is other
+        class A(Base):
+            def __eq__(self, other):
+                return True
+
+        def f(a):
+            if a:
+                o = Base()
+            else:
+                o = A()
+
+            return o == Base()
+
+        assert self.interpret(f, [0]) == f(0)
+        assert self.interpret(f, [1]) == f(1)
+
+    def test_eq_reversed(self):
+        class A(object):
+            def __eq__(self, other):
+                return not bool(other)
+
+        def f(a):
+            return (a == A()) == (A() == a)
+        assert self.interpret(f, [0]) == f(0)
+        assert self.interpret(f, [1]) == f(1)
+
+    def test_eq_without_ne(self):
+        class A(object):
+            def __eq__(self, other):
+                return False
+
+        def f():
+            a = A()
+            return a != A()
+
+        assert self.interpret(f, []) == f()
+
+    def test_overriding_ne(self):
+        class Base(object):
+            def __ne__(self, other):
+                return self is other
+        class A(Base):
+            def __ne__(self, other):
+                return True
+
+        def f(a):
+            if a:
+                o = Base()
+            else:
+                o = A()
+
+            return o != Base()
+
+        assert self.interpret(f, [0]) == f(0)
+        assert self.interpret(f, [1]) == f(1)
+
+    def test_ne_reversed(self):
+        class A(object):
+            def __ne__(self, other):
+                return not bool(other)
+
+        def f(a):
+            return (a != A()) == (A() != a)
+        assert self.interpret(f, [0]) == f(0)
+        assert self.interpret(f, [1]) == f(1)
+
+    def test_arithmetic_ops(self):
+        class A(object):
+            def __add__(self, other):
+                return other + other
+
+            def __mul__(self, other):
+                return other * other
+
+        def f(a):
+            o = A()
+            return (o + a) + (o * a)
+
+        for i in range(10):
+            assert self.interpret(f, [i]) == f(i)


More information about the pypy-commit mailing list