[pypy-svn] r65776 - in pypy/branch/parser-compiler/pypy/interpreter/astcompiler: . test

benjamin at codespeak.net benjamin at codespeak.net
Sun Jun 14 22:36:11 CEST 2009


Author: benjamin
Date: Sun Jun 14 22:36:09 2009
New Revision: 65776

Added:
   pypy/branch/parser-compiler/pypy/interpreter/astcompiler/astbuilder.py   (contents, props changed)
   pypy/branch/parser-compiler/pypy/interpreter/astcompiler/test/test_astbuilder.py   (contents, props changed)
Log:
add astbuilder.py, which transforms parse trees into new AST

Added: pypy/branch/parser-compiler/pypy/interpreter/astcompiler/astbuilder.py
==============================================================================
--- (empty file)
+++ pypy/branch/parser-compiler/pypy/interpreter/astcompiler/astbuilder.py	Sun Jun 14 22:36:09 2009
@@ -0,0 +1,1194 @@
+from pypy.interpreter.astcompiler import ast2 as ast
+from pypy.interpreter import error
+from pypy.interpreter.pyparser.pygram import syms, tokens
+from pypy.interpreter.pyparser.error import SyntaxError
+from pypy.interpreter.pyparser import parsestring
+
+
+def ast_from_node(space, n):
+    return ASTBuilder(space, n).build_ast()
+
+
+augassign_operator_map = {
+    '+='  : ast.Add,
+    '-='  : ast.Sub,
+    '/='  : ast.Div,
+    '//=' : ast.FloorDiv,
+    '%='  : ast.Mod,
+    '<<='  : ast.LShift,
+    '>>='  : ast.RShift,
+    '&='  : ast.BitAnd,
+    '|='  : ast.BitOr,
+    '^='  : ast.BitXor,
+    '*='  : ast.Mult,
+    '**=' : ast.Pow
+}
+
+operator_map = {
+    tokens.VBAR : ast.BitOr,
+    tokens.CIRCUMFLEX : ast.BitXor,
+    tokens.AMPER : ast.BitAnd,
+    tokens.LEFTSHIFT : ast.LShift,
+    tokens.RIGHTSHIFT : ast.RShift,
+    tokens.PLUS : ast.Add,
+    tokens.MINUS : ast.Sub,
+    tokens.STAR : ast.Mult,
+    tokens.SLASH : ast.Div,
+    tokens.DOUBLESLASH : ast.FloorDiv,
+    tokens.PERCENT : ast.Mod
+}
+
+
+class ASTBuilder(object):
+
+    def __init__(self, space, n):
+        self.space = space
+        if n.type == syms.encoding_decl:
+            self.encoding = n.value
+            n = n.children[0]
+        else:
+            self.encoding = None
+        self.root_node = n
+
+    def build_ast(self):
+        n = self.root_node
+        if n.type == syms.file_input:
+            stmts = []
+            for i in range(len(n.children) - 1):
+                stmt = n.children[i]
+                if stmt.type == tokens.NEWLINE:
+                    continue
+                sub_stmts_count = self.number_of_statements(stmt)
+                if sub_stmts_count == 1:
+                    stmts.append(self.handle_stmt(stmt))
+                else:
+                    for j in range(sub_stmts_count):
+                        small_stmt = stmt.children[j]
+                        stmts.append(self.handle_stmt(small_stmt))
+            return ast.Module(stmts)
+        elif n.type == syms.eval_input:
+            body = self.handle_testlist(n.children[0])
+            return ast.Expression(body)
+        elif n.type == syms.single_input:
+            first_child = n.children[0]
+            if first_child.type == tokens.NEWLINE:
+                # An empty line.
+                return ast.Interactive([])
+            else:
+                num_stmts = self.number_of_statements(first_child)
+                if num_stmts == 1:
+                    stmts = [self.handle_stmt(first_child)]
+                else:
+                    stmts = []
+                    for i in range(0, len(first_child.children), 2):
+                        stmt = first_child.children[i]
+                        if stmt.type == tokens.NEWLINE:
+                            break
+                        stmts.append(self.handle_stmt(first_child))
+                return ast.Interactive(stmts)
+        else:
+            raise AssertionError("unkown root node")
+
+    def number_of_statements(self, n):
+        stmt_type = n.type
+        if stmt_type == syms.compound_stmt:
+            return 1
+        elif stmt_type == syms.stmt:
+            return self.number_of_statements(n.children[0])
+        elif stmt_type == syms.simple_stmt:
+            # Divide to remove semi-colons.
+            return len(n.children) // 2
+        else:
+            raise AssertionError("non-statement node")
+
+    def error(self, msg, n):
+        raise SyntaxError(msg, n.lineno, n.column)
+
+    def check_forbidden_name(self, name, node):
+        if name == "None":
+            self.error("assignment to None", node)
+        if name == "__debug__":
+            self.error("assignment to __debug__", node)
+        # XXX Warn about using True and False
+
+    def set_context(self, expr, ctx, node):
+        error = None
+        sequence = None
+        expr_type = expr.__class__
+        if expr_type is ast.Attribute:
+            if ctx is ast.Store:
+                self.check_forbidden_name(expr.attr, node)
+            expr.ctx = ctx
+        elif expr_type is ast.Subscript:
+            expr.ctx = ctx
+        elif expr_type is ast.Name:
+            if ctx is ast.Store:
+                self.check_forbidden_name(expr.id, node)
+            expr.ctx = ctx
+        elif expr_type is ast.List:
+            expr.ctx = ctx
+            sequence = expr.elts
+        elif expr_type is ast.Tuple:
+            if expr.elts:
+                expr.ctx = ctx
+                sequence = expr.elts
+            else:
+                error = "()"
+        elif expr_type is ast.Lambda:
+            error = "lambda"
+        elif expr_type is ast.Call:
+            error = "call"
+        elif expr_type is ast.BoolOp or \
+                expr_type is ast.BinOp or \
+                expr_type is ast.UnaryOp:
+            error = "operator"
+        elif expr_type is ast.GeneratorExp:
+            error = "generator expression"
+        elif expr_type is ast.Yield:
+            error = "yield expression"
+        elif expr_type is ast.ListComp:
+            error = "list comprehension"
+        elif expr_type is ast.Dict or \
+                expr_type is ast.Num or \
+                expr_type is ast.Str:
+            error = "literal"
+        elif expr_type is ast.Compare:
+            error = "comparison"
+        elif expr_type is ast.IfExp:
+            error = "conditional expression"
+        elif expr_type is ast.Repr:
+            error = "repr"
+        else:
+            raise AssertionError("unkown expression in set_context()")
+        if error is not None:
+            if ctx is ast.Store:
+                action = "assign to"
+            else:
+                action = "delete"
+            self.error("can't %s %s" % (action, error), node)
+        if sequence:
+            for item in sequence:
+                self.set_context(item, ctx, node)
+
+    def handle_print_stmt(self, print_node):
+        dest = None
+        expressions = None
+        newline = True
+        start = 1
+        child_count = len(print_node.children)
+        if child_count > 2 and print_node.children[1].type == tokens.RIGHTSHIFT:
+            dest = self.handle_expr(print_node.children[2])
+            start = 4
+        if (child_count + 1 - start) // 2:
+            expressions = [self.handle_expr(print_node.children[i])
+                           for i in range(start, child_count, 2)]
+        if print_node.children[-1].type == tokens.COMMA:
+            newline = False
+        return ast.Print(dest, expressions, newline, print_node.lineno,
+                         print_node.column)
+
+    def handle_del_stmt(self, del_node):
+        targets = self.handle_exprlist(del_node.children[1], ast.Del)
+        return ast.Delete(targets, del_node.lineno, del_node.column)
+
+    def handle_flow_stmt(self, flow_node):
+        first_child = flow_node.children[0]
+        first_child_type = first_child.type
+        if first_child_type == syms.break_stmt:
+            return ast.Break(flow_node.lineno, flow_node.column)
+        elif first_child_type == syms.continue_stmt:
+            return ast.Continue(flow_node.lineno, flow_node.column)
+        elif first_child_type == syms.yield_stmt:
+            yield_expr = self.handle_expr(first_child.children[0])
+            return ast.Expr(yield_expr, flow_node.lineno, flow_node.column)
+        elif first_child_type == syms.return_stmt:
+            if len(first_child.children) == 1:
+                values = None
+            else:
+                values = self.handle_testlist(first_child.children[1])
+            return ast.Return(values, flow_node.lineno, flow_node.column)
+        elif first_child_type == syms.raise_stmt:
+            exc = None
+            value = None
+            traceback = None
+            child_count = len(first_child.children)
+            if child_count >= 2:
+                exc = self.handle_expr(first_child.children[1])
+            if child_count >= 4:
+                value = self.handle_expr(first_child.children[3])
+            if child_count == 6:
+                traceback = self.handle_expr(first_child.children[5])
+            return ast.Raise(exc, value, traceback, flow_node.lineno,
+                             flow_node.column)
+        else:
+            raise AssertionError("unkown flow statement")
+
+    def alias_for_import_name(self, import_name, store=True):
+        while True:
+            import_name_type = import_name.type
+            if import_name_type == syms.import_as_name:
+                name = import_name.children[0].value
+                if len(import_name.children) == 3:
+                    as_name = import_name.children[2].value
+                    self.check_forbidden_name(as_name, import_name.children[2])
+                else:
+                    as_name = None
+                    self.check_forbidden_name(name, import_name.children[0])
+                return ast.alias(name, as_name)
+            elif import_name_type == syms.dotted_as_name:
+                if len(import_name.children) == 1:
+                    import_name = import_name.children[0]
+                    continue
+                alias = self.alias_for_import_name(import_name.children[0])
+                asname_node = import_name.children[2]
+                alias.asname = asname_node.value
+                self.check_forbidden_name(alias.asname, asname_node)
+                return alias
+            elif import_name_type == syms.dotted_name:
+                if len(import_name.children) == 1:
+                    name = import_name.children[0].value
+                    if store:
+                        self.check_forbidden_name(name, import_name.children[0])
+                    return ast.alias(name, None)
+                name_parts = [import_name.children[i].value
+                              for i in range(0, len(import_name.children), 2)]
+                name = ".".join(name_parts)
+                return ast.alias(name, None)
+            elif import_name_type == tokens.STAR:
+                return ast.alias("*", None)
+            else:
+                raise AssertionError("unkown import name")
+
+    def handle_import_stmt(self, import_node):
+        import_node = import_node.children[0]
+        if import_node.type == syms.import_name:
+            dotted_as_names = import_node.children[1]
+            aliases = [self.alias_for_import_name(dotted_as_names.children[i])
+                       for i in range(0, len(dotted_as_names.children), 2)]
+            return ast.Import(aliases, import_node.lineno, import_node.column)
+        elif import_node.type == syms.import_from:
+            child_count = len(import_node.children)
+            module = None
+            modname = None
+            i = 1
+            dot_count = 0
+            while i < child_count:
+                child = import_node.children[i]
+                if child.type == syms.dotted_name:
+                    module = self.alias_for_import_name(child, False)
+                    i += 1
+                    break
+                elif child.type != tokens.DOT:
+                    break
+                i += 1
+                dot_count += 1
+            i += 1
+            after_import_type = import_node.children[i].type
+            star_import = False
+            if after_import_type == tokens.STAR:
+                names_node = import_node.children[i]
+                star_import = True
+            elif after_import_type == tokens.LPAR:
+                names_node = import_node.children[i + 1]
+            elif after_import_type == syms.import_as_names:
+                names_node = import_node.children[i]
+                if len(names_node.children) % 2 == 0:
+                    self.error("trailing comma is only allowed with "
+                               "surronding parenthesis", names_node)
+            else:
+                raise AssertionError("unkown import node")
+            if star_import:
+                aliases = [self.alias_for_import_name(names_node)]
+            else:
+                aliases = [self.alias_for_import_name(names_node.children[i])
+                           for i in range(0, len(names_node.children), 2)]
+            if module is not None:
+                modname = module.name
+            return ast.ImportFrom(modname, aliases, dot_count,
+                                  import_node.lineno, import_node.column)
+        else:
+            raise AssertionError("unkown import node")
+
+    def handle_global_stmt(self, global_node):
+        names = [global_node.children[i].value
+                 for i in range(1, len(global_node.children), 2)]
+        return ast.Global(names, global_node.lineno, global_node.column)
+
+    def handle_exec_stmt(self, exec_node):
+        child_count = len(exec_node.children)
+        globs = None
+        locs = None
+        to_execute = self.handle_expr(exec_node.children[1])
+        if child_count >= 4:
+            globs = self.handle_expr(exec_node.children[3])
+        if child_count == 6:
+            locs = self.handle_expr(exec_node.children[5])
+        return ast.Exec(to_execute, globs, locs, exec_node.lineno,
+                        exec_node.column)
+
+    def handle_assert_stmt(self, assert_node):
+        child_count = len(assert_node.children)
+        expr = self.handle_expr(assert_node.children[1])
+        msg = None
+        if len(assert_node.children) == 4:
+            msg = self.handle_expr(assert_node.children[3])
+        return ast.Assert(expr, msg, assert_node.lineno, assert_node.column)
+
+    def handle_suite(self, suite_node):
+        first_child = suite_node.children[0]
+        if first_child.type == syms.simple_stmt:
+            end = len(first_child.children) - 1
+            if first_child.children[end - 1].type == tokens.SEMI:
+                end -= 1
+            stmts = [self.handle_stmt(first_child.children[i])
+                     for i in range(0, end, 2)]
+        else:
+            stmts = []
+            for i in range(2, len(suite_node.children) - 1):
+                stmt = suite_node.children[i]
+                stmt_count = self.number_of_statements(stmt)
+                if stmt_count == 1:
+                    stmts.append(self.handle_stmt(stmt))
+                else:
+                    simple_stmt = stmt.children[0]
+                    for j in range(0, len(simple_stmt.children), 2):
+                        stmt = simple_stmt.children[j]
+                        if not stmt.children:
+                            break
+                        stmts.append(self.handle_stmt(stmt))
+        return stmts
+
+    def handle_if_stmt(self, if_node):
+        child_count = len(if_node.children)
+        if child_count == 4:
+            test = self.handle_expr(if_node.children[1])
+            suite = self.handle_suite(if_node.children[3])
+            return ast.If(test, suite, None, if_node.lineno, if_node.column)
+        otherwise_string = if_node.children[4].value
+        if otherwise_string == "else":
+            test = self.handle_expr(if_node.children[1])
+            suite = self.handle_suite(if_node.children[3])
+            else_suite = self.handle_suite(if_node.children[6])
+            return ast.If(test, suite, else_suite, if_node.lineno,
+                          if_node.column)
+        elif otherwise_string == "elif":
+            elif_count = child_count - 4
+            after_elif = if_node.children[elif_count + 1]
+            if after_elif.type == tokens.NAME and \
+                    after_elif.value == "else":
+                has_else = True
+                elif_count -= 3
+            else:
+                has_else = False
+            elif_count /= 4
+            if has_else:
+                last_elif = if_node.children[-6]
+                last_elif_test = self.handle_expr(last_elif)
+                elif_body = self.handle_suite(if_node.children[-4])
+                else_body = self.handle_suite(if_node.children[-1])
+                otherwise = [ast.If(last_elif_test, elif_body, else_body,
+                                    last_elif.lineno, last_elif.column)]
+                elif_count -= 1
+            else:
+                otherwise = None
+            for i in range(elif_count):
+                offset = 5 + (elif_count - i - 1) * 4
+                elif_test_node = if_node.children[offset]
+                elif_test = self.handle_expr(elif_test_node)
+                elif_body = self.handle_suite(if_node.children[offset + 2])
+                new_if = ast.If(elif_test, elif_body, otherwise,
+                                elif_test_node.lineno, elif_test_node.column)
+                otherwise = [new_if]
+            expr = self.handle_expr(if_node.children[1])
+            body = self.handle_suite(if_node.children[3])
+            return ast.If(expr, body, otherwise, if_node.lineno, if_node.column)
+        else:
+            raise AssertionError("unkown if statement configuration")
+
+    def handle_while_stmt(self, while_node):
+        loop_test = self.handle_expr(while_node.children[1])
+        body = self.handle_suite(while_node.children[3])
+        if len(while_node.children) == 7:
+            otherwise = self.handle_suite(while_node.children[6])
+        else:
+            otherwise = None
+        return ast.While(loop_test, body, otherwise, while_node.lineno,
+                         while_node.column)
+
+    def handle_for_stmt(self, for_node):
+        target_node = for_node.children[1]
+        target_as_exprlist = self.handle_exprlist(target_node, ast.Store)
+        if len(target_node.children) == 1:
+            target = target_as_exprlist[0]
+        else:
+            target = ast.Tuple(target_as_exprlist, ast.Store,
+                               target_node.lineno, target_node.column)
+        expr = self.handle_testlist(for_node.children[3])
+        body = self.handle_suite(for_node.children[5])
+        if len(for_node.children) == 9:
+            otherwise = self.handle_suite(for_node.children[8])
+        else:
+            otherwise = None
+        return ast.For(target, expr, body, otherwise, for_node.lineno,
+                       for_node.column)
+
+    def handle_except_clause(self, exc, body):
+        test = None
+        target = None
+        suite = self.handle_suite(body)
+        child_count = len(exc.children)
+        if child_count >= 2:
+            test = self.handle_expr(exc.children[1])
+        if child_count == 4:
+            target_child = exc.children[3]
+            target = self.handle_expr(target_child)
+            self.set_context(target, ast.Store, target_child)
+        return ast.excepthandler(test, target, suite, exc.lineno, exc.column)
+
+    def handle_try_stmt(self, try_node):
+        body = self.handle_suite(try_node.children[2])
+        child_count = len(try_node.children)
+        except_count = (child_count - 3 ) // 3
+        otherwise = None
+        finally_suite = None
+        possible_extra_clause = try_node.children[-3]
+        if possible_extra_clause.type == tokens.NAME:
+            if possible_extra_clause.value == "finally":
+                if child_count >= 9 and \
+                        try_node.children[-6].type == tokens.NAME:
+                    otherwise = self.handle_suite(try_node.children[-4])
+                    except_count -= 1
+                finally_suite = self.handle_suite(try_node.children[-1])
+                except_count -= 1
+            else:
+                otherwise = self.handle_suite(try_node.children[-1])
+                except_count -= 1
+        if except_count:
+            handlers = []
+            for i in range(except_count):
+                base_offset = i * 3
+                exc = try_node.children[3 + base_offset]
+                except_body = try_node.children[5 + base_offset]
+                handlers.append(self.handle_except_clause(exc, except_body))
+            except_ast = ast.TryExcept(body, handlers, otherwise,
+                                       try_node.lineno, try_node.column)
+            if finally_suite is None:
+                return except_ast
+            body = [except_ast]
+        return ast.TryFinally(body, finally_suite, try_node.lineno,
+                              try_node.column)
+
+    def handle_with_stmt(self, with_node):
+        test = self.handle_expr(with_node.children[1])
+        body = self.handle_suite(with_node.children[-1])
+        if len(with_node.children) == 5:
+            target_node = with_node.children[2]
+            target = self.handle_with_var(target_node)
+            self.set_context(target, ast.Store, target_node)
+        else:
+            target = None
+        return ast.With(test, target, body, with_node.lineno, with_node.column)
+
+    def handle_with_var(self, with_var_node):
+        if with_var_node.children[0].value != "as":
+            self.error("expected \"with [expr] as [var]\"", with_var_node)
+        return self.handle_expr(with_var_node.children[1])
+
+    def handle_classdef(self, classdef_node):
+        name_node = classdef_node.children[1]
+        name = name_node.value
+        self.check_forbidden_name(name, name_node)
+        if len(classdef_node.children) == 4:
+            body = self.handle_suite(classdef_node.children[3])
+            return ast.ClassDef(name, None, body, classdef_node.lineno,
+                                classdef_node.column)
+        if classdef_node.children[3].type == tokens.RPAR:
+            body = self.handle_suite(classdef_node.children[5])
+            return ast.ClassDef(name, None, body, classdef_node.lineno,
+                                classdef_node.column)
+        bases = self.handle_class_bases(classdef_node.children[3])
+        body = self.handle_suite(classdef_node.children[6])
+        return ast.ClassDef(name, bases, body, classdef_node.lineno,
+                            classdef_node.column)
+
+    def handle_class_bases(self, bases_node):
+        if len(bases_node.children) == 1:
+            return [self.handle_expr(bases_node.children[0])]
+        return self.get_expression_list(bases_node)
+
+    def handle_funcdef(self, funcdef_node):
+        if len(funcdef_node.children) == 6:
+            decorators = self.handle_decorators(funcdef_node.children[0])
+            name_index = 2
+        else:
+            decorators = None
+            name_index = 1
+        name_node = funcdef_node.children[name_index]
+        name = name_node.value
+        self.check_forbidden_name(name, name_node)
+        args = self.handle_arguments(funcdef_node.children[name_index + 1])
+        body = self.handle_suite(funcdef_node.children[name_index + 3])
+        return ast.FunctionDef(name, args, body, decorators,
+                               funcdef_node.lineno, funcdef_node.column)
+
+    def handle_decorators(self, decorators_node):
+        return [self.handle_decorator(dec) for dec in decorators_node.children]
+
+    def handle_decorator(self, decorator_node):
+        dec_name = self.handle_dotted_name(decorator_node.children[1])
+        if len(decorator_node.children) == 3:
+            dec = dec_name
+        elif len(decorator_node.children) == 5:
+            dec = ast.Call(dec_name, None, None, None, None,
+                           decorator_node.lineno, decorator_node.column)
+        else:
+            dec = self.handle_call(decorator_node.children[3], dec_name)
+        return dec
+
+    def handle_dotted_name(self, dotted_name_node):
+        base_value = dotted_name_node.children[0].value
+        name = ast.Name(base_value, ast.Load, dotted_name_node.lineno,
+                        dotted_name_node.column)
+        for i in range(2, len(dotted_name_node.children), 2):
+            attr = dotted_name_node.children[i].value
+            name = ast.Attribute(name, attr, ast.Load, dotted_name_node.lineno,
+                                 dotted_name_node.column)
+        return name
+
+    def handle_arguments(self, arguments_node):
+        if arguments_node.type == syms.parameters:
+            if len(arguments_node.children) == 2:
+                return ast.arguments(None, None, None, None)
+            arguments_node = arguments_node.children[1]
+        i = 0
+        child_count = len(arguments_node.children)
+        defaults = []
+        args = []
+        variable_arg = None
+        keywords_arg = None
+        have_default = False
+        while i < child_count:
+            argument = arguments_node.children[i]
+            arg_type = argument.type
+            if arg_type == syms.fpdef:
+                while True:
+                    if i + 1 < child_count and \
+                            arguments_node.children[i + 1].type == tokens.EQUAL:
+                        default_node = arguments_node.children[i + 2]
+                        defaults.append(self.handle_expr(default_node))
+                        i += 2
+                        have_default = True
+                    elif have_default:
+                        self.error("non-default argument after default one",
+                                   arguments_node)
+                    if len(argument.children) == 3:
+                        sub_arg = argument.children[1]
+                        if len(sub_arg.children) != 1:
+                            args.append(self.handle_arg_unpacking(sub_arg))
+                        else:
+                            argument = sub_arg.children[0]
+                            continue
+                    if argument.children[0].type == tokens.NAME:
+                        name_node = argument.children[0]
+                        arg_name = name_node.value
+                        self.check_forbidden_name(arg_name, name_node)
+                        name = ast.Name(arg_name, ast.Store, name_node.lineno,
+                                        name_node.column)
+                        args.append(name)
+                    i += 2
+                    break
+            elif arg_type == tokens.STAR:
+                name_node = arguments_node.children[i + 1]
+                variable_arg = name_node.value
+                self.check_forbidden_name(variable_arg, name_node)
+                i += 3
+            elif arg_type == tokens.DOUBLESTAR:
+                name_node = arguments_node.children[i + 1]
+                keywords_arg = name_node.value
+                self.check_forbidden_name(keywords_arg, name_node)
+                i += 3
+            else:
+                raise AssertionError("unkown node in argument list")
+        if not defaults:
+            defaults = None
+        if not args:
+            args = None
+        return ast.arguments(args, variable_arg, keywords_arg, defaults)
+
+    def handle_arg_unpacking(self, fplist_node):
+        args = []
+        for i in range((len(fplist_node.children) + 1) / 2):
+            fpdef_node = fplist_node.children[i * 2]
+            while True:
+                child = fpdef_node.children[0]
+                if child.type == tokens.NAME:
+                    arg = ast.Name(child.value, ast.Store, child.lineno,
+                                   child.column)
+                    args.append(arg)
+                else:
+                    child = fpdef_node.children[1]
+                    if len(child.children) == 1:
+                        fpdef_node = child.children[0]
+                        continue
+                    args.append(self.handle_arg_unpacking(child))
+                break
+        tup = ast.Tuple(args, ast.Store, fplist_node.lineno, fplist_node.column)
+        self.set_context(tup, ast.Store, fplist_node)
+        return tup
+
+    def handle_stmt(self, stmt):
+        stmt_type = stmt.type
+        if stmt_type == syms.stmt:
+            stmt = stmt.children[0]
+            stmt_type = stmt.type
+        if stmt_type == syms.simple_stmt:
+            stmt = stmt.children[0]
+            stmt_type = stmt.type
+        if stmt_type == syms.small_stmt:
+            stmt = stmt.children[0]
+            stmt_type = stmt.type
+            if stmt_type == syms.expr_stmt:
+                return self.handle_expr_stmt(stmt)
+            elif stmt_type == syms.print_stmt:
+                return self.handle_print_stmt(stmt)
+            elif stmt_type == syms.del_stmt:
+                return self.handle_del_stmt(stmt)
+            elif stmt_type == syms.pass_stmt:
+                return ast.Pass(stmt.lineno, stmt.column)
+            elif stmt_type == syms.flow_stmt:
+                return self.handle_flow_stmt(stmt)
+            elif stmt_type == syms.import_stmt:
+                return self.handle_import_stmt(stmt)
+            elif stmt_type == syms.global_stmt:
+                return self.handle_global_stmt(stmt)
+            elif stmt_type == syms.assert_stmt:
+                return self.handle_assert_stmt(stmt)
+            elif stmt_type == syms.exec_stmt:
+                return self.handle_exec_stmt(stmt)
+            else:
+                raise AssertionError("unhandled small statement")
+        elif stmt_type == syms.compound_stmt:
+            stmt = stmt.children[0]
+            stmt_type = stmt.type
+            if stmt_type == syms.if_stmt:
+                return self.handle_if_stmt(stmt)
+            elif stmt_type == syms.while_stmt:
+                return self.handle_while_stmt(stmt)
+            elif stmt_type == syms.for_stmt:
+                return self.handle_for_stmt(stmt)
+            elif stmt_type == syms.try_stmt:
+                return self.handle_try_stmt(stmt)
+            elif stmt_type == syms.with_stmt:
+                return self.handle_with_stmt(stmt)
+            elif stmt_type == syms.funcdef:
+                return self.handle_funcdef(stmt)
+            elif stmt_type == syms.classdef:
+                return self.handle_classdef(stmt)
+            else:
+                raise AssertionError("unhandled compound statement")
+        else:
+            raise AssertionError("unkown statment type")
+
+    def handle_expr_stmt(self, stmt):
+        if len(stmt.children) == 1:
+            expression = self.handle_testlist(stmt.children[0])
+            return ast.Expr(expression, stmt.lineno, stmt.column)
+        elif stmt.children[1].type == syms.augassign:
+            # Augmented assignment.
+            target_child = stmt.children[0]
+            target_expr = self.handle_testlist(target_child)
+            self.set_context(target_expr, ast.Store, target_child)
+            value_child = stmt.children[2]
+            if value_child.type == syms.testlist:
+                value_expr = self.handle_testlist(value_child)
+            else:
+                value_expr = self.handle_expr(value_child)
+            op_str = stmt.children[1].children[0].value
+            operator = augassign_operator_map[op_str]
+            return ast.AugAssign(target_expr, operator, value_expr,
+                                 stmt.lineno, stmt.column)
+        else:
+            # Normal assignment.
+            targets = []
+            for i in range(0, len(stmt.children) - 2, 2):
+                target_node = stmt.children[i]
+                if target_node.type == syms.yield_expr:
+                    self.error("can't assign to yield expr", target_node)
+                target_expr = self.handle_testlist(target_node)
+                self.set_context(target_expr, ast.Store, target_node)
+                targets.append(target_expr)
+            value_child = stmt.children[-1]
+            if value_child.type == syms.testlist:
+                value_expr = self.handle_testlist(value_child)
+            else:
+                value_expr = self.handle_expr(value_child)
+            return ast.Assign(targets, value_expr, stmt.lineno, stmt.column)
+
+    def get_expression_list(self, tests):
+        return [self.handle_expr(tests.children[i])
+                for i in range(0, len(tests.children), 2)]
+
+    def handle_testlist(self, tests):
+        if len(tests.children) == 1:
+            return self.handle_expr(tests.children[0])
+        else:
+            elts = self.get_expression_list(tests)
+            return ast.Tuple(elts, ast.Load, tests.lineno, tests.column)
+
+    def handle_expr(self, expr_node):
+        # Loop until we return something.
+        while True:
+            expr_node_type = expr_node.type
+            if expr_node_type == syms.test or expr_node_type == syms.old_test:
+                first_child = expr_node.children[0]
+                if first_child.type in (syms.lambdef, syms.old_lambdef):
+                    return self.handle_lambdef(first_child)
+                elif len(expr_node.children) > 1:
+                    return self.handle_ifexp(expr_node)
+                else:
+                    expr_node = first_child
+            elif expr_node_type == syms.or_test or \
+                    expr_node_type == syms.and_test:
+                if len(expr_node.children) == 1:
+                    expr_node = expr_node.children[0]
+                    continue
+                seq = [self.handle_expr(expr_node.children[i])
+                       for i in range(0, len(expr_node.children), 2)]
+                if expr_node_type == syms.or_test:
+                    op = ast.Or
+                else:
+                    op = ast.And
+                return ast.BoolOp(op, seq, expr_node.lineno, expr_node.column)
+            elif expr_node_type == syms.not_test:
+                if len(expr_node.children) == 1:
+                    expr_node = expr_node.children[0]
+                    continue
+                expr = self.handle_expr(expr_node.children[1])
+                return ast.UnaryOp(ast.Not, expr, expr_node.lineno,
+                                   expr_node.column)
+            elif expr_node_type == syms.comparison:
+                if len(expr_node.children) == 1:
+                    expr_node = expr_node.children[0]
+                    continue
+                operators = []
+                operands = []
+                expr = self.handle_expr(expr_node.children[0])
+                for i in range(1, len(expr_node.children), 2):
+                    operators.append(self.handle_comp_op(expr_node.children[i]))
+                    operands.append(self.handle_expr(expr_node.children[i + 1]))
+                return ast.Compare(expr, operators, operands, expr_node.lineno,
+                                   expr_node.column)
+            elif expr_node_type == syms.expr or \
+                    expr_node_type == syms.xor_expr or \
+                    expr_node_type == syms.and_expr or \
+                    expr_node_type == syms.shift_expr or \
+                    expr_node_type == syms.arith_expr or \
+                    expr_node_type == syms.term:
+                if len(expr_node.children) == 1:
+                    expr_node = expr_node.children[0]
+                    continue
+                return self.handle_binop(expr_node)
+            elif expr_node_type == syms.yield_expr:
+                if len(expr_node.children) == 2:
+                    exp = self.handle_testlist(expr_node.children[1])
+                else:
+                    exp = None
+                return ast.Yield(exp, expr_node.lineno, expr_node.column)
+            elif expr_node_type == syms.factor:
+                if len(expr_node.children) == 1:
+                    expr_node = expr_node.children[0]
+                    continue
+                return self.handle_factor(expr_node)
+            elif expr_node_type == syms.power:
+                return self.handle_power(expr_node)
+            else:
+                raise AssertionError("unkown expr")
+
+    def handle_lambdef(self, lambdef_node):
+        expr = self.handle_expr(lambdef_node.children[-1])
+        if len(lambdef_node.children) == 3:
+            args = ast.arguments(None, None, None, None)
+        else:
+            args = self.handle_arguments(lambdef_node.children[1])
+        return ast.Lambda(args, expr, lambdef_node.lineno, lambdef_node.column)
+
+    def handle_ifexp(self, if_expr_node):
+        body = self.handle_expr(if_expr_node.children[0])
+        expression = self.handle_expr(if_expr_node.children[2])
+        otherwise = self.handle_expr(if_expr_node.children[4])
+        return ast.IfExp(expression, body, otherwise, if_expr_node.lineno,
+                         if_expr_node.column)
+
+    def handle_comp_op(self, comp_op_node):
+        comp_node = comp_op_node.children[0]
+        comp_type = comp_node.type
+        if len(comp_op_node.children) == 1:
+            if comp_type == tokens.LESS:
+                return ast.Lt
+            elif comp_type == tokens.GREATER:
+                return ast.Gt
+            elif comp_type == tokens.EQEQUAL:
+                return ast.Eq
+            elif comp_type == tokens.LESSEQUAL:
+                return ast.LtE
+            elif comp_type == tokens.GREATEREQUAL:
+                return ast.GtE
+            elif comp_type == tokens.NOTEQUAL:
+                return ast.NotEq
+            elif comp_type == tokens.NAME:
+                if comp_node.value == "is":
+                    return ast.Is
+                elif comp_node.value == "in":
+                    return ast.In
+                else:
+                    raise AssertionError("invalid comparison")
+            else:
+                raise AssertionError("invalid comparison")
+        else:
+            if comp_op_node.children[1].value == "in":
+                return ast.NotIn
+            elif comp_node.value == "is":
+                return ast.IsNot
+            else:
+                raise AssertionError("invalid comparison")
+
+    def handle_binop(self, binop_node):
+        left = self.handle_expr(binop_node.children[0])
+        right = self.handle_expr(binop_node.children[2])
+        op = operator_map[binop_node.children[1].type]
+        result = ast.BinOp(left, op, right, binop_node.lineno,
+                           binop_node.column)
+        number_of_ops = (len(binop_node.children) - 1) / 2
+        for i in range(1, number_of_ops):
+            op_node = binop_node.children[i * 2 + 1]
+            op = operator_map[op_node.type]
+            sub_right = self.handle_expr(binop_node.children[i * 2 + 2])
+            result = ast.BinOp(result, op, sub_right, op_node.lineno,
+                               op_node.column)
+        return result
+
+    def handle_factor(self, factor_node):
+        expr = self.handle_expr(factor_node.children[1])
+        op_type = factor_node.children[0].type
+        if op_type == tokens.PLUS:
+            op = ast.UAdd
+        elif op_type == tokens.MINUS:
+            op = ast.USub
+        elif op_type == tokens.TILDE:
+            op = ast.Invert
+        else:
+            raise AssertionError("invalid factor node")
+        return ast.UnaryOp(op, expr, factor_node.lineno, factor_node.column)
+
+    def handle_power(self, power_node):
+        atom_expr = self.handle_atom(power_node.children[0])
+        if len(power_node.children) == 1:
+            return atom_expr
+        for i in range(1, len(power_node.children)):
+            trailer = power_node.children[i]
+            if trailer.type != syms.trailer:
+                break
+            tmp_atom_expr = self.handle_trailer(trailer, atom_expr)
+            tmp_atom_expr.lineno = atom_expr.lineno
+            tmp_atom_expr.column = atom_expr.col_offset
+            atom_expr = tmp_atom_expr
+        if power_node.children[-1].type == syms.factor:
+            right = self.handle_expr(power_node.children[-1])
+            atom_expr = ast.BinOp(atom_expr, ast.Pow, right, power_node.lineno,
+                                  power_node.column)
+        return atom_expr
+
+    def handle_slice(self, slice_node):
+        first_child = slice_node.children[0]
+        if first_child.type == tokens.DOT:
+            return ast.Ellipsis()
+        if len(slice_node.children) == 1 and first_child.type == syms.test:
+            index = self.handle_expr(first_child)
+            return ast.Index(index)
+        lower = None
+        upper = None
+        step = None
+        if first_child.type == syms.test:
+            lower = self.handle_expr(first_child)
+        if first_child.type == tokens.COLON:
+            if len(slice_node.children) > 1:
+                second_child = slice_node.children[1]
+                if second_child.type == syms.test:
+                    upper = self.handle_expr(second_child)
+        elif len(slice_node.children) > 2:
+            third_child = slice_node.children[2]
+            if third_child.type == syms.test:
+                upper = self.handle_expr(third_child)
+        last_child = slice_node.children[-1]
+        if last_child.type == syms.sliceop and len(last_child.children) == 2:
+                step_child = last_child.children[1]
+                if step_child.type == syms.test:
+                    step = self.handle_expr(step_child)
+        return ast.Slice(lower, upper, step)
+
+    def handle_trailer(self, trailer_node, left_expr):
+        first_child = trailer_node.children[0]
+        if first_child.type == tokens.LPAR:
+            if len(trailer_node.children) == 2:
+                return ast.Call(left_expr, None, None, None, None,
+                                trailer_node.lineno, trailer_node.column)
+            else:
+                return self.handle_call(trailer_node.children[1], left_expr)
+        elif first_child.type == tokens.DOT:
+            attr = trailer_node.children[1].value
+            return ast.Attribute(left_expr, attr, ast.Load,
+                                 trailer_node.lineno, trailer_node.column)
+        else:
+            middle = trailer_node.children[1]
+            if len(middle.children) == 1:
+                slice = self.handle_slice(middle.children[0])
+                return ast.Subscript(left_expr, slice, ast.Load,
+                                     middle.lineno, middle.column)
+            slices = []
+            simple = True
+            for i in range(0, len(middle.children), 2):
+                slc = self.handle_slice(middle.children[i])
+                if not isinstance(slc, ast.Index):
+                    simple = False
+                slices.append(slc)
+            if not simple:
+                ext_slice = ast.ExtSlice(slices)
+                return ast.Subscript(left_expr, ext_slice, ast.Load,
+                                     middle.lineno, middle.column)
+            elts = [idx.value for idx in slices]
+            tup = ast.Tuple(elts, ast.Load, middle.lineno, middle.column)
+            return ast.Subscript(left_expr, ast.Index(tup), ast.Load,
+                                 middle.lineno, middle.column)
+
+    def handle_call(self, args_node, callable_expr):
+        arg_count = 0
+        keyword_count = 0
+        generator_count = 0
+        for argument in args_node.children:
+            if argument.type == syms.argument:
+                if len(argument.children) == 1:
+                    arg_count += 1
+                elif argument.children[1].type == syms.gen_for:
+                    generator_count += 1
+                else:
+                    keyword_count += 1
+        if generator_count > 1 or \
+                (generator_count and (keyword_count or arg_count)):
+            self.error("Generator expression must be parenthesized "
+                       "if not sole argument", args_node)
+        if arg_count + keyword_count + generator_count > 255:
+            self.error("more than 255 arguments", args_node)
+        args = []
+        keywords = []
+        used_keywords = {}
+        variable_arg = None
+        keywords_arg = None
+        child_count = len(args_node.children)
+        i = 0
+        while i < child_count:
+            argument = args_node.children[i]
+            if argument.type == syms.argument:
+                if len(argument.children) == 1:
+                    expr_node = argument.children[0]
+                    if keywords:
+                        self.error("non-keyword arg after keyword arg",
+                                   expr_node)
+                    args.append(self.handle_expr(expr_node))
+                elif argument.children[1].type == syms.gen_for:
+                    args.append(self.handle_genexp(argument))
+                else:
+                    keyword_node = argument.children[0]
+                    keyword_expr = self.handle_expr(keyword_node)
+                    if isinstance(keyword_expr, ast.Lambda):
+                        self.error("lambda cannot contain assignment",
+                                   keyword_node)
+                    elif not isinstance(keyword_expr, ast.Name):
+                        self.error("keyword can't be an expression",
+                                   keyword_node)
+                    keyword = keyword_expr.id
+                    if keyword in used_keywords:
+                        self.error("keyword argument repeated", keyword_node)
+                    used_keywords[keyword] = None
+                    self.check_forbidden_name(keyword, keyword_node)
+                    keyword_value = self.handle_expr(argument.children[2])
+                    keywords.append(ast.keyword(keyword, keyword_value))
+            elif argument.type == tokens.STAR:
+                variable_arg = self.handle_expr(args_node.children[i + 1])
+                i += 1
+            elif argument.type == tokens.DOUBLESTAR:
+                keywords_arg = self.handle_expr(args_node.children[i + 1])
+                i += 1
+            i += 1
+        if not args:
+            args = None
+        if not keywords:
+            keywords = None
+        return ast.Call(callable_expr, args, keywords, variable_arg,
+                        keywords_arg, callable_expr.lineno,
+                        callable_expr.col_offset)
+
+    def parse_number(self, raw):
+        base = 10
+        w_num_str = self.space.wrap(raw)
+        if raw[-1] in "lL":
+            tp = self.space.w_long
+            return self.space.call_function(tp, w_num_str)
+        elif raw[-1] in "jJ":
+            tp = self.space.w_complex
+            return self.space.call_function(tp, w_num_str)
+        try:
+            return self.space.call_function(self.space.w_int, w_num_str)
+        except error.OperationError, e:
+            if not e.match(self.space, self.space.w_ValueError):
+                raise
+            return self.space.call_function(self.space.w_float, w_num_str)
+
+    def handle_atom(self, atom_node):
+        first_child = atom_node.children[0]
+        first_child_type = first_child.type
+        if first_child_type == tokens.NAME:
+            return ast.Name(first_child.value, ast.Load,
+                            atom_node.lineno, atom_node.column)
+        elif first_child_type == tokens.STRING:
+            space = self.space
+            sub_strings_w = [parsestring.parsestr(space, self.encoding, s.value)
+                             for s in atom_node.children]
+            if len(sub_strings_w) > 1:
+                w_sub_strings = space.newlist(sub_strings_w)
+                w_join = space.getattr(space.wrap(""), space.wrap("join"))
+                final_string = space.call_function(w_join, w_sub_strings)
+            else:
+                final_string = sub_strings_w[0]
+            return ast.Str(final_string, atom_node.lineno, atom_node.column)
+        elif first_child_type == tokens.NUMBER:
+            num_value = self.parse_number(first_child.value)
+            return ast.Num(num_value, atom_node.lineno, atom_node.column)
+        elif first_child_type == tokens.LPAR:
+            second_child = atom_node.children[1]
+            if second_child.type == tokens.RPAR:
+                return ast.Tuple(None, ast.Load, atom_node.lineno,
+                                 atom_node.column)
+            elif second_child.type == syms.yield_expr:
+                return self.handle_expr(second_child)
+            return self.handle_testlist_gexp(second_child)
+        elif first_child_type == tokens.LSQB:
+            second_child = atom_node.children[1]
+            if second_child.type == tokens.RSQB:
+                return ast.List(None, ast.Load, atom_node.lineno,
+                                atom_node.column)
+            if len(second_child.children) == 1 or \
+                    second_child.children[1].type == tokens.COMMA:
+                elts = self.get_expression_list(second_child)
+                return ast.List(elts, ast.Load, atom_node.lineno,
+                                atom_node.column)
+            return self.handle_listcomp(second_child)
+        elif first_child_type == tokens.LBRACE:
+            if len(atom_node.children) == 2:
+                return ast.Dict(None, None, atom_node.lineno, atom_node.column)
+            second_child = atom_node.children[1]
+            keys = []
+            values = []
+            for i in range(0, len(second_child.children), 4):
+                keys.append(self.handle_expr(second_child.children[i]))
+                values.append(self.handle_expr(second_child.children[i + 2]))
+            return ast.Dict(keys, values, atom_node.lineno, atom_node.column)
+        elif first_child_type == tokens.BACKQUOTE:
+            expr = self.handle_testlist(atom_node.children[1])
+            return ast.Repr(expr, atom_node.lineno, atom_node.column)
+        else:
+            raise AssertionError("unkown atom")
+
+    def handle_testlist_gexp(self, gexp_node):
+        if len(gexp_node.children) > 1 and \
+                gexp_node.children[1].type == syms.gen_for:
+            return self.handle_genexp(gexp_node)
+        return self.handle_testlist(gexp_node)
+
+    def count_comp_fors(self, comp_node, for_type, if_type):
+        count = 0
+        current_for = comp_node.children[1]
+        while True:
+            count += 1
+            if len(current_for.children) == 5:
+                current_iter = current_for.children[4]
+            else:
+                return count
+            while True:
+                first_child = current_iter.children[0]
+                if first_child.type == for_type:
+                    current_for = current_iter.children[0]
+                    break
+                elif first_child.type == if_type:
+                    if len(first_child.children) == 3:
+                        current_iter = first_child.children[2]
+                    else:
+                        return count
+                else:
+                    raise AssertionError("should not reach here")
+
+    def count_comp_ifs(self, iter_node, for_type):
+        count = 0
+        while True:
+            first_child = iter_node.children[0]
+            if first_child.type == for_type:
+                return count
+            count += 1
+            if len(first_child.children) == 2:
+                return count
+            iter_node = first_child.children[2]
+
+    def comprehension_helper(self, comp_node, for_type, if_type, iter_type,
+                             handle_source_expression):
+        elt = self.handle_expr(comp_node.children[0])
+        fors_count = self.count_comp_fors(comp_node, for_type, if_type)
+        comps = []
+        comp_for = comp_node.children[1]
+        for i in range(fors_count):
+            for_node = comp_for.children[1]
+            for_targets = self.handle_exprlist(for_node, ast.Store)
+            expr = handle_source_expression(comp_for.children[3])
+            if len(for_node.children) == 1:
+                comp = ast.comprehension(for_targets[0], expr, None)
+            else:
+                target = ast.Tuple(for_targets, ast.Store, comp_for.lineno,
+                                   comp_for.column)
+                comp = ast.comprehension(target, expr, None)
+            if len(comp_for.children) == 5:
+                comp_for = comp_iter = comp_for.children[4]
+                assert comp_iter.type == iter_type
+                ifs_count = self.count_comp_ifs(comp_iter, for_type)
+                if ifs_count:
+                    ifs = []
+                    for j in range(ifs_count):
+                        comp_for = comp_if = comp_iter.children[0]
+                        ifs.append(self.handle_expr(comp_if.children[1]))
+                        if len(comp_if.children) == 3:
+                            comp_for = comp_iter = comp_if.children[2]
+                    comp.ifs = ifs
+                if comp_for.type == iter_type:
+                    comp_for = comp_for.children[0]
+            comps.append(comp)
+        return elt, comps
+
+    def handle_genexp(self, genexp_node):
+        elt, comps = self.comprehension_helper(genexp_node, syms.gen_for,
+                                               syms.gen_if, syms.gen_iter,
+                                               self.handle_expr)
+        return ast.GeneratorExp(elt, comps, genexp_node.lineno,
+                                genexp_node.column)
+
+    def handle_listcomp(self, listcomp_node):
+        elt, comps = self.comprehension_helper(listcomp_node, syms.list_for,
+                                               syms.list_if, syms.list_iter,
+                                               self.handle_testlist)
+        return ast.ListComp(elt, comps, listcomp_node.lineno,
+                            listcomp_node.column)
+
+    def handle_exprlist(self, exprlist, context):
+        exprs = []
+        for i in range(0, len(exprlist.children), 2):
+            child = exprlist.children[i]
+            expr = self.handle_expr(child)
+            self.set_context(expr, context, child)
+            exprs.append(expr)
+        return exprs

