[pypy-commit] pypy stdlib-2.7.3: CPython Issue #10811: sqlite: Fix recursive usage of cursors.

amauryfa noreply at buildbot.pypy.org
Thu Jun 14 23:15:07 CEST 2012


Author: Amaury Forgeot d'Arc <amauryfa at gmail.com>
Branch: stdlib-2.7.3
Changeset: r55670:6749c2482195
Date: 2012-06-14 23:13 +0200
http://bitbucket.org/pypy/pypy/changeset/6749c2482195/

Log:	CPython Issue #10811: sqlite: Fix recursive usage of cursors.

diff --git a/lib_pypy/_sqlite3.py b/lib_pypy/_sqlite3.py
--- a/lib_pypy/_sqlite3.py
+++ b/lib_pypy/_sqlite3.py
@@ -722,6 +722,19 @@
 
 DML, DQL, DDL = range(3)
 
+class CursorLock(object):
+    def __init__(self, cursor):
+        self.cursor = cursor
+
+    def __enter__(self):
+        if self.cursor.locked:
+            raise ProgrammingError("Recursive use of cursors not allowed.")
+        self.cursor.locked = True
+
+    def __exit__(self, *args):
+        self.cursor.locked = False
+
+
 class Cursor(object):
     def __init__(self, con):
         if not isinstance(con, Connection):
@@ -736,6 +749,7 @@
         self.rowcount = -1
         self.statement = None
         self.reset = False
+        self.locked = False
 
     def _check_closed(self):
         if not getattr(self, 'connection', None):
@@ -743,64 +757,72 @@
         self.connection._check_thread()
         self.connection._check_closed()
 
+    def _check_and_lock(self):
+        self._check_closed()
+        return CursorLock(self)
+
     def execute(self, sql, params=None):
-        self._description = None
-        self.reset = False
         if type(sql) is unicode:
             sql = sql.encode("utf-8")
-        self._check_closed()
-        self.statement = self.connection.statement_cache.get(sql, self, self.row_factory)
 
-        if self.connection._isolation_level is not None:
-            if self.statement.kind == DDL:
-                self.connection.commit()
-            elif self.statement.kind == DML:
-                self.connection._begin()
+        with self._check_and_lock():
+            self._description = None
+            self.reset = False
+            self.statement = self.connection.statement_cache.get(
+                sql, self, self.row_factory)
 
-        self.statement.set_params(params)
+            if self.connection._isolation_level is not None:
+                if self.statement.kind == DDL:
+                    self.connection.commit()
+                elif self.statement.kind == DML:
+                    self.connection._begin()
 
-        # Actually execute the SQL statement
-        ret = sqlite.sqlite3_step(self.statement.statement)
-        if ret not in (SQLITE_DONE, SQLITE_ROW):
-            self.statement.reset()
-            raise self.connection._get_exception(ret)
+            self.statement.set_params(params)
 
-        if self.statement.kind == DQL and ret == SQLITE_ROW:
-            self.statement._build_row_cast_map()
-            self.statement._readahead(self)
-        else:
-            self.statement.item = None
-            self.statement.exhausted = True
+            # Actually execute the SQL statement
+            ret = sqlite.sqlite3_step(self.statement.statement)
+            if ret not in (SQLITE_DONE, SQLITE_ROW):
+                self.statement.reset()
+                raise self.connection._get_exception(ret)
 
-        if self.statement.kind == DML:
-            self.statement.reset()
+            if self.statement.kind == DQL and ret == SQLITE_ROW:
+                self.statement._build_row_cast_map()
+                self.statement._readahead(self)
+            else:
+                self.statement.item = None
+                self.statement.exhausted = True
 
-        self.rowcount = -1
-        if self.statement.kind == DML:
-            self.rowcount = sqlite.sqlite3_changes(self.connection.db)
+            if self.statement.kind == DML:
+                self.statement.reset()
+
+            self.rowcount = -1
+            if self.statement.kind == DML:
+                self.rowcount = sqlite.sqlite3_changes(self.connection.db)
 
         return self
 
     def executemany(self, sql, many_params):
-        self._description = None
-        self.reset = False
         if type(sql) is unicode:
             sql = sql.encode("utf-8")
-        self._check_closed()
-        self.statement = self.connection.statement_cache.get(sql, self, self.row_factory)
 
-        if self.statement.kind == DML:
-            self.connection._begin()
-        else:
-            raise ProgrammingError("executemany is only for DML statements")
+        with self._check_and_lock():
+            self._description = None
+            self.reset = False
+            self.statement = self.connection.statement_cache.get(
+                sql, self, self.row_factory)
 
-        self.rowcount = 0
-        for params in many_params:
-            self.statement.set_params(params)
-            ret = sqlite.sqlite3_step(self.statement.statement)
-            if ret != SQLITE_DONE:
-                raise self.connection._get_exception(ret)
-            self.rowcount += sqlite.sqlite3_changes(self.connection.db)
+            if self.statement.kind == DML:
+                self.connection._begin()
+            else:
+                raise ProgrammingError("executemany is only for DML statements")
+
+            self.rowcount = 0
+            for params in many_params:
+                self.statement.set_params(params)
+                ret = sqlite.sqlite3_step(self.statement.statement)
+                if ret != SQLITE_DONE:
+                    raise self.connection._get_exception(ret)
+                self.rowcount += sqlite.sqlite3_changes(self.connection.db)
 
         return self
 


More information about the pypy-commit mailing list