[Python-checkins] gh-79579: Improve DML query detection in sqlite3 (#93623)

erlend-aasland webhook-mailer at python.org
Tue Jun 14 07:56:41 EDT 2022


https://github.com/python/cpython/commit/46740073ef32bf83964c39609c7a7a4772c51ce3
commit: 46740073ef32bf83964c39609c7a7a4772c51ce3
branch: main
author: Erlend Egeberg Aasland <erlend.aasland at protonmail.com>
committer: erlend-aasland <erlend.aasland at protonmail.com>
date: 2022-06-14T13:56:36+02:00
summary:

gh-79579: Improve DML query detection in sqlite3 (#93623)

The fix involves using pysqlite_check_remaining_sql(), not only to check
for multiple statements, but now also to strip leading comments and
whitespace from SQL statements, so we can improve DML query detection.

pysqlite_check_remaining_sql() is renamed lstrip_sql(), to more
accurately reflect its function, and hardened to handle more SQL comment
corner cases.

files:
A Misc/NEWS.d/next/Library/2022-06-06-12-58-27.gh-issue-79579.e8rB-M.rst
M Lib/test/test_sqlite3/test_dbapi.py
M Modules/_sqlite/statement.c

diff --git a/Lib/test/test_sqlite3/test_dbapi.py b/Lib/test/test_sqlite3/test_dbapi.py
index 05180a3616c5d..18e84e9e58632 100644
--- a/Lib/test/test_sqlite3/test_dbapi.py
+++ b/Lib/test/test_sqlite3/test_dbapi.py
@@ -746,22 +746,44 @@ def test_execute_illegal_sql(self):
         with self.assertRaises(sqlite.OperationalError):
             self.cu.execute("select asdf")
 
-    def test_execute_too_much_sql(self):
-        self.assertRaisesRegex(sqlite.ProgrammingError,
-                               "You can only execute one statement at a time",
-                               self.cu.execute, "select 5+4; select 4+5")
-
-    def test_execute_too_much_sql2(self):
-        self.cu.execute("select 5+4; -- foo bar")
+    def test_execute_multiple_statements(self):
+        msg = "You can only execute one statement at a time"
+        dataset = (
+            "select 1; select 2",
+            "select 1; // c++ comments are not allowed",
+            "select 1; *not a comment",
+            "select 1; -*not a comment",
+            "select 1; /* */ a",
+            "select 1; /**/a",
+            "select 1; -",
+            "select 1; /",
+            "select 1; -\n- select 2",
+            """select 1;
+               -- comment
+               select 2
+            """,
+        )
+        for query in dataset:
+            with self.subTest(query=query):
+                with self.assertRaisesRegex(sqlite.ProgrammingError, msg):
+                    self.cu.execute(query)
 
-    def test_execute_too_much_sql3(self):
-        self.cu.execute("""
+    def test_execute_with_appended_comments(self):
+        dataset = (
+            "select 1; -- foo bar",
+            "select 1; --",
+            "select 1; /*",  # Unclosed comments ending in \0 are skipped.
+            """
             select 5+4;
 
             /*
             foo
             */
-            """)
+            """,
+        )
+        for query in dataset:
+            with self.subTest(query=query):
+                self.cu.execute(query)
 
     def test_execute_wrong_sql_arg(self):
         with self.assertRaises(TypeError):
@@ -906,6 +928,30 @@ def test_rowcount_update_returning(self):
         self.assertEqual(self.cu.fetchone()[0], 1)
         self.assertEqual(self.cu.rowcount, 1)
 
+    def test_rowcount_prefixed_with_comment(self):
+        # gh-79579: rowcount is updated even if query is prefixed with comments
+        self.cu.execute("""
+            -- foo
+            insert into test(name) values ('foo'), ('foo')
+        """)
+        self.assertEqual(self.cu.rowcount, 2)
+        self.cu.execute("""
+            /* -- messy *r /* /* ** *- *--
+            */
+            /* one more */ insert into test(name) values ('messy')
+        """)
+        self.assertEqual(self.cu.rowcount, 1)
+        self.cu.execute("/* bar */ update test set name='bar' where name='foo'")
+        self.assertEqual(self.cu.rowcount, 3)
+
+    def test_rowcount_vaccuum(self):
+        data = ((1,), (2,), (3,))
+        self.cu.executemany("insert into test(income) values(?)", data)
+        self.assertEqual(self.cu.rowcount, 3)
+        self.cx.commit()
+        self.cu.execute("vacuum")
+        self.assertEqual(self.cu.rowcount, -1)
+
     def test_total_changes(self):
         self.cu.execute("insert into test(name) values ('foo')")
         self.cu.execute("insert into test(name) values ('foo')")
diff --git a/Misc/NEWS.d/next/Library/2022-06-06-12-58-27.gh-issue-79579.e8rB-M.rst b/Misc/NEWS.d/next/Library/2022-06-06-12-58-27.gh-issue-79579.e8rB-M.rst
new file mode 100644
index 0000000000000..82b1a1c28a600
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2022-06-06-12-58-27.gh-issue-79579.e8rB-M.rst
@@ -0,0 +1,2 @@
+:mod:`sqlite3` now correctly detects DML queries with leading comments.
+Patch by Erlend E. Aasland.
diff --git a/Modules/_sqlite/statement.c b/Modules/_sqlite/statement.c
index f9cb70f0ef146..aee460747b45f 100644
--- a/Modules/_sqlite/statement.c
+++ b/Modules/_sqlite/statement.c
@@ -26,16 +26,7 @@
 #include "util.h"
 
 /* prototypes */
-static int pysqlite_check_remaining_sql(const char* tail);
-
-typedef enum {
-    LINECOMMENT_1,
-    IN_LINECOMMENT,
-    COMMENTSTART_1,
-    IN_COMMENT,
-    COMMENTEND_1,
-    NORMAL
-} parse_remaining_sql_state;
+static const char *lstrip_sql(const char *sql);
 
 pysqlite_Statement *
 pysqlite_statement_create(pysqlite_Connection *connection, PyObject *sql)
@@ -73,7 +64,7 @@ pysqlite_statement_create(pysqlite_Connection *connection, PyObject *sql)
         return NULL;
     }
 
-    if (pysqlite_check_remaining_sql(tail)) {
+    if (lstrip_sql(tail) != NULL) {
         PyErr_SetString(connection->ProgrammingError,
                         "You can only execute one statement at a time.");
         goto error;
@@ -82,20 +73,12 @@ pysqlite_statement_create(pysqlite_Connection *connection, PyObject *sql)
     /* Determine if the statement is a DML statement.
        SELECT is the only exception. See #9924. */
     int is_dml = 0;
-    for (const char *p = sql_cstr; *p != 0; p++) {
-        switch (*p) {
-            case ' ':
-            case '\r':
-            case '\n':
-            case '\t':
-                continue;
-        }
-
+    const char *p = lstrip_sql(sql_cstr);
+    if (p != NULL) {
         is_dml = (PyOS_strnicmp(p, "insert", 6) == 0)
                   || (PyOS_strnicmp(p, "update", 6) == 0)
                   || (PyOS_strnicmp(p, "delete", 6) == 0)
                   || (PyOS_strnicmp(p, "replace", 7) == 0);
-        break;
     }
 
     pysqlite_Statement *self = PyObject_GC_New(pysqlite_Statement,
@@ -139,73 +122,61 @@ stmt_traverse(pysqlite_Statement *self, visitproc visit, void *arg)
 }
 
 /*
- * Checks if there is anything left in an SQL string after SQLite compiled it.
- * This is used to check if somebody tried to execute more than one SQL command
- * with one execute()/executemany() command, which the DB-API and we don't
- * allow.
+ * Strip leading whitespace and comments from incoming SQL (null terminated C
+ * string) and return a pointer to the first non-whitespace, non-comment
+ * character.
  *
- * Returns 1 if there is more left than should be. 0 if ok.
+ * This is used to check if somebody tries to execute more than one SQL query
+ * with one execute()/executemany() command, which the DB-API don't allow.
+ *
+ * It is also used to harden DML query detection.
  */
-static int pysqlite_check_remaining_sql(const char* tail)
+static inline const char *
+lstrip_sql(const char *sql)
 {
-    const char* pos = tail;
-
-    parse_remaining_sql_state state = NORMAL;
-
-    for (;;) {
+    // This loop is borrowed from the SQLite source code.
+    for (const char *pos = sql; *pos; pos++) {
         switch (*pos) {
-            case 0:
-                return 0;
-            case '-':
-                if (state == NORMAL) {
-                    state  = LINECOMMENT_1;
-                } else if (state == LINECOMMENT_1) {
-                    state = IN_LINECOMMENT;
-                }
-                break;
             case ' ':
             case '\t':
-                break;
+            case '\f':
             case '\n':
-            case 13:
-                if (state == IN_LINECOMMENT) {
-                    state = NORMAL;
-                }
+            case '\r':
+                // Skip whitespace.
                 break;
-            case '/':
-                if (state == NORMAL) {
-                    state = COMMENTSTART_1;
-                } else if (state == COMMENTEND_1) {
-                    state = NORMAL;
-                } else if (state == COMMENTSTART_1) {
-                    return 1;
+            case '-':
+                // Skip line comments.
+                if (pos[1] == '-') {
+                    pos += 2;
+                    while (pos[0] && pos[0] != '\n') {
+                        pos++;
+                    }
+                    if (pos[0] == '\0') {
+                        return NULL;
+                    }
+                    continue;
                 }
-                break;
-            case '*':
-                if (state == NORMAL) {
-                    return 1;
-                } else if (state == LINECOMMENT_1) {
-                    return 1;
-                } else if (state == COMMENTSTART_1) {
-                    state = IN_COMMENT;
-                } else if (state == IN_COMMENT) {
-                    state = COMMENTEND_1;
+                return pos;
+            case '/':
+                // Skip C style comments.
+                if (pos[1] == '*') {
+                    pos += 2;
+                    while (pos[0] && (pos[0] != '*' || pos[1] != '/')) {
+                        pos++;
+                    }
+                    if (pos[0] == '\0') {
+                        return NULL;
+                    }
+                    pos++;
+                    continue;
                 }
-                break;
+                return pos;
             default:
-                if (state == COMMENTEND_1) {
-                    state = IN_COMMENT;
-                } else if (state == IN_LINECOMMENT) {
-                } else if (state == IN_COMMENT) {
-                } else {
-                    return 1;
-                }
+                return pos;
         }
-
-        pos++;
     }
 
-    return 0;
+    return NULL;
 }
 
 static PyType_Slot stmt_slots[] = {



More information about the Python-checkins mailing list