[pypy-svn] r27066 - pypy/dist/pypy/objspace/std

arigo at codespeak.net arigo at codespeak.net
Thu May 11 12:16:42 CEST 2006


Author: arigo
Date: Thu May 11 12:16:36 2006
New Revision: 27066

Modified:
   pypy/dist/pypy/objspace/std/multimethod.py
Log:
Fixed multimethod installer version 2 to check the types of the
arguments.  Now the (very limited) test_multimethod passes with
this installer too.


Modified: pypy/dist/pypy/objspace/std/multimethod.py
==============================================================================
--- pypy/dist/pypy/objspace/std/multimethod.py	(original)
+++ pypy/dist/pypy/objspace/std/multimethod.py	Thu May 11 12:16:36 2006
@@ -278,7 +278,7 @@
 # ____________________________________________________________
 # Installer version 2
 
-class MMDispatcher:
+class MMDispatcher(object):
     """NOT_RPYTHON
     Explicit dispatcher class.  This is not used in normal execution, which
     uses the complex Installer below to install single-dispatch methods to
@@ -361,12 +361,16 @@
         self.arguments = arguments
 
 
-class CompressedArray:
+class CompressedArray(object):
     def __init__(self, null_value, reserved_count):
         self.null_value = null_value
         self.reserved_count = reserved_count
         self.items = [null_value] * reserved_count
 
+    def ensure_length(self, newlen):
+        if newlen > len(self.items):
+            self.items.extend([self.null_value] * (newlen - len(self.items)))
+
     def insert_subarray(self, array):
         # insert the given array of numbers into the indexlist,
         # allowing null values to become non-null
@@ -377,8 +381,7 @@
             initial_nulls += 1
         test = max(self.reserved_count - initial_nulls, 0)
         while True:
-            while test+len(array) > len(self.items):
-                self.items.append(self.null_value)
+            self.ensure_length(test+len(array))
             for i in range(len(array)):
                 if not (array[i] == self.items[test+i] or
                         array[i] == self.null_value or
@@ -396,7 +399,7 @@
         return True
 
 
-class MRDTable:
+class MRDTable(object):
     # Multi-Method Dispatch Using Multiple Row Displacement,
     # Candy Pang, Wade Holst, Yuri Leontiev, and Duane Szafron
     # University of Alberta, Edmonton AB T6G 2H1 Canada
@@ -412,9 +415,103 @@
         for t1, num in self.typenum.items():
             setattr(t1, self.attrname, num)
         self.indexarray = CompressedArray(0, 1)
+        self.allarrays = [self.indexarray]
+
+    def normalize_length(self, next_array):
+        self.allarrays.append(next_array)
+        length = max([len(a.items) for a in self.allarrays])
+        for a in self.allarrays:
+            a.ensure_length(length)
+
+
+class FuncEntry(object):
+
+    def __init__(self, bodylines, miniglobals, fallback):
+        self.body = '\n    '.join(bodylines)
+        self.miniglobals = miniglobals
+        self.fallback = fallback
+        self.possiblenames = []
+        self.typetree = {}
+        self._function = None
+
+    def key(self):
+        lst = self.miniglobals.items()
+        lst.sort()
+        return self.body, tuple(lst)
 
+    def make_function(self, fnargs, nbargs_before, mrdtable):
+        if self._function is not None:
+            return self._function
+        name = min(self.possiblenames)   # pick a random one, but consistently
+        self.compress_typechecks(mrdtable)
+        checklines = self.generate_typechecks(fnargs[nbargs_before:], mrdtable)
+        source = 'def %s(%s):\n    %s\n    %s\n' % (name, ', '.join(fnargs),
+                                                    '\n    '.join(checklines),
+                                                    self.body)
+        print '_' * 60
+        print self.possiblenames
+        print source
+        exec compile2(source) in self.miniglobals
+        self._function = self.miniglobals[name]
+        return self._function
+
+    def register_valid_types(self, typenums):
+        node = self.typetree
+        for n in typenums[:-1]:
+            if node is True:
+                return
+            node = node.setdefault(n, {})
+        if node is True:
+            return
+        node[typenums[-1]] = True
+
+    def no_typecheck(self):
+        self.typetree = True
+
+    def compress_typechecks(self, mrdtable):
+        def full(node):
+            if node is True:
+                return 1
+            fulls = 0
+            for key, subnode in node.items():
+                if full(subnode):
+                    node[key] = True
+                    fulls += 1
+            if fulls == types_total:
+                return 1
+            return 0
+
+        types_total = len(mrdtable.list_of_types)
+        if full(self.typetree):
+            self.typetree = True
+
+    def generate_typechecks(self, args, mrdtable):
+        def generate(node, level=0):
+            indent = '    '*level
+            if node is True:
+                result.append('%spass' % (indent,))
+                return
+            if not node:
+                result.append('%sraise FailedToImplement' % (indent,))
+                return
+            keyword = 'if'
+            for key, subnode in node.items():
+                result.append('%s%s %s == %d:' % (indent, keyword,
+                                                  typeidexprs[level], key))
+                generate(subnode, level+1)
+                keyword = 'elif'
+            result.append('%selse:' % (indent,))
+            result.append('%s    raise FailedToImplement' % (indent,))
+
+        typeidexprs = ['%s.%s' % (arg, mrdtable.attrname) for arg in args]
+        result = []
+        generate(self.typetree)
+        if result == ['pass']:
+            del result[:]
+        return result
 
-class InstallerVersion2:
+
+class InstallerVersion2(object):
     """NOT_RPYTHON"""
 
     mrdtables = {}
@@ -458,26 +555,36 @@
         return len(self.table) == 0
 
     def install(self):
-        null_func = self.build_function(self.prefix + '_fail', [], True)
+        nskip = len(self.multimethod.argnames_before)
+        null_entry = self.build_funcentry(self.prefix + '0fail', [])
+        null_entry.no_typecheck()
         if self.is_empty():
-            return null_func
+            return self.answer(null_entry)
 
-        funcarray = CompressedArray(null_func, 1)
+        entryarray = CompressedArray(null_entry, 1)
         indexarray = self.mrdtable.indexarray
         lst = self.mrdtable.list_of_types
         indexline = []
-        for t0 in lst:
+        for num0, t0 in enumerate(lst):
             flatline = []
-            for t1 in lst:
+            for num1, t1 in enumerate(lst):
                 calllist = self.table.get((t0, t1), [])
                 funcname = '_'.join([self.prefix, t0.__name__, t1.__name__])
-                fn = self.build_function(funcname, calllist)
-                flatline.append(fn)
-            index = funcarray.insert_subarray(flatline)
+                entry = self.build_funcentry(funcname, calllist)
+                entry.register_valid_types((num0, num1))
+                flatline.append(entry)
+            index = entryarray.insert_subarray(flatline)
             indexline.append(index)
 
         master_index = indexarray.insert_subarray(indexline)
 
+        null_func = null_entry.make_function(self.fnargs, nskip, self.mrdtable)
+        funcarray = CompressedArray(null_func, 0)
+        for entry in entryarray.items:
+            func = entry.make_function(self.fnargs, nskip, self.mrdtable)
+            funcarray.items.append(func)
+        self.mrdtable.normalize_length(funcarray)
+
         print master_index
         print indexarray.items
         print funcarray.items
@@ -486,13 +593,23 @@
         exprfn = "funcarray.items[indexarray.items[%d + arg0.%s] + arg1.%s]" %(
             master_index, attrname, attrname)
         expr = Call(exprfn, self.fnargs)
-        return self.build_function(self.prefix + '_perform_call',
-                                   [expr], True,
-                                   indexarray = indexarray,
-                                   funcarray = funcarray)
+        entry = self.build_funcentry(self.prefix + '0perform_call',
+                                     [expr],
+                                     indexarray = indexarray,
+                                     funcarray = funcarray)
+        entry.no_typecheck()
+        return self.answer(entry)
+
+    def answer(self, entry):
+        if self.baked_perform_call:
+            nskip = len(self.multimethod.argnames_before)
+            return entry.make_function(self.fnargs, nskip, self.mrdtable)
+        else:
+            assert entry.body.startswith('return ')
+            expr = entry.body[len('return '):]
+            return self.fnargs, expr, entry.miniglobals, entry.fallback
 
-    def build_function(self, funcname, calllist, is_perform_call=False,
-                       **extranames):
+    def build_funcentry(self, funcname, calllist, **extranames):
         def invent_name(obj):
             if isinstance(obj, str):
                 return obj
@@ -526,30 +643,14 @@
                 bodylines.append('    pass')
             bodylines.append('return %s' % expr(calllist[-1]))
 
-        if is_perform_call and not self.baked_perform_call:
-            return self.fnargs, bodylines[0][len('return '):], miniglobals, fallback
-
-        # indent mode
-        bodylines = ['    ' + line for line in bodylines]
-        bodylines.append('')
-        source = '\n'.join(bodylines)
-
-        # XXX find a better place (or way) to avoid duplicate functions 
-        l = miniglobals.items()
-        l.sort()
-        l = tuple(l)
-        key = (source, l)
-        try: 
-            func = self.mmfunccache[key]
-        except KeyError: 
-            source = 'def %s(%s):\n%s' % (funcname, ', '.join(self.fnargs),
-                                          source)
-            exec compile2(source) in miniglobals
-            func = miniglobals[funcname]
-            self.mmfunccache[key] = func 
-        #else: 
-        #    print "avoided duplicate function", func
-        return func
+        entry = FuncEntry(bodylines, miniglobals, fallback)
+        key = entry.key()
+        try:
+            entry = self.mmfunccache[key]
+        except KeyError:
+            self.mmfunccache[key] = entry
+        entry.possiblenames.append(funcname)
+        return entry
 
 # ____________________________________________________________
 # Selection of the version to use



More information about the Pypy-commit mailing list