[Python-checkins] r86721 - in python/branches/dmalcolm-ast-optimization-branch: Lib/__optimizer__.py Lib/test/test_optimize.py Python/ceval.c Python/compile.c

david.malcolm python-checkins at python.org
Tue Nov 23 22:51:41 CET 2010


Author: david.malcolm
Date: Tue Nov 23 22:51:41 2010
New Revision: 86721

Log:
Work toward inlining of method calls


Modified:
   python/branches/dmalcolm-ast-optimization-branch/Lib/__optimizer__.py
   python/branches/dmalcolm-ast-optimization-branch/Lib/test/test_optimize.py
   python/branches/dmalcolm-ast-optimization-branch/Python/ceval.c
   python/branches/dmalcolm-ast-optimization-branch/Python/compile.c

Modified: python/branches/dmalcolm-ast-optimization-branch/Lib/__optimizer__.py
==============================================================================
--- python/branches/dmalcolm-ast-optimization-branch/Lib/__optimizer__.py	(original)
+++ python/branches/dmalcolm-ast-optimization-branch/Lib/__optimizer__.py	Tue Nov 23 22:51:41 2010
@@ -92,16 +92,78 @@
     p.communicate(dot.encode('utf-8'))
     p.wait()
 
-def get_ste_for_path(path):
+class NodePathEntry:
+    __slots__ = ('node',  # the ast.Node
+                 'field', # the name of the field
+                 'index', # the index within the field (for lists), or None
+                 )
+    def __init__(self, node, field, index):
+        self.node = node
+        self.field = field
+        self.index = index
+
+    def __str__(self):
+        if self.index is not None:
+            return '%s.%s[%i]' % (self.node, self.field, self.index)
+        else:
+            return '%s.%s' % (self.node, self.field)
+
+class NodePath:
     '''
-    Given a list of (node, field, index) triples, obtain a list of
-    symbol table entries
+    A list of NodePathEntries
     '''
-    result = []
-    for node, field, index in path:
-        if hasattr(node, 'ste'):
-            result.append(node.ste)
-    return result
+    __slots__ = ('entries', )
+
+    def __init__(self, entries):
+        self.entries = entries
+
+    def __str__(self):
+        return '/'.join([str(entry) for entry in self.entries])
+
+    def __repr__(self):
+        return '/'.join([str(entry) for entry in self.entries])
+
+    def extend(self, node, field, index):
+        return NodePath(self.entries + [NodePathEntry(node, field, index)])
+
+    def get_dotted_name(self, childnode=None):
+        nsp = NamespacePath.from_node_path(self, childnode)
+        return nsp.as_dotted_str()
+
+class NamespacePath:
+    '''
+    A list of symbol table entries
+    '''
+    __slots__ = ('_stes',)
+
+    def __init__(self, stes):
+        self._stes = stes
+
+    @classmethod
+    def from_node_path(cls, nodepath, childnode=None):
+        result = []
+        for npe in nodepath.entries:
+            if hasattr(npe.node, 'ste'):
+                result.append(npe.node.ste)
+        if childnode is not None:
+            if hasattr(childnode, 'ste'):
+                result.append(childnode.ste)
+        return NamespacePath(result)
+
+    def as_dotted_str(self):
+        '''
+        Generate a dotted string representing the namespace e.g. "SomeClass.some_method"
+        '''
+        # Start at 1: don't include the "top" STE:
+        return '.'.join([ste.name for ste in self._stes[1:]])
+
+    def get_parent_path(self):
+        return NamespacePath(self._stes[:-1])
+
+    def get_innermost_scope(self):
+        return self._stes[-1]
+
+
 
 class PathTransformer:
     """
@@ -109,10 +171,12 @@
 
     The path is passed in as a list of (node, field, index) triples
     """
-    def visit(self, node, path=[]):
+    def visit(self, node, path=None):
         """Visit a node."""
         method = 'visit_' + node.__class__.__name__
         visitor = getattr(self, method, self.generic_visit)
