[Python-checkins] cpython (2.7): Fix some set algebra methods of WeakSet objects.

antoine.pitrou python-checkins at python.org
Sun Mar 4 20:51:07 CET 2012


http://hg.python.org/cpython/rev/428bfb58e3b3
changeset:   75397:428bfb58e3b3
branch:      2.7
parent:      75395:5f79d68ba087
user:        Antoine Pitrou <solipsis at pitrou.net>
date:        Sun Mar 04 20:47:05 2012 +0100
summary:
  Fix some set algebra methods of WeakSet objects.

files:
  Lib/_weakrefset.py       |  41 ++++++++-------------------
  Lib/test/test_weakset.py |  22 ++++++++++++--
  2 files changed, 30 insertions(+), 33 deletions(-)


diff --git a/Lib/_weakrefset.py b/Lib/_weakrefset.py
--- a/Lib/_weakrefset.py
+++ b/Lib/_weakrefset.py
@@ -123,26 +123,14 @@
         self.update(other)
         return self
 
-    # Helper functions for simple delegating methods.
-    def _apply(self, other, method):
-        if not isinstance(other, self.__class__):
-            other = self.__class__(other)
-        newdata = method(other.data)
-        newset = self.__class__()
-        newset.data = newdata
+    def difference(self, other):
+        newset = self.copy()
+        newset.difference_update(other)
         return newset
-
-    def difference(self, other):
-        return self._apply(other, self.data.difference)
     __sub__ = difference
 
     def difference_update(self, other):
-        if self._pending_removals:
-            self._commit_removals()
-        if self is other:
-            self.data.clear()
-        else:
-            self.data.difference_update(ref(item) for item in other)
+        self.__isub__(other)
     def __isub__(self, other):
         if self._pending_removals:
             self._commit_removals()
@@ -153,13 +141,11 @@
         return self
 
     def intersection(self, other):
-        return self._apply(other, self.data.intersection)
+        return self.__class__(item for item in other if item in self)
     __and__ = intersection
 
     def intersection_update(self, other):
-        if self._pending_removals:
-            self._commit_removals()
-        self.data.intersection_update(ref(item) for item in other)
+        self.__iand__(other)
     def __iand__(self, other):
         if self._pending_removals:
             self._commit_removals()
@@ -186,27 +172,24 @@
         return self.data == set(ref(item) for item in other)
 
     def symmetric_difference(self, other):
-        return self._apply(other, self.data.symmetric_difference)
+        newset = self.copy()
+        newset.symmetric_difference_update(other)
+        return newset
     __xor__ = symmetric_difference
 
     def symmetric_difference_update(self, other):
-        if self._pending_removals:
-            self._commit_removals()
-        if self is other:
-            self.data.clear()
-        else:
-            self.data.symmetric_difference_update(ref(item) for item in other)
+        self.__ixor__(other)
     def __ixor__(self, other):
         if self._pending_removals:
             self._commit_removals()
         if self is other:
             self.data.clear()
         else:
-            self.data.symmetric_difference_update(ref(item) for item in other)
+            self.data.symmetric_difference_update(ref(item, self._remove) for item in other)
         return self
 
     def union(self, other):
-        return self._apply(other, self.data.union)
+        return self.__class__(e for s in (self, other) for e in s)
     __or__ = union
 
     def isdisjoint(self, other):
diff --git a/Lib/test/test_weakset.py b/Lib/test/test_weakset.py
--- a/Lib/test/test_weakset.py
+++ b/Lib/test/test_weakset.py
@@ -83,6 +83,11 @@
             x = WeakSet(self.items + self.items2)
             c = C(self.items2)
             self.assertEqual(self.s.union(c), x)
+            del c
+        self.assertEqual(len(u), len(self.items) + len(self.items2))
+        self.items2.pop()
+        gc.collect()
+        self.assertEqual(len(u), len(self.items) + len(self.items2))
 
     def test_or(self):
         i = self.s.union(self.items2)
@@ -90,14 +95,19 @@
         self.assertEqual(self.s | frozenset(self.items2), i)
 
     def test_intersection(self):
-        i = self.s.intersection(self.items2)
+        s = WeakSet(self.letters)
+        i = s.intersection(self.items2)
         for c in self.letters:
-            self.assertEqual(c in i, c in self.d and c in self.items2)
-        self.assertEqual(self.s, WeakSet(self.items))
+            self.assertEqual(c in i, c in self.items2 and c in self.letters)
+        self.assertEqual(s, WeakSet(self.letters))
         self.assertEqual(type(i), WeakSet)
         for C in set, frozenset, dict.fromkeys, list, tuple:
             x = WeakSet([])
-            self.assertEqual(self.s.intersection(C(self.items2)), x)
+            self.assertEqual(i.intersection(C(self.items)), x)
+        self.assertEqual(len(i), len(self.items2))
+        self.items2.pop()
+        gc.collect()
+        self.assertEqual(len(i), len(self.items2))
 
     def test_isdisjoint(self):
         self.assertTrue(self.s.isdisjoint(WeakSet(self.items2)))
@@ -128,6 +138,10 @@
         self.assertEqual(self.s, WeakSet(self.items))
         self.assertEqual(type(i), WeakSet)
         self.assertRaises(TypeError, self.s.symmetric_difference, [[]])
+        self.assertEqual(len(i), len(self.items) + len(self.items2))
+        self.items2.pop()
+        gc.collect()
+        self.assertEqual(len(i), len(self.items) + len(self.items2))
 
     def test_xor(self):
         i = self.s.symmetric_difference(self.items2)

-- 
Repository URL: http://hg.python.org/cpython


More information about the Python-checkins mailing list