[Python-checkins] cpython: Close #10042: functools.total_ordering now handles NotImplemented

nick.coghlan python-checkins at python.org
Tue Oct 1 16:02:26 CEST 2013


http://hg.python.org/cpython/rev/ad9f207645ab
changeset:   85913:ad9f207645ab
user:        Nick Coghlan <ncoghlan at gmail.com>
date:        Wed Oct 02 00:02:03 2013 +1000
summary:
  Close #10042: functools.total_ordering now handles NotImplemented

(Patch by Katie Miller)

files:
  Doc/library/functools.rst  |   19 ++++
  Lib/functools.py           |   94 ++++++++++++++++++--
  Lib/test/test_functools.py |  108 +++++++++++++++++++++++-
  Misc/ACKS                  |    1 +
  Misc/NEWS                  |    4 +
  5 files changed, 207 insertions(+), 19 deletions(-)


diff --git a/Doc/library/functools.rst b/Doc/library/functools.rst
--- a/Doc/library/functools.rst
+++ b/Doc/library/functools.rst
@@ -134,15 +134,34 @@
 
        @total_ordering
        class Student:
+           def _is_valid_operand(self, other):
+               return (hasattr(other, "lastname") and
+                       hasattr(other, "firstname"))
            def __eq__(self, other):
+               if not self._is_valid_operand(other):
+                   return NotImplemented
                return ((self.lastname.lower(), self.firstname.lower()) ==
                        (other.lastname.lower(), other.firstname.lower()))
            def __lt__(self, other):
+               if not self._is_valid_operand(other):
+                   return NotImplemented
                return ((self.lastname.lower(), self.firstname.lower()) <
                        (other.lastname.lower(), other.firstname.lower()))
 
+   .. note::
+
+      While this decorator makes it easy to create well behaved totally
+      ordered types, it *does* come at the cost of slower execution and
+      more complex stack traces for the derived comparison methods. If
+      performance benchmarking indicates this is a bottleneck for a given
+      application, implementing all six rich comparison methods instead is
+      likely to provide an easy speed boost.
+
    .. versionadded:: 3.2
 
+   .. versionchanged:: 3.4
+      Returning NotImplemented from the underlying comparison function for
+      unrecognised types is now supported.
 
 .. function:: partial(func, *args, **keywords)
 
diff --git a/Lib/functools.py b/Lib/functools.py
--- a/Lib/functools.py
+++ b/Lib/functools.py
@@ -89,21 +89,91 @@
 ### total_ordering class decorator
 ################################################################################
 
