[Python-checkins] bpo-44859: Improve error handling in sqlite3 and and raise more accurate exceptions. (GH-27654)

serhiy-storchaka webhook-mailer at python.org
Sun Aug 8 01:50:01 EDT 2021


https://github.com/python/cpython/commit/0eec6276fdcdde5221370d92b50ea95851760c72
commit: 0eec6276fdcdde5221370d92b50ea95851760c72
branch: main
author: Serhiy Storchaka <storchaka at gmail.com>
committer: serhiy-storchaka <storchaka at gmail.com>
date: 2021-08-08T08:49:44+03:00
summary:

bpo-44859: Improve error handling in sqlite3 and and raise more accurate exceptions. (GH-27654)

* MemoryError is now raised instead of sqlite3.Warning when
  memory is not enough for encoding a statement to UTF-8
  in Connection.__call__() and Cursor.execute().
* UnicodEncodeError is now raised instead of sqlite3.Warning when
  the statement contains surrogate characters
  in Connection.__call__() and Cursor.execute().
* TypeError is now raised instead of ValueError for non-string
  script argument in Cursor.executescript().
* ValueError is now raised for script containing the null
  character instead of truncating it in Cursor.executescript().
* Correctly handle exceptions raised when getting boolean value
  of the result of the progress handler.
* Add many tests covering different corner cases.

Co-authored-by: Erlend Egeberg Aasland <erlend.aasland at innova.no>

files:
A Misc/NEWS.d/next/Library/2021-08-07-17-28-56.bpo-44859.CCopjk.rst
M Lib/sqlite3/test/dbapi.py
M Lib/sqlite3/test/hooks.py
M Lib/sqlite3/test/regression.py
M Lib/sqlite3/test/types.py
M Lib/sqlite3/test/userfunctions.py
M Modules/_sqlite/clinic/cursor.c.h
M Modules/_sqlite/connection.c
M Modules/_sqlite/cursor.c
M Modules/_sqlite/statement.c

diff --git a/Lib/sqlite3/test/dbapi.py b/Lib/sqlite3/test/dbapi.py
index 408f9945f2c970..5d7e5bba05bc45 100644
--- a/Lib/sqlite3/test/dbapi.py
+++ b/Lib/sqlite3/test/dbapi.py
@@ -26,7 +26,7 @@
 import threading
 import unittest
 
-from test.support import check_disallow_instantiation, threading_helper
+from test.support import check_disallow_instantiation, threading_helper, bigmemtest
 from test.support.os_helper import TESTFN, unlink
 
 
@@ -758,9 +758,35 @@ def test_script_error_normal(self):
     def test_cursor_executescript_as_bytes(self):
         con = sqlite.connect(":memory:")
         cur = con.cursor()
-        with self.assertRaises(ValueError) as cm:
+        with self.assertRaises(TypeError):
             cur.executescript(b"create table test(foo); insert into test(foo) values (5);")
-        self.assertEqual(str(cm.exception), 'script argument must be unicode.')
+
+    def test_cursor_executescript_with_null_characters(self):
+        con = sqlite.connect(":memory:")
+        cur = con.cursor()
+        with self.assertRaises(ValueError):
+            cur.executescript("""
+                create table a(i);\0
+                insert into a(i) values (5);
+                """)
+
+    def test_cursor_executescript_with_surrogates(self):
+        con = sqlite.connect(":memory:")
+        cur = con.cursor()
+        with self.assertRaises(UnicodeEncodeError):
+            cur.executescript("""
+                create table a(s);
+                insert into a(s) values ('\ud8ff');
+                """)
+
+    @unittest.skipUnless(sys.maxsize > 2**32, 'requires 64bit platform')
+    @bigmemtest(size=2**31, memuse=3, dry_run=False)
+    def test_cursor_executescript_too_large_script(self, maxsize):
+        con = sqlite.connect(":memory:")
+        cur = con.cursor()
+        for size in 2**31-1, 2**31:
+            with self.assertRaises(sqlite.DataError):
+                cur.executescript("create table a(s);".ljust(size))
 
     def test_connection_execute(self):
         con = sqlite.connect(":memory:")
