[Python-checkins] bpo-43977: Properly update the tp_flags of existing subclasses when their parents are registered (GH-26864)

brandtbucher webhook-mailer at python.org
Fri Jun 25 11:21:09 EDT 2021


https://github.com/python/cpython/commit/ca2009d72a52a98bf43aafa9ad270a4fcfabfc89
commit: ca2009d72a52a98bf43aafa9ad270a4fcfabfc89
branch: main
author: Brandt Bucher <brandt at python.org>
committer: brandtbucher <brandtbucher at gmail.com>
date: 2021-06-25T08:20:43-07:00
summary:

bpo-43977: Properly update the tp_flags of existing subclasses when their parents are registered (GH-26864)

files:
A Misc/NEWS.d/next/Library/2021-06-22-16-45-48.bpo-43977.bamAGF.rst
M Doc/library/dis.rst
M Lib/test/test_patma.py
M Modules/_abc.c

diff --git a/Doc/library/dis.rst b/Doc/library/dis.rst
index 65fbb00bd6597..645e94a669fd1 100644
--- a/Doc/library/dis.rst
+++ b/Doc/library/dis.rst
@@ -772,17 +772,20 @@ iterations of the loop.
 
 .. opcode:: MATCH_MAPPING
 
-   If TOS is an instance of :class:`collections.abc.Mapping`, push ``True`` onto
-   the stack.  Otherwise, push ``False``.
+   If TOS is an instance of :class:`collections.abc.Mapping` (or, more technically: if
+   it has the :const:`Py_TPFLAGS_MAPPING` flag set in its
+   :c:member:`~PyTypeObject.tp_flags`), push ``True`` onto the stack.  Otherwise, push
+   ``False``.
 
    .. versionadded:: 3.10
 
 
 .. opcode:: MATCH_SEQUENCE
 
-   If TOS is an instance of :class:`collections.abc.Sequence` and is *not* an
-   instance of :class:`str`/:class:`bytes`/:class:`bytearray`, push ``True``
-   onto the stack.  Otherwise, push ``False``.
+   If TOS is an instance of :class:`collections.abc.Sequence` and is *not* an instance
+   of :class:`str`/:class:`bytes`/:class:`bytearray` (or, more technically: if it has
+   the :const:`Py_TPFLAGS_SEQUENCE` flag set in its :c:member:`~PyTypeObject.tp_flags`),
+   push ``True`` onto the stack.  Otherwise, push ``False``.
 
    .. versionadded:: 3.10
 
diff --git a/Lib/test/test_patma.py b/Lib/test/test_patma.py
index 91b1f7e2aa4c7..a8889cae1091f 100644
--- a/Lib/test/test_patma.py
+++ b/Lib/test/test_patma.py
@@ -24,7 +24,43 @@ def test_refleaks(self):
 
 class TestInheritance(unittest.TestCase):
 
-    def test_multiple_inheritance(self):
+    @staticmethod
+    def check_sequence_then_mapping(x):
+        match x:
+            case [*_]:
+                return "seq"
+            case {}:
+                return "map"
+
+    @staticmethod
+    def check_mapping_then_sequence(x):
+        match x:
+            case {}:
+                return "map"
+            case [*_]:
+                return "seq"
+
+    def test_multiple_inheritance_mapping(self):
+        class C:
+            pass
+        class M1(collections.UserDict, collections.abc.Sequence):
+            pass
+        class M2(C, collections.UserDict, collections.abc.Sequence):
+            pass
+        class M3(collections.UserDict, C, list):
+            pass
+        class M4(dict, collections.abc.Sequence, C):
+            pass
+        self.assertEqual(self.check_sequence_then_mapping(M1()), "map")
+        self.assertEqual(self.check_sequence_then_mapping(M2()), "map")
+        self.assertEqual(self.check_sequence_then_mapping(M3()), "map")
+        self.assertEqual(self.check_sequence_then_mapping(M4()), "map")
+        self.assertEqual(self.check_mapping_then_sequence(M1()), "map")
+        self.assertEqual(self.check_mapping_then_sequence(M2()), "map")
+        self.assertEqual(self.check_mapping_then_sequence(M3()), "map")
+        self.assertEqual(self.check_mapping_then_sequence(M4()), "map")
+
+    def test_multiple_inheritance_sequence(self):
         class C:
             pass
         class S1(collections.UserList, collections.abc.Mapping):