+# The correct way to indicate that a comparison operation doesn't
+# recognise the other type is to return NotImplemented and let the
+# interpreter handle raising TypeError if both operands return
+# NotImplemented from their respective comparison methods
+#
+# This makes the implementation of total_ordering more complicated, since
+# we need to be careful not to trigger infinite recursion when two
+# different types that both use this decorator encounter each other.
+#
+# For example, if a type implements __lt__, it's natural to define
+# __gt__ as something like:
+#
+#    lambda self, other: not self < other and not self == other
+#
+# However, using the operator syntax like that ends up invoking the full
+# type checking machinery again and means we can end up bouncing back and
+# forth between the two operands until we run out of stack space.
+#
+# The solution is to define helper functions that invoke the appropriate
+# magic methods directly, ensuring we only try each operand once, and
+# return NotImplemented immediately if it is returned from the
+# underlying user provided method. Using this scheme, the __gt__ derived
+# from a user provided __lt__ becomes:
+#
+#    lambda self, other: _not_op_and_not_eq(self.__lt__, self, other))
+
+def _not_op(op, other):
+    # "not a < b" handles "a >= b"
+    # "not a <= b" handles "a > b"
+    # "not a >= b" handles "a < b"
+    # "not a > b" handles "a <= b"
+    op_result = op(other)
+    if op_result is NotImplemented:
+        return NotImplemented
+    return not op_result
+
+def _op_or_eq(op, self, other):
+    # "a < b or a == b" handles "a <= b"
+    # "a > b or a == b" handles "a >= b"
+    op_result = op(other)
+    if op_result is NotImplemented:
+        return NotImplemented
+    return op_result or self == other
+
+def _not_op_and_not_eq(op, self, other):
+    # "not (a < b or a == b)" handles "a > b"
+    # "not a < b and a != b" is equivalent
+    # "not (a > b or a == b)" handles "a < b"
+    # "not a > b and a != b" is equivalent
+    op_result = op(other)
+    if op_result is NotImplemented:
+        return NotImplemented
+    return not op_result and self != other
+
+def _not_op_or_eq(op, self, other):
+    # "not a <= b or a == b" handles "a >= b"
+    # "not a >= b or a == b" handles "a <= b"
+    op_result = op(other)
+    if op_result is NotImplemented:
+        return NotImplemented
+    return not op_result or self == other
+
+def _op_and_not_eq(op, self, other):
+    # "a <= b and not a == b" handles "a < b"
+    # "a >= b and not a == b" handles "a > b"
+    op_result = op(other)
+    if op_result is NotImplemented:
+        return NotImplemented
+    return op_result and self != other
+
 def total_ordering(cls):
     """Class decorator that fills in missing ordering methods"""
     convert = {
-        '__lt__': [('__gt__', lambda self, other: not (self < other or self == other)),
-                   ('__le__', lambda self, other: self < other or self == other),
-                   ('__ge__', lambda self, other: not self < other)],
-        '__le__': [('__ge__', lambda self, other: not self <= other or self == other),
-                   ('__lt__', lambda self, other: self <= other and not self == other),
-                   ('__gt__', lambda self, other: not self <= other)],
-        '__gt__': [('__lt__', lambda self, other: not (self > other or self == other)),
-                   ('__ge__', lambda self, other: self > other or self == other),
-                   ('__le__', lambda self, other: not self > other)],
-        '__ge__': [('__le__', lambda self, other: (not self >= other) or self == other),
-                   ('__gt__', lambda self, other: self >= other and not self == other),
-                   ('__lt__', lambda self, other: not self >= other)]
+        '__lt__': [('__gt__', lambda self, other: _not_op_and_not_eq(self.__lt__, self, other)),
+                   ('__le__', lambda self, other: _op_or_eq(self.__lt__, self, other)),
+                   ('__ge__', lambda self, other: _not_op(self.__lt__, other))],
+        '__le__': [('__ge__', lambda self, other: _not_op_or_eq(self.__le__, self, other)),
+                   ('__lt__', lambda self, other: _op_and_not_eq(self.__le__, self, other)),
+                   ('__gt__', lambda self, other: _not_op(self.__le__, other))],
+        '__gt__': [('__lt__', lambda self, other: _not_op_and_not_eq(self.__gt__, self, other)),
+                   ('__ge__', lambda self, other: _op_or_eq(self.__gt__, self, other)),
+                   ('__le__', lambda self, other: _not_op(self.__gt__, other))],
+        '__ge__': [('__le__', lambda self, other: _not_op_or_eq(self.__ge__, self, other)),
+                   ('__gt__', lambda self, other: _op_and_not_eq(self.__ge__, self, other)),
+                   ('__lt__', lambda self, other: _not_op(self.__ge__, other))]
     }
     # Find user-defined comparisons (not those inherited from object).
     roots = [op for op in convert if getattr(cls, op, None) is not getattr(object, op, None)]
diff --git a/Lib/test/test_functools.py b/Lib/test/test_functools.py
--- a/Lib/test/test_functools.py
+++ b/Lib/test/test_functools.py
@@ -584,6 +584,7 @@
         self.assertTrue(A(2) >= A(1))
         self.assertTrue(A(2) <= A(2))
         self.assertTrue(A(2) >= A(2))
+        self.assertFalse(A(1) > A(2))
 
     def test_total_ordering_le(self):
         @functools.total_ordering
@@ -600,6 +601,7 @@
         self.assertTrue(A(2) >= A(1))
         self.assertTrue(A(2) <= A(2))
         self.assertTrue(A(2) >= A(2))
+        self.assertFalse(A(1) >= A(2))
 
     def test_total_ordering_gt(self):
         @functools.total_ordering
@@ -616,6 +618,7 @@
         self.assertTrue(A(2) >= A(1))
         self.assertTrue(A(2) <= A(2))
         self.assertTrue(A(2) >= A(2))
+        self.assertFalse(A(2) < A(1))
 
     def test_total_ordering_ge(self):
         @functools.total_ordering
@@ -632,6 +635,7 @@
         self.assertTrue(A(2) >= A(1))
         self.assertTrue(A(2) <= A(2))
         self.assertTrue(A(2) >= A(2))
+        self.assertFalse(A(2) <= A(1))
 
     def test_total_ordering_no_overwrite(self):
         # new methods should not overwrite existing
@@ -651,22 +655,112 @@
             class A:
                 pass
 
-    def test_bug_10042(self):
+    def test_type_error_when_not_implemented(self):
+        # bug 10042; ensure stack overflow does not occur
+        # when decorated types return NotImplemented
         @functools.total_ordering