+        if path is None:
+            path = NodePath([])
         return visitor(node, path)
 
     def generic_visit(self, node, path):
@@ -122,7 +186,7 @@
                 new_values = []
                 for idx, value in enumerate(old_value):
                     if isinstance(value, ast.AST):
-                        value = self.visit(value, path + [(node, field, idx)])
+                        value = self.visit(value, path.extend(node, field, idx))
                         if value is None:
                             continue
                         elif not isinstance(value, ast.AST):
@@ -131,7 +195,7 @@
                     new_values.append(value)
                 old_value[:] = new_values
             elif isinstance(old_value, ast.AST):
-                new_node = self.visit(old_value, path + [(node, field, None)])
+                new_node = self.visit(old_value, path.extend(node, field, None))
                 if new_node is None:
                     delattr(node, field)
                 else:
@@ -187,77 +251,93 @@
         return make_assignment(self.varprefix + "__returnval__", node.value, node)
 
 class FunctionInliner(PathTransformer):
-    def __init__(self, tree, defn):
+    def __init__(self, tree, funcdef, dotted_name):
         self.tree = tree
-        self.defn = defn
-        assert hasattr(defn, 'ste')
+        self.funcdef = funcdef
+        self.dotted_name = dotted_name
+        assert hasattr(funcdef, 'ste')
         self.num_callsites = 0
-        log('ste for body: %r' % defn.ste)
+        self.log('inlining calls to %r' % dotted_name)
+        #self.log('ste for body: %r' % funcdef.ste)
 
-    def visit_Call(self, node, path):
+    def log(self, msg):
+        if 0:
+            print('%s: %s' % (self.__class__.__name__, msg))
+
+    def is_inlinable_callsite(self, call, path):
+        # Return the name of the stored global if this is inlinable, otherwise None
+        assert isinstance(call, ast.Call)
+
+        if isinstance(call.func, ast.Name):
+            # Name must match:
+            if call.func.id == self.dotted_name:
+                return self.dotted_name
+
+        if isinstance(call.func, ast.Attribute):
+            # Handle simple "self.METHOD_NAME" case:
+            attr = call.func
+            value = attr.value
+            if isinstance(value, ast.Name) and isinstance(value.ctx, ast.Load):
+                if value.id == 'self':
+                    #print('attr.attr:', attr.attr)
+                    #print('path:', path)
+                    #print(path.get_dotted_name())
+                    parent_nsp = NamespacePath.from_node_path(path).get_parent_path()
+                    return parent_nsp.as_dotted_str() + '.' + attr.attr
+# FIXME: only makes sense to traverse within this class and within subclasses
+
+        # Don't try to inline where the function is a non-trivial
+        # expression e.g. "f()()", or for other awkward cases
+        return None
+
+
+    def visit_Call(self, call, path):
         # Stop inlining beyond an arbitrary cutoff
         # (bm_simplecall was exploding):
         if self.num_callsites > 10:
-            return node
+            return call
 
         # Visit children:
-        self.generic_visit(node, path)
+        self.generic_visit(call, path)
 
-        if isinstance(node.func, ast.Attribute):
+        stored_name = self.is_inlinable_callsite(call, path)
+        if stored_name is None:
+            return call
+
+        self.log('Got inlinable callsite of %r' % stored_name)
+
+        if isinstance(call.func, ast.Attribute):
             # Don't try to inline method calls yet:
-            return node
+            self.log('Got attribute %s' % ast.dump(call.func))
+            return call
 
-        if not isinstance(node.func, ast.Name):
+        if not isinstance(call.func, (ast.Name, ast.Attribute)):
             # Don't try to inline where the function is a non-trivial
             # expression e.g. "f()()"
-            return node
+            print('foo!')
+            return call
 
-        if node.func.id != self.defn.name:
-            return node
+        # FIXME
 
