[Python-checkins] bpo-1635741: Fix PyInit_pyexpat() error handling (GH-22489)

vstinner webhook-mailer at python.org
Wed Nov 4 12:37:30 EST 2020


https://github.com/python/cpython/commit/7184218e186811e75be663be78e57d5302bf8af6
commit: 7184218e186811e75be663be78e57d5302bf8af6
branch: master
author: Mohamed Koubaa <koubaa.m at gmail.com>
committer: vstinner <vstinner at python.org>
date: 2020-11-04T18:37:23+01:00
summary:

bpo-1635741: Fix PyInit_pyexpat() error handling (GH-22489)

Split PyInit_pyexpat() into sub-functions and fix reference leaks
on error paths.

files:
M Modules/pyexpat.c

diff --git a/Modules/pyexpat.c b/Modules/pyexpat.c
index 73ea51385ee80..7d7da568972a2 100644
--- a/Modules/pyexpat.c
+++ b/Modules/pyexpat.c
@@ -1587,18 +1587,6 @@ PyDoc_STRVAR(pyexpat_module_documentation,
 #define MODULE_INITFUNC PyInit_pyexpat
 #endif
 
-static struct PyModuleDef pyexpatmodule = {
-        PyModuleDef_HEAD_INIT,
-        MODULE_NAME,
-        pyexpat_module_documentation,
-        -1,
-        pyexpat_methods,
-        NULL,
-        NULL,
-        NULL,
-        NULL
-};
-
 static int init_handler_descrs(void)
 {
     int i;
@@ -1623,210 +1611,182 @@ static int init_handler_descrs(void)
     return 0;
 }
 
-PyMODINIT_FUNC
-MODULE_INITFUNC(void)
+static PyObject *
+add_submodule(PyObject *mod, const char *fullname)
 {
-    PyObject *m, *d;
-    PyObject *errmod_name = PyUnicode_FromString(MODULE_NAME ".errors");
-    PyObject *errors_module;
-    PyObject *modelmod_name;
-    PyObject *model_module;
-    PyObject *tmpnum, *tmpstr;
-    PyObject *codes_dict;
-    PyObject *rev_codes_dict;
-    int res;
-    static struct PyExpat_CAPI capi;
-    PyObject *capi_object;
+    const char *name = strrchr(fullname, '.') + 1;
 
-    if (errmod_name == NULL)
-        return NULL;
-    modelmod_name = PyUnicode_FromString(MODULE_NAME ".model");
-    if (modelmod_name == NULL)
+    PyObject *submodule = PyModule_New(fullname);
+    if (submodule == NULL) {
         return NULL;
+    }
 
-    if (PyType_Ready(&Xmlparsetype) < 0 || init_handler_descrs() < 0)
+    PyObject *mod_name = PyUnicode_FromString(fullname);
+    if (mod_name == NULL) {
+        Py_DECREF(submodule);
         return NULL;
+    }
 
-    /* Create the module and add the functions */
-    m = PyModule_Create(&pyexpatmodule);
-    if (m == NULL)
+    if (_PyImport_SetModule(mod_name, submodule) < 0) {
+        Py_DECREF(submodule);
+        Py_DECREF(mod_name);
         return NULL;
+    }
+    Py_DECREF(mod_name);
 
-    /* Add some symbolic constants to the module */
-    if (ErrorObject == NULL) {
-        ErrorObject = PyErr_NewException("xml.parsers.expat.ExpatError",
-                                         NULL, NULL);
-        if (ErrorObject == NULL)
-            return NULL;
+    /* gives away the reference to the submodule */
+    if (PyModule_AddObject(mod, name, submodule) < 0) {
+        Py_DECREF(submodule);
+        return NULL;
     }
-    Py_INCREF(ErrorObject);
-    PyModule_AddObject(m, "error", ErrorObject);
-    Py_INCREF(ErrorObject);
-    PyModule_AddObject(m, "ExpatError", ErrorObject);
-    Py_INCREF(&Xmlparsetype);
-    PyModule_AddObject(m, "XMLParserType", (PyObject *) &Xmlparsetype);
 
-    PyModule_AddStringConstant(m, "EXPAT_VERSION",
-                               XML_ExpatVersion());
-    {
-        XML_Expat_Version info = XML_ExpatVersionInfo();
-        PyModule_AddObject(m, "version_info",
-                           Py_BuildValue("(iii)", info.major,
-                                         info.minor, info.micro));
+    return submodule;
+}
+
+static int
+add_error(PyObject *errors_module, PyObject *codes_dict,
+          PyObject *rev_codes_dict, const char *name, int value)
+{
+    const char *error_string = XML_ErrorString(value);
+    if (PyModule_AddStringConstant(errors_module, name, error_string) < 0) {
+        return -1;
     }
-    /* XXX When Expat supports some way of figuring out how it was
-       compiled, this should check and set native_encoding
-       appropriately.
-    */
-    PyModule_AddStringConstant(m, "native_encoding", "UTF-8");
 
-    d = PyModule_GetDict(m);
-    if (d == NULL) {
-        Py_DECREF(m);
-        return NULL;
+    PyObject *num = PyLong_FromLong(value);
+    if (num == NULL) {
+        return -1;
     }
-    errors_module = PyDict_GetItemWithError(d, errmod_name);
-    if (errors_module == NULL && !PyErr_Occurred()) {
-        errors_module = PyModule_New(MODULE_NAME ".errors");
-        if (errors_module != NULL) {
-            _PyImport_SetModule(errmod_name, errors_module);
-            /* gives away the reference to errors_module */
-            PyModule_AddObject(m, "errors", errors_module);
-        }
+
+    if (PyDict_SetItemString(codes_dict, error_string, num) < 0) {
+        Py_DECREF(num);
+        return -1;
     }
-    Py_DECREF(errmod_name);
-    model_module = PyDict_GetItemWithError(d, modelmod_name);
-    if (model_module == NULL && !PyErr_Occurred()) {
-        model_module = PyModule_New(MODULE_NAME ".model");
-        if (model_module != NULL) {
-            _PyImport_SetModule(modelmod_name, model_module);
-            /* gives away the reference to model_module */
-            PyModule_AddObject(m, "model", model_module);
-        }
+
+    PyObject *str = PyUnicode_FromString(error_string);
+    if (str == NULL) {
+        Py_DECREF(num);
+        return -1;
     }
-    Py_DECREF(modelmod_name);
-    if (errors_module == NULL || model_module == NULL) {
-        /* Don't core dump later! */
-        Py_DECREF(m);
-        return NULL;
+
+    int res = PyDict_SetItem(rev_codes_dict, num, str);
+    Py_DECREF(str);
+    Py_DECREF(num);
+    if (res < 0) {
+        return -1;
     }
 
-#if XML_COMBINED_VERSION > 19505
-    {
-        const XML_Feature *features = XML_GetFeatureList();
-        PyObject *list = PyList_New(0);
-        if (list == NULL)
-            /* just ignore it */
-            PyErr_Clear();
-        else {
-            int i = 0;
-            for (; features[i].feature != XML_FEATURE_END; ++i) {
-                int ok;
-                PyObject *item = Py_BuildValue("si", features[i].name,
-                                               features[i].value);
-                if (item == NULL) {
-                    Py_DECREF(list);
-                    list = NULL;
-                    break;
-                }
-                ok = PyList_Append(list, item);
-                Py_DECREF(item);
-                if (ok < 0) {
-                    PyErr_Clear();
-                    break;
-                }
-            }
-            if (list != NULL)
-                PyModule_AddObject(m, "features", list);
-        }
+    return 0;
+}
+
+static int
+add_errors_module(PyObject *mod)
+{
+    PyObject *errors_module = add_submodule(mod, MODULE_NAME ".errors");
+    if (errors_module == NULL) {
+        return -1;
     }
-#endif
 
-    codes_dict = PyDict_New();
-    rev_codes_dict = PyDict_New();
+    PyObject *codes_dict = PyDict_New();
+    PyObject *rev_codes_dict = PyDict_New();
     if (codes_dict == NULL || rev_codes_dict == NULL) {
-        Py_XDECREF(codes_dict);
-        Py_XDECREF(rev_codes_dict);
-        return NULL;
+        goto error;
     }
 
-#define MYCONST(name) \
-    if (PyModule_AddStringConstant(errors_module, #name,               \
-                                   XML_ErrorString(name)) < 0)         \
-        return NULL;                                                   \
-    tmpnum = PyLong_FromLong(name);                                    \
-    if (tmpnum == NULL) return NULL;                                   \
-    res = PyDict_SetItemString(codes_dict,                             \
-                               XML_ErrorString(name), tmpnum);         \
-    if (res < 0) return NULL;                                          \
-    tmpstr = PyUnicode_FromString(XML_ErrorString(name));              \
-    if (tmpstr == NULL) return NULL;                                   \
-    res = PyDict_SetItem(rev_codes_dict, tmpnum, tmpstr);              \
-    Py_DECREF(tmpstr);                                                 \
-    Py_DECREF(tmpnum);                                                 \
-    if (res < 0) return NULL;                                          \
-
-    MYCONST(XML_ERROR_NO_MEMORY);
-    MYCONST(XML_ERROR_SYNTAX);
-    MYCONST(XML_ERROR_NO_ELEMENTS);
-    MYCONST(XML_ERROR_INVALID_TOKEN);
-    MYCONST(XML_ERROR_UNCLOSED_TOKEN);
-    MYCONST(XML_ERROR_PARTIAL_CHAR);
-    MYCONST(XML_ERROR_TAG_MISMATCH);
-    MYCONST(XML_ERROR_DUPLICATE_ATTRIBUTE);
-    MYCONST(XML_ERROR_JUNK_AFTER_DOC_ELEMENT);
-    MYCONST(XML_ERROR_PARAM_ENTITY_REF);
-    MYCONST(XML_ERROR_UNDEFINED_ENTITY);
-    MYCONST(XML_ERROR_RECURSIVE_ENTITY_REF);
-    MYCONST(XML_ERROR_ASYNC_ENTITY);
-    MYCONST(XML_ERROR_BAD_CHAR_REF);
-    MYCONST(XML_ERROR_BINARY_ENTITY_REF);
-    MYCONST(XML_ERROR_ATTRIBUTE_EXTERNAL_ENTITY_REF);
-    MYCONST(XML_ERROR_MISPLACED_XML_PI);
-    MYCONST(XML_ERROR_UNKNOWN_ENCODING);
-    MYCONST(XML_ERROR_INCORRECT_ENCODING);
-    MYCONST(XML_ERROR_UNCLOSED_CDATA_SECTION);
-    MYCONST(XML_ERROR_EXTERNAL_ENTITY_HANDLING);
-    MYCONST(XML_ERROR_NOT_STANDALONE);
-    MYCONST(XML_ERROR_UNEXPECTED_STATE);
-    MYCONST(XML_ERROR_ENTITY_DECLARED_IN_PE);
-    MYCONST(XML_ERROR_FEATURE_REQUIRES_XML_DTD);
-    MYCONST(XML_ERROR_CANT_CHANGE_FEATURE_ONCE_PARSING);
+#define ADD_CONST(name) do {                                        \
+        if (add_error(errors_module, codes_dict, rev_codes_dict,    \
+                      #name, name) < 0) {                           \
+            goto error;                                             \
+        }                                                           \
+    } while(0)
+
+    ADD_CONST(XML_ERROR_NO_MEMORY);
+    ADD_CONST(XML_ERROR_SYNTAX);
+    ADD_CONST(XML_ERROR_NO_ELEMENTS);
+    ADD_CONST(XML_ERROR_INVALID_TOKEN);
+    ADD_CONST(XML_ERROR_UNCLOSED_TOKEN);
+    ADD_CONST(XML_ERROR_PARTIAL_CHAR);
+    ADD_CONST(XML_ERROR_TAG_MISMATCH);
+    ADD_CONST(XML_ERROR_DUPLICATE_ATTRIBUTE);
+    ADD_CONST(XML_ERROR_JUNK_AFTER_DOC_ELEMENT);
+    ADD_CONST(XML_ERROR_PARAM_ENTITY_REF);
+    ADD_CONST(XML_ERROR_UNDEFINED_ENTITY);
+    ADD_CONST(XML_ERROR_RECURSIVE_ENTITY_REF);
+    ADD_CONST(XML_ERROR_ASYNC_ENTITY);
+    ADD_CONST(XML_ERROR_BAD_CHAR_REF);
+    ADD_CONST(XML_ERROR_BINARY_ENTITY_REF);
+    ADD_CONST(XML_ERROR_ATTRIBUTE_EXTERNAL_ENTITY_REF);
+    ADD_CONST(XML_ERROR_MISPLACED_XML_PI);
+    ADD_CONST(XML_ERROR_UNKNOWN_ENCODING);
+    ADD_CONST(XML_ERROR_INCORRECT_ENCODING);
+    ADD_CONST(XML_ERROR_UNCLOSED_CDATA_SECTION);
+    ADD_CONST(XML_ERROR_EXTERNAL_ENTITY_HANDLING);
+    ADD_CONST(XML_ERROR_NOT_STANDALONE);
+    ADD_CONST(XML_ERROR_UNEXPECTED_STATE);
+    ADD_CONST(XML_ERROR_ENTITY_DECLARED_IN_PE);
+    ADD_CONST(XML_ERROR_FEATURE_REQUIRES_XML_DTD);
+    ADD_CONST(XML_ERROR_CANT_CHANGE_FEATURE_ONCE_PARSING);
     /* Added in Expat 1.95.7. */
-    MYCONST(XML_ERROR_UNBOUND_PREFIX);
+    ADD_CONST(XML_ERROR_UNBOUND_PREFIX);
     /* Added in Expat 1.95.8. */
-    MYCONST(XML_ERROR_UNDECLARING_PREFIX);
-    MYCONST(XML_ERROR_INCOMPLETE_PE);
-    MYCONST(XML_ERROR_XML_DECL);
-    MYCONST(XML_ERROR_TEXT_DECL);
-    MYCONST(XML_ERROR_PUBLICID);
-    MYCONST(XML_ERROR_SUSPENDED);
-    MYCONST(XML_ERROR_NOT_SUSPENDED);
-    MYCONST(XML_ERROR_ABORTED);
-    MYCONST(XML_ERROR_FINISHED);
-    MYCONST(XML_ERROR_SUSPEND_PE);
+    ADD_CONST(XML_ERROR_UNDECLARING_PREFIX);
+    ADD_CONST(XML_ERROR_INCOMPLETE_PE);
+    ADD_CONST(XML_ERROR_XML_DECL);
+    ADD_CONST(XML_ERROR_TEXT_DECL);
+    ADD_CONST(XML_ERROR_PUBLICID);
+    ADD_CONST(XML_ERROR_SUSPENDED);
+    ADD_CONST(XML_ERROR_NOT_SUSPENDED);
+    ADD_CONST(XML_ERROR_ABORTED);
+    ADD_CONST(XML_ERROR_FINISHED);
+    ADD_CONST(XML_ERROR_SUSPEND_PE);
+#undef ADD_CONST
 
     if (PyModule_AddStringConstant(errors_module, "__doc__",
                                    "Constants used to describe "
-                                   "error conditions.") < 0)
-        return NULL;
+                                   "error conditions.") < 0) {
+        goto error;
+    }
 
-    if (PyModule_AddObject(errors_module, "codes", codes_dict) < 0)
-        return NULL;
-    if (PyModule_AddObject(errors_module, "messages", rev_codes_dict) < 0)
-        return NULL;
+    Py_INCREF(codes_dict);
+    if (PyModule_AddObject(errors_module, "codes", codes_dict) < 0) {
+        Py_DECREF(codes_dict);
+        goto error;
+    }
+    Py_CLEAR(codes_dict);
 
-#undef MYCONST
+    Py_INCREF(rev_codes_dict);
+    if (PyModule_AddObject(errors_module, "messages", rev_codes_dict) < 0) {
+        Py_DECREF(rev_codes_dict);
+        goto error;
+    }
+    Py_CLEAR(rev_codes_dict);
 
-#define MYCONST(c) PyModule_AddIntConstant(m, #c, c)
-    MYCONST(XML_PARAM_ENTITY_PARSING_NEVER);
-    MYCONST(XML_PARAM_ENTITY_PARSING_UNLESS_STANDALONE);
-    MYCONST(XML_PARAM_ENTITY_PARSING_ALWAYS);
-#undef MYCONST
+    return 0;
+
+error:
+    Py_XDECREF(codes_dict);
+    Py_XDECREF(rev_codes_dict);
+    return -1;
+}
+
+static int
+add_model_module(PyObject *mod)
+{
+    PyObject *model_module = add_submodule(mod, MODULE_NAME ".model");
+    if (model_module == NULL) {
+        return -1;
+    }
+
+#define MYCONST(c)  do {                                        \
+        if (PyModule_AddIntConstant(model_module, #c, c) < 0) { \
+            return -1;                                          \
+        }                                                       \
+    } while(0)
 
-#define MYCONST(c) PyModule_AddIntConstant(model_module, #c, c)
-    PyModule_AddStringConstant(model_module, "__doc__",
-                     "Constants used to interpret content model information.");
+    if (PyModule_AddStringConstant(
+        model_module, "__doc__",
+        "Constants used to interpret content model information.") < 0) {
+        return -1;
+    }
 
     MYCONST(XML_CTYPE_EMPTY);
     MYCONST(XML_CTYPE_ANY);
@@ -1840,7 +1800,128 @@ MODULE_INITFUNC(void)
     MYCONST(XML_CQUANT_REP);
     MYCONST(XML_CQUANT_PLUS);
 #undef MYCONST
+    return 0;
+}
 
+#if XML_COMBINED_VERSION > 19505
+static int
+add_features(PyObject *mod)
+{
+    PyObject *list = PyList_New(0);
+    if (list == NULL) {
+        return -1;
+    }
+
+    const XML_Feature *features = XML_GetFeatureList();
+    for (size_t i = 0; features[i].feature != XML_FEATURE_END; ++i) {
+        PyObject *item = Py_BuildValue("si", features[i].name,
+                                       features[i].value);
+        if (item == NULL) {
+            goto error;
+        }
+        int ok = PyList_Append(list, item);
+        Py_DECREF(item);
+        if (ok < 0) {
+            goto error;
+        }
+    }
+    if (PyModule_AddObject(mod, "features", list) < 0) {
+        goto error;
+    }
+    return 0;
+
+error:
+    Py_DECREF(list);
+    return -1;
+}
+#endif
+
+static int
+pyexpat_exec(PyObject *mod)
+{
+    if (PyType_Ready(&Xmlparsetype) < 0) {
+        return -1;
+    }
+
+    if (init_handler_descrs() < 0) {
+        return -1;
+    }
+
+    /* Add some symbolic constants to the module */
+    if (ErrorObject == NULL) {
+        ErrorObject = PyErr_NewException("xml.parsers.expat.ExpatError",
+                                         NULL, NULL);
+    }
+    if (ErrorObject == NULL) {
+        return -1;
+    }
+
+    Py_INCREF(ErrorObject);
+    if (PyModule_AddObject(mod, "error", ErrorObject) < 0) {
+        Py_DECREF(ErrorObject);
+        return -1;
+    }
+    Py_INCREF(ErrorObject);
+    if (PyModule_AddObject(mod, "ExpatError", ErrorObject) < 0) {
+        Py_DECREF(ErrorObject);
+        return -1;
+    }
+    Py_INCREF(&Xmlparsetype);
+    if (PyModule_AddObject(mod, "XMLParserType",
+                           (PyObject *) &Xmlparsetype) < 0) {
+        Py_DECREF(&Xmlparsetype);
+        return -1;
+    }
+
+    if (PyModule_AddStringConstant(mod, "EXPAT_VERSION",
+                                   XML_ExpatVersion()) < 0) {
+        return -1;
+    }
+    {
+        XML_Expat_Version info = XML_ExpatVersionInfo();
+        PyObject *versionInfo = Py_BuildValue("(iii)",
+                                              info.major,
+                                              info.minor,
+                                              info.micro);
+        if (PyModule_AddObject(mod, "version_info", versionInfo) < 0) {
+            Py_DECREF(versionInfo);
+            return -1;
+        }
+    }
+    /* XXX When Expat supports some way of figuring out how it was
+       compiled, this should check and set native_encoding
+       appropriately.
+    */
+    if (PyModule_AddStringConstant(mod, "native_encoding", "UTF-8") < 0) {
+        return -1;
+    }
+
+    if (add_errors_module(mod) < 0) {
+        return -1;
+    }
+
+    if (add_model_module(mod) < 0) {
+        return -1;
+    }
+
+#if XML_COMBINED_VERSION > 19505
+    if (add_features(mod) < 0) {
+        return -1;
+    }
+#endif
+
+#define MYCONST(c) do {                                 \
+        if (PyModule_AddIntConstant(mod, #c, c) < 0) {  \
+            return -1;                                  \
+        }                                               \
+    } while(0)
+
+    MYCONST(XML_PARAM_ENTITY_PARSING_NEVER);
+    MYCONST(XML_PARAM_ENTITY_PARSING_UNLESS_STANDALONE);
+    MYCONST(XML_PARAM_ENTITY_PARSING_ALWAYS);
+#undef MYCONST
+
+    static struct PyExpat_CAPI capi;
     /* initialize pyexpat dispatch table */
     capi.size = sizeof(capi);
     capi.magic = PyExpat_CAPI_MAGIC;
@@ -1872,10 +1953,39 @@ MODULE_INITFUNC(void)
 #endif
 
     /* export using capsule */
-    capi_object = PyCapsule_New(&capi, PyExpat_CAPSULE_NAME, NULL);
-    if (capi_object)
-        PyModule_AddObject(m, "expat_CAPI", capi_object);
-    return m;
+    PyObject *capi_object = PyCapsule_New(&capi, PyExpat_CAPSULE_NAME, NULL);
+    if (capi_object == NULL) {
+        return -1;
+    }
+
+    if (PyModule_AddObject(mod, "expat_CAPI", capi_object) < 0) {
+        Py_DECREF(capi_object);
+        return -1;
+    }
+
+    return 0;
+}
+
+static struct PyModuleDef pyexpatmodule = {
+    PyModuleDef_HEAD_INIT,
+    .m_name = MODULE_NAME,
+    .m_doc = pyexpat_module_documentation,
+    .m_size = -1,
+    .m_methods = pyexpat_methods,
+};
+
+PyMODINIT_FUNC
+PyInit_pyexpat(void)
+{
+    PyObject *mod = PyModule_Create(&pyexpatmodule);
+    if (mod == NULL)
+        return NULL;
+
+    if (pyexpat_exec(mod) < 0) {
+        Py_DECREF(mod);
+        return NULL;
+    }
+    return mod;
 }
 
 static void



More information about the Python-checkins mailing list