-        class TestTO:
+        class ImplementsLessThan:
             def __init__(self, value):
                 self.value = value
             def __eq__(self, other):
-                if isinstance(other, TestTO):
+                if isinstance(other, ImplementsLessThan):
                     return self.value == other.value
                 return False
             def __lt__(self, other):
-                if isinstance(other, TestTO):
+                if isinstance(other, ImplementsLessThan):
                     return self.value < other.value
-                raise TypeError
-        with self.assertRaises(TypeError):
-            TestTO(8) <= ()
+                return NotImplemented
 
+        @functools.total_ordering
+        class ImplementsGreaterThan:
+            def __init__(self, value):
+                self.value = value
+            def __eq__(self, other):
+                if isinstance(other, ImplementsGreaterThan):
+                    return self.value == other.value
+                return False
+            def __gt__(self, other):
+                if isinstance(other, ImplementsGreaterThan):
+                    return self.value > other.value
+                return NotImplemented
+
+        @functools.total_ordering
+        class ImplementsLessThanEqualTo:
+            def __init__(self, value):
+                self.value = value
+            def __eq__(self, other):
+                if isinstance(other, ImplementsLessThanEqualTo):
+                    return self.value == other.value
+                return False
+            def __le__(self, other):
+                if isinstance(other, ImplementsLessThanEqualTo):
+                    return self.value <= other.value
+                return NotImplemented
+
+        @functools.total_ordering
+        class ImplementsGreaterThanEqualTo:
+            def __init__(self, value):
+                self.value = value
+            def __eq__(self, other):
+                if isinstance(other, ImplementsGreaterThanEqualTo):
+                    return self.value == other.value
+                return False
+            def __ge__(self, other):
+                if isinstance(other, ImplementsGreaterThanEqualTo):
+                    return self.value >= other.value
+                return NotImplemented
+
+        @functools.total_ordering
+        class ComparatorNotImplemented:
+            def __init__(self, value):
+                self.value = value
+            def __eq__(self, other):
+                if isinstance(other, ComparatorNotImplemented):
+                    return self.value == other.value
+                return False
+            def __lt__(self, other):
+                return NotImplemented
+
+        with self.subTest("LT < 1"), self.assertRaises(TypeError):
+            ImplementsLessThan(-1) < 1
+
+        with self.subTest("LT < LE"), self.assertRaises(TypeError):
+            ImplementsLessThan(0) < ImplementsLessThanEqualTo(0)
+
+        with self.subTest("LT < GT"), self.assertRaises(TypeError):
+            ImplementsLessThan(1) < ImplementsGreaterThan(1)
+
+        with self.subTest("LE <= LT"), self.assertRaises(TypeError):
+            ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2)
+
+        with self.subTest("LE <= GE"), self.assertRaises(TypeError):
+            ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3)
+
+        with self.subTest("GT > GE"), self.assertRaises(TypeError):
+            ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4)
+
+        with self.subTest("GT > LT"), self.assertRaises(TypeError):
+            ImplementsGreaterThan(5) > ImplementsLessThan(5)
+
+        with self.subTest("GE >= GT"), self.assertRaises(TypeError):
+            ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6)
+
+        with self.subTest("GE >= LE"), self.assertRaises(TypeError):
+            ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7)
+
+        with self.subTest("GE when equal"):
+            a = ComparatorNotImplemented(8)
+            b = ComparatorNotImplemented(8)
+            self.assertEqual(a, b)
+            with self.assertRaises(TypeError):
+                a >= b
+
+        with self.subTest("LE when equal"):
+            a = ComparatorNotImplemented(9)
+            b = ComparatorNotImplemented(9)
+            self.assertEqual(a, b)
+            with self.assertRaises(TypeError):
+                a <= b
 
 class TestLRU(unittest.TestCase):
 
diff --git a/Misc/ACKS b/Misc/ACKS
--- a/Misc/ACKS
+++ b/Misc/ACKS
@@ -862,6 +862,7 @@
 Damien Miller
 Jason V. Miller
 Jay T. Miller
+Katie Miller
 Roman Milner
 Julien Miotte
 Andrii V. Mishkovskyi
diff --git a/Misc/NEWS b/Misc/NEWS
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -13,6 +13,10 @@
 Library
 -------
 
+- Issue #10042: functools.total_ordering now correctly handles
+  NotImplemented being returned by the underlying comparison function (Patch
+  by Katie Miller)
+
 - Issue #19092: contextlib.ExitStack now correctly reraises exceptions
   from the __exit__ callbacks of inner context managers (Patch by Hrvoje
   Nikšić)

-- 
Repository URL: http://hg.python.org/cpython


More information about the Python-checkins mailing list