-        log('Considering call to: %r' % node.func.id)
-        log('Path: %r' % path)
-        log('ste for callsite: %r' % get_ste_for_path(path))
-
-        # We need to find the insertion point for statements:
-        # Walk up the ancestors until you find a statement-holding node:
-        for ancestor in path[::-1]:
-            #log('foo:', ancestor)
-            if isinstance(ancestor[0], (ast.Module, ast.Interactive,
-                                        ast.FunctionDef, ast.ClassDef,
-                                        ast.For, ast.With,
-                                        ast.TryExcept, ast.TryFinally,
-                                        ast.ExceptHandler)):
-                break
-
-            # Can we inline predicates?
-            if isinstance(ancestor[0], (ast.While)):
-                if ancestor[1] == 'test':
-                    # No good place to insert code for an inlined "while"
-                    # predicate; bail:
-                    return node
-
-            if isinstance(ancestor[0], (ast.If)):
-                if ancestor[1] == 'test':
-                    # We may be able to inline a predicate before the "if":
-                    continue
-                else:
-                    # Inline within the body or the orelse:
-                    break
+        self.log('Considering call to: %s' % ast.dump(call.func))
+        self.log('NodePath: %r' % path)
+        nsp = NamespacePath.from_node_path(path)
+        self.log('NamespacePath for callsite: %r' % nsp)
+
+        if stored_name != self.dotted_name:
+            return call
+        #if call.func.id != self.funcdef.name:
+        #    return call
 
         # Locate innermost scope at callsite:
-        ste = get_ste_for_path(path)[-1]
+        ste = nsp.get_innermost_scope()
 
-        #print('Inlining call to: %r within %r' % (node.func.id, ste.name))
+        self.log('Inlining call to: %r within %r' % (self.dotted_name, ste.name))
         self.num_callsites += 1
-        log('ancestor: %r' % (ancestor, ))
 
-        assert ancestor[2] is not None
-
-
-        #log(ast.dump(self.defn))
-        varprefix = '__inline_%s_%x__' % (node.func.id, id(node))
-        #log('varprefix: %s' % varprefix)
+        self.log(ast.dump(self.funcdef))
+        varprefix = '__inline_%s%x__' % (self.dotted_name, id(call))
+        self.log('varprefix: %s' % varprefix)
 
         # Generate a body of specialized statements that can replace the call:
         specialized = []
@@ -266,13 +346,13 @@
         #    __inline__x = expr for x
         # for each parameter
         # We will insert before the callsite
-        for formal, actual in zip(self.defn.args.args, node.args):
+        for formal, actual in zip(self.funcdef.args.args, call.args):
             #log('formal: %s' % ast.dump(formal))
             #log('actual: %s' % ast.dump(actual))
             # FIXME: ste
             add_local(ste, varprefix+formal.arg)
 
-            assign = make_assignment(varprefix+formal.arg, actual, node)
+            assign = make_assignment(varprefix+formal.arg, actual, call)
             specialized.append(assign)
 
             # FIXME: these seem to be being done using LOAD_NAME/STORE_NAME; why isn't it using _FAST?
@@ -284,7 +364,7 @@
         add_local(ste, returnval)
         # FIXME: this requires "None", how to do this in AST?
         assign = make_assignment(returnval,
-                                 make_load_name('None', node), node)
+                                 make_load_name('None', call), call)
         # FIXME: this leads to LOAD_NAME None, when it should be LOAD_CONST, surely?
         specialized.append(assign)
 
@@ -292,52 +372,42 @@
         # ending with:
         #    __inline____returnval = expr
         inline_body = []
-        fixer = InlineBodyFixups(varprefix, self.defn.ste)
-        for stmt in self.defn.body:
+        fixer = InlineBodyFixups(varprefix, self.funcdef.ste)
+        for stmt in self.funcdef.body:
             assert isinstance(stmt, ast.AST)
             inline_body.append(fixer.visit(ast_clone(stmt)))
         #log('inline_body:', inline_body)
         specialized += inline_body
 
