[Python-checkins] bpo-46615: Don't crash when set operations mutate the sets (GH-31120)

sweeneyde webhook-mailer at python.org
Fri Feb 11 11:25:18 EST 2022


https://github.com/python/cpython/commit/4a66615ba736f84eadf9456bfd5d32a94cccf117
commit: 4a66615ba736f84eadf9456bfd5d32a94cccf117
branch: main
author: Dennis Sweeney <36520290+sweeneyde at users.noreply.github.com>
committer: sweeneyde <36520290+sweeneyde at users.noreply.github.com>
date: 2022-02-11T11:25:08-05:00
summary:

bpo-46615: Don't crash when set operations mutate the sets (GH-31120)

Ensure strong references are acquired whenever using `set_next()`. Added randomized test cases for `__eq__` methods that sometimes mutate sets when called.

files:
A Misc/NEWS.d/next/Core and Builtins/2022-02-04-04-33-18.bpo-46615.puArY9.rst
M Lib/test/test_set.py
M Objects/setobject.c

diff --git a/Lib/test/test_set.py b/Lib/test/test_set.py
index 77f3da40c063a..03b911920e179 100644
--- a/Lib/test/test_set.py
+++ b/Lib/test/test_set.py
@@ -1815,6 +1815,192 @@ def __eq__(self, o):
         s = {0}
         s.update(other)
 