@@ -969,6 +995,7 @@ def suite():
         CursorTests,
         ExtensionTests,
         ModuleTests,
+        OpenTests,
         SqliteOnConflictTests,
         ThreadTests,
         UninitialisedConnectionTests,
diff --git a/Lib/sqlite3/test/hooks.py b/Lib/sqlite3/test/hooks.py
index 1be6d380abd20a..43e3810d13df18 100644
--- a/Lib/sqlite3/test/hooks.py
+++ b/Lib/sqlite3/test/hooks.py
@@ -24,7 +24,7 @@
 import sqlite3 as sqlite
 
 from test.support.os_helper import TESTFN, unlink
-
+from .userfunctions import with_tracebacks
 
 class CollationTests(unittest.TestCase):
     def test_create_collation_not_string(self):
@@ -145,7 +145,6 @@ def progress():
             """)
         self.assertTrue(progress_calls)
 
-
     def test_opcode_count(self):
         """
         Test that the opcode argument is respected.
@@ -198,6 +197,32 @@ def progress():
         con.execute("select 1 union select 2 union select 3").fetchall()
         self.assertEqual(action, 0, "progress handler was not cleared")
 
+    @with_tracebacks(['bad_progress', 'ZeroDivisionError'])
+    def test_error_in_progress_handler(self):
+        con = sqlite.connect(":memory:")
+        def bad_progress():
+            1 / 0
+        con.set_progress_handler(bad_progress, 1)
+        with self.assertRaises(sqlite.OperationalError):
+            con.execute("""
+                create table foo(a, b)
+                """)
+
+    @with_tracebacks(['__bool__', 'ZeroDivisionError'])
+    def test_error_in_progress_handler_result(self):
+        con = sqlite.connect(":memory:")
+        class BadBool:
+            def __bool__(self):
+                1 / 0
+        def bad_progress():
+            return BadBool()
+        con.set_progress_handler(bad_progress, 1)
+        with self.assertRaises(sqlite.OperationalError):
+            con.execute("""
+                create table foo(a, b)
+                """)
+
+
 class TraceCallbackTests(unittest.TestCase):
     def test_trace_callback_used(self):
         """
diff --git a/Lib/sqlite3/test/regression.py b/Lib/sqlite3/test/regression.py
index 6c093d7c2c36e0..ddf36e71819445 100644
--- a/Lib/sqlite3/test/regression.py
+++ b/Lib/sqlite3/test/regression.py
@@ -21,6 +21,7 @@
 # 3. This notice may not be removed or altered from any source distribution.
 
 import datetime
+import sys
 import unittest
 import sqlite3 as sqlite
 import weakref
@@ -273,7 +274,7 @@ def test_connection_call(self):
         Call a connection with a non-string SQL request: check error handling
         of the statement constructor.
         """
-        self.assertRaises(TypeError, self.con, 1)
+        self.assertRaises(TypeError, self.con, b"select 1")
 
     def test_collation(self):
         def collation_cb(a, b):
@@ -344,6 +345,26 @@ def test_null_character(self):
         self.assertRaises(ValueError, cur.execute, " \0select 2")
         self.assertRaises(ValueError, cur.execute, "select 2\0")
 
