[Python-checkins] bpo-42972: Track sqlite3 statement objects (GH-26475)

vstinner webhook-mailer at python.org
Tue Jun 1 06:47:42 EDT 2021


https://github.com/python/cpython/commit/fffa0f92adaaed0bcb3907d982506f78925e9052
commit: fffa0f92adaaed0bcb3907d982506f78925e9052
branch: main
author: Erlend Egeberg Aasland <erlend.aasland at innova.no>
committer: vstinner <vstinner at python.org>
date: 2021-06-01T12:47:37+02:00
summary:

bpo-42972: Track sqlite3 statement objects (GH-26475)

Allocate and track statement objects in pysqlite_statement_create.

By allocating and tracking creation of statement object in
pysqlite_statement_create(), the caller does not need to worry about GC
syncronization, and eliminates the possibility of getting a badly
created object. All related fault handling is moved to
pysqlite_statement_create().

Co-authored-by: Victor Stinner <vstinner at python.org>

files:
M Modules/_sqlite/connection.c
M Modules/_sqlite/cursor.c
M Modules/_sqlite/statement.c
M Modules/_sqlite/statement.h

diff --git a/Modules/_sqlite/connection.c b/Modules/_sqlite/connection.c
index 54e477739bd45..7252ccab10b4b 100644
--- a/Modules/_sqlite/connection.c
+++ b/Modules/_sqlite/connection.c
@@ -1337,7 +1337,6 @@ pysqlite_connection_call(pysqlite_Connection *self, PyObject *args,
     PyObject* sql;
     pysqlite_Statement* statement;
     PyObject* weakref;
-    int rc;
 
     if (!pysqlite_check_thread(self) || !pysqlite_check_connection(self)) {
         return NULL;
@@ -1351,31 +1350,11 @@ pysqlite_connection_call(pysqlite_Connection *self, PyObject *args,
 
     _pysqlite_drop_unused_statement_references(self);
 
-    statement = PyObject_GC_New(pysqlite_Statement, pysqlite_StatementType);
-    if (!statement) {
+    statement = pysqlite_statement_create(self, sql);
+    if (statement == NULL) {
         return NULL;
     }
 
-    statement->db = NULL;
-    statement->st = NULL;
-    statement->sql = NULL;
-    statement->in_use = 0;
-    statement->in_weakreflist = NULL;
-
-    rc = pysqlite_statement_create(statement, self, sql);
-    if (rc != SQLITE_OK) {
-        if (rc == PYSQLITE_TOO_MUCH_SQL) {
-            PyErr_SetString(pysqlite_Warning, "You can only execute one statement at a time.");
-        } else if (rc == PYSQLITE_SQL_WRONG_TYPE) {
-            if (PyErr_ExceptionMatches(PyExc_TypeError))
-                PyErr_SetString(pysqlite_Warning, "SQL is of wrong type. Must be string.");
-        } else {
-            (void)pysqlite_statement_reset(statement);
-            _pysqlite_seterror(self->db);
-        }
-        goto error;
-    }
-
     weakref = PyWeakref_NewRef((PyObject*)statement, NULL);
     if (weakref == NULL)
         goto error;
diff --git a/Modules/_sqlite/cursor.c b/Modules/_sqlite/cursor.c
index 4eb9c6d28b103..7656c903a4acf 100644
--- a/Modules/_sqlite/cursor.c
+++ b/Modules/_sqlite/cursor.c
@@ -507,13 +507,8 @@ _pysqlite_query_execute(pysqlite_Cursor* self, int multiple, PyObject* operation
 
     if (self->statement->in_use) {
         Py_SETREF(self->statement,
-                  PyObject_GC_New(pysqlite_Statement, pysqlite_StatementType));
-        if (!self->statement) {
-            goto error;
-        }
-        rc = pysqlite_statement_create(self->statement, self->connection, operation);
-        if (rc != SQLITE_OK) {
-            Py_CLEAR(self->statement);
+                  pysqlite_statement_create(self->connection, operation));
+        if (self->statement == NULL) {
             goto error;
         }
     }
diff --git a/Modules/_sqlite/statement.c b/Modules/_sqlite/statement.c
index 332bf030a9b90..3a18ad8331f69 100644
--- a/Modules/_sqlite/statement.c
+++ b/Modules/_sqlite/statement.c
@@ -48,7 +48,8 @@ typedef enum {
     TYPE_UNKNOWN
 } parameter_type;
 
-int pysqlite_statement_create(pysqlite_Statement* self, pysqlite_Connection* connection, PyObject* sql)
+pysqlite_Statement *
+pysqlite_statement_create(pysqlite_Connection *connection, PyObject *sql)
 {
     const char* tail;
     int rc;
@@ -56,27 +57,36 @@ int pysqlite_statement_create(pysqlite_Statement* self, pysqlite_Connection* con
     Py_ssize_t sql_cstr_len;
     const char* p;
 
-    self->st = NULL;
-    self->in_use = 0;
-
     assert(PyUnicode_Check(sql));
 
     sql_cstr = PyUnicode_AsUTF8AndSize(sql, &sql_cstr_len);
     if (sql_cstr == NULL) {
-        rc = PYSQLITE_SQL_WRONG_TYPE;
-        return rc;
+        PyErr_Format(pysqlite_Warning,
+                     "SQL is of wrong type ('%s'). Must be string.",
+                     Py_TYPE(sql)->tp_name);
+        return NULL;
     }
     if (strlen(sql_cstr) != (size_t)sql_cstr_len) {
-        PyErr_SetString(PyExc_ValueError, "the query contains a null character");
-        return PYSQLITE_SQL_WRONG_TYPE;
+        PyErr_SetString(PyExc_ValueError,
+                        "the query contains a null character");
+        return NULL;
     }
 
-    self->in_weakreflist = NULL;
+    pysqlite_Statement *self = PyObject_GC_New(pysqlite_Statement,
+                                               pysqlite_StatementType);
+    if (self == NULL) {
+        return NULL;
+    }
+
+    self->db = connection->db;
+    self->st = NULL;
     self->sql = Py_NewRef(sql);
+    self->in_use = 0;
+    self->is_dml = 0;
+    self->in_weakreflist = NULL;
 
     /* Determine if the statement is a DML statement.
        SELECT is the only exception. See #9924. */
-    self->is_dml = 0;
     for (p = sql_cstr; *p != 0; p++) {
         switch (*p) {
             case ' ':
@@ -94,22 +104,33 @@ int pysqlite_statement_create(pysqlite_Statement* self, pysqlite_Connection* con
     }
 
     Py_BEGIN_ALLOW_THREADS
-    rc = sqlite3_prepare_v2(connection->db,
+    rc = sqlite3_prepare_v2(self->db,
                             sql_cstr,
                             -1,
                             &self->st,
                             &tail);
     Py_END_ALLOW_THREADS
 
-    self->db = connection->db;
+    PyObject_GC_Track(self);
+
+    if (rc != SQLITE_OK) {
+        _pysqlite_seterror(self->db);
+        goto error;
+    }
 
     if (rc == SQLITE_OK && pysqlite_check_remaining_sql(tail)) {
         (void)sqlite3_finalize(self->st);
         self->st = NULL;
-        rc = PYSQLITE_TOO_MUCH_SQL;
+        PyErr_SetString(pysqlite_Warning,
+                        "You can only execute one statement at a time.");
+        goto error;
     }
 
-    return rc;
+    return self;
+
+error:
+    Py_DECREF(self);
+    return NULL;
 }
 
 int pysqlite_statement_bind_parameter(pysqlite_Statement* self, int pos, PyObject* parameter)
diff --git a/Modules/_sqlite/statement.h b/Modules/_sqlite/statement.h
index 56ff7271448d1..e8c86a0ec963f 100644
--- a/Modules/_sqlite/statement.h
+++ b/Modules/_sqlite/statement.h
@@ -29,9 +29,6 @@
 #include "connection.h"
 #include "sqlite3.h"
 
-#define PYSQLITE_TOO_MUCH_SQL (-100)
-#define PYSQLITE_SQL_WRONG_TYPE (-101)
-
 typedef struct
 {
     PyObject_HEAD
@@ -45,7 +42,7 @@ typedef struct
 
 extern PyTypeObject *pysqlite_StatementType;
 
-int pysqlite_statement_create(pysqlite_Statement* self, pysqlite_Connection* connection, PyObject* sql);
+pysqlite_Statement *pysqlite_statement_create(pysqlite_Connection *connection, PyObject *sql);
 
 int pysqlite_statement_bind_parameter(pysqlite_Statement* self, int pos, PyObject* parameter);
 void pysqlite_statement_bind_parameters(pysqlite_Statement* self, PyObject* parameters);



More information about the Python-checkins mailing list