[Python-checkins] bpo-41194: The _ast module cannot be loaded more than once (GH-21290)

Victor Stinner webhook-mailer at python.org
Fri Jul 3 08:16:04 EDT 2020


https://github.com/python/cpython/commit/91e1bc18bd467a13bceb62e16fbc435b33381c82
commit: 91e1bc18bd467a13bceb62e16fbc435b33381c82
branch: master
author: Victor Stinner <vstinner at python.org>
committer: GitHub <noreply at github.com>
date: 2020-07-03T14:15:53+02:00
summary:

bpo-41194: The _ast module cannot be loaded more than once (GH-21290)

Fix a crash in the _ast module: it can no longer be loaded more than
once. It now uses a global state rather than a module state.

* Move _ast module state: use a global state instead.
* Set _astmodule.m_size to -1, so the extension cannot be loaded more
  than once.

files:
A Misc/NEWS.d/next/Library/2020-07-03-13-15-08.bpo-41194.djrKjs.rst
M Parser/asdl_c.py
M Python/Python-ast.c

diff --git a/Misc/NEWS.d/next/Library/2020-07-03-13-15-08.bpo-41194.djrKjs.rst b/Misc/NEWS.d/next/Library/2020-07-03-13-15-08.bpo-41194.djrKjs.rst
new file mode 100644
index 0000000000000..d63a0e5222ba9
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2020-07-03-13-15-08.bpo-41194.djrKjs.rst
@@ -0,0 +1,2 @@
+Fix a crash in the ``_ast`` module: it can no longer be loaded more than once.
+It now uses a global state rather than a module state.
diff --git a/Parser/asdl_c.py b/Parser/asdl_c.py
index 39e216b8192b1..f029ca6618f93 100755
--- a/Parser/asdl_c.py
+++ b/Parser/asdl_c.py
@@ -691,7 +691,7 @@ def visitModule(self, mod):
     Py_ssize_t i, numfields = 0;
     int res = -1;
     PyObject *key, *value, *fields;
-    astmodulestate *state = astmodulestate_global;
+    astmodulestate *state = get_global_ast_state();
     if (_PyObject_LookupAttr((PyObject*)Py_TYPE(self), state->_fields, &fields) < 0) {
         goto cleanup;
     }