@@ -35,32 +71,60 @@ class S3(list, C, collections.abc.Mapping):
             pass
         class S4(collections.UserList, dict, C):
             pass
-        class M1(collections.UserDict, collections.abc.Sequence):
+        self.assertEqual(self.check_sequence_then_mapping(S1()), "seq")
+        self.assertEqual(self.check_sequence_then_mapping(S2()), "seq")
+        self.assertEqual(self.check_sequence_then_mapping(S3()), "seq")
+        self.assertEqual(self.check_sequence_then_mapping(S4()), "seq")
+        self.assertEqual(self.check_mapping_then_sequence(S1()), "seq")
+        self.assertEqual(self.check_mapping_then_sequence(S2()), "seq")
+        self.assertEqual(self.check_mapping_then_sequence(S3()), "seq")
+        self.assertEqual(self.check_mapping_then_sequence(S4()), "seq")
+
+    def test_late_registration_mapping(self):
+        class Parent:
             pass
-        class M2(C, collections.UserDict, collections.abc.Sequence):
+        class ChildPre(Parent):
             pass
-        class M3(collections.UserDict, C, list):
+        class GrandchildPre(ChildPre):
             pass
-        class M4(dict, collections.abc.Sequence, C):
+        collections.abc.Mapping.register(Parent)
+        class ChildPost(Parent):
             pass
-        def f(x):
-            match x:
-                case []:
-                    return "seq"
-                case {}:
-                    return "map"
-        def g(x):
-            match x:
-                case {}:
-                    return "map"
-                case []:
-                    return "seq"
-        for Seq in (S1, S2, S3, S4):
-            self.assertEqual(f(Seq()), "seq")
-            self.assertEqual(g(Seq()), "seq")
-        for Map in (M1, M2, M3, M4):
-            self.assertEqual(f(Map()), "map")
-            self.assertEqual(g(Map()), "map")
+        class GrandchildPost(ChildPost):
+            pass
+        self.assertEqual(self.check_sequence_then_mapping(Parent()), "map")
+        self.assertEqual(self.check_sequence_then_mapping(ChildPre()), "map")
+        self.assertEqual(self.check_sequence_then_mapping(GrandchildPre()), "map")
+        self.assertEqual(self.check_sequence_then_mapping(ChildPost()), "map")
+        self.assertEqual(self.check_sequence_then_mapping(GrandchildPost()), "map")
+        self.assertEqual(self.check_mapping_then_sequence(Parent()), "map")
+        self.assertEqual(self.check_mapping_then_sequence(ChildPre()), "map")
+        self.assertEqual(self.check_mapping_then_sequence(GrandchildPre()), "map")
+        self.assertEqual(self.check_mapping_then_sequence(ChildPost()), "map")
+        self.assertEqual(self.check_mapping_then_sequence(GrandchildPost()), "map")
+
+    def test_late_registration_sequence(self):
+        class Parent:
+            pass
+        class ChildPre(Parent):
+            pass
+        class GrandchildPre(ChildPre):
+            pass
+        collections.abc.Sequence.register(Parent)
+        class ChildPost(Parent):
+            pass
+        class GrandchildPost(ChildPost):
+            pass
+        self.assertEqual(self.check_sequence_then_mapping(Parent()), "seq")
+        self.assertEqual(self.check_sequence_then_mapping(ChildPre()), "seq")
+        self.assertEqual(self.check_sequence_then_mapping(GrandchildPre()), "seq")
+        self.assertEqual(self.check_sequence_then_mapping(ChildPost()), "seq")
+        self.assertEqual(self.check_sequence_then_mapping(GrandchildPost()), "seq")
+        self.assertEqual(self.check_mapping_then_sequence(Parent()), "seq")
+        self.assertEqual(self.check_mapping_then_sequence(ChildPre()), "seq")
+        self.assertEqual(self.check_mapping_then_sequence(GrandchildPre()), "seq")
+        self.assertEqual(self.check_mapping_then_sequence(ChildPost()), "seq")
+        self.assertEqual(self.check_mapping_then_sequence(GrandchildPost()), "seq")
 
 
 class TestPatma(unittest.TestCase):