+    def test_surrogates(self):
+        con = sqlite.connect(":memory:")
+        self.assertRaises(UnicodeEncodeError, con, "select '\ud8ff'")
+        self.assertRaises(UnicodeEncodeError, con, "select '\udcff'")
+        cur = con.cursor()
+        self.assertRaises(UnicodeEncodeError, cur.execute, "select '\ud8ff'")
+        self.assertRaises(UnicodeEncodeError, cur.execute, "select '\udcff'")
+
+    @unittest.skipUnless(sys.maxsize > 2**32, 'requires 64bit platform')
+    @support.bigmemtest(size=2**31, memuse=4, dry_run=False)
+    def test_large_sql(self, maxsize):
+        # Test two cases: size+1 > INT_MAX and size+1 <= INT_MAX.
+        for size in (2**31, 2**31-2):
+            con = sqlite.connect(":memory:")
+            sql = "select 1".ljust(size)
+            self.assertRaises(sqlite.DataError, con, sql)
+            cur = con.cursor()
+            self.assertRaises(sqlite.DataError, cur.execute, sql)
+            del sql
+
     def test_commit_cursor_reset(self):
         """
         Connection.commit() did reset cursors, which made sqlite3
diff --git a/Lib/sqlite3/test/types.py b/Lib/sqlite3/test/types.py
index 4f0e4f6d268392..b8926ffee22e87 100644
--- a/Lib/sqlite3/test/types.py
+++ b/Lib/sqlite3/test/types.py
@@ -23,11 +23,14 @@
 import datetime
 import unittest
 import sqlite3 as sqlite
+import sys
 try:
     import zlib
 except ImportError:
     zlib = None
 
+from test import support
+
 
 class SqliteTypeTests(unittest.TestCase):
     def setUp(self):
@@ -45,6 +48,12 @@ def test_string(self):
         row = self.cur.fetchone()
         self.assertEqual(row[0], "Österreich")
 
+    def test_string_with_null_character(self):
+        self.cur.execute("insert into test(s) values (?)", ("a\0b",))
+        self.cur.execute("select s from test")
+        row = self.cur.fetchone()
+        self.assertEqual(row[0], "a\0b")
+
     def test_small_int(self):
         self.cur.execute("insert into test(i) values (?)", (42,))
         self.cur.execute("select i from test")
@@ -52,7 +61,7 @@ def test_small_int(self):
         self.assertEqual(row[0], 42)
 
     def test_large_int(self):
-        num = 2**40
+        num = 123456789123456789
         self.cur.execute("insert into test(i) values (?)", (num,))
         self.cur.execute("select i from test")
         row = self.cur.fetchone()
@@ -78,6 +87,45 @@ def test_unicode_execute(self):
         row = self.cur.fetchone()
         self.assertEqual(row[0], "Österreich")
 
+    def test_too_large_int(self):
+        for value in 2**63, -2**63-1, 2**64:
+            with self.assertRaises(OverflowError):
+                self.cur.execute("insert into test(i) values (?)", (value,))
+        self.cur.execute("select i from test")
+        row = self.cur.fetchone()
+        self.assertIsNone(row)
+
+    def test_string_with_surrogates(self):
+        for value in 0xd8ff, 0xdcff:
+            with self.assertRaises(UnicodeEncodeError):
+                self.cur.execute("insert into test(s) values (?)", (chr(value),))
+        self.cur.execute("select s from test")
+        row = self.cur.fetchone()
+        self.assertIsNone(row)
+
+    @unittest.skipUnless(sys.maxsize > 2**32, 'requires 64bit platform')
+    @support.bigmemtest(size=2**31, memuse=4, dry_run=False)
+    def test_too_large_string(self, maxsize):
+        with self.assertRaises(sqlite.InterfaceError):
+            self.cur.execute("insert into test(s) values (?)", ('x'*(2**31-1),))
+        with self.assertRaises(OverflowError):
+            self.cur.execute("insert into test(s) values (?)", ('x'*(2**31),))
+        self.cur.execute("select 1 from test")
+        row = self.cur.fetchone()
+        self.assertIsNone(row)
+
+    @unittest.skipUnless(sys.maxsize > 2**32, 'requires 64bit platform')
+    @support.bigmemtest(size=2**31, memuse=3, dry_run=False)
+    def test_too_large_blob(self, maxsize):
+        with self.assertRaises(sqlite.InterfaceError):
+            self.cur.execute("insert into test(s) values (?)", (b'x'*(2**31-1),))
+        with self.assertRaises(OverflowError):
+            self.cur.execute("insert into test(s) values (?)", (b'x'*(2**31),))
+        self.cur.execute("select 1 from test")
+        row = self.cur.fetchone()
+        self.assertIsNone(row)
+
+
 class DeclTypesTests(unittest.TestCase):
     class Foo:
         def __init__(self, _val):
@@ -163,7 +211,7 @@ def test_small_int(self):
 
     def test_large_int(self):
         # default
-        num = 2**40
+        num = 123456789123456789
         self.cur.execute("insert into test(i) values (?)", (num,))
         self.cur.execute("select i from test")
         row = self.cur.fetchone()
diff --git a/Lib/sqlite3/test/userfunctions.py b/Lib/sqlite3/test/userfunctions.py
index 9681dbdde2b092..b4d5181777ebdf 100644
--- a/Lib/sqlite3/test/userfunctions.py
+++ b/Lib/sqlite3/test/userfunctions.py
@@ -33,28 +33,37 @@
 from test.support import bigmemtest
 
 
-def with_tracebacks(strings):
+def with_tracebacks(strings, traceback=True):
     """Convenience decorator for testing callback tracebacks."""
-    strings.append('Traceback')
+    if traceback:
+        strings.append('Traceback')
 
     def decorator(func):
         @functools.wraps(func)
         def wrapper(self, *args, **kwargs):
             # First, run the test with traceback enabled.
-            sqlite.enable_callback_tracebacks(True)
-            buf = io.StringIO()
-            with contextlib.redirect_stderr(buf):
+            with check_tracebacks(self, strings):
                 func(self, *args, **kwargs)
-            tb = buf.getvalue()
-            for s in strings:
-                self.assertIn(s, tb)
 
             # Then run the test with traceback disabled.
-            sqlite.enable_callback_tracebacks(False)
             func(self, *args, **kwargs)
         return wrapper
     return decorator
 
+ at contextlib.contextmanager
+def check_tracebacks(self, strings):
+    """Convenience context manager for testing callback tracebacks."""
+    sqlite.enable_callback_tracebacks(True)
+    try:
+        buf = io.StringIO()
+        with contextlib.redirect_stderr(buf):
+            yield
+        tb = buf.getvalue()
+        for s in strings:
+            self.assertIn(s, tb)
+    finally:
+        sqlite.enable_callback_tracebacks(False)
+
 def func_returntext():
     return "foo"
 def func_returntextwithnull():
@@ -408,9 +417,26 @@ def md5sum(t):
         del x,y
         gc.collect()
 
+    def test_func_return_too_large_int(self):
+        cur = self.con.cursor()
+        for value in 2**63, -2**63-1, 2**64:
+            self.con.create_function("largeint", 0, lambda value=value: value)
+            with check_tracebacks(self, ['OverflowError']):
+                with self.assertRaises(sqlite.DataError):
+                    cur.execute("select largeint()")
+
+    def test_func_return_text_with_surrogates(self):
+        cur = self.con.cursor()
+        self.con.create_function("pychr", 1, chr)
+        for value in 0xd8ff, 0xdcff:
+            with check_tracebacks(self,
+                    ['UnicodeEncodeError', 'surrogates not allowed']):
+                with self.assertRaises(sqlite.OperationalError):
+                    cur.execute("select pychr(?)", (value,))
+
     @unittest.skipUnless(sys.maxsize > 2**32, 'requires 64bit platform')
     @bigmemtest(size=2**31, memuse=3, dry_run=False)
-    def test_large_text(self, size):
+    def test_func_return_too_large_text(self, size):
         cur = self.con.cursor()
         for size in 2**31-1, 2**31:
             self.con.create_function("largetext", 0, lambda size=size: "b" * size)
@@ -419,7 +445,7 @@ def test_large_text(self, size):
 
     @unittest.skipUnless(sys.maxsize > 2**32, 'requires 64bit platform')
     @bigmemtest(size=2**31, memuse=2, dry_run=False)
-    def test_large_blob(self, size):
+    def test_func_return_too_large_blob(self, size):
         cur = self.con.cursor()
         for size in 2**31-1, 2**31:
             self.con.create_function("largeblob", 0, lambda size=size: b"b" * size)
diff --git a/Misc/NEWS.d/next/Library/2021-08-07-17-28-56.bpo-44859.CCopjk.rst b/Misc/NEWS.d/next/Library/2021-08-07-17-28-56.bpo-44859.CCopjk.rst
new file mode 100644
index 00000000000000..ec9f774d66b8c4
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2021-08-07-17-28-56.bpo-44859.CCopjk.rst
@@ -0,0 +1,8 @@
+Improve error handling in :mod:`sqlite3` and raise more accurate exceptions.
+
+* :exc:`MemoryError` is now raised instead of :exc:`sqlite3.Warning` when memory is not enough for encoding a statement to UTF-8 in ``Connection.__call__()`` and ``Cursor.execute()``.
+* :exc:`UnicodEncodeError` is now raised instead of :exc:`sqlite3.Warning` when the statement contains surrogate characters in ``Connection.__call__()`` and ``Cursor.execute()``.
+* :exc:`TypeError` is now raised instead of :exc:`ValueError` for non-string script argument in ``Cursor.executescript()``.
+* :exc:`ValueError` is now raised for script containing the null character instead of truncating it in ``Cursor.executescript()``.
+* Correctly handle exceptions raised when getting boolean value of the result of the progress handler.
+* Add many tests covering different corner cases.
diff --git a/Modules/_sqlite/clinic/cursor.c.h b/Modules/_sqlite/clinic/cursor.c.h
index d2c453b38b4b9e..07e15870146cf7 100644
--- a/Modules/_sqlite/clinic/cursor.c.h
+++ b/Modules/_sqlite/clinic/cursor.c.h
@@ -119,6 +119,35 @@ PyDoc_STRVAR(pysqlite_cursor_executescript__doc__,
 #define PYSQLITE_CURSOR_EXECUTESCRIPT_METHODDEF    \
     {"executescript", (PyCFunction)pysqlite_cursor_executescript, METH_O, pysqlite_cursor_executescript__doc__},
 
+static PyObject *
+pysqlite_cursor_executescript_impl(pysqlite_Cursor *self,
+                                   const char *sql_script);
+
+static PyObject *
+pysqlite_cursor_executescript(pysqlite_Cursor *self, PyObject *arg)
+{
+    PyObject *return_value = NULL;
+    const char *sql_script;
+
+    if (!PyUnicode_Check(arg)) {
+        _PyArg_BadArgument("executescript", "argument", "str", arg);
+        goto exit;
+    }
+    Py_ssize_t sql_script_length;
+    sql_script = PyUnicode_AsUTF8AndSize(arg, &sql_script_length);
+    if (sql_script == NULL) {
+        goto exit;
+    }
+    if (strlen(sql_script) != (size_t)sql_script_length) {
+        PyErr_SetString(PyExc_ValueError, "embedded null character");
+        goto exit;
+    }
+    return_value = pysqlite_cursor_executescript_impl(self, sql_script);
+
+exit:
+    return return_value;
+}
+
 PyDoc_STRVAR(pysqlite_cursor_fetchone__doc__,
 "fetchone($self, /)\n"
 "--\n"
@@ -270,4 +299,4 @@ pysqlite_cursor_close(pysqlite_Cursor *self, PyTypeObject *cls, PyObject *const
 exit:
     return return_value;
 }
-/*[clinic end generated code: output=7b216aba2439f5cf input=a9049054013a1b77]*/
+/*[clinic end generated code: output=ace31a7481aa3f41 input=a9049054013a1b77]*/
diff --git a/Modules/_sqlite/connection.c b/Modules/_sqlite/connection.c
index 0dab3e85160e82..67160c4c449aa1 100644
--- a/Modules/_sqlite/connection.c
+++ b/Modules/_sqlite/connection.c
@@ -997,6 +997,14 @@ static int _progress_handler(void* user_arg)
     ret = _PyObject_CallNoArg((PyObject*)user_arg);
 
     if (!ret) {
+        /* abort query if error occurred */
+        rc = -1;
+    }
+    else {
+        rc = PyObject_IsTrue(ret);
+        Py_DECREF(ret);
+    }
+    if (rc < 0) {
         pysqlite_state *state = pysqlite_get_state(NULL);
         if (state->enable_callback_tracebacks) {
             PyErr_Print();
@@ -1004,12 +1012,6 @@ static int _progress_handler(void* user_arg)
         else {
             PyErr_Clear();
         }
-
-        /* abort query if error occurred */
-        rc = 1;
-    } else {
-        rc = (int)PyObject_IsTrue(ret);
-        Py_DECREF(ret);
     }
 
     PyGILState_Release(gilstate);
diff --git a/Modules/_sqlite/cursor.c b/Modules/_sqlite/cursor.c
index 2f4494690f9557..7308f3062da4b9 100644
--- a/Modules/_sqlite/cursor.c
+++ b/Modules/_sqlite/cursor.c
@@ -728,21 +728,21 @@ pysqlite_cursor_executemany_impl(pysqlite_Cursor *self, PyObject *sql,
 /*[clinic input]
 _sqlite3.Cursor.executescript as pysqlite_cursor_executescript
 
-    sql_script as script_obj: object
+    sql_script: str
     /
 
 Executes multiple SQL statements at once. Non-standard.
 [clinic start generated code]*/
 
 static PyObject *
-pysqlite_cursor_executescript(pysqlite_Cursor *self, PyObject *script_obj)
-/*[clinic end generated code: output=115a8132b0f200fe input=ba3ec59df205e362]*/
+pysqlite_cursor_executescript_impl(pysqlite_Cursor *self,
+                                   const char *sql_script)
+/*[clinic end generated code: output=8fd726dde1c65164 input=1ac0693dc8db02a8]*/
 {
     _Py_IDENTIFIER(commit);
-    const char* script_cstr;
     sqlite3_stmt* statement;
     int rc;
-    Py_ssize_t sql_len;
+    size_t sql_len;
     PyObject* result;
 
     if (!check_cursor(self)) {
@@ -751,21 +751,12 @@ pysqlite_cursor_executescript(pysqlite_Cursor *self, PyObject *script_obj)
 
     self->reset = 0;
 
-    if (PyUnicode_Check(script_obj)) {
-        script_cstr = PyUnicode_AsUTF8AndSize(script_obj, &sql_len);
-        if (!script_cstr) {
-            return NULL;
-        }
-
-        int max_length = sqlite3_limit(self->connection->db,
-                                       SQLITE_LIMIT_LENGTH, -1);
-        if (sql_len >= max_length) {
-            PyErr_SetString(self->connection->DataError,
-                            "query string is too large");
-            return NULL;
-        }
-    } else {
-        PyErr_SetString(PyExc_ValueError, "script argument must be unicode.");
+    sql_len = strlen(sql_script);
+    int max_length = sqlite3_limit(self->connection->db,
+                                   SQLITE_LIMIT_LENGTH, -1);
+    if (sql_len >= (unsigned)max_length) {
+        PyErr_SetString(self->connection->DataError,
+                        "query string is too large");
         return NULL;
     }
 
@@ -782,7 +773,7 @@ pysqlite_cursor_executescript(pysqlite_Cursor *self, PyObject *script_obj)
 
         Py_BEGIN_ALLOW_THREADS
         rc = sqlite3_prepare_v2(self->connection->db,
-                                script_cstr,
+                                sql_script,
                                 (int)sql_len + 1,
                                 &statement,
                                 &tail);
@@ -816,8 +807,8 @@ pysqlite_cursor_executescript(pysqlite_Cursor *self, PyObject *script_obj)
         if (*tail == (char)0) {
             break;
         }
-        sql_len -= (tail - script_cstr);
-        script_cstr = tail;
+        sql_len -= (tail - sql_script);
+        sql_script = tail;
     }
 
 error:
diff --git a/Modules/_sqlite/statement.c b/Modules/_sqlite/statement.c
index 983df2d50c975d..2d5c72d13b7edb 100644
--- a/Modules/_sqlite/statement.c
+++ b/Modules/_sqlite/statement.c
@@ -56,9 +56,6 @@ pysqlite_statement_create(pysqlite_Connection *connection, PyObject *sql)
     Py_ssize_t size;
     const char *sql_cstr = PyUnicode_AsUTF8AndSize(sql, &size);
     if (sql_cstr == NULL) {
-        PyErr_Format(connection->Warning,
-                     "SQL is of wrong type ('%s'). Must be string.",
-                     Py_TYPE(sql)->tp_name);
         return NULL;
     }
 



More information about the Python-checkins mailing list