[pypy-commit] pypy utf8-unicode2: Add support for __eq__, __ne__, __add__ and __mul__ to RPython
waedt
noreply at buildbot.pypy.org
Tue Jul 29 16:17:02 CEST 2014
Author: Tyler Wade <wayedt at gmail.com>
Branch: utf8-unicode2
Changeset: r72607:92302cdd34ec
Date: 2014-07-28 13:02 -0500
http://bitbucket.org/pypy/pypy/changeset/92302cdd34ec/
Log: Add support for __eq__, __ne__, __add__ and __mul__ to RPython
diff --git a/rpython/annotator/binaryop.py b/rpython/annotator/binaryop.py
--- a/rpython/annotator/binaryop.py
+++ b/rpython/annotator/binaryop.py
@@ -718,6 +718,19 @@
thistype = pairtype(SomeInstance, SomeInstance)
return super(thistype, pair(ins1, ins2)).improve()
+ def eq((s_obj1, s_obj2)):
+ if s_obj1.classdef.classdesc.lookup('__eq__'):
+ return s_obj1._emulate_call("__eq__", s_obj2)
+ elif s_obj2.classdef.classdesc.lookup('__eq__'):
+ return s_obj2._emulate_call("__eq__", s_obj1)
+ return s_Bool
+
+ def ne((s_obj1, s_obj2)):
+ if s_obj1.classdef.classdesc.lookup('__ne__'):
+ return s_obj1._emulate_call("__ne__", s_obj2)
+ elif s_obj2.classdef.classdesc.lookup('__ne__'):
+ return s_obj2._emulate_call("__ne__", s_obj1)
+ return s_Bool
class __extend__(pairtype(SomeInstance, SomeObject)):
def getitem((s_ins, s_idx)):
@@ -726,6 +739,33 @@
def setitem((s_ins, s_idx), s_value):
return s_ins._emulate_call("__setitem__", s_idx, s_value)
+ def add((s_ins, s_other)):
+ return s_ins._emulate_call("__add__", s_other)
+
+ def mul((s_ins, s_other)):
+ return s_ins._emulate_call("__mul__", s_other)
+
+ def eq((s_ins, s_obj)):
+ if s_ins.classdef.classdesc.lookup('__eq__'):
+ return s_ins._emulate_call("__eq__", s_obj)
+ return super(pairtype(SomeInstance, SomeObject), pair(s_ins, s_obj)).eq()
+
+ def ne((s_ins, s_obj)):
+ if s_ins.classdef.classdesc.lookup('__ne__'):
+ return s_ins._emulate_call("__ne__", s_obj)
+ return super(pairtype(SomeInstance, SomeObject), pair(s_ins, s_obj)).ne()
+
+class __extend__(pairtype(SomeObject, SomeInstance)):
+ def eq((s_obj, s_ins)):
+ if s_ins.classdef.classdesc.lookup('__eq__'):
+ return s_ins._emulate_call("__eq__", s_obj)
+ return super(pairtype(SomeObject, SomeInstance), pair(s_obj, s_ins)).eq()
+
+ def ne((s_obj, s_ins)):
+ if s_ins.classdef.classdesc.lookup('__ne__'):
+ return s_ins._emulate_call("__ne__", s_obj)
+ return super(pairtype(SomeObject, SomeInstance), pair(s_obj, s_ins)).ne()
+
class __extend__(pairtype(SomeIterator, SomeIterator)):
diff --git a/rpython/annotator/description.py b/rpython/annotator/description.py
--- a/rpython/annotator/description.py
+++ b/rpython/annotator/description.py
@@ -476,6 +476,17 @@
if self.pyobj not in classdef.FORCE_ATTRIBUTES_INTO_CLASSES:
self.all_enforced_attrs = [] # no attribute allowed
+ if (self.lookup('__eq__') and
+ not all(b.lookup('__eq__') for b in self.getallbases())):
+ raise AnnotatorError("A class may only define a __eq__ method if "
+ "the class at the base of its heirarchy also "
+ "has a __eq__ method.")
+ if (self.lookup('__ne__') and
+ not all(b.lookup('__ne__') for b in self.getallbases())):
+ raise AnnotatorError("A class may only define a __ne__ method if "
+ "the class at the base of its heirarchy also "
+ "has a __ne__ method.")
+
def add_source_attribute(self, name, value, mixin=False):
if isinstance(value, types.FunctionType):
# for debugging
diff --git a/rpython/annotator/test/test_annrpython.py b/rpython/annotator/test/test_annrpython.py
--- a/rpython/annotator/test/test_annrpython.py
+++ b/rpython/annotator/test/test_annrpython.py
@@ -2780,6 +2780,42 @@
s = a.build_types(f, [])
assert s.knowntype == int
+ def test__eq__in_sub_class(self):
+ class Base(object):
+ pass
+ class A(Base):
+ def __eq__(self, other):
+ return True
+
+ def f(a):
+ if a:
+ o = Base()
+ else:
+ o = A()
+
+ return o == Base()
+
+ a = self.RPythonAnnotator()
+ py.test.raises(annmodel.AnnotatorError, a.build_types, f, [int])
+
+ def test__ne__in_sub_class(self):
+ class Base(object):
+ pass
+ class A(Base):
+ def __ne__(self, other):
+ return True
+
+ def f(a):
+ if a:
+ o = Base()
+ else:
+ o = A()
+
+ return o != Base()
+
+ a = self.RPythonAnnotator()
+ py.test.raises(annmodel.AnnotatorError, a.build_types, f, [int])
+
def test_chr_out_of_bounds(self):
def g(n, max):
if n < max:
diff --git a/rpython/rtyper/lltypesystem/rclass.py b/rpython/rtyper/lltypesystem/rclass.py
--- a/rpython/rtyper/lltypesystem/rclass.py
+++ b/rpython/rtyper/lltypesystem/rclass.py
@@ -657,10 +657,8 @@
r_ins = getinstancerepr(r_ins1.rtyper, basedef, r_ins1.gcflavor)
return pairtype(Repr, Repr).rtype_is_(pair(r_ins, r_ins), hop)
- rtype_eq = rtype_is_
-
- def rtype_ne(rpair, hop):
- v = rpair.rtype_eq(hop)
+ def _rtype_ne(rpair, hop):
+ v = rpair.rtype_is_(hop)
return hop.genop("bool_not", [v], resulttype=Bool)
# ____________________________________________________________
diff --git a/rpython/rtyper/rclass.py b/rpython/rtyper/rclass.py
--- a/rpython/rtyper/rclass.py
+++ b/rpython/rtyper/rclass.py
@@ -7,7 +7,7 @@
from rpython.rtyper.lltypesystem.lltype import Void
from rpython.rtyper.rmodel import Repr, getgcflavor, inputconst
from rpython.rlib.objectmodel import UnboxedValue
-from rpython.tool.pairtype import pairtype
+from rpython.tool.pairtype import pair, pairtype
class FieldListAccessor(object):
@@ -471,14 +471,77 @@
break
+def create_forwarding_func(name):
+ def f((r_ins, r_obj), hop):
+ return r_ins._emulate_call(hop, name)
+ return f
+
class __extend__(pairtype(AbstractInstanceRepr, Repr)):
- def rtype_getitem((r_ins, r_obj), hop):
- return r_ins._emulate_call(hop, "__getitem__")
+ rtype_getitem = create_forwarding_func('__getitem__')
+ rtype_setitem = create_forwarding_func('__setitem__')
+ rtype_add = create_forwarding_func('__add__')
+ rtype_mul = create_forwarding_func('__mul__')
- def rtype_setitem((r_ins, r_obj), hop):
- return r_ins._emulate_call(hop, "__setitem__")
+ rtype_inplace_add = rtype_add
+ rtype_inplace_mul = rtype_mul
+ def rtype_eq((r_ins, r_other), hop):
+ if r_ins.classdef.classdesc.lookup('__eq__'):
+ return r_ins._emulate_call(hop, '__eq__')
+ return super(pairtype(AbstractInstanceRepr, Repr),
+ pair(r_ins, r_other)).rtype_eq(hop)
+ def rtype_ne((r_ins, r_other), hop):
+ if r_ins.classdef.classdesc.lookup('__ne__'):
+ return r_ins._emulate_call(hop, '__ne__')
+ return super(pairtype(AbstractInstanceRepr, Repr),
+ pair(r_ins, r_other)).rtype_ne(hop)
+
+class __extend__(pairtype(AbstractInstanceRepr, AbstractInstanceRepr)):
+ def rtype_eq((r_ins, r_other), hop):
+ if r_ins.classdef.classdesc.lookup('__eq__'):
+ return r_ins._emulate_call(hop, '__eq__')
+ elif r_other.classdef.classdesc.lookup('__eq__'):
+ # Reverse the order of the arguments before the call to __eq__
+ hop2 = hop.copy()
+ hop2.args_r = hop.args_r[::-1]
+ hop2.args_s = hop.args_s[::-1]
+ hop2.args_v = hop.args_v[::-1]
+ return r_other._emulate_call(hop2, '__eq__')
+ return pair(r_ins, r_other).rtype_is_(hop)
+
+ def rtype_ne((r_ins, r_other), hop):
+ if r_ins.classdef.classdesc.lookup('__ne__'):
+ return r_ins._emulate_call(hop, '__ne__')
+ elif r_other.classdef.classdesc.lookup('__ne__'):
+ # Reverse the order of the arguments before the call to __ne__
+ hop2 = hop.copy()
+ hop2.args_r = hop.args_r[::-1]
+ hop2.args_s = hop.args_s[::-1]
+ hop2.args_v = hop.args_v[::-1]
+ return r_other._emulate_call(hop2, '__ne__')
+ return pair(r_ins, r_other)._rtype_ne(hop)
+
+class __extend__(pairtype(Repr, AbstractInstanceRepr)):
+ def rtype_eq((r_other, r_ins), hop):
+ if r_ins.classdef.classdesc.lookup('__eq__'):
+ hop2 = hop.copy()
+ hop2.args_r = hop.args_r[::-1]
+ hop2.args_s = hop.args_s[::-1]
+ hop2.args_v = hop.args_v[::-1]
+ return r_ins._emulate_call(hop2, '__eq__')
+ return super(pairtype(Repr, AbstractInstanceRepr),
+ pair(r_other, r_ins)).rtype_eq(hop)
+
+ def rtype_ne((r_other, r_ins), hop):
+ if r_ins.classdef.classdesc.lookup('__ne__'):
+ hop2 = hop.copy()
+ hop2.args_r = hop.args_r[::-1]
+ hop2.args_s = hop.args_s[::-1]
+ hop2.args_v = hop.args_v[::-1]
+ return r_ins._emulate_call(hop2, '__ne__')
+ return super(pairtype(Repr, AbstractInstanceRepr),
+ pair(r_other, r_ins)).rtype_ne(hop)
# ____________________________________________________________
diff --git a/rpython/rtyper/test/test_rclass.py b/rpython/rtyper/test/test_rclass.py
--- a/rpython/rtyper/test/test_rclass.py
+++ b/rpython/rtyper/test/test_rclass.py
@@ -1271,3 +1271,87 @@
return cls[k](a, b).b
assert self.interpret(f, [1, 4, 7]) == 7
+
+ def test_overriding_eq(self):
+ class Base(object):
+ def __eq__(self, other):
+ return self is other
+ class A(Base):
+ def __eq__(self, other):
+ return True
+
+ def f(a):
+ if a:
+ o = Base()
+ else:
+ o = A()
+
+ return o == Base()
+
+ assert self.interpret(f, [0]) == f(0)
+ assert self.interpret(f, [1]) == f(1)
+
+ def test_eq_reversed(self):
+ class A(object):
+ def __eq__(self, other):
+ return not bool(other)
+
+ def f(a):
+ return (a == A()) == (A() == a)
+ assert self.interpret(f, [0]) == f(0)
+ assert self.interpret(f, [1]) == f(1)
+
+ def test_eq_without_ne(self):
+ class A(object):
+ def __eq__(self, other):
+ return False
+
+ def f():
+ a = A()
+ return a != A()
+
+ assert self.interpret(f, []) == f()
+
+ def test_overriding_ne(self):
+ class Base(object):
+ def __ne__(self, other):
+ return self is other
+ class A(Base):
+ def __ne__(self, other):
+ return True
+
+ def f(a):
+ if a:
+ o = Base()
+ else:
+ o = A()
+
+ return o != Base()
+
+ assert self.interpret(f, [0]) == f(0)
+ assert self.interpret(f, [1]) == f(1)
+
+ def test_ne_reversed(self):
+ class A(object):
+ def __ne__(self, other):
+ return not bool(other)
+
+ def f(a):
+ return (a != A()) == (A() != a)
+ assert self.interpret(f, [0]) == f(0)
+ assert self.interpret(f, [1]) == f(1)
+
+ def test_arithmetic_ops(self):
+ class A(object):
+ def __add__(self, other):
+ return other + other
+
+ def __mul__(self, other):
+ return other * other
+
+ def f(a):
+ o = A()
+ return (o + a) + (o * a)
+
+ for i in range(10):
+ assert self.interpret(f, [i]) == f(i)
More information about the pypy-commit
mailing list