[Python-checkins] cpython (3.4): Issue 8743: Improve interoperability between sets and the collections.Set

raymond.hettinger python-checkins at python.org
Mon May 26 09:14:23 CEST 2014


http://hg.python.org/cpython/rev/cd8b5b5b6356
changeset:   90840:cd8b5b5b6356
branch:      3.4
parent:      90834:ebeade01bd8e
user:        Raymond Hettinger <python at rcn.com>
date:        Mon May 26 00:09:04 2014 -0700
summary:
  Issue 8743: Improve interoperability between sets and the collections.Set abstract base class.

files:
  Lib/_collections_abc.py      |   23 +++-
  Lib/test/test_collections.py |  160 ++++++++++++++++++++++-
  Misc/NEWS                    |    3 +
  3 files changed, 180 insertions(+), 6 deletions(-)


diff --git a/Lib/_collections_abc.py b/Lib/_collections_abc.py
--- a/Lib/_collections_abc.py
+++ b/Lib/_collections_abc.py
@@ -207,12 +207,17 @@
     def __gt__(self, other):
         if not isinstance(other, Set):
             return NotImplemented
-        return other.__lt__(self)
+        return len(self) > len(other) and self.__ge__(other)
 
     def __ge__(self, other):
         if not isinstance(other, Set):
             return NotImplemented
-        return other.__le__(self)
+        if len(self) < len(other):
+            return False
+        for elem in other:
+            if elem not in self:
+                return False
+        return True
 
     def __eq__(self, other):
         if not isinstance(other, Set):
@@ -236,6 +241,8 @@
             return NotImplemented
         return self._from_iterable(value for value in other if value in self)
 
+    __rand__ = __and__
+
     def isdisjoint(self, other):
         'Return True if two sets have a null intersection.'
         for value in other:
@@ -249,6 +256,8 @@
         chain = (e for s in (self, other) for e in s)
         return self._from_iterable(chain)
 
+    __ror__ = __or__
+
     def __sub__(self, other):
         if not isinstance(other, Set):
             if not isinstance(other, Iterable):
@@ -257,6 +266,14 @@
         return self._from_iterable(value for value in self
                                    if value not in other)
 
+    def __rsub__(self, other):
+        if not isinstance(other, Set):
+            if not isinstance(other, Iterable):
+                return NotImplemented
+            other = self._from_iterable(other)
+        return self._from_iterable(value for value in other
+                                   if value not in self)
+
     def __xor__(self, other):
         if not isinstance(other, Set):
             if not isinstance(other, Iterable):
@@ -264,6 +281,8 @@
             other = self._from_iterable(other)
         return (self - other) | (other - self)
 