+
+class TestOperationsMutating:
+    """Regression test for bpo-46615"""
+
+    constructor1 = None
+    constructor2 = None
+
+    def make_sets_of_bad_objects(self):
+        class Bad:
+            def __eq__(self, other):
+                if not enabled:
+                    return False
+                if randrange(20) == 0:
+                    set1.clear()
+                if randrange(20) == 0:
+                    set2.clear()
+                return bool(randrange(2))
+            def __hash__(self):
+                return randrange(2)
+        # Don't behave poorly during construction.
+        enabled = False
+        set1 = self.constructor1(Bad() for _ in range(randrange(50)))
+        set2 = self.constructor2(Bad() for _ in range(randrange(50)))
+        # Now start behaving poorly
+        enabled = True
+        return set1, set2
+
+    def check_set_op_does_not_crash(self, function):
+        for _ in range(100):
+            set1, set2 = self.make_sets_of_bad_objects()
+            try:
+                function(set1, set2)
+            except RuntimeError as e:
+                # Just make sure we don't crash here.
+                self.assertIn("changed size during iteration", str(e))
+
+
+class TestBinaryOpsMutating(TestOperationsMutating):
+
+    def test_eq_with_mutation(self):
+        self.check_set_op_does_not_crash(lambda a, b: a == b)
+
+    def test_ne_with_mutation(self):
+        self.check_set_op_does_not_crash(lambda a, b: a != b)
+
+    def test_lt_with_mutation(self):
+        self.check_set_op_does_not_crash(lambda a, b: a < b)
+
+    def test_le_with_mutation(self):
+        self.check_set_op_does_not_crash(lambda a, b: a <= b)
+
+    def test_gt_with_mutation(self):
+        self.check_set_op_does_not_crash(lambda a, b: a > b)
+
+    def test_ge_with_mutation(self):
+        self.check_set_op_does_not_crash(lambda a, b: a >= b)
+
+    def test_and_with_mutation(self):
+        self.check_set_op_does_not_crash(lambda a, b: a & b)
+
+    def test_or_with_mutation(self):
+        self.check_set_op_does_not_crash(lambda a, b: a | b)
+
+    def test_sub_with_mutation(self):
+        self.check_set_op_does_not_crash(lambda a, b: a - b)
+
+    def test_xor_with_mutation(self):
+        self.check_set_op_does_not_crash(lambda a, b: a ^ b)
+
+    def test_iadd_with_mutation(self):
+        def f(a, b):
+            a &= b
+        self.check_set_op_does_not_crash(f)
+
+    def test_ior_with_mutation(self):
+        def f(a, b):
+            a |= b
+        self.check_set_op_does_not_crash(f)
+
+    def test_isub_with_mutation(self):
+        def f(a, b):
+            a -= b
+        self.check_set_op_does_not_crash(f)
+
+    def test_ixor_with_mutation(self):
+        def f(a, b):
+            a ^= b
+        self.check_set_op_does_not_crash(f)
+
+    def test_iteration_with_mutation(self):
+        def f1(a, b):
+            for x in a:
+                pass
+            for y in b:
+                pass
+        def f2(a, b):
+            for y in b:
+                pass
+            for x in a:
+                pass
+        def f3(a, b):
+            for x, y in zip(a, b):
+                pass
+        self.check_set_op_does_not_crash(f1)
+        self.check_set_op_does_not_crash(f2)
+        self.check_set_op_does_not_crash(f3)
+
+
+class TestBinaryOpsMutating_Set_Set(TestBinaryOpsMutating, unittest.TestCase):
+    constructor1 = set
+    constructor2 = set
+
+class TestBinaryOpsMutating_Subclass_Subclass(TestBinaryOpsMutating, unittest.TestCase):
+    constructor1 = SetSubclass
+    constructor2 = SetSubclass
+
+class TestBinaryOpsMutating_Set_Subclass(TestBinaryOpsMutating, unittest.TestCase):
+    constructor1 = set
+    constructor2 = SetSubclass
+
+class TestBinaryOpsMutating_Subclass_Set(TestBinaryOpsMutating, unittest.TestCase):
+    constructor1 = SetSubclass
+    constructor2 = set
+
+
+class TestMethodsMutating(TestOperationsMutating):
+
+    def test_issubset_with_mutation(self):
+        self.check_set_op_does_not_crash(set.issubset)
+
+    def test_issuperset_with_mutation(self):
+        self.check_set_op_does_not_crash(set.issuperset)
+
+    def test_intersection_with_mutation(self):
+        self.check_set_op_does_not_crash(set.intersection)
+
+    def test_union_with_mutation(self):
+        self.check_set_op_does_not_crash(set.union)
+
+    def test_difference_with_mutation(self):
+        self.check_set_op_does_not_crash(set.difference)
+
+    def test_symmetric_difference_with_mutation(self):
+        self.check_set_op_does_not_crash(set.symmetric_difference)
+
+    def test_isdisjoint_with_mutation(self):
+        self.check_set_op_does_not_crash(set.isdisjoint)
+
+    def test_difference_update_with_mutation(self):
+        self.check_set_op_does_not_crash(set.difference_update)
+
+    def test_intersection_update_with_mutation(self):
+        self.check_set_op_does_not_crash(set.intersection_update)
+
+    def test_symmetric_difference_update_with_mutation(self):
+        self.check_set_op_does_not_crash(set.symmetric_difference_update)
+
+    def test_update_with_mutation(self):
+        self.check_set_op_does_not_crash(set.update)
+
+
+class TestMethodsMutating_Set_Set(TestMethodsMutating, unittest.TestCase):
+    constructor1 = set
+    constructor2 = set
+
+class TestMethodsMutating_Subclass_Subclass(TestMethodsMutating, unittest.TestCase):
+    constructor1 = SetSubclass
+    constructor2 = SetSubclass
+
+class TestMethodsMutating_Set_Subclass(TestMethodsMutating, unittest.TestCase):
+    constructor1 = set
+    constructor2 = SetSubclass
+
+class TestMethodsMutating_Subclass_Set(TestMethodsMutating, unittest.TestCase):
+    constructor1 = SetSubclass
+    constructor2 = set
+
+class TestMethodsMutating_Set_Dict(TestMethodsMutating, unittest.TestCase):
+    constructor1 = set
+    constructor2 = dict.fromkeys
+
+class TestMethodsMutating_Set_List(TestMethodsMutating, unittest.TestCase):
+    constructor1 = set
+    constructor2 = list
+
+
 # Application tests (based on David Eppstein's graph recipes ====================================
 
 def powerset(U):
diff --git a/Misc/NEWS.d/next/Core and Builtins/2022-02-04-04-33-18.bpo-46615.puArY9.rst b/Misc/NEWS.d/next/Core and Builtins/2022-02-04-04-33-18.bpo-46615.puArY9.rst
new file mode 100644
index 0000000000000..6dee92a546e33
--- /dev/null
+++ b/Misc/NEWS.d/next/Core and Builtins/2022-02-04-04-33-18.bpo-46615.puArY9.rst	
@@ -0,0 +1 @@
+When iterating over sets internally in ``setobject.c``, acquire strong references to the resulting items from the set.  This prevents crashes in corner-cases of various set operations where the set gets mutated.
diff --git a/Objects/setobject.c b/Objects/setobject.c
index fe124945b1c7e..0dd28402afbd2 100644
--- a/Objects/setobject.c
+++ b/Objects/setobject.c
@@ -1205,17 +1205,21 @@ set_intersection(PySetObject *so, PyObject *other)
         while (set_next((PySetObject *)other, &pos, &entry)) {
             key = entry->key;
             hash = entry->hash;
+            Py_INCREF(key);
             rv = set_contains_entry(so, key, hash);
             if (rv < 0) {
                 Py_DECREF(result);
+                Py_DECREF(key);
                 return NULL;
             }
             if (rv) {
                 if (set_add_entry(result, key, hash)) {
                     Py_DECREF(result);
+                    Py_DECREF(key);
                     return NULL;
                 }
             }
+            Py_DECREF(key);
         }
         return (PyObject *)result;
     }
