[pypy-svn] pypy default: SELECT statements should execute the query in execute(), and not wait for the user to call fetchone()

amauryfa commits-noreply at bitbucket.org
Wed Feb 9 19:13:31 CET 2011


Author: Amaury Forgeot d'Arc <amauryfa at gmail.com>
Branch: 
Changeset: r41751:a45980624e0e
Date: 2011-02-09 19:12 +0100
http://bitbucket.org/pypy/pypy/changeset/a45980624e0e/

Log:	SELECT statements should execute the query in execute(), and not
	wait for the user to call fetchone()

diff --git a/lib_pypy/_sqlite3.py b/lib_pypy/_sqlite3.py
--- a/lib_pypy/_sqlite3.py
+++ b/lib_pypy/_sqlite3.py
@@ -257,6 +257,7 @@
         self.NotSupportedError = NotSupportedError
 
         self.func_cache = {}
+        self._aggregates = {}
         self.aggregate_instances = {}
         self._collations = {}
         self.thread_ident = thread_get_ident()
@@ -378,6 +379,12 @@
         self._check_closed()
         if sqlite.sqlite3_get_autocommit(self.db):
             return
+
+        for statement in self.statements:
+            obj = statement()
+            if obj is not None:
+                obj.reset()
+
         try:
             sql = "COMMIT"
             statement = c_void_p()
@@ -396,6 +403,12 @@
         self._check_closed()
         if sqlite.sqlite3_get_autocommit(self.db):
             return
+
+        for statement in self.statements:
+            obj = statement()
+            if obj is not None:
+                obj.reset()
+
         try:
             sql = "ROLLBACK"
             statement = c_void_p()
@@ -501,7 +514,7 @@
         self._check_closed()
 
         try:
-            c_authorizer, _ = self.func_cache[callable]
+            c_authorizer, _ = self.func_cache[callback]
         except KeyError:
             def authorizer(userdata, action, arg1, arg2, dbname, source):
                 try:
@@ -510,7 +523,7 @@
                     return SQLITE_DENY
             c_authorizer = AUTHORIZER(authorizer)
 
-            self.func_cache[callable] = c_authorizer, authorizer
+            self.func_cache[callback] = c_authorizer, authorizer
 
         ret = sqlite.sqlite3_set_authorizer(self.db,
                                             c_authorizer,
@@ -541,7 +554,7 @@
         self._check_closed()
 
         try:
-            c_step_callback, c_final_callback = self.func_cache[cls]
+            c_step_callback, c_final_callback, _, _ = self._aggregates[cls]
         except KeyError:
             def step_callback(context, argc, c_params):
 
@@ -565,8 +578,9 @@
                     aggregate = self.aggregate_instances[aggregate_ptr[0]]
 
                 params = _convert_params(context, argc, c_params)
+                step = aggregate.step
                 try:
-                    aggregate.step(*params)
+                    step(*params)
                 except Exception, e:
                     msg = ("user-defined aggregate's 'step' "
                            "method raised error")
@@ -595,7 +609,8 @@
             c_step_callback = STEP(step_callback)
             c_final_callback = FINAL(final_callback)
 
-            self.func_cache[cls] = c_step_callback, c_final_callback
+            self._aggregates[cls] = (c_step_callback, c_final_callback,
+                                     step_callback, final_callback)
 
         ret = sqlite.sqlite3_create_function(self.db, name, num_args,
                                              SQLITE_UTF8, None,
@@ -639,10 +654,21 @@
                 self.connection._begin()
 
         self.statement.set_params(params)
+
+        # Actually execute the SQL statement
+        ret = sqlite.sqlite3_step(self.statement.statement)
+        if ret not in (SQLITE_DONE, SQLITE_ROW):
+            self.statement.reset()
+            raise self.connection._get_exception(ret)
+
+        if self.statement.kind == "DQL":
+            self.statement._readahead()
+            self.statement._build_row_cast_map()
+
         if self.statement.kind in ("DML", "DDL"):
-            ret = sqlite.sqlite3_step(self.statement.statement)
-            if ret != SQLITE_DONE:
-                raise self.connection._get_exception(ret)
+            self.statement.reset()
+
+        self.rowcount = -1
         if self.statement.kind == "DML":
             self.rowcount = sqlite.sqlite3_changes(self.connection.db)
 
@@ -701,11 +727,14 @@
         self._check_closed()
         if self.statement is None:
             return None
+
         try:
             return self.statement.next()
         except StopIteration:
             return None
 
+        return nextrow
+
     def fetchmany(self, size=None):
         self._check_closed()
         if self.statement is None:
@@ -785,8 +814,6 @@
 
         self._build_row_cast_map()
 
-        self.started = False
-
     def _build_row_cast_map(self):
         self.row_cast_map = []
         for i in range(sqlite.sqlite3_column_count(self.statement)):
@@ -880,27 +907,23 @@
     def next(self):
         self.con._check_closed()
         self.con._check_thread()
-        if not self.started:
-            self.item = self._readahead()
-            self.started = True
         if self.exhausted:
             raise StopIteration
         item = self.item
-        self.item = self._readahead()
+
+        ret = sqlite.sqlite3_step(self.statement)
+        if ret == SQLITE_DONE:
+            self.exhausted = True
+            self.item = None
+        elif ret != SQLITE_ROW:
+            exc = self.con._get_exception(ret)
+            sqlite.sqlite3_reset(self.statement)
+            raise exc
+
+        self._readahead()
         return item
 
     def _readahead(self):
-        ret = sqlite.sqlite3_step(self.statement)
-        if ret == SQLITE_DONE:
-            self.exhausted = True
-            return
-        elif ret == SQLITE_ERROR:
-            sqlite.sqlite3_reset(self.statement)
-            exc = self.con._get_exception(ret)
-            raise exc
-        else:
-            assert ret == SQLITE_ROW
-
         self.column_count = sqlite.sqlite3_column_count(self.statement)
         row = []
         for i in xrange(self.column_count):
@@ -935,11 +958,11 @@
         row = tuple(row)
         if self.row_factory is not None:
             row = self.row_factory(self.cur(), row)
-        return row
+        self.item = row
 
     def reset(self):
         self.row_cast_map = None
-        sqlite.sqlite3_reset(self.statement)
+        return sqlite.sqlite3_reset(self.statement)
 
     def finalize(self):
         sqlite.sqlite3_finalize(self.statement)


More information about the Pypy-commit mailing list