-        #log('Parent: %s' % ast.dump(find_parent(self.tree, node)))
-
-        #seq = getattr(ancestor[0], ancestor[1])
-
-        # Splice the compound statements into place:
-        #seq = seq[:ancestor[2]] + compound + seq[ancestor[2]:] # FIXME
-        #setattr(ancestor[0], ancestor[1], seq)
-
-        #log(seq)
-
-        #log(ast.dump(ancestor[0]))
+        #log('Parent: %s' % ast.dump(find_parent(self.tree, call)))
 
         # FIXME: need some smarts about the value of the "Specialize":
         # it's the final Expr within the body
         specialized_result = ast.copy_location(ast.Name(id=returnval,
                                                         ctx=ast.Load()),
-                                               node)
+                                               call)
 
-        return ast.copy_location(ast.Specialize(name=node.func,
-                                                generalized=node,
+        return ast.copy_location(ast.Specialize(name=call.func,
+                                                generalized=call,
                                                 specialized_body=specialized,
                                                 specialized_result=specialized_result),
-                                 node)
+                                 call)
 
         # Replace the call with a load from __inline____returnval__
         return ast.copy_location(ast.Name(id=returnval,
-                                          ctx=ast.Load()), node)
+                                          ctx=ast.Load()), call)
 
 class NotInlinable(Exception):
     pass
 
 class CheckInlinableVisitor(PathTransformer):
-    def __init__(self, defn):
-        self.defn = defn
+    def __init__(self, funcdef):
+        self.funcdef = funcdef
         self.returns = []
 
     # Various nodes aren't handlable:
     def visit_FunctionDef(self, node, path):
-        if node != self.defn:
+        if node != self.funcdef:
             raise NotInlinable()
         self.generic_visit(node, path)
         return node
@@ -354,25 +424,25 @@
         self.returns.append(path)
         return node
 
-def fn_is_inlinable(defn, mod):
+def fn_is_inlinable(funcdef, mod):
     # Should we inline calls to the given FunctionDef ?
-    assert(isinstance(defn, ast.FunctionDef))
+    assert(isinstance(funcdef, ast.FunctionDef))
 
     # Only inline "simple" calling conventions for now:
-    if len(defn.decorator_list) > 0:
+    if len(funcdef.decorator_list) > 0:
         return False
 
-    if (defn.args.vararg is not None or
-        defn.args.kwarg is not None or
-        defn.args.kwonlyargs != [] or
-        defn.args.defaults != [] or
-        defn.args.kw_defaults != []):
+    if (funcdef.args.vararg is not None or
+        funcdef.args.kwarg is not None or
+        funcdef.args.kwonlyargs != [] or
+        funcdef.args.defaults != [] or
+        funcdef.args.kw_defaults != []):
         return False
 
     # Don't try to inline generators and other awkward cases:
-    v = CheckInlinableVisitor(defn)
+    v = CheckInlinableVisitor(funcdef)
     try:
-        v.visit(defn)
+        v.visit(funcdef)
     except NotInlinable:
         return False
 
@@ -383,7 +453,7 @@
     # for each return:
     # FIXME: for now, only inline functions which have a single, final
     # explicit "return" at the end, or no returns:
-    log('returns of %s: %r' % (defn.name, v.returns))
+    log('returns of %s: %r' % (funcdef.name, v.returns))
     if len(v.returns)>1:
         return False
 
@@ -391,19 +461,19 @@
     if len(v.returns) == 1:
         rpath = v.returns[0]
         # Must be at toplevel:
-        if len(rpath) != 1:
+        if len(rpath.entries) != 1:
             return False
 
         # Must be at the end of that level
-        if rpath[0][2] != len(defn.body)-1:
+        if rpath.entries[0].index != len(funcdef.body)-1:
             return False
 
 
 
     # Don't inline functions that use free or cell vars
     # (just locals and globals):
-    assert hasattr(defn, 'ste')
-    ste = defn.ste
+    assert hasattr(funcdef, 'ste')
+    ste = funcdef.ste
     for varname in ste.varnames:
         scope = get_scope(ste, varname)
         #log('%r: %r' % (varname, scope))