diff --git a/Misc/NEWS.d/next/Library/2021-06-22-16-45-48.bpo-43977.bamAGF.rst b/Misc/NEWS.d/next/Library/2021-06-22-16-45-48.bpo-43977.bamAGF.rst
new file mode 100644
index 0000000000000..5f8cb7b7ea729
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2021-06-22-16-45-48.bpo-43977.bamAGF.rst
@@ -0,0 +1,3 @@
+Set the proper :const:`Py_TPFLAGS_MAPPING` and :const:`Py_TPFLAGS_SEQUENCE`
+flags for subclasses created before a parent has been registered as a
+:class:`collections.abc.Mapping` or :class:`collections.abc.Sequence`.
diff --git a/Modules/_abc.c b/Modules/_abc.c
index 7720d4051fe9e..8aa68359039e7 100644
--- a/Modules/_abc.c
+++ b/Modules/_abc.c
@@ -481,6 +481,32 @@ _abc__abc_init(PyObject *module, PyObject *self)
     Py_RETURN_NONE;
 }
 
+static void
+set_collection_flag_recursive(PyTypeObject *child, unsigned long flag)
+{
+    assert(flag == Py_TPFLAGS_MAPPING || flag == Py_TPFLAGS_SEQUENCE);
+    if (PyType_HasFeature(child, Py_TPFLAGS_IMMUTABLETYPE) ||
+        (child->tp_flags & COLLECTION_FLAGS) == flag)
+    {
+        return;
+    }
+    child->tp_flags &= ~COLLECTION_FLAGS;
+    child->tp_flags |= flag;
+    PyObject *grandchildren = child->tp_subclasses;
+    if (grandchildren == NULL) {
+        return;
+    }
+    assert(PyDict_CheckExact(grandchildren));
+    Py_ssize_t i = 0;
+    while (PyDict_Next(grandchildren, &i, NULL, &grandchildren)) {
+        assert(PyWeakref_CheckRef(grandchildren));
+        PyObject *grandchild = PyWeakref_GET_OBJECT(grandchildren);
+        if (PyType_Check(grandchild)) {
+            set_collection_flag_recursive((PyTypeObject *)grandchild, flag);
+        }
+    }
+}
+
 /*[clinic input]
 _abc._abc_register
 
@@ -532,12 +558,11 @@ _abc__abc_register_impl(PyObject *module, PyObject *self, PyObject *subclass)
     get_abc_state(module)->abc_invalidation_counter++;
 
     /* Set Py_TPFLAGS_SEQUENCE  or Py_TPFLAGS_MAPPING flag */
-    if (PyType_Check(self) &&
-        !PyType_HasFeature((PyTypeObject *)subclass, Py_TPFLAGS_IMMUTABLETYPE) &&
-        ((PyTypeObject *)self)->tp_flags & COLLECTION_FLAGS)
-    {
-        ((PyTypeObject *)subclass)->tp_flags &= ~COLLECTION_FLAGS;
-        ((PyTypeObject *)subclass)->tp_flags |= (((PyTypeObject *)self)->tp_flags & COLLECTION_FLAGS);
+    if (PyType_Check(self)) {
+        unsigned long collection_flag = ((PyTypeObject *)self)->tp_flags & COLLECTION_FLAGS;
+        if (collection_flag) {
+            set_collection_flag_recursive((PyTypeObject *)subclass, collection_flag);
+        }
     }
     Py_INCREF(subclass);
     return subclass;



More information about the Python-checkins mailing list