[Python-checkins] bpo-42064: Pass module state to trace, progress, and authorizer callbacks (GH-27940)

encukou webhook-mailer at python.org
Tue Sep 7 09:06:26 EDT 2021


https://github.com/python/cpython/commit/979336de34e3d3f40cf6e666b72a618f6330f3c1
commit: 979336de34e3d3f40cf6e666b72a618f6330f3c1
branch: main
author: Erlend Egeberg Aasland <erlend.aasland at innova.no>
committer: encukou <encukou at gmail.com>
date: 2021-09-07T15:06:17+02:00
summary:

bpo-42064: Pass module state to trace, progress, and authorizer callbacks (GH-27940)

- add print-or-clear traceback helper
- add helpers to clear and visit saved contexts
- modify callbacks to use the new callback_context struct

files:
M Modules/_sqlite/connection.c
M Modules/_sqlite/connection.h

diff --git a/Modules/_sqlite/connection.c b/Modules/_sqlite/connection.c
index 0780d41c5a0c3..bf803370c0571 100644
--- a/Modules/_sqlite/connection.c
+++ b/Modules/_sqlite/connection.c
@@ -55,6 +55,9 @@ static const char * const begin_statements[] = {
 
 static int pysqlite_connection_set_isolation_level(pysqlite_Connection* self, PyObject* isolation_level, void *Py_UNUSED(ignored));
 static void _pysqlite_drop_unused_cursor_references(pysqlite_Connection* self);
+static void free_callback_context(callback_context *ctx);
+static void set_callback_context(callback_context **ctx_pp,
+                                 callback_context *ctx);
 
 static PyObject *
 new_statement_cache(pysqlite_Connection *self, int maxsize)
@@ -170,9 +173,9 @@ pysqlite_connection_init_impl(pysqlite_Connection *self,
     self->thread_ident = PyThread_get_thread_ident();
     self->check_same_thread = check_same_thread;
 
-    self->function_pinboard_trace_callback = NULL;
-    self->function_pinboard_progress_handler = NULL;
-    self->function_pinboard_authorizer_cb = NULL;
+    set_callback_context(&self->trace_ctx, NULL);
+    set_callback_context(&self->progress_ctx, NULL);
+    set_callback_context(&self->authorizer_ctx, NULL);
 
     self->Warning               = state->Warning;
     self->Error                 = state->Error;
@@ -216,6 +219,13 @@ pysqlite_do_all_statements(pysqlite_Connection *self)
     }
 }
 
+#define VISIT_CALLBACK_CONTEXT(ctx) \
+do {                                \
+    if (ctx) {                      \
+        Py_VISIT(ctx->callable);    \
+    }                               \
+} while (0)
+
 static int
 connection_traverse(pysqlite_Connection *self, visitproc visit, void *arg)
 {
@@ -225,12 +235,21 @@ connection_traverse(pysqlite_Connection *self, visitproc visit, void *arg)
     Py_VISIT(self->cursors);
     Py_VISIT(self->row_factory);
     Py_VISIT(self->text_factory);
-    Py_VISIT(self->function_pinboard_trace_callback);
-    Py_VISIT(self->function_pinboard_progress_handler);
-    Py_VISIT(self->function_pinboard_authorizer_cb);
+    VISIT_CALLBACK_CONTEXT(self->trace_ctx);
+    VISIT_CALLBACK_CONTEXT(self->progress_ctx);
+    VISIT_CALLBACK_CONTEXT(self->authorizer_ctx);
+#undef VISIT_CALLBACK_CONTEXT
     return 0;
 }
 
+static inline void
+clear_callback_context(callback_context *ctx)
+{
+    if (ctx != NULL) {
+        Py_CLEAR(ctx->callable);
+    }
+}
+
 static int
 connection_clear(pysqlite_Connection *self)
 {
@@ -239,9 +258,9 @@ connection_clear(pysqlite_Connection *self)
     Py_CLEAR(self->cursors);
     Py_CLEAR(self->row_factory);
     Py_CLEAR(self->text_factory);
-    Py_CLEAR(self->function_pinboard_trace_callback);
-    Py_CLEAR(self->function_pinboard_progress_handler);
-    Py_CLEAR(self->function_pinboard_authorizer_cb);
+    clear_callback_context(self->trace_ctx);
+    clear_callback_context(self->progress_ctx);
+    clear_callback_context(self->authorizer_ctx);
     return 0;
 }
 
@@ -255,6 +274,14 @@ connection_close(pysqlite_Connection *self)
     }
 }
 