@@ -419,7 +489,7 @@
         if isinstance(node, ast.Assign):
             for target in node.targets:
                 if isinstance(target, ast.Name):
-                    if target.id == defn.name:
+                    if target.id == funcdef.name:
                         if isinstance(target.ctx, ast.Store):
                             return False
 
@@ -427,38 +497,48 @@
 
 
 
+class InlinableFunctionFinder(PathTransformer):
+    # Locate function definitions that look inlinable, recording, and adding globals:
+    def __init__(self, mod):
+        self.mod = mod
+        self.funcdefs = {}
+
+    def log(self, msg):
+        if 0:
+            print('%s: %s' % (self.__class__.__name__, msg))
+
+    def visit_FunctionDef(self, funcdef, path):
+        self.log('got function def: %r %r' % (funcdef.name, path))
+        self.generic_visit(funcdef, path)
+        if fn_is_inlinable(funcdef, self.mod):
+            dotted_name = path.get_dotted_name(funcdef)
+            self.log('using dotted name: %r for %s' % (dotted_name, path))
+            self.funcdefs[dotted_name] = funcdef
+
+            storedname = '__internal__.saved.' + dotted_name
+            global_ = ast.copy_location(ast.Global(names=[storedname]),
+                                            funcdef)
+            assign = ast.copy_location(ast.Assign(targets=[ast.Name(id=storedname, ctx=ast.Store())],
+                                                  value=ast.Name(id=funcdef.name, ctx=ast.Load())),
+                                       funcdef)
+            ast.fix_missing_locations(assign)
+
+            return [funcdef, global_, assign]
+        else:
+            return funcdef
+
+
 def _inline_function_calls(t):
-    # Locate top-level function defs:
-    inlinable_function_defs = {}
-    for s in t.body:
-        if isinstance(s, ast.FunctionDef):
-            if fn_is_inlinable(s, t):
-                inlinable_function_defs[s.name] = s
-
-    # We need to be able to look up functions by name to detect "monkeypatching"
-    # Keep a stored copy of each function as it's built, for later comparison
-    # A better way to do this would be to store the const function somewhere as it's compiled
-    # and then go in and fixup all of the JUMP_IF_SPECIALIZABLE entrypoints
-    # But for now, this will do for a prototype:
-    new_body = []
-    for s in t.body:
-        new_body.append(s)
-        if isinstance(s, ast.FunctionDef):
-            storedname = '__saved__' + s.name
-            if s.name in inlinable_function_defs:
-                assign = ast.copy_location(ast.Assign(targets=[ast.Name(id=storedname, ctx=ast.Store())],
-                                                      value=ast.Name(id=s.name, ctx=ast.Load())),
-                                           s)
-                ast.fix_missing_locations(assign)
-                new_body.append(assign)
-    t.body = new_body
+    v = InlinableFunctionFinder(t)
+    v.visit(t)
+    inlinable_function_defs = v.funcdefs
 
-    #print('inlinable_function_defs:%r' % inlinable_function_defs)
+    # print('inlinable_function_defs:%r' % inlinable_function_defs)
 
     # Locate call sites:
-    for name in inlinable_function_defs:
-        log('inlining calls to %r' % name)
-        inliner = FunctionInliner(t, inlinable_function_defs[name])
+    for dotted_name in inlinable_function_defs:
+        inliner = FunctionInliner(t, inlinable_function_defs[dotted_name],
+                                  dotted_name)
         inliner.visit(t)
 
     return t

