[pypy-svn] pypy default: Implement aggregate functions in sqlite

amauryfa commits-noreply at bitbucket.org
Wed Feb 9 15:54:06 CET 2011


Author: Amaury Forgeot d'Arc <amauryfa at gmail.com>
Branch: 
Changeset: r41735:253b4ff37ba6
Date: 2011-02-09 14:08 +0100
http://bitbucket.org/pypy/pypy/changeset/253b4ff37ba6/

Log:	Implement aggregate functions in sqlite

diff --git a/lib_pypy/_sqlite3.py b/lib_pypy/_sqlite3.py
--- a/lib_pypy/_sqlite3.py
+++ b/lib_pypy/_sqlite3.py
@@ -23,6 +23,7 @@
 
 from ctypes import c_void_p, c_int, c_double, c_int64, c_char_p, cdll
 from ctypes import POINTER, byref, string_at, CFUNCTYPE, cast
+from ctypes import sizeof, c_ssize_t
 import datetime
 import sys
 import time
@@ -256,6 +257,7 @@
         self.NotSupportedError = NotSupportedError
 
         self.func_cache = {}
+        self.aggregate_instances = {}
         self.thread_ident = thread_get_ident()
 
     def _get_exception(self, error_code = None):
@@ -334,6 +336,14 @@
             cur.row_factory = self.row_factory
         return cur.executescript(*args)
 
+    def __call__(self, sql):
+        self._check_closed()
+        cur = Cursor(self)
+        if not isinstance(sql, (str, unicode)):
+            raise Warning("SQL is of wrong type. Must be string or unicode.")
+        statement = Statement(cur, sql, self.row_factory)
+        return statement
+
     def _get_isolation_level(self):
         return self._isolation_level
     def _set_isolation_level(self, val):
@@ -458,7 +468,75 @@
             raise self._get_exception(ret)
 
     def create_aggregate(self, name, num_args, cls):
-        raise NotImplementedError
+        self._check_thread()
+        self._check_closed()
+
+        try:
+            c_step_callback, c_final_callback = self.func_cache[cls]
+        except KeyError:
+            def step_callback(context, argc, c_params):
+
+                aggregate_ptr = cast(
+                    sqlite.sqlite3_aggregate_context(
+                    context, sizeof(c_ssize_t)),
+                    POINTER(c_ssize_t))
+
+                if not aggregate_ptr[0]:
+                    try:
+                        aggregate = cls()
+                    except Exception, e:
+                        msg = ("user-defined aggregate's '__init__' "
+                               "method raised error")
+                        sqlite.sqlite3_result_error(context, msg, len(msg))
+                        return
+                    aggregate_id = id(aggregate)
+                    self.aggregate_instances[aggregate_id] = aggregate
+                    aggregate_ptr[0] = aggregate_id
+                else:
+                    aggregate = self.aggregate_instances[aggregate_ptr[0]]
+
+                params = _convert_params(context, argc, c_params)
+                try:
+                    aggregate.step(*params)
+                except Exception, e:
+                    msg = ("user-defined aggregate's 'step' "
+                           "method raised error")
+                    sqlite.sqlite3_result_error(context, msg, len(msg))
+                return 0
+
+            def final_callback(context):
+
+                aggregate_ptr = cast(
+                    sqlite.sqlite3_aggregate_context(
+                    context, sizeof(c_ssize_t)),
+                    POINTER(c_ssize_t))
+
+                if aggregate_ptr[0]:
+                    aggregate = self.aggregate_instances[aggregate_ptr[0]]
+                    try:
+                        val = aggregate.finalize()
+                    except Exception, e:
+                        msg = ("user-defined aggregate's 'finalize' "
+                               "method raised error")
+                        sqlite.sqlite3_result_error(context, msg, len(msg))
+                    else:
+                        _convert_result(context, val)
+                    finally:
+                        del self.aggregate_instances[aggregate_ptr[0]]
+                return 0
+
+            c_step_callback = STEP(step_callback)
+            c_final_callback = FINAL(final_callback)
+
+            self.func_cache[cls] = c_step_callback, c_final_callback
+
+        ret = sqlite.sqlite3_create_function(self.db, name, num_args,
+                                             SQLITE_UTF8, None,
+                                             cast(None, FUNC),
+                                             c_step_callback,
+                                             c_final_callback)
+        if ret != SQLITE_OK:
+            raise self._get_exception(ret)
 
 class Cursor(object):
     def __init__(self, con):
@@ -936,14 +1014,16 @@
         sqlite.sqlite3_result_error(context, msg, len(msg))
     else:
         _convert_result(context, val)
-    return 0
 
-FUNC = CFUNCTYPE(c_int, c_void_p, c_int, POINTER(c_void_p))
-STEP = CFUNCTYPE(c_int, c_void_p, c_int, POINTER(c_void_p))
-FINAL = CFUNCTYPE(c_int, c_void_p)
+FUNC = CFUNCTYPE(None, c_void_p, c_int, POINTER(c_void_p))
+STEP = CFUNCTYPE(None, c_void_p, c_int, POINTER(c_void_p))
+FINAL = CFUNCTYPE(None, c_void_p)
 sqlite.sqlite3_create_function.argtypes = [c_void_p, c_char_p, c_int, c_int, c_void_p, FUNC, STEP, FINAL]
 sqlite.sqlite3_create_function.restype = c_int
 
+sqlite.sqlite3_aggregate_context.argtypes = [c_void_p, c_int]
+sqlite.sqlite3_aggregate_context.restype = c_void_p
+
 converters = {}
 adapters = {}
 


More information about the Pypy-commit mailing list