+static void
+free_callback_contexts(pysqlite_Connection *self)
+{
+    set_callback_context(&self->trace_ctx, NULL);
+    set_callback_context(&self->progress_ctx, NULL);
+    set_callback_context(&self->authorizer_ctx, NULL);
+}
+
 static void
 connection_dealloc(pysqlite_Connection *self)
 {
@@ -264,6 +291,7 @@ connection_dealloc(pysqlite_Connection *self)
 
     /* Clean up if user has not called .close() explicitly. */
     connection_close(self);
+    free_callback_contexts(self);
 
     tp->tp_free(self);
     Py_DECREF(tp);
@@ -600,6 +628,19 @@ _pysqlite_build_py_params(sqlite3_context *context, int argc,
     return NULL;
 }
 
+static void
+print_or_clear_traceback(callback_context *ctx)
+{
+    assert(ctx != NULL);
+    assert(ctx->state != NULL);
+    if (ctx->state->enable_callback_tracebacks) {
+        PyErr_Print();
+    }
+    else {
+        PyErr_Clear();
+    }
+}
+
 // Checks the Python exception and sets the appropriate SQLite error code.
 static void
 set_sqlite_error(sqlite3_context *context, const char *msg)
@@ -615,14 +656,7 @@ set_sqlite_error(sqlite3_context *context, const char *msg)
         sqlite3_result_error(context, msg, -1);
     }
     callback_context *ctx = (callback_context *)sqlite3_user_data(context);
-    assert(ctx != NULL);
-    assert(ctx->state != NULL);
-    if (ctx->state->enable_callback_tracebacks) {
-        PyErr_Print();
-    }
-    else {
-        PyErr_Clear();
-    }
+    print_or_clear_traceback(ctx);
 }
 
 static void
@@ -796,10 +830,21 @@ static void
 free_callback_context(callback_context *ctx)
 {
     assert(ctx != NULL);
-    Py_DECREF(ctx->callable);
+    Py_XDECREF(ctx->callable);
     PyMem_Free(ctx);
 }
 