Added: pypy/branch/parser-compiler/pypy/interpreter/astcompiler/test/test_astbuilder.py
==============================================================================
--- (empty file)
+++ pypy/branch/parser-compiler/pypy/interpreter/astcompiler/test/test_astbuilder.py	Sun Jun 14 22:36:09 2009
@@ -0,0 +1,1066 @@
+import random
+import string
+import py
+from pypy.interpreter.baseobjspace import W_Root
+from pypy.interpreter.pyparser import pyparse, pygram
+from pypy.interpreter.pyparser.error import SyntaxError
+from pypy.interpreter.astcompiler.astbuilder import ast_from_node
+from pypy.interpreter.astcompiler import ast2 as ast
+
+
+class TestAstBuilder:
+
+    def setup_class(cls):
+        cls.parser = pyparse.PythonParser(pygram.python_grammar)
+
+    def get_ast(self, source, p_mode="exec"):
+        tree = self.parser.parse_source(source, p_mode)
+        ast_node = ast_from_node(self.space, tree)
+        return ast_node
+
+    def get_first_expr(self, source):
+        mod = self.get_ast(source)
+        assert len(mod.body) == 1
+        expr = mod.body[0]
+        assert isinstance(expr, ast.Expr)
+        return expr.value
+
+    def get_first_stmt(self, source):
+        mod = self.get_ast(source)
+        assert len(mod.body) == 1
+        return mod.body[0]
+
+    def test_top_level(self):
+        mod = self.get_ast("hi = 32")
+        assert isinstance(mod, ast.Module)
+        body = mod.body
+        assert len(body) == 1
+
+        mod = self.get_ast("hi", p_mode="eval")
+        assert isinstance(mod, ast.Expression)
+        assert isinstance(mod.body, ast.expr)
+
+        mod = self.get_ast("x = 23", p_mode="single")
+        assert isinstance(mod, ast.Interactive)
+        assert len(mod.body) == 1
+
+    def test_print(self):
+        pri = self.get_first_stmt("print x")
+        assert isinstance(pri, ast.Print)
+        assert pri.dest is None
+        assert pri.nl
+        assert len(pri.values) == 1
+        assert isinstance(pri.values[0], ast.Name)
+        pri = self.get_first_stmt("print x, 34")
+        assert len(pri.values) == 2
+        assert isinstance(pri.values[0], ast.Name)
+        assert isinstance(pri.values[1], ast.Num)
+        pri = self.get_first_stmt("print")
+        assert pri.nl
+        assert pri.values is None
+        pri = self.get_first_stmt("print x,")
+        assert len(pri.values) == 1
+        assert not pri.nl
+        pri = self.get_first_stmt("print >> y, 4")
+        assert isinstance(pri.dest, ast.Name)
+        assert len(pri.values) == 1
+        assert isinstance(pri.values[0], ast.Num)
+        assert pri.nl
+        pri = self.get_first_stmt("print >> y")
+        assert isinstance(pri.dest, ast.Name)
+        assert pri.values is None
+        assert pri.nl
+
+    def test_del(self):
+        d = self.get_first_stmt("del x")
+        assert isinstance(d, ast.Delete)
+        assert len(d.targets) == 1
+        assert isinstance(d.targets[0], ast.Name)
+        assert d.targets[0].ctx is ast.Del
+        d = self.get_first_stmt("del x, y")
+        assert len(d.targets) == 2
+        assert d.targets[0].ctx is ast.Del
+        assert d.targets[1].ctx is ast.Del
+        d = self.get_first_stmt("del x.y")
+        assert len(d.targets) == 1
+        attr = d.targets[0]
+        assert isinstance(attr, ast.Attribute)
+        assert attr.ctx is ast.Del
+        d = self.get_first_stmt("del x[:]")
+        assert len(d.targets) == 1
+        sub = d.targets[0]
+        assert isinstance(sub, ast.Subscript)
+        assert sub.ctx is ast.Del
+
+    def test_break(self):
+        br = self.get_first_stmt("while True: break").body[0]
+        assert isinstance(br, ast.Break)
+
+    def test_continue(self):
+        cont = self.get_first_stmt("while True: continue").body[0]
+        assert isinstance(cont, ast.Continue)
+
+    def test_return(self):
+        ret = self.get_first_stmt("def f(): return").body[0]
+        assert isinstance(ret, ast.Return)
+        assert ret.value is None
+        ret = self.get_first_stmt("def f(): return x").body[0]
+        assert isinstance(ret.value, ast.Name)
+
+    def test_raise(self):
+        ra = self.get_first_stmt("raise")
+        assert ra.type is None
+        assert ra.inst is None
+        assert ra.tback is None
+        ra = self.get_first_stmt("raise x")
+        assert isinstance(ra.type, ast.Name)
+        assert ra.inst is None
+        assert ra.tback is None
+        ra = self.get_first_stmt("raise x, 3")
+        assert isinstance(ra.type, ast.Name)
+        assert isinstance(ra.inst, ast.Num)
+        assert ra.tback is None
+        ra = self.get_first_stmt("raise x, 4, 'hi'")
+        assert isinstance(ra.type, ast.Name)
+        assert isinstance(ra.inst, ast.Num)
+        assert isinstance(ra.tback, ast.Str)
+
+    def test_import(self):
+        im = self.get_first_stmt("import x")
+        assert isinstance(im, ast.Import)
+        assert len(im.names) == 1
+        alias = im.names[0]
+        assert isinstance(alias, ast.alias)
+        assert alias.name == "x"
+        assert alias.asname is None
+        im = self.get_first_stmt("import x.y")
+        assert len(im.names) == 1
+        alias = im.names[0]
+        assert alias.name == "x.y"
+        assert alias.asname is None
+        im = self.get_first_stmt("import x as y")
+        assert len(im.names) == 1
+        alias = im.names[0]
+        assert alias.name == "x"
+        assert alias.asname == "y"
+        im = self.get_first_stmt("import x, y as w")
+        assert len(im.names) == 2
+        a1, a2 = im.names
+        assert a1.name == "x"
+        assert a1.asname is None
+        assert a2.name == "y"
+        assert a2.asname == "w"
+
+    def test_from_import(self):
+        im = self.get_first_stmt("from x import y")
+        assert isinstance(im, ast.ImportFrom)
+        assert im.module == "x"
+        assert im.level == 0
+        assert len(im.names) == 1
+        a = im.names[0]
+        assert isinstance(a, ast.alias)
+        assert a.name == "y"
+        assert a.asname is None
+        im = self.get_first_stmt("from . import y")
+        assert im.level == 1
+        assert im.module is None
+        im = self.get_first_stmt("from ... import y")
+        assert im.level == 3
+        assert im.module is None
+        im = self.get_first_stmt("from .x import y")
+        assert im.level == 1
+        assert im.module == "x"
+        im = self.get_first_stmt("from ..x.y import m")
+        assert im.level == 2
+        assert im.module == "x.y"
+        im = self.get_first_stmt("from x import *")
+        assert len(im.names) == 1
+        a = im.names[0]
+        assert a.name == "*"
+        assert a.asname is None
+        for input in ("from x import x, y", "from x import (x, y)"):
+            im = self.get_first_stmt(input)
+            assert len(im.names) == 2
+            a1, a2 = im.names
+            assert a1.name == "x"
+            assert a1.asname is None
+            assert a2.name == "y"
+            assert a2.asname is None
+        for input in ("from x import a as b, w", "from x import (a as b, w)"):
+            im = self.get_first_stmt(input)
+            assert len(im.names) == 2
+            a1, a2 = im.names
+            assert a1.name == "a"
+            assert a1.asname == "b"
+            assert a2.name == "w"
+            assert a2.asname is None
+        input = "from x import a, b,"
+        exc = py.test.raises(SyntaxError, self.get_ast, input).value
+        assert exc.msg == "trailing comma is only allowed with surronding " \
+            "parenthesis"
+
+    def test_global(self):
+        glob = self.get_first_stmt("global x")
+        assert isinstance(glob, ast.Global)
+        assert glob.names == ["x"]
+        glob = self.get_first_stmt("global x, y")
+        assert glob.names == ["x", "y"]
+
+    def test_exec(self):
+        exc = self.get_first_stmt("exec x")
+        assert isinstance(exc, ast.Exec)
+        assert isinstance(exc.body, ast.Name)
+        assert exc.globals is None
+        assert exc.locals is None
+        exc = self.get_first_stmt("exec 'hi' in x")
+        assert isinstance(exc.body, ast.Str)
+        assert isinstance(exc.globals, ast.Name)
+        assert exc.locals is None
+        exc = self.get_first_stmt("exec 'hi' in x, 2")
+        assert isinstance(exc.body, ast.Str)
+        assert isinstance(exc.globals, ast.Name)
+        assert isinstance(exc.locals, ast.Num)
+
+    def test_assert(self):
+        asrt = self.get_first_stmt("assert x")
+        assert isinstance(asrt, ast.Assert)
+        assert isinstance(asrt.test, ast.Name)
+        assert asrt.msg is None
+        asrt = self.get_first_stmt("assert x, 'hi'")
+        assert isinstance(asrt.test, ast.Name)
+        assert isinstance(asrt.msg, ast.Str)
+
+    def test_suite(self):
+        suite = self.get_first_stmt("while x: n;").body
+        assert len(suite) == 1
+        assert isinstance(suite[0].value, ast.Name)
+        suite = self.get_first_stmt("while x: n").body
+        assert len(suite) == 1
+        suite = self.get_first_stmt("while x: \n    n;").body
+        assert len(suite) == 1
+        suite = self.get_first_stmt("while x: n;").body
+        assert len(suite) == 1
+        suite = self.get_first_stmt("while x:\n    n; f;").body
+        assert len(suite) == 2
+
+    def test_if(self):
+        if_ = self.get_first_stmt("if x: 4")
+        assert isinstance(if_, ast.If)
+        assert isinstance(if_.test, ast.Name)
+        assert if_.test.ctx is ast.Load
+        assert len(if_.body) == 1
+        assert isinstance(if_.body[0].value, ast.Num)
+        assert if_.orelse is None
+        if_ = self.get_first_stmt("if x: 4\nelse: 'hi'")
+        assert isinstance(if_.test, ast.Name)
+        assert len(if_.body) == 1
+        assert isinstance(if_.body[0].value, ast.Num)
+        assert len(if_.orelse) == 1
+        assert isinstance(if_.orelse[0].value, ast.Str)
+        if_ = self.get_first_stmt("if x: 3\nelif 'hi': pass")
+        assert isinstance(if_.test, ast.Name)
+        assert len(if_.orelse) == 1
+        sub_if = if_.orelse[0]
+        assert isinstance(sub_if, ast.If)
+        assert isinstance(sub_if.test, ast.Str)
+        assert sub_if.orelse is None
+        if_ = self.get_first_stmt("if x: pass\nelif 'hi': 3\nelse: ()")
+        assert isinstance(if_.test, ast.Name)
+        assert len(if_.body) == 1
+        assert isinstance(if_.body[0], ast.Pass)
+        assert len(if_.orelse) == 1
+        sub_if = if_.orelse[0]
+        assert isinstance(sub_if, ast.If)
+        assert isinstance(sub_if.test, ast.Str)
+        assert len(sub_if.body) == 1
+        assert isinstance(sub_if.body[0].value, ast.Num)
+        assert len(sub_if.orelse) == 1
+        assert isinstance(sub_if.orelse[0].value, ast.Tuple)
+
+    def test_while(self):
+        wh = self.get_first_stmt("while x: pass")
+        assert isinstance(wh, ast.While)
+        assert isinstance(wh.test, ast.Name)
+        assert wh.test.ctx is ast.Load
+        assert len(wh.body) == 1
+        assert isinstance(wh.body[0], ast.Pass)
+        assert wh.orelse is None
+        wh = self.get_first_stmt("while x: pass\nelse: 4")
+        assert isinstance(wh.test, ast.Name)
+        assert len(wh.body) == 1
+        assert isinstance(wh.body[0], ast.Pass)
+        assert len(wh.orelse) == 1
+        assert isinstance(wh.orelse[0].value, ast.Num)
+
+    def test_for(self):
+        fr = self.get_first_stmt("for x in y: pass")
+        assert isinstance(fr, ast.For)
+        assert isinstance(fr.target, ast.Name)
+        assert fr.target.ctx is ast.Store
+        assert isinstance(fr.iter, ast.Name)
+        assert fr.iter.ctx is ast.Load
+        assert len(fr.body) == 1
+        assert isinstance(fr.body[0], ast.Pass)
+        assert fr.orelse is None
+        fr = self.get_first_stmt("for x, in y: pass")
+        tup = fr.target
+        assert isinstance(tup, ast.Tuple)
+        assert tup.ctx is ast.Store
+        assert len(tup.elts) == 1
+        assert isinstance(tup.elts[0], ast.Name)
+        assert tup.elts[0].ctx is ast.Store
+        fr = self.get_first_stmt("for x, y in g: pass")
+        tup = fr.target
+        assert isinstance(tup, ast.Tuple)
+        assert tup.ctx is ast.Store
+        assert len(tup.elts) == 2
+        for elt in tup.elts:
+            assert isinstance(elt, ast.Name)
+            assert elt.ctx is ast.Store
+        fr = self.get_first_stmt("for x in g: pass\nelse: 4")
+        assert len(fr.body) == 1
+        assert isinstance(fr.body[0], ast.Pass)
+        assert len(fr.orelse) == 1
+        assert isinstance(fr.orelse[0].value, ast.Num)
+
+    def test_try(self):
+        tr = self.get_first_stmt("try: x\nfinally: pass")
+        assert isinstance(tr, ast.TryFinally)
+        assert len(tr.body) == 1
+        assert isinstance(tr.body[0].value, ast.Name)
+        assert len(tr.finalbody) == 1
+        assert isinstance(tr.finalbody[0], ast.Pass)
+        tr = self.get_first_stmt("try: x\nexcept: pass")
+        assert isinstance(tr, ast.TryExcept)
+        assert len(tr.body) == 1
+        assert isinstance(tr.body[0].value, ast.Name)
+        assert len(tr.handlers) == 1
+        handler = tr.handlers[0]
+        assert isinstance(handler, ast.excepthandler)
+        assert handler.type is None
+        assert handler.name is None
+        assert len(handler.body) == 1
+        assert isinstance(handler.body[0], ast.Pass)
+        assert tr.orelse is None
+        tr = self.get_first_stmt("try: x\nexcept Exception: pass")
+        assert len(tr.handlers) == 1
+        handler = tr.handlers[0]
+        assert isinstance(handler.type, ast.Name)
+        assert handler.type.ctx is ast.Load
+        assert handler.name is None
+        assert len(handler.body) == 1
+        assert tr.orelse is None
+        tr = self.get_first_stmt("try: x\nexcept Exception, e: pass")
+        assert len(tr.handlers) == 1
+        handler = tr.handlers[0]
+        assert isinstance(handler.type, ast.Name)
+        assert isinstance(handler.name, ast.Name)
+        assert handler.name.ctx is ast.Store
+        assert handler.name.id == "e"
+        assert len(handler.body) == 1
+        tr = self.get_first_stmt("try: x\nexcept: pass\nelse: 4")
+        assert len(tr.body) == 1
+        assert isinstance(tr.body[0].value, ast.Name)
+        assert len(tr.handlers) == 1
+        assert isinstance(tr.handlers[0].body[0], ast.Pass)
+        assert len(tr.orelse) == 1
+        assert isinstance(tr.orelse[0].value, ast.Num)
+        tr = self.get_first_stmt("try: x\nexcept Exc, a: 5\nexcept F: pass")
+        assert len(tr.handlers) == 2
+        h1, h2 = tr.handlers
+        assert isinstance(h1.type, ast.Name)
+        assert isinstance(h1.name, ast.Name)
+        assert isinstance(h1.body[0].value, ast.Num)
+        assert isinstance(h2.type, ast.Name)
+        assert h2.name is None
+        assert isinstance(h2.body[0], ast.Pass)
+        tr = self.get_first_stmt("try: x\nexcept: 4\nfinally: pass")
+        assert isinstance(tr, ast.TryFinally)
+        assert len(tr.finalbody) == 1
+        assert isinstance(tr.finalbody[0], ast.Pass)
+        assert len(tr.body) == 1
+        exc = tr.body[0]
+        assert isinstance(exc, ast.TryExcept)
+        assert len(exc.handlers) == 1
+        assert len(exc.handlers[0].body) == 1
+        assert isinstance(exc.handlers[0].body[0].value, ast.Num)
+        assert len(exc.body) == 1
+        assert isinstance(exc.body[0].value, ast.Name)
+        tr = self.get_first_stmt("try: x\nexcept: 4\nelse: 'hi'\nfinally: pass")
+        assert isinstance(tr, ast.TryFinally)
+        assert len(tr.finalbody) == 1
+        assert isinstance(tr.finalbody[0], ast.Pass)
+        assert len(tr.body) == 1
+        exc = tr.body[0]
+        assert isinstance(exc, ast.TryExcept)
+        assert len(exc.orelse) == 1
+        assert isinstance(exc.orelse[0].value, ast.Str)
+        assert len(exc.body) == 1
+        assert isinstance(exc.body[0].value, ast.Name)
+        assert len(exc.handlers) == 1
+
+    def test_with(self):
+        wi = self.get_first_stmt("with x: pass")
+        assert isinstance(wi, ast.With)
+        assert isinstance(wi.context_expr, ast.Name)
+        assert len(wi.body) == 1
+        assert wi.optional_vars is None
+        wi = self.get_first_stmt("with x as y: pass")
+        assert isinstance(wi.context_expr, ast.Name)
+        assert len(wi.body) == 1
+        assert isinstance(wi.optional_vars, ast.Name)
+        assert wi.optional_vars.ctx is ast.Store
+        wi = self.get_first_stmt("with x as (y,): pass")
+        assert isinstance(wi.optional_vars, ast.Tuple)
+        assert len(wi.optional_vars.elts) == 1
+        assert wi.optional_vars.ctx is ast.Store
+        assert wi.optional_vars.elts[0].ctx is ast.Store
+        input = "with x hi y: pass"
+        exc = py.test.raises(SyntaxError, self.get_ast, input).value
+        assert exc.msg == "expected \"with [expr] as [var]\""
+
+    def test_class(self):
+        for input in ("class X: pass", "class X(): pass"):
+            cls = self.get_first_stmt(input)
+            assert isinstance(cls, ast.ClassDef)
+            assert cls.name == "X"
+            assert len(cls.body) == 1
+            assert isinstance(cls.body[0], ast.Pass)
+            assert cls.bases is None
+        for input in ("class X(Y): pass", "class X(Y,): pass"):
+            cls = self.get_first_stmt(input)
+            assert len(cls.bases) == 1
+            base = cls.bases[0]
+            assert isinstance(base, ast.Name)
+            assert base.ctx is ast.Load
+            assert base.id == "Y"
+        cls = self.get_first_stmt("class X(Y, Z): pass")
+        assert len(cls.bases) == 2
+        for b in cls.bases:
+            assert isinstance(b, ast.Name)
+            assert b.ctx is ast.Load
+
+    def test_function(self):
+        func = self.get_first_stmt("def f(): pass")
+        assert isinstance(func, ast.FunctionDef)
+        assert func.name == "f"
+        assert len(func.body) == 1
+        assert isinstance(func.body[0], ast.Pass)
+        assert func.decorators is None
+        args = func.args
+        assert isinstance(args, ast.arguments)
+        assert args.args is None
+        assert args.defaults is None
+        assert args.kwarg is None
+        assert args.vararg is None
+        args = self.get_first_stmt("def f(a, b): pass").args
+        assert len(args.args) == 2
+        a1, a2 = args.args
+        assert isinstance(a1, ast.Name)
+        assert a1.id == "a"
+        assert a1.ctx is ast.Store
+        assert isinstance(a2, ast.Name)
+        assert a2.id == "b"
+        assert a2.ctx is ast.Store
+        assert args.vararg is None
+        assert args.kwarg is None
+        args = self.get_first_stmt("def f(a=b): pass").args
+        assert len(args.args) == 1
+        arg = args.args[0]
+        assert isinstance(arg, ast.Name)
+        assert arg.id == "a"
+        assert arg.ctx is ast.Store
+        assert len(args.defaults) == 1
+        default = args.defaults[0]
+        assert isinstance(default, ast.Name)
+        assert default.id == "b"
+        assert default.ctx is ast.Load
+        args = self.get_first_stmt("def f(*a): pass").args
+        assert args.args is None
+        assert args.defaults is None
+        assert args.kwarg is None
+        assert args.vararg == "a"
+        args = self.get_first_stmt("def f(**a): pass").args
+        assert args.args is None
+        assert args.defaults is None
+        assert args.vararg is None
+        assert args.kwarg == "a"
+        args = self.get_first_stmt("def f((a, b)): pass").args
+        assert args.defaults is None
+        assert args.kwarg is None
+        assert args.vararg is None
+        assert len(args.args) == 1
+        tup = args.args[0]
+        assert isinstance(tup, ast.Tuple)
+        assert tup.ctx is ast.Store
+        assert len(tup.elts) == 2
+        e1, e2 = tup.elts
+        assert isinstance(e1, ast.Name)
+        assert e1.ctx is ast.Store
+        assert e1.id == "a"
+        assert isinstance(e2, ast.Name)
+        assert e2.ctx is ast.Store
+        assert e2.id == "b"
+        args = self.get_first_stmt("def f((a, (b, c))): pass").args
+        assert len(args.args) == 1
+        tup = args.args[0]
+        assert isinstance(tup, ast.Tuple)
+        assert len(tup.elts) == 2
+        tup2 = tup.elts[1]
+        assert isinstance(tup2, ast.Tuple)
+        assert tup2.ctx is ast.Store
+        for elt in tup2.elts:
+            assert isinstance(elt, ast.Name)
+            assert elt.ctx is ast.Store
+        assert tup2.elts[0].id == "b"
+        assert tup2.elts[1].id == "c"
+        args = self.get_first_stmt("def f(a, b, c=d, *e, **f): pass").args
+        assert len(args.args) == 3
+        for arg in args.args:
+            assert isinstance(arg, ast.Name)
+            assert arg.ctx is ast.Store
+        assert len(args.defaults) == 1
+        assert isinstance(args.defaults[0], ast.Name)
+        assert args.defaults[0].ctx is ast.Load
+        assert args.vararg == "e"
+        assert args.kwarg == "f"
+
+    def test_decorator(self):
+        func = self.get_first_stmt("@dec\ndef f(): pass")
+        assert isinstance(func, ast.FunctionDef)
+        assert len(func.decorators) == 1
+        dec = func.decorators[0]
+        assert isinstance(dec, ast.Name)
+        assert dec.id == "dec"
+        assert dec.ctx is ast.Load
+        func = self.get_first_stmt("@mod.hi.dec\ndef f(): pass")
+        assert len(func.decorators) == 1
+        dec = func.decorators[0]
+        assert isinstance(dec, ast.Attribute)
+        assert dec.ctx is ast.Load
+        assert dec.attr == "dec"
+        assert isinstance(dec.value, ast.Attribute)
+        assert dec.value.attr == "hi"
+        assert isinstance(dec.value.value, ast.Name)
+        assert dec.value.value.id == "mod"
+        func = self.get_first_stmt("@dec\n at dec2\ndef f(): pass")
+        assert len(func.decorators) == 2
+        for dec in func.decorators:
+            assert isinstance(dec, ast.Name)
+            assert dec.ctx is ast.Load
+        assert func.decorators[0].id == "dec"
+        assert func.decorators[1].id == "dec2"
+        func = self.get_first_stmt("@dec()\ndef f(): pass")
+        assert len(func.decorators) == 1
+        dec = func.decorators[0]
+        assert isinstance(dec, ast.Call)
+        assert isinstance(dec.func, ast.Name)
+        assert dec.func.id == "dec"
+        assert dec.args is None
+        assert dec.keywords is None
+        assert dec.starargs is None
+        assert dec.kwargs is None
+        func = self.get_first_stmt("@dec(a, b)\ndef f(): pass")
+        assert len(func.decorators) == 1
+        dec = func.decorators[0]
+        assert isinstance(dec, ast.Call)
+        assert dec.func.id == "dec"
+        assert len(dec.args) == 2
+        assert dec.keywords is None
+        assert dec.starargs is None
+        assert dec.kwargs is None
+
+    def test_augassign(self):
+        aug_assigns = (
+            ("+=", ast.Add),
+            ("-=", ast.Sub),
+            ("/=", ast.Div),
+            ("//=", ast.FloorDiv),
+            ("%=", ast.Mod),
+            ("<<=", ast.LShift),
+            (">>=", ast.RShift),
+            ("&=", ast.BitAnd),
+            ("|=", ast.BitOr),
+            ("^=", ast.BitXor),
+            ("*=", ast.Mult),
+            ("**=", ast.Pow)
+        )
+        for op, ast_type in aug_assigns:
+            input = "x %s 4" % (op,)
+            assign = self.get_first_stmt(input)
+            assert isinstance(assign, ast.AugAssign)
+            assert assign.op is ast_type
+            assert isinstance(assign.target, ast.Name)
+            assert assign.target.ctx is ast.Store
+            assert isinstance(assign.value, ast.Num)
+
+    def test_assign(self):
+        assign = self.get_first_stmt("hi = 32")
+        assert isinstance(assign, ast.Assign)
+        assert len(assign.targets) == 1
+        name = assign.targets[0]
+        assert isinstance(name, ast.Name)
+        assert name.ctx is ast.Store
+        value = assign.value
+        assert self.space.eq_w(value.n, self.space.wrap(32))
+        assign = self.get_first_stmt("hi, = something")
+        assert len(assign.targets) == 1
+        tup = assign.targets[0]
+        assert isinstance(tup, ast.Tuple)
+        assert tup.ctx is ast.Store
+        assert len(tup.elts) == 1
+        assert isinstance(tup.elts[0], ast.Name)
+        assert tup.elts[0].ctx is ast.Store
+
+    def test_name(self):
+        name = self.get_first_expr("hi")
+        assert isinstance(name, ast.Name)
+        assert name.ctx is ast.Load
+
+    def test_tuple(self):
+        tup = self.get_first_expr("()")
+        assert isinstance(tup, ast.Tuple)
+        assert tup.elts is None
+        assert tup.ctx is ast.Load
+        tup = self.get_first_expr("(3,)")
+        assert len(tup.elts) == 1
+        assert self.space.eq_w(tup.elts[0].n, self.space.wrap(3))
+        tup = self.get_first_expr("2, 3, 4")
+        assert len(tup.elts) == 3
+
+    def test_list(self):
+        seq = self.get_first_expr("[]")
+        assert isinstance(seq, ast.List)
+        assert seq.elts is None
+        assert seq.ctx is ast.Load
+        seq = self.get_first_expr("[3,]")
+        assert len(seq.elts) == 1
+        assert self.space.eq_w(seq.elts[0].n, self.space.wrap(3))
+        seq = self.get_first_expr("[3]")
+        assert len(seq.elts) == 1
+        seq = self.get_first_expr("[1, 2, 3, 4, 5]")
+        assert len(seq.elts) == 5
+        nums = range(1, 6)
+        assert [self.space.int_w(n.n) for n in seq.elts] == nums
+
+    def test_dict(self):
+        d = self.get_first_expr("{}")
+        assert isinstance(d, ast.Dict)
+        assert d.keys is None
+        assert d.values is None
+        d = self.get_first_expr("{4 : x, y : 7}")
+        assert len(d.keys) == len(d.values) == 2
+        key1, key2 = d.keys
+        assert isinstance(key1, ast.Num)
+        assert isinstance(key2, ast.Name)
+        assert key2.ctx is ast.Load
+        v1, v2 = d.values
+        assert isinstance(v1, ast.Name)
+        assert v1.ctx is ast.Load
+        assert isinstance(v2, ast.Num)
+
+    def test_set_context(self):
+        tup = self.get_ast("(a, b) = c").body[0].targets[0]
+        assert all(elt.ctx is ast.Store for elt in tup.elts)
+        seq = self.get_ast("[a, b] = c").body[0].targets[0]
+        assert all(elt.ctx is ast.Store for elt in seq.elts)
+        invalid_stores = (
+            ("(lambda x: x)", "lambda"),
+            ("f()", "call"),
+            ("~x", "operator"),
+            ("+x", "operator"),
+            ("-x", "operator"),
+            ("(x or y)", "operator"),
+            ("(x and y)", "operator"),
+            ("(not g)", "operator"),
+            ("(x for y in g)", "generator expression"),
+            ("(yield x)", "yield expression"),
+            ("[x for y in g]", "list comprehension"),
+            ("'str'", "literal"),
+            ("()", "()"),
+            ("23", "literal"),
+            ("{}", "literal"),
+            ("(x > 4)", "comparison"),
+            ("(x if y else a)", "conditional expression"),
+            ("`x`", "repr")
+        )
+        test_contexts = (
+            ("assign to", "%s = 23"),
+            ("delete", "del %s")
+        )
+        for ctx_type, template in test_contexts:
+            for expr, type_str in invalid_stores:
+                input = template % (expr,)
+                exc = py.test.raises(SyntaxError, self.get_ast, input).value
+                assert exc.msg == "can't %s %s" % (ctx_type, type_str)
+
+    def test_assignment_to_forbidden_names(self):
+        invalid = (
+            "%s = x",
+            "%s, x = y",
+            "def %s(): pass",
+            "class %s(): pass",
+            "def f(%s): pass",
+            "def f(%s=x): pass",
+            "def f(*%s): pass",
+            "def f(**%s): pass",
+            "f(%s=x)",
+            "with x as %s: pass",
+            "import %s",
+            "import x as %s",
+            "from x import %s",
+            "from x import y as %s",
+            "for %s in x: pass",
+        )
+        for name in ("None", "__debug__"):
+            for template in invalid:
+                input = template % (name,)
+                exc = py.test.raises(SyntaxError, self.get_ast, input).value
+                assert exc.msg == "assignment to %s" % (name,)
+
+    def test_lambda(self):
+        lam = self.get_first_expr("lambda x: expr")
+        assert isinstance(lam, ast.Lambda)
+        args = lam.args
+        assert isinstance(args, ast.arguments)
+        assert args.vararg is None
+        assert args.kwarg is None
+        assert args.defaults is None
+        assert len(args.args) == 1
+        assert isinstance(args.args[0], ast.Name)
+        assert isinstance(lam.body, ast.Name)
+        lam = self.get_first_expr("lambda: True")
+        args = lam.args
+        assert args.args is None
+        lam = self.get_first_expr("lambda x=x: y")
+        assert len(lam.args.args) == 1
+        assert len(lam.args.defaults) == 1
+        assert isinstance(lam.args.defaults[0], ast.Name)
+        input = "f(lambda x: x[0] = y)"
+        exc = py.test.raises(SyntaxError, self.get_ast, input).value
+        assert exc.msg == "lambda cannot contain assignment"
+
+    def test_ifexp(self):
+        ifexp = self.get_first_expr("x if y else g")
+        assert isinstance(ifexp, ast.IfExp)
+        assert isinstance(ifexp.test, ast.Name)
+        assert ifexp.test.ctx is ast.Load
+        assert isinstance(ifexp.body, ast.Name)
+        assert ifexp.body.ctx is ast.Load
+        assert isinstance(ifexp.orelse, ast.Name)
+        assert ifexp.orelse.ctx is ast.Load
+
+    def test_boolop(self):
+        for ast_type, op in ((ast.And, "and"), (ast.Or, "or")):
+            bo = self.get_first_expr("x %s a" % (op,))
+            assert isinstance(bo, ast.BoolOp)
+            assert bo.op is ast_type
+            assert len(bo.values) == 2
+            assert isinstance(bo.values[0], ast.Name)
+            assert isinstance(bo.values[1], ast.Name)
+            bo = self.get_first_expr("x %s a %s b" % (op, op))
+            assert bo.op is ast_type
+            assert len(bo.values) == 3
+
+    def test_not(self):
+        n = self.get_first_expr("not x")
+        assert isinstance(n, ast.UnaryOp)
+        assert n.op is ast.Not
+        assert isinstance(n.operand, ast.Name)
+        assert n.operand.ctx is ast.Load
+
+    def test_comparison(self):
+        compares = (
+            (">", ast.Gt),
+            (">=", ast.GtE),
+            ("<", ast.Lt),
+            ("<=", ast.LtE),
+            ("==", ast.Eq),
+            ("!=", ast.NotEq),
+            ("<>", ast.NotEq),
+            ("in", ast.In),
+            ("is", ast.Is),
+            ("is not", ast.IsNot),
+            ("not in", ast.NotIn)
+        )
+        for op, ast_type in compares:
+            comp = self.get_first_expr("x %s y" % (op,))
+            assert isinstance(comp, ast.Compare)
+            assert isinstance(comp.left, ast.Name)
+            assert comp.left.ctx is ast.Load
+            assert len(comp.ops) == 1
+            assert comp.ops[0] is ast_type
+            assert len(comp.comparators) == 1
+            assert isinstance(comp.comparators[0], ast.Name)
+            assert comp.comparators[0].ctx is ast.Load
+        # Just for fun let's randomly combine operators. :)
+        for j in range(10):
+            vars = string.ascii_letters[:random.randint(3, 7)]
+            ops = [random.choice(compares) for i in range(len(vars) - 1)]
+            input = vars[0]
+            for i, (op, _) in enumerate(ops):
+                input += " %s %s" % (op, vars[i + 1])
+            comp = self.get_first_expr(input)
+            assert comp.ops == [tup[1] for tup in ops]
+            names = comp.left.id + "".join(n.id for n in comp.comparators)
+            assert names == vars
+
+    def test_binop(self):
+        binops = (
+            ("|", ast.BitOr),
+            ("&", ast.BitAnd),
+            ("^", ast.BitXor),
+            ("<<", ast.LShift),
+            (">>", ast.RShift),
+            ("+", ast.Add),
+            ("-", ast.Sub),
+            ("/", ast.Div),
+            ("*", ast.Mult),
+            ("//", ast.FloorDiv),
+            ("%", ast.Mod)
+        )
+        for op, ast_type in binops:
+            bin = self.get_first_expr("a %s b" % (op,))
+            assert isinstance(bin, ast.BinOp)
+            assert bin.op is ast_type
+            assert isinstance(bin.left, ast.Name)
+            assert isinstance(bin.right, ast.Name)
+            assert bin.left.ctx is ast.Load
+            assert bin.right.ctx is ast.Load
+            bin = self.get_first_expr("a %s b %s c" % (op, op))
+            assert isinstance(bin.left, ast.BinOp)
+            assert bin.left.op is ast_type
+            assert isinstance(bin.right, ast.Name)
+
+    def test_yield(self):
+        expr = self.get_first_expr("yield")
+        assert isinstance(expr, ast.Yield)
+        assert expr.value is None
+        expr = self.get_first_expr("yield x")
+        assert isinstance(expr.value, ast.Name)
+        assign = self.get_first_stmt("x = yield x")
+        assert isinstance(assign, ast.Assign)
+        assert isinstance(assign.value, ast.Yield)
+
+    def test_unaryop(self):
+        unary_ops = (
+            ("+", ast.UAdd),
+            ("-", ast.USub),
+            ("~", ast.Invert)
+        )
+        for op, ast_type in unary_ops:
+            unary = self.get_first_expr("%sx" % (op,))
+            assert isinstance(unary, ast.UnaryOp)
+            assert unary.op is ast_type
+            assert isinstance(unary.operand, ast.Name)
+            assert unary.operand.ctx is ast.Load
+
+    def test_power(self):
+        power = self.get_first_expr("x**5")
+        assert isinstance(power, ast.BinOp)
+        assert power.op is ast.Pow
+        assert isinstance(power.left , ast.Name)
+        assert power.left.ctx is ast.Load
+        assert isinstance(power.right, ast.Num)
+
+    def test_call(self):
+        call = self.get_first_expr("f()")
+        assert isinstance(call, ast.Call)
+        assert call.args is None
+        assert call.keywords is None
+        assert call.starargs is None
+        assert call.kwargs is None
+        assert isinstance(call.func, ast.Name)
+        assert call.func.ctx is ast.Load
+        call = self.get_first_expr("f(2, 3)")
+        assert len(call.args) == 2
+        assert isinstance(call.args[0], ast.Num)
+        assert isinstance(call.args[1], ast.Num)
+        assert call.keywords is None
+        assert call.starargs is None
+        assert call.kwargs is None
+        call = self.get_first_expr("f(a=3)")
+        assert call.args is None
+        assert len(call.keywords) == 1
+        keyword = call.keywords[0]
+        assert isinstance(keyword, ast.keyword)
+        assert keyword.arg == "a"
+        assert isinstance(keyword.value, ast.Num)
+        call = self.get_first_expr("f(*a, **b)")
+        assert call.args is None
+        assert isinstance(call.starargs, ast.Name)
+        assert call.starargs.id == "a"
+        assert call.starargs.ctx is ast.Load
+        assert isinstance(call.kwargs, ast.Name)
+        assert call.kwargs.id == "b"
+        assert call.kwargs.ctx is ast.Load
+        call = self.get_first_expr("f(a, b, x=4, *m, **f)")
+        assert len(call.args) == 2
+        assert isinstance(call.args[0], ast.Name)
+        assert isinstance(call.args[1], ast.Name)
+        assert len(call.keywords) == 1
+        assert call.keywords[0].arg == "x"
+        assert call.starargs.id == "m"
+        assert call.kwargs.id == "f"
+        call = self.get_first_expr("f(x for x in y)")
+        assert len(call.args) == 1
+        assert isinstance(call.args[0], ast.GeneratorExp)
+        input = "f(x for x in y, 1)"
+        exc = py.test.raises(SyntaxError, self.get_ast, input).value
+        assert exc.msg == "Generator expression must be parenthesized if not " \
+            "sole argument"
+        many_args = ", ".join("x%i" % i for i in range(256))
+        input = "f(%s)" % (many_args,)
+        exc = py.test.raises(SyntaxError, self.get_ast, input).value
+        assert exc.msg == "more than 255 arguments"
+        exc = py.test.raises(SyntaxError, self.get_ast, "f((a+b)=c)").value
+        assert exc.msg == "keyword can't be an expression"
+        exc = py.test.raises(SyntaxError, self.get_ast, "f(a=c, a=d)").value
+        assert exc.msg == "keyword argument repeated"
+
+    def test_attribute(self):
+        attr = self.get_first_expr("x.y")
+        assert isinstance(attr, ast.Attribute)
+        assert isinstance(attr.value, ast.Name)
+        assert attr.value.ctx is ast.Load
+        assert attr.attr == "y"
+        assert attr.ctx is ast.Load
+        assign = self.get_first_stmt("x.y = 54")
+        assert isinstance(assign, ast.Assign)
+        assert len(assign.targets) == 1
+        attr = assign.targets[0]
+        assert isinstance(attr, ast.Attribute)
+        assert attr.value.ctx is ast.Load
+        assert attr.ctx is ast.Store
+
+    def test_subscript_and_slices(self):
+        sub = self.get_first_expr("x[y]")
+        assert isinstance(sub, ast.Subscript)
+        assert isinstance(sub.value, ast.Name)
+        assert sub.value.ctx is ast.Load
+        assert sub.ctx is ast.Load
+        assert isinstance(sub.slice, ast.Index)
+        assert isinstance(sub.slice.value, ast.Name)
+        for input in (":", "::"):
+            slc = self.get_first_expr("x[%s]" % (input,)).slice
+            assert slc.upper is None
+            assert slc.lower is None
+            assert slc.step is None
+        for input in ("1:", "1::"):
+            slc = self.get_first_expr("x[%s]" % (input,)).slice
+            assert isinstance(slc.lower, ast.Num)
+            assert slc.upper is None
+            assert slc.step is None
+        for input in (":2", ":2:"):
+            slc = self.get_first_expr("x[%s]" % (input,)).slice
+            assert slc.lower is None
+            assert isinstance(slc.upper, ast.Num)
+            assert slc.step is None
+        for input in ("2:2:", "2:2"):
+            slc = self.get_first_expr("x[%s]" % (input,)).slice
+            assert isinstance(slc.lower, ast.Num)
+            assert isinstance(slc.upper, ast.Num)
+            assert slc.step is None
+        slc = self.get_first_expr("x[::2]").slice
+        assert slc.lower is None
+        assert slc.upper is None
+        assert isinstance(slc.step, ast.Num)
+        slc = self.get_first_expr("x[2::2]").slice
+        assert isinstance(slc.lower, ast.Num)
+        assert slc.upper is None
+        assert isinstance(slc.step, ast.Num)
+        slc = self.get_first_expr("x[:2:2]").slice
+        assert slc.lower is None
+        assert isinstance(slc.upper, ast.Num)
+        assert isinstance(slc.step, ast.Num)
+        slc = self.get_first_expr("x[1:2:3]").slice
+        for field in (slc.lower, slc.upper, slc.step):
+            assert isinstance(field, ast.Num)
+        sub = self.get_first_expr("x[...]")
+        assert isinstance(sub.slice, ast.Ellipsis)
+        sub = self.get_first_expr("x[1,2,3]")
+        slc = sub.slice
+        assert isinstance(slc, ast.Index)
+        assert isinstance(slc.value, ast.Tuple)
+        assert len(slc.value.elts) == 3
+        assert slc.value.ctx is ast.Load
+        slc = self.get_first_expr("x[1,3:4]").slice
+        assert isinstance(slc, ast.ExtSlice)
+        assert len(slc.dims) == 2
+        complex_slc = slc.dims[1]
+        assert isinstance(complex_slc, ast.Slice)
+        assert isinstance(complex_slc.lower, ast.Num)
+        assert isinstance(complex_slc.upper, ast.Num)
+        assert complex_slc.step is None
+
+    def test_repr(self):
+        rep = self.get_first_expr("`x`")
+        assert isinstance(rep, ast.Repr)
+        assert isinstance(rep.value, ast.Name)
+
+    def test_string(self):
+        space = self.space
+        s = self.get_first_expr("'hi'")
+        assert isinstance(s, ast.Str)
+        assert space.eq_w(s.s, space.wrap("hi"))
+        s = self.get_first_expr("'hi' ' implicitly' ' extra'")
+        assert isinstance(s, ast.Str)
+        assert space.eq_w(s.s, space.wrap("hi implicitly extra"))
+
+    def test_number(self):
+        def get_num(s):
+            node = self.get_first_expr(s)
+            assert isinstance(node, ast.Num)
+            value = node.n
+            assert isinstance(value, W_Root)
+            return value
+        space = self.space
+        assert space.eq_w(get_num("32"), space.wrap(32))
+        assert space.eq_w(get_num("32.5"), space.wrap(32.5))
+        assert space.eq_w(get_num("32L"), space.newlong(32))
+        assert space.eq_w(get_num("32l"), space.newlong(32))
+        assert space.eq_w(get_num("13j"), space.wrap(13j))
+        assert space.eq_w(get_num("13J"), space.wrap(13J))
+
+    def check_comprehension(self, brackets, ast_type):
+        def brack(s):
+            return brackets % s
+        gen = self.get_first_expr(brack("x for x in y"))
+        assert isinstance(gen, ast_type)
+        assert isinstance(gen.elt, ast.Name)
+        assert gen.elt.ctx is ast.Load
+        assert len(gen.generators) == 1
+        comp = gen.generators[0]
+        assert isinstance(comp, ast.comprehension)
+        assert comp.ifs is None
+        assert isinstance(comp.target, ast.Name)
+        assert isinstance(comp.iter, ast.Name)
+        gen = self.get_first_expr(brack("x for x in y if w"))
+        comp = gen.generators[0]
+        assert len(comp.ifs) == 1
+        test = comp.ifs[0]
+        assert isinstance(test, ast.Name)
+        gen = self.get_first_expr(brack("x for x, in y if w"))
+        tup = gen.generators[0].target
+        assert isinstance(tup, ast.Tuple)
+        assert len(tup.elts) == 1
+        gen = self.get_first_expr(brack("a for w in x for m in p if g"))
+        gens = gen.generators
+        assert len(gens) == 2
+        comp1, comp2 = gens
+        assert comp1.ifs is None
+        assert len(comp2.ifs) == 1
+        assert isinstance(comp2.ifs[0], ast.Name)
+        gen = self.get_first_expr(brack("x for x in y if m if g"))
+        comps = gen.generators
+        assert len(comps) == 1
+        assert len(comps[0].ifs) == 2
+        if1, if2 = comps[0].ifs
+        assert isinstance(if1, ast.Name)
+        assert isinstance(if2, ast.Name)
+
+    def test_genexp(self):
+        self.check_comprehension("(%s)", ast.GeneratorExp)
+
+    def test_listcomp(self):
+        self.check_comprehension("[%s]", ast.ListComp)



More information about the Pypy-commit mailing list