@@ -760,7 +760,7 @@ def visitModule(self, mod):
 static PyObject *
 ast_type_reduce(PyObject *self, PyObject *unused)
 {
-    astmodulestate *state = astmodulestate_global;
+    astmodulestate *state = get_global_ast_state();
     PyObject *dict;
     if (_PyObject_LookupAttr(self, state->__dict__, &dict) < 0) {
         return NULL;
@@ -971,19 +971,7 @@ def visitModule(self, mod):
 
         self.emit("static int init_types(void)",0)
         self.emit("{", 0)
-        self.emit("PyObject *module = PyState_FindModule(&_astmodule);", 1)
-        self.emit("if (module == NULL) {", 1)
-        self.emit("module = PyModule_Create(&_astmodule);", 2)
-        self.emit("if (!module) {", 2)
-        self.emit("return 0;", 3)
-        self.emit("}", 2)
-        self.emit("if (PyState_AddModule(module, &_astmodule) < 0) {", 2)
-        self.emit("return 0;", 3)
-        self.emit("}", 2)
-        self.emit("}", 1)
-        self.emit("", 0)
-
-        self.emit("astmodulestate *state = get_ast_state(module);", 1)
+        self.emit("astmodulestate *state = get_global_ast_state();", 1)
         self.emit("if (state->initialized) return 1;", 1)
         self.emit("if (init_identifiers(state) < 0) return 0;", 1)
         self.emit("state->AST_type = PyType_FromSpec(&AST_type_spec);", 1)
@@ -1061,13 +1049,16 @@ def visitModule(self, mod):
         self.emit("PyMODINIT_FUNC", 0)
         self.emit("PyInit__ast(void)", 0)
         self.emit("{", 0)
-        self.emit("PyObject *m;", 1)
-        self.emit("if (!init_types()) return NULL;", 1)
-        self.emit('m = PyState_FindModule(&_astmodule);', 1)
-        self.emit("if (!m) return NULL;", 1)
+        self.emit("PyObject *m = PyModule_Create(&_astmodule);", 1)
+        self.emit("if (!m) {", 1)
+        self.emit("return NULL;", 2)
+        self.emit("}", 1)
         self.emit('astmodulestate *state = get_ast_state(m);', 1)
         self.emit('', 1)
 
+        self.emit("if (!init_types()) {", 1)
+        self.emit("goto error;", 2)
+        self.emit("}", 1)
         self.emit('if (PyModule_AddObject(m, "AST", state->AST_type) < 0) {', 1)
         self.emit('goto error;', 2)
         self.emit('}', 1)
@@ -1084,6 +1075,7 @@ def visitModule(self, mod):
         for dfn in mod.dfns:
             self.visit(dfn)
         self.emit("return m;", 1)
+        self.emit("", 0)
         self.emit("error:", 0)
         self.emit("Py_DECREF(m);", 1)
         self.emit("return NULL;", 1)
@@ -1263,9 +1255,11 @@ class PartingShots(StaticVisitor):
     CODE = """
 PyObject* PyAST_mod2obj(mod_ty t)
 {
-    if (!init_types())
+    if (!init_types()) {
         return NULL;
-    astmodulestate *state = astmodulestate_global;
+    }
+
+    astmodulestate *state = get_global_ast_state();
     return ast2obj_mod(state, t);
 }
 
@@ -1279,7 +1273,7 @@ class PartingShots(StaticVisitor):
         return NULL;
     }
 
-    astmodulestate *state = astmodulestate_global;
+    astmodulestate *state = get_global_ast_state();
     PyObject *req_type[3];
     req_type[0] = state->Module_type;
     req_type[1] = state->Expression_type;
@@ -1287,8 +1281,9 @@ class PartingShots(StaticVisitor):
 
     assert(0 <= mode && mode <= 2);
 
-    if (!init_types())
+    if (!init_types()) {
         return NULL;
+    }
 
     isinstance = PyObject_IsInstance(ast, req_type[mode]);
     if (isinstance == -1)
@@ -1308,9 +1303,11 @@ class PartingShots(StaticVisitor):
 
 int PyAST_Check(PyObject* obj)
 {
-    if (!init_types())
+    if (!init_types()) {
         return -1;
-    astmodulestate *state = astmodulestate_global;
+    }
+
+    astmodulestate *state = get_global_ast_state();
     return PyObject_IsInstance(obj, state->AST_type);
 }
 """
@@ -1361,13 +1358,12 @@ def generate_module_def(f, mod):
         f.write('    PyObject *' + s + ';\n')
     f.write('} astmodulestate;\n\n')
     f.write("""
+static astmodulestate global_ast_state;
+
 static astmodulestate *
-get_ast_state(PyObject *module)
+get_ast_state(PyObject *Py_UNUSED(module))
 {
-    assert(module != NULL);
-    void *state = PyModule_GetState(module);
-    assert(state != NULL);
-    return (astmodulestate *)state;
+    return &global_ast_state;
 }
 
 static int astmodule_clear(PyObject *module)
@@ -1396,17 +1392,14 @@ def generate_module_def(f, mod):
 
 static struct PyModuleDef _astmodule = {
     PyModuleDef_HEAD_INIT,
-    "_ast",
-    NULL,
-    sizeof(astmodulestate),
-    NULL,
-    NULL,
-    astmodule_traverse,
-    astmodule_clear,
-    astmodule_free,
+    .m_name = "_ast",
+    .m_size = -1,
+    .m_traverse = astmodule_traverse,
+    .m_clear = astmodule_clear,
+    .m_free = astmodule_free,
 };
 
-#define astmodulestate_global get_ast_state(PyState_FindModule(&_astmodule))
+#define get_global_ast_state() (&global_ast_state)
 
 """)
     f.write('static int init_identifiers(astmodulestate *state)\n')
diff --git a/Python/Python-ast.c b/Python/Python-ast.c
index 3296b067584d0..844033ee37e28 100644
--- a/Python/Python-ast.c
+++ b/Python/Python-ast.c
@@ -224,13 +224,12 @@ typedef struct {
 } astmodulestate;
 
 
+static astmodulestate global_ast_state;
+
 static astmodulestate *
-get_ast_state(PyObject *module)
+get_ast_state(PyObject *Py_UNUSED(module))
 {
-    assert(module != NULL);
-    void *state = PyModule_GetState(module);
-    assert(state != NULL);
-    return (astmodulestate *)state;
+    return &global_ast_state;
 }
 
 static int astmodule_clear(PyObject *module)
@@ -679,17 +678,14 @@ static void astmodule_free(void* module) {
 
 static struct PyModuleDef _astmodule = {
     PyModuleDef_HEAD_INIT,
-    "_ast",
-    NULL,
-    sizeof(astmodulestate),
-    NULL,
-    NULL,
-    astmodule_traverse,
-    astmodule_clear,
-    astmodule_free,
+    .m_name = "_ast",
+    .m_size = -1,
+    .m_traverse = astmodule_traverse,
+    .m_clear = astmodule_clear,
+    .m_free = astmodule_free,
 };
 
-#define astmodulestate_global get_ast_state(PyState_FindModule(&_astmodule))
+#define get_global_ast_state() (&global_ast_state)
 
 static int init_identifiers(astmodulestate *state)
 {
@@ -1135,7 +1131,7 @@ ast_type_init(PyObject *self, PyObject *args, PyObject *kw)
     Py_ssize_t i, numfields = 0;
     int res = -1;
     PyObject *key, *value, *fields;
-    astmodulestate *state = astmodulestate_global;
+    astmodulestate *state = get_global_ast_state();
     if (_PyObject_LookupAttr((PyObject*)Py_TYPE(self), state->_fields, &fields) < 0) {
         goto cleanup;
     }
@@ -1204,7 +1200,7 @@ ast_type_init(PyObject *self, PyObject *args, PyObject *kw)
 static PyObject *
 ast_type_reduce(PyObject *self, PyObject *unused)
 {
-    astmodulestate *state = astmodulestate_global;
+    astmodulestate *state = get_global_ast_state();
     PyObject *dict;
     if (_PyObject_LookupAttr(self, state->__dict__, &dict) < 0) {
         return NULL;
@@ -1414,18 +1410,7 @@ static int add_ast_fields(astmodulestate *state)
 
 static int init_types(void)
 {
-    PyObject *module = PyState_FindModule(&_astmodule);
-    if (module == NULL) {
-        module = PyModule_Create(&_astmodule);
-        if (!module) {
-            return 0;
-        }
-        if (PyState_AddModule(module, &_astmodule) < 0) {
-            return 0;
-        }
-    }
-
-    astmodulestate *state = get_ast_state(module);
+    astmodulestate *state = get_global_ast_state();
     if (state->initialized) return 1;
     if (init_identifiers(state) < 0) return 0;
     state->AST_type = PyType_FromSpec(&AST_type_spec);
@@ -9857,12 +9842,15 @@ obj2ast_type_ignore(astmodulestate *state, PyObject* obj, type_ignore_ty* out,
 PyMODINIT_FUNC
 PyInit__ast(void)
 {
-    PyObject *m;
-    if (!init_types()) return NULL;
-    m = PyState_FindModule(&_astmodule);
-    if (!m) return NULL;
+    PyObject *m = PyModule_Create(&_astmodule);
+    if (!m) {
+        return NULL;
+    }
     astmodulestate *state = get_ast_state(m);
 
+    if (!init_types()) {
+        goto error;
+    }
     if (PyModule_AddObject(m, "AST", state->AST_type) < 0) {
         goto error;
     }
@@ -10303,6 +10291,7 @@ PyInit__ast(void)
     }
     Py_INCREF(state->TypeIgnore_type);
     return m;
+
 error:
     Py_DECREF(m);
     return NULL;
@@ -10311,9 +10300,11 @@ PyInit__ast(void)
 
 PyObject* PyAST_mod2obj(mod_ty t)
 {
-    if (!init_types())
+    if (!init_types()) {
         return NULL;
-    astmodulestate *state = astmodulestate_global;
+    }
+
+    astmodulestate *state = get_global_ast_state();
     return ast2obj_mod(state, t);
 }
 
@@ -10327,7 +10318,7 @@ mod_ty PyAST_obj2mod(PyObject* ast, PyArena* arena, int mode)
         return NULL;
     }
 
-    astmodulestate *state = astmodulestate_global;
+    astmodulestate *state = get_global_ast_state();
     PyObject *req_type[3];
     req_type[0] = state->Module_type;
     req_type[1] = state->Expression_type;
@@ -10335,8 +10326,9 @@ mod_ty PyAST_obj2mod(PyObject* ast, PyArena* arena, int mode)
 
     assert(0 <= mode && mode <= 2);
 
-    if (!init_types())
+    if (!init_types()) {
         return NULL;
+    }
 
     isinstance = PyObject_IsInstance(ast, req_type[mode]);
     if (isinstance == -1)
@@ -10356,9 +10348,11 @@ mod_ty PyAST_obj2mod(PyObject* ast, PyArena* arena, int mode)
 
 int PyAST_Check(PyObject* obj)
 {
-    if (!init_types())
+    if (!init_types()) {
         return -1;
-    astmodulestate *state = astmodulestate_global;
+    }
+
+    astmodulestate *state = get_global_ast_state();
     return PyObject_IsInstance(obj, state->AST_type);
 }
 



More information about the Python-checkins mailing list