+static void
+set_callback_context(callback_context **ctx_pp, callback_context *ctx)
+{
+    assert(ctx_pp != NULL);
+    callback_context *tmp = *ctx_pp;
+    *ctx_pp = ctx;
+    if (tmp != NULL) {
+        free_callback_context(tmp);
+    }
+}
+
 static void
 destructor_callback(void *ctx)
 {
@@ -917,33 +962,22 @@ authorizer_callback(void *ctx, int action, const char *arg1,
     PyGILState_STATE gilstate = PyGILState_Ensure();
 
     PyObject *ret;
-    int rc;
+    int rc = SQLITE_DENY;
 
-    ret = PyObject_CallFunction((PyObject*)ctx, "issss", action, arg1, arg2,
-                                dbname, access_attempt_source);
+    assert(ctx != NULL);
+    PyObject *callable = ((callback_context *)ctx)->callable;
+    ret = PyObject_CallFunction(callable, "issss", action, arg1, arg2, dbname,
+                                access_attempt_source);
 
     if (ret == NULL) {
-        pysqlite_state *state = pysqlite_get_state(NULL);
-        if (state->enable_callback_tracebacks) {
-            PyErr_Print();
-        }
-        else {
-            PyErr_Clear();
-        }
-
+        print_or_clear_traceback(ctx);
         rc = SQLITE_DENY;
     }
     else {
         if (PyLong_Check(ret)) {
             rc = _PyLong_AsInt(ret);
             if (rc == -1 && PyErr_Occurred()) {
-                pysqlite_state *state = pysqlite_get_state(NULL);
-                if (state->enable_callback_tracebacks) {
-                    PyErr_Print();
-                }
-                else {
-                    PyErr_Clear();
-                }
+                print_or_clear_traceback(ctx);
                 rc = SQLITE_DENY;
             }
         }
@@ -964,8 +998,10 @@ progress_callback(void *ctx)
 
     int rc;
     PyObject *ret;
-    ret = _PyObject_CallNoArg((PyObject*)ctx);
 
+    assert(ctx != NULL);
+    PyObject *callable = ((callback_context *)ctx)->callable;
+    ret = _PyObject_CallNoArg(callable);
     if (!ret) {
         /* abort query if error occurred */
         rc = -1;
@@ -975,13 +1011,7 @@ progress_callback(void *ctx)
         Py_DECREF(ret);
     }
     if (rc < 0) {
-        pysqlite_state *state = pysqlite_get_state(NULL);
-        if (state->enable_callback_tracebacks) {
-            PyErr_Print();
-        }
-        else {
-            PyErr_Clear();
-        }
+        print_or_clear_traceback(ctx);
     }
 
     PyGILState_Release(gilstate);
@@ -1015,21 +1045,18 @@ trace_callback(void *ctx, const char *statement_string)
     PyObject *ret = NULL;
     py_statement = PyUnicode_DecodeUTF8(statement_string,
             strlen(statement_string), "replace");
+    assert(ctx != NULL);
     if (py_statement) {
-        ret = PyObject_CallOneArg((PyObject*)ctx, py_statement);
+        PyObject *callable = ((callback_context *)ctx)->callable;
+        ret = PyObject_CallOneArg(callable, py_statement);
         Py_DECREF(py_statement);
     }
 
     if (ret) {
         Py_DECREF(ret);
-    } else {
-        pysqlite_state *state = pysqlite_get_state(NULL);
-        if (state->enable_callback_tracebacks) {
-            PyErr_Print();
-        }
-        else {
-            PyErr_Clear();
-        }
+    }
+    else {
+        print_or_clear_traceback(ctx);
     }
 
     PyGILState_Release(gilstate);
@@ -1058,17 +1085,20 @@ pysqlite_connection_set_authorizer_impl(pysqlite_Connection *self,
     int rc;
     if (callable == Py_None) {
         rc = sqlite3_set_authorizer(self->db, NULL, NULL);
-        Py_XSETREF(self->function_pinboard_authorizer_cb, NULL);
+        set_callback_context(&self->authorizer_ctx, NULL);
     }
     else {
-        Py_INCREF(callable);
-        Py_XSETREF(self->function_pinboard_authorizer_cb, callable);
-        rc = sqlite3_set_authorizer(self->db, authorizer_callback, callable);
+        callback_context *ctx = create_callback_context(self->state, callable);
+        if (ctx == NULL) {
+            return NULL;
+        }
+        rc = sqlite3_set_authorizer(self->db, authorizer_callback, ctx);
+        set_callback_context(&self->authorizer_ctx, ctx);
     }
     if (rc != SQLITE_OK) {
         PyErr_SetString(self->OperationalError,
                         "Error setting authorizer callback");
-        Py_XSETREF(self->function_pinboard_authorizer_cb, NULL);
+        set_callback_context(&self->authorizer_ctx, NULL);
         return NULL;
     }
     Py_RETURN_NONE;
@@ -1095,11 +1125,15 @@ pysqlite_connection_set_progress_handler_impl(pysqlite_Connection *self,
     if (callable == Py_None) {
         /* None clears the progress handler previously set */
         sqlite3_progress_handler(self->db, 0, 0, (void*)0);
-        Py_XSETREF(self->function_pinboard_progress_handler, NULL);
-    } else {
-        sqlite3_progress_handler(self->db, n, progress_callback, callable);
-        Py_INCREF(callable);
-        Py_XSETREF(self->function_pinboard_progress_handler, callable);
+        set_callback_context(&self->progress_ctx, NULL);
+    }
+    else {
+        callback_context *ctx = create_callback_context(self->state, callable);
+        if (ctx == NULL) {
+            return NULL;
+        }
+        sqlite3_progress_handler(self->db, n, progress_callback, ctx);
+        set_callback_context(&self->progress_ctx, ctx);
     }
     Py_RETURN_NONE;
 }
@@ -1136,15 +1170,19 @@ pysqlite_connection_set_trace_callback_impl(pysqlite_Connection *self,
 #else
         sqlite3_trace(self->db, 0, (void*)0);
 #endif
-        Py_XSETREF(self->function_pinboard_trace_callback, NULL);
-    } else {
+        set_callback_context(&self->trace_ctx, NULL);
+    }
+    else {
+        callback_context *ctx = create_callback_context(self->state, callable);
+        if (ctx == NULL) {
+            return NULL;
+        }
 #ifdef HAVE_TRACE_V2
-        sqlite3_trace_v2(self->db, SQLITE_TRACE_STMT, trace_callback, callable);
+        sqlite3_trace_v2(self->db, SQLITE_TRACE_STMT, trace_callback, ctx);
 #else
-        sqlite3_trace(self->db, trace_callback, callable);
+        sqlite3_trace(self->db, trace_callback, ctx);
 #endif
-        Py_INCREF(callable);
-        Py_XSETREF(self->function_pinboard_trace_callback, callable);
+        set_callback_context(&self->trace_ctx, ctx);
     }
 
     Py_RETURN_NONE;
diff --git a/Modules/_sqlite/connection.h b/Modules/_sqlite/connection.h
index 11b3a8005e1f9..c4cec857ddbfe 100644
--- a/Modules/_sqlite/connection.h
+++ b/Modules/_sqlite/connection.h
@@ -82,10 +82,10 @@ typedef struct
      */
     PyObject* text_factory;
 
-    /* remember references to object used in trace_callback/progress_handler/authorizer_cb */
-    PyObject* function_pinboard_trace_callback;
-    PyObject* function_pinboard_progress_handler;
-    PyObject* function_pinboard_authorizer_cb;
+    // Remember contexts used by the trace, progress, and authoriser callbacks
+    callback_context *trace_ctx;
+    callback_context *progress_ctx;
+    callback_context *authorizer_ctx;
 
     /* Exception objects: borrowed refs. */
     PyObject* Warning;



More information about the Python-checkins mailing list