+    __rxor__ = __xor__
+
     def _hash(self):
         """Compute the hash value of a set.
 
diff --git a/Lib/test/test_collections.py b/Lib/test/test_collections.py
--- a/Lib/test/test_collections.py
+++ b/Lib/test/test_collections.py
@@ -720,14 +720,166 @@
 
         cs = MyComparableSet()
         ncs = MyNonComparableSet()
+        self.assertFalse(ncs < cs)
+        self.assertTrue(ncs <= cs)
+        self.assertFalse(ncs > cs)
+        self.assertTrue(ncs >= cs)
+
+    def assertSameSet(self, s1, s2):
+        # coerce both to a real set then check equality
+        self.assertSetEqual(set(s1), set(s2))
+
+    def test_Set_interoperability_with_real_sets(self):
+        # Issue: 8743
+        class ListSet(Set):
+            def __init__(self, elements=()):
+                self.data = []
+                for elem in elements:
+                    if elem not in self.data:
+                        self.data.append(elem)
+            def __contains__(self, elem):
+                return elem in self.data
+            def __iter__(self):
+                return iter(self.data)
+            def __len__(self):
+                return len(self.data)
+            def __repr__(self):
+                return 'Set({!r})'.format(self.data)
+
+        r1 = set('abc')
+        r2 = set('bcd')
+        r3 = set('abcde')
+        f1 = ListSet('abc')
+        f2 = ListSet('bcd')
+        f3 = ListSet('abcde')
+        l1 = list('abccba')
+        l2 = list('bcddcb')
+        l3 = list('abcdeedcba')
+
+        target = r1 & r2
+        self.assertSameSet(f1 & f2, target)
+        self.assertSameSet(f1 & r2, target)
+        self.assertSameSet(r2 & f1, target)
+        self.assertSameSet(f1 & l2, target)
+
+        target = r1 | r2
+        self.assertSameSet(f1 | f2, target)
+        self.assertSameSet(f1 | r2, target)
+        self.assertSameSet(r2 | f1, target)
+        self.assertSameSet(f1 | l2, target)
+
+        fwd_target = r1 - r2
+        rev_target = r2 - r1
+        self.assertSameSet(f1 - f2, fwd_target)
+        self.assertSameSet(f2 - f1, rev_target)
+        self.assertSameSet(f1 - r2, fwd_target)
+        self.assertSameSet(f2 - r1, rev_target)
+        self.assertSameSet(r1 - f2, fwd_target)
+        self.assertSameSet(r2 - f1, rev_target)
+        self.assertSameSet(f1 - l2, fwd_target)
+        self.assertSameSet(f2 - l1, rev_target)
+
+        target = r1 ^ r2
+        self.assertSameSet(f1 ^ f2, target)
+        self.assertSameSet(f1 ^ r2, target)
+        self.assertSameSet(r2 ^ f1, target)
+        self.assertSameSet(f1 ^ l2, target)
+
+        # Don't change the following to use assertLess or other
+        # "more specific" unittest assertions.  The current
+        # assertTrue/assertFalse style makes the pattern of test
+        # case combinations clear and allows us to know for sure
+        # the exact operator being invoked.
+
+        # proper subset
+        self.assertTrue(f1 < f3)
+        self.assertFalse(f1 < f1)
+        self.assertFalse(f1 < f2)
+        self.assertTrue(r1 < f3)
+        self.assertFalse(r1 < f1)
+        self.assertFalse(r1 < f2)
+        self.assertTrue(r1 < r3)
+        self.assertFalse(r1 < r1)
+        self.assertFalse(r1 < r2)
         with self.assertRaises(TypeError):
-            ncs < cs
+            f1 < l3
         with self.assertRaises(TypeError):
-            ncs <= cs
+            f1 < l1
         with self.assertRaises(TypeError):
-            cs > ncs
+            f1 < l2
+
+        # any subset
+        self.assertTrue(f1 <= f3)
+        self.assertTrue(f1 <= f1)
+        self.assertFalse(f1 <= f2)
+        self.assertTrue(r1 <= f3)
+        self.assertTrue(r1 <= f1)
+        self.assertFalse(r1 <= f2)
+        self.assertTrue(r1 <= r3)
+        self.assertTrue(r1 <= r1)
+        self.assertFalse(r1 <= r2)
         with self.assertRaises(TypeError):
-            cs >= ncs
+            f1 <= l3
+        with self.assertRaises(TypeError):
+            f1 <= l1
+        with self.assertRaises(TypeError):
+            f1 <= l2
+
+        # proper superset
+        self.assertTrue(f3 > f1)
+        self.assertFalse(f1 > f1)
+        self.assertFalse(f2 > f1)
+        self.assertTrue(r3 > r1)
+        self.assertFalse(f1 > r1)
+        self.assertFalse(f2 > r1)
+        self.assertTrue(r3 > r1)
+        self.assertFalse(r1 > r1)
+        self.assertFalse(r2 > r1)
+        with self.assertRaises(TypeError):
+            f1 > l3
+        with self.assertRaises(TypeError):
+            f1 > l1
+        with self.assertRaises(TypeError):
+            f1 > l2
+
+        # any superset
+        self.assertTrue(f3 >= f1)
+        self.assertTrue(f1 >= f1)
+        self.assertFalse(f2 >= f1)
+        self.assertTrue(r3 >= r1)
+        self.assertTrue(f1 >= r1)
+        self.assertFalse(f2 >= r1)
+        self.assertTrue(r3 >= r1)
+        self.assertTrue(r1 >= r1)
+        self.assertFalse(r2 >= r1)
+        with self.assertRaises(TypeError):
+            f1 >= l3
+        with self.assertRaises(TypeError):
+            f1 >=l1
+        with self.assertRaises(TypeError):
+            f1 >= l2
+
+        # equality
+        self.assertTrue(f1 == f1)
+        self.assertTrue(r1 == f1)
+        self.assertTrue(f1 == r1)
+        self.assertFalse(f1 == f3)
+        self.assertFalse(r1 == f3)
+        self.assertFalse(f1 == r3)
+        self.assertFalse(f1 == l3)
+        self.assertFalse(f1 == l1)
+        self.assertFalse(f1 == l2)
+
+        # inequality
+        self.assertFalse(f1 != f1)
+        self.assertFalse(r1 != f1)
+        self.assertFalse(f1 != r1)
+        self.assertTrue(f1 != f3)
+        self.assertTrue(r1 != f3)
+        self.assertTrue(f1 != r3)
+        self.assertTrue(f1 != l3)
+        self.assertTrue(f1 != l1)
+        self.assertTrue(f1 != l2)
 
     def test_Mapping(self):
         for sample in [dict]:
diff --git a/Misc/NEWS b/Misc/NEWS
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -24,6 +24,9 @@
 - Issue #14710: pkgutil.find_loader() no longer raises an exception when a
   module doesn't exist.
 
+- Issue #8743: Fix interoperability between set objects and the
+  collections.Set() abstract base class.
+
 - Issue #13355: random.triangular() no longer fails with a ZeroDivisionError
   when low equals high.
 

-- 
Repository URL: http://hg.python.org/cpython


More information about the Python-checkins mailing list