@@ -1355,11 +1359,16 @@ set_isdisjoint(PySetObject *so, PyObject *other)
             other = tmp;
         }
         while (set_next((PySetObject *)other, &pos, &entry)) {
-            rv = set_contains_entry(so, entry->key, entry->hash);
-            if (rv < 0)
+            PyObject *key = entry->key;
+            Py_INCREF(key);
+            rv = set_contains_entry(so, key, entry->hash);
+            Py_DECREF(key);
+            if (rv < 0) {
                 return NULL;
-            if (rv)
+            }
+            if (rv) {
                 Py_RETURN_FALSE;
+            }
         }
         Py_RETURN_TRUE;
     }
@@ -1418,11 +1427,16 @@ set_difference_update_internal(PySetObject *so, PyObject *other)
             Py_INCREF(other);
         }
 
-        while (set_next((PySetObject *)other, &pos, &entry))
-            if (set_discard_entry(so, entry->key, entry->hash) < 0) {
+        while (set_next((PySetObject *)other, &pos, &entry)) {
+            PyObject *key = entry->key;
+            Py_INCREF(key);
+            if (set_discard_entry(so, key, entry->hash) < 0) {
                 Py_DECREF(other);
+                Py_DECREF(key);
                 return -1;
             }
+            Py_DECREF(key);
+        }
 
         Py_DECREF(other);
     } else {
@@ -1513,17 +1527,21 @@ set_difference(PySetObject *so, PyObject *other)
         while (set_next(so, &pos, &entry)) {
             key = entry->key;
             hash = entry->hash;
+            Py_INCREF(key);
             rv = _PyDict_Contains_KnownHash(other, key, hash);
             if (rv < 0) {
                 Py_DECREF(result);
+                Py_DECREF(key);
                 return NULL;
             }
             if (!rv) {
                 if (set_add_entry((PySetObject *)result, key, hash)) {
                     Py_DECREF(result);
+                    Py_DECREF(key);
                     return NULL;
                 }
             }
+            Py_DECREF(key);
         }
         return result;
     }
@@ -1532,17 +1550,21 @@ set_difference(PySetObject *so, PyObject *other)
     while (set_next(so, &pos, &entry)) {
         key = entry->key;
         hash = entry->hash;
+        Py_INCREF(key);
         rv = set_contains_entry((PySetObject *)other, key, hash);
         if (rv < 0) {
             Py_DECREF(result);
+            Py_DECREF(key);
             return NULL;
         }
         if (!rv) {
             if (set_add_entry((PySetObject *)result, key, hash)) {
                 Py_DECREF(result);
+                Py_DECREF(key);
                 return NULL;
             }
         }
+        Py_DECREF(key);
     }
     return result;
 }
@@ -1639,17 +1661,21 @@ set_symmetric_difference_update(PySetObject *so, PyObject *other)
     while (set_next(otherset, &pos, &entry)) {
         key = entry->key;
         hash = entry->hash;
+        Py_INCREF(key);
         rv = set_discard_entry(so, key, hash);
         if (rv < 0) {
             Py_DECREF(otherset);
+            Py_DECREF(key);
             return NULL;
         }
         if (rv == DISCARD_NOTFOUND) {
             if (set_add_entry(so, key, hash)) {
                 Py_DECREF(otherset);
+                Py_DECREF(key);
                 return NULL;
             }
         }
+        Py_DECREF(key);
     }
     Py_DECREF(otherset);
     Py_RETURN_NONE;
@@ -1724,11 +1750,16 @@ set_issubset(PySetObject *so, PyObject *other)
         Py_RETURN_FALSE;
 
     while (set_next(so, &pos, &entry)) {
-        rv = set_contains_entry((PySetObject *)other, entry->key, entry->hash);
-        if (rv < 0)
+        PyObject *key = entry->key;
+        Py_INCREF(key);
+        rv = set_contains_entry((PySetObject *)other, key, entry->hash);
+        Py_DECREF(key);
+        if (rv < 0) {
             return NULL;
-        if (!rv)
+        }
+        if (!rv) {
             Py_RETURN_FALSE;
+        }
     }
     Py_RETURN_TRUE;
 }



More information about the Python-checkins mailing list