Modified: python/branches/dmalcolm-ast-optimization-branch/Lib/test/test_optimize.py
==============================================================================
--- python/branches/dmalcolm-ast-optimization-branch/Lib/test/test_optimize.py	(original)
+++ python/branches/dmalcolm-ast-optimization-branch/Lib/test/test_optimize.py	Tue Nov 23 22:51:41 2010
@@ -121,7 +121,7 @@
 ''')
         callsite = self.compile_to_code(src, 'callsite')
         asm = disassemble(callsite)
-        self.assertHasLineWith(asm, ('LOAD_GLOBAL', '(__saved__function_to_be_inlined)'))
+        self.assertHasLineWith(asm, ('LOAD_GLOBAL', '(__internal__.saved.function_to_be_inlined)'))
         self.assertIn('JUMP_IF_SPECIALIZABLE', asm)
         self.assertHasLineWith(asm, ('LOAD_CONST', "('feefifofum')"))
 
@@ -386,6 +386,43 @@
         asm = disassemble(fn)
         #print(asm)
 
+    def test_method(self):
+        src = '''
+class Foo:
+    def simple_method(self, a):
+         return self.bar + a + self.baz
+
+    def user_of_method(self):
+         print(self.simple_method(1))
+         print(self.simple_method(2))
+         print(self.simple_method(3))
+'''
+
+        # "Foo.simple_method" should be inlinable
+
+        # Ensure that we're saving the inlinable function as a global for use
+        # by JUMP_IF_SPECIALIZABLE at callsites:
+        fn = self.compile_to_code(src, 'Foo')
+        asm = disassemble(fn)
+        self.assertHasLineWith(asm,
+            ('STORE_GLOBAL', '(__internal__.saved.Foo.simple_method)'))
+        #print(asm)
+
+        self.assertIsInlinable(src, fnname='Foo.simple_method')
+        fn = self.compile_to_code(src, 'Foo.user_of_method')
+        asm = disassemble(fn)
+        #print(asm)
+
+
+    def test_namespaces(self):
+        src = '''
+class Foo:
+    def simple_method(self, a):
+         return self.bar + a + self.baz
+'''
+        fn = self.compile_to_code(src, 'Foo.simple_method')
+        asm = disassemble(fn)
+        #from __optimizer__ import get_dotted_name
 
 
 def function_to_be_inlined():
@@ -406,7 +443,7 @@
         #print(asm)
         # Should have logic for detecting if it can specialize:
         self.assertIn('JUMP_IF_SPECIALIZABLE', asm)
-        self.assertIn('(__saved__function_to_be_inlined)', asm)
+        self.assertIn('(__internal__.saved.function_to_be_inlined)', asm)
         # Should have inlined constant value:
         self.assertIn("('I am the original implementation')", asm)
 

Modified: python/branches/dmalcolm-ast-optimization-branch/Python/ceval.c
==============================================================================
--- python/branches/dmalcolm-ast-optimization-branch/Python/ceval.c	(original)
+++ python/branches/dmalcolm-ast-optimization-branch/Python/ceval.c	Tue Nov 23 22:51:41 2010
@@ -2866,10 +2866,10 @@
 
             if (is_specializable(u, v)) {
                 /* Jump to specialized implementation, popping u */
-                JUMPTO(oparg);
                 Py_DECREF(v);
                 STACKADJ(-1);
                 Py_DECREF(u);
+                JUMPTO(oparg);
                 FAST_DISPATCH();
             } else {
                 /* Generalized implementation; fall through to next opcode,

Modified: python/branches/dmalcolm-ast-optimization-branch/Python/compile.c
==============================================================================
--- python/branches/dmalcolm-ast-optimization-branch/Python/compile.c	(original)
+++ python/branches/dmalcolm-ast-optimization-branch/Python/compile.c	Tue Nov 23 22:51:41 2010
@@ -3272,7 +3272,7 @@
         return 0;
     id = e->v.Specialize.name->v.Name.id;
     assert(id);
-    saved_name = PyUnicode_FromFormat("__saved__%U", id);
+    saved_name = PyUnicode_FromFormat("__internal__.saved.%U", id);
     if (!saved_name)
         return 0;
     ADDOP_O(c, LOAD_GLOBAL, saved_name, names); /* takes ownership of the reference */


More information about the Python-checkins mailing list