Generator comprehensions -- patch for compiler module

Jeff Epler jepler at unpythonic.net
Tue Aug 26 17:33:17 EDT 2003


Hello.

Recently, Generator Comprehensions were mentioned again on python-list.
I have written an implementation for the compiler module.  To try it
out, however, you must be able to rebuild Python from source, because it
also requires a change to Grammar.

1. Edit Python-2.3/Grammar/Grammar and add an alternative to the
"listmaker" production:
-listmaker: test ( list_for | (',' test)* [','] )
+listmaker: test ( list_for | (',' test)* [','] ) | 'yield' test list_for

1.5. Now [yield None for x in []] parses, but crashes the written-in-C
compiler:
    >>> [yield None for x in []]
    SystemError: com_node: unexpected node type

2. Apply the patch below to Lib/compiler

3. Use compiler.compile to compile code with generator comprehensions:
    from compiler import compile
    import dis

    code = compile("""
    def f():
	    gg = [yield (x,y) for x in range(10) for y in range(10) if y > x]
	    print gg, type(gg), list(gg)
    """, "<None>", "exec")
    exec code
    dis.dis(f.func_code.co_consts[1])
    f()

4. It's possible to write code so that __import__ uses compiler.compile
instead of the written-in-C compiler, but I don't have this code handy. 
Also, a test suite is needed, and presumably a written-in-C implementation
as well. (option 2: make the compiler.compile interface the standard
compiler, and let the builtin compiler support a Python subset
sufficient to bootstrap the written-in-python compiler, or arrange
to ship .pyc of the compiler package and completely get rid of the
written-in-C compiler)

5. PEP remains rejected by BDFL anyway


diff -ur compiler.orig/ast.py compiler/ast.py
--- compiler.orig/ast.py	2002-02-23 16:35:33.000000000 -0600
+++ compiler/ast.py	2003-08-26 06:55:51.000000000 -0500
@@ -1191,6 +1191,53 @@
     def __repr__(self):
         return "If(%s, %s)" % (repr(self.tests), repr(self.else_))
 
+class GenCompInner(Node):
+    nodes["gencompinner"] = "GenCompInner"
+    def __init__(self, expr, quals):
+        self.expr = expr
+        self.quals = quals
+
+    def getChildren(self):
+        children = []
+        children.append(self.expr)
+        children.extend(flatten(self.quals))
+        return tuple(children)
+
+    def getChildNodes(self):
+        nodelist = []
+        nodelist.append(self.expr)
+        nodelist.extend(flatten_nodes(self.quals))
+        return tuple(nodelist)
+
+    def __repr__(self):
+        return "GenCompInner(%s, %s)" % (repr(self.expr), repr(self.quals))
+
+class GenComp(Node):
+    nodes["gencomp"] = "GenComp"
+    def __init__(self, inner):
+        self.argnames = ()
+        self.defaults = ()
+        self.flags = 0
+        self.code = inner
+        self.varargs = self.kwargs = None
+
+    def getChildren(self):
+        children = []
+        children.append(self.argnames)
+        children.extend(flatten(self.defaults))
+        children.append(self.flags)
+        children.append(self.code)
+        return tuple(children)
+
+    def getChildNodes(self):
+        nodelist = []
+        nodelist.extend(flatten_nodes(self.defaults))
+        nodelist.append(self.code)
+        return tuple(nodelist)
+
+    def __repr__(self):
+        return "GenComp(%s)" % (repr(self.code),)
+ 
 class ListComp(Node):
     nodes["listcomp"] = "ListComp"
     def __init__(self, expr, quals):
diff -ur compiler.orig/pycodegen.py compiler/pycodegen.py
--- compiler.orig/pycodegen.py	2002-12-31 12:26:17.000000000 -0600
+++ compiler/pycodegen.py	2003-08-26 06:54:53.000000000 -0500
@@ -563,6 +563,51 @@
     # list comprehensions
     __list_count = 0
 
+    def visitGenCompInner(self, node):
+        self.set_lineno(node)
+        # setup list
+
+        stack = []
+        for i, for_ in zip(range(len(node.quals)), node.quals):
+            start, anchor = self.visit(for_)
+            cont = None
+            for if_ in for_.ifs:
+                if cont is None:
+                    cont = self.newBlock()
+                self.visit(if_, cont)
+            stack.insert(0, (start, cont, anchor))
+
+        self.visit(node.expr)
+        self.emit('YIELD_VALUE')
+
+        for start, cont, anchor in stack:
+            if cont:
+                skip_one = self.newBlock()
+                self.emit('JUMP_FORWARD', skip_one)
+                self.startBlock(cont)
+                self.emit('POP_TOP')
+                self.nextBlock(skip_one)
+            self.emit('JUMP_ABSOLUTE', start)
+            self.startBlock(anchor)
+        self.emit('LOAD_CONST', None)
+
+    def visitGenComp(self, node):
+        gen = GenCompCodeGenerator(node, self.scopes, self.class_name,
+                                   self.get_module())
+        walk(node.code, gen)
+        gen.finish()
+        self.set_lineno(node)
+        frees = gen.scope.get_free_vars()
+        if frees:
+            for name in frees:
+                self.emit('LOAD_CLOSURE', name)
+            self.emit('LOAD_CONST', gen)
+            self.emit('MAKE_CLOSURE', len(node.defaults))
+        else:
+            self.emit('LOAD_CONST', gen)
+            self.emit('MAKE_FUNCTION', len(node.defaults))
+        self.emit('CALL_FUNCTION', 0)
+
     def visitListComp(self, node):
         self.set_lineno(node)
         # setup list
@@ -1245,6 +1290,20 @@
 
     unpackTuple = unpackSequence
 
+class GenCompCodeGenerator(NestedScopeMixin, AbstractFunctionCode,
+                           CodeGenerator):
+    super_init = CodeGenerator.__init__ # call be other init
+
+    __super_init = AbstractFunctionCode.__init__
+
+    def __init__(self, comp, scopes, class_name, mod):
+        self.scopes = scopes
+        self.scope = scopes[comp]
+        self.__super_init(comp, scopes, 1, class_name, mod)
+        self.graph.setFreeVars(self.scope.get_free_vars())
+        self.graph.setCellVars(self.scope.get_cell_vars())
+        self.graph.setFlag(CO_GENERATOR)
+
 class FunctionCodeGenerator(NestedScopeMixin, AbstractFunctionCode,
                             CodeGenerator):
     super_init = CodeGenerator.__init__ # call be other init
diff -ur compiler.orig/symbols.py compiler/symbols.py
--- compiler.orig/symbols.py	2002-12-31 12:17:42.000000000 -0600
+++ compiler/symbols.py	2003-08-25 17:03:27.000000000 -0500
@@ -231,6 +231,15 @@
         self.visit(node.code, scope)
         self.handle_free_vars(scope, parent)
 
+    def visitGenComp(self, node, parent):
+        scope = LambdaScope(self.module, self.klass);
+        if parent.nested or isinstance(parent, FunctionScope):
+            scope.nested = 1
+        self.scopes[node] = scope
+        self._do_args(scope, node.argnames)
+        self.visit(node.code, scope)
+        self.handle_free_vars(scope, parent)
+
     def _do_args(self, scope, args):
         for name in args:
             if type(name) == types.TupleType:
diff -ur compiler.orig/transformer.py compiler/transformer.py
--- compiler.orig/transformer.py	2003-04-06 04:00:45.000000000 -0500
+++ compiler/transformer.py	2003-08-26 06:56:02.000000000 -0500
@@ -1026,18 +1026,25 @@
     if hasattr(symbol, 'list_for'):
         def com_list_constructor(self, nodelist):
             # listmaker: test ( list_for | (',' test)* [','] )
+            #            | 'yield' list_for
             values = []
+            yield_flag = 0
+            if nodelist[1][1] == 'yield':
+                yield_flag = 1
+                nodelist = nodelist[1:]
             for i in range(1, len(nodelist)):
                 if nodelist[i][0] == symbol.list_for:
                     assert len(nodelist[i:]) == 1
                     return self.com_list_comprehension(values[0],
-                                                       nodelist[i])
+                                                       nodelist[i],
+                                                       yield_flag)
                 elif nodelist[i][0] == token.COMMA:
                     continue
                 values.append(self.com_node(nodelist[i]))
+            assert not yieldflag
             return List(values)
 
-        def com_list_comprehension(self, expr, node):
+        def com_list_comprehension(self, expr, node, yield_flag):
             # list_iter: list_for | list_if
             # list_for: 'for' exprlist 'in' testlist [list_iter]
             # list_if: 'if' test [list_iter]
@@ -1071,7 +1078,10 @@
                     raise SyntaxError, \
                           ("unexpected list comprehension element: %s %d"
                            % (node, lineno))
-            n = ListComp(expr, fors)
+            if yield_flag:
+                n = GenComp(GenCompInner(expr, fors))
+            else:
+                n = ListComp(expr, fors)
             n.lineno = lineno
             return n
 





More information about the Python-list mailing list