[Python-checkins] cpython: add a AST validator (closes #12575)

benjamin.peterson python-checkins at python.org
Tue Aug 9 23:17:26 CEST 2011


http://hg.python.org/cpython/rev/4090dfdf91a4
changeset:   71795:4090dfdf91a4
user:        Benjamin Peterson <benjamin at python.org>
date:        Tue Aug 09 16:15:04 2011 -0500
summary:
  add a AST validator (closes #12575)

files:
  Include/ast.h        |    1 +
  Lib/test/test_ast.py |  410 ++++++++++++++++++++++++++-
  Misc/NEWS            |    2 +
  Python/ast.c         |  486 ++++++++++++++++++++++++++++++-
  Python/bltinmodule.c |    4 +
  5 files changed, 897 insertions(+), 6 deletions(-)


diff --git a/Include/ast.h b/Include/ast.h
--- a/Include/ast.h
+++ b/Include/ast.h
@@ -4,6 +4,7 @@
 extern "C" {
 #endif
 
+PyAPI_FUNC(int) PyAST_Validate(mod_ty);
 PyAPI_FUNC(mod_ty) PyAST_FromNode(
     const node *n,
     PyCompilerFlags *flags,
diff --git a/Lib/test/test_ast.py b/Lib/test/test_ast.py
--- a/Lib/test/test_ast.py
+++ b/Lib/test/test_ast.py
@@ -1,4 +1,6 @@
-import sys, unittest
+import os
+import sys
+import unittest
 from test import support
 import ast
 
@@ -490,8 +492,412 @@
         self.assertEqual(ast.literal_eval('1.5 - 2j'), 1.5 - 2j)
 
 
+class ASTValidatorTests(unittest.TestCase):
+
+    def mod(self, mod, msg=None, mode="exec", *, exc=ValueError):
+        mod.lineno = mod.col_offset = 0
+        ast.fix_missing_locations(mod)
+        with self.assertRaises(exc) as cm:
+            compile(mod, "<test>", mode)
+        if msg is not None:
+            self.assertIn(msg, str(cm.exception))
+
+    def expr(self, node, msg=None, *, exc=ValueError):
+        mod = ast.Module([ast.Expr(node)])
+        self.mod(mod, msg, exc=exc)
+
+    def stmt(self, stmt, msg=None):
+        mod = ast.Module([stmt])
+        self.mod(mod, msg)
+
+    def test_module(self):
+        m = ast.Interactive([ast.Expr(ast.Name("x", ast.Store()))])
+        self.mod(m, "must have Load context", "single")
+        m = ast.Expression(ast.Name("x", ast.Store()))
+        self.mod(m, "must have Load context", "eval")
+
+    def _check_arguments(self, fac, check):
+        def arguments(args=None, vararg=None, varargannotation=None,
+                      kwonlyargs=None, kwarg=None, kwargannotation=None,
+                      defaults=None, kw_defaults=None):
+            if args is None:
+                args = []
+            if kwonlyargs is None:
+                kwonlyargs = []
+            if defaults is None:
+                defaults = []
+            if kw_defaults is None:
+                kw_defaults = []
+            args = ast.arguments(args, vararg, varargannotation, kwonlyargs,
+                                 kwarg, kwargannotation, defaults, kw_defaults)
+            return fac(args)
+        args = [ast.arg("x", ast.Name("x", ast.Store()))]
+        check(arguments(args=args), "must have Load context")
+        check(arguments(varargannotation=ast.Num(3)),
+              "varargannotation but no vararg")
+        check(arguments(varargannotation=ast.Name("x", ast.Store()), vararg="x"),
+                         "must have Load context")
+        check(arguments(kwonlyargs=args), "must have Load context")
+        check(arguments(kwargannotation=ast.Num(42)),
+                       "kwargannotation but no kwarg")
+        check(arguments(kwargannotation=ast.Name("x", ast.Store()),
+                          kwarg="x"), "must have Load context")
+        check(arguments(defaults=[ast.Num(3)]),
+                       "more positional defaults than args")
+        check(arguments(kw_defaults=[ast.Num(4)]),
+                       "length of kwonlyargs is not the same as kw_defaults")
+        args = [ast.arg("x", ast.Name("x", ast.Load()))]
+        check(arguments(args=args, defaults=[ast.Name("x", ast.Store())]),
+                       "must have Load context")
+        args = [ast.arg("a", ast.Name("x", ast.Load())),
+                ast.arg("b", ast.Name("y", ast.Load()))]
+        check(arguments(kwonlyargs=args,
+                          kw_defaults=[None, ast.Name("x", ast.Store())]),
+                          "must have Load context")
+
+    def test_funcdef(self):
+        a = ast.arguments([], None, None, [], None, None, [], [])
+        f = ast.FunctionDef("x", a, [], [], None)
+        self.stmt(f, "empty body on FunctionDef")
+        f = ast.FunctionDef("x", a, [ast.Pass()], [ast.Name("x", ast.Store())],
+                            None)
+        self.stmt(f, "must have Load context")
+        f = ast.FunctionDef("x", a, [ast.Pass()], [],
+                            ast.Name("x", ast.Store()))
+        self.stmt(f, "must have Load context")
+        def fac(args):
+            return ast.FunctionDef("x", args, [ast.Pass()], [], None)
+        self._check_arguments(fac, self.stmt)
+
+    def test_classdef(self):
+        def cls(bases=None, keywords=None, starargs=None, kwargs=None,
+                body=None, decorator_list=None):
+            if bases is None:
+                bases = []
+            if keywords is None:
+                keywords = []
+            if body is None:
+                body = [ast.Pass()]
+            if decorator_list is None:
+                decorator_list = []
+            return ast.ClassDef("myclass", bases, keywords, starargs,
+                                kwargs, body, decorator_list)
+        self.stmt(cls(bases=[ast.Name("x", ast.Store())]),
+                  "must have Load context")
+        self.stmt(cls(keywords=[ast.keyword("x", ast.Name("x", ast.Store()))]),
+                  "must have Load context")
+        self.stmt(cls(starargs=ast.Name("x", ast.Store())),
+                  "must have Load context")
+        self.stmt(cls(kwargs=ast.Name("x", ast.Store())),
+                  "must have Load context")
+        self.stmt(cls(body=[]), "empty body on ClassDef")
+        self.stmt(cls(body=[None]), "None disallowed")
+        self.stmt(cls(decorator_list=[ast.Name("x", ast.Store())]),
+                  "must have Load context")
+
+    def test_delete(self):
+        self.stmt(ast.Delete([]), "empty targets on Delete")
+        self.stmt(ast.Delete([None]), "None disallowed")
+        self.stmt(ast.Delete([ast.Name("x", ast.Load())]),
+                  "must have Del context")
+
+    def test_assign(self):
+        self.stmt(ast.Assign([], ast.Num(3)), "empty targets on Assign")
+        self.stmt(ast.Assign([None], ast.Num(3)), "None disallowed")
+        self.stmt(ast.Assign([ast.Name("x", ast.Load())], ast.Num(3)),
+                  "must have Store context")
+        self.stmt(ast.Assign([ast.Name("x", ast.Store())],
+                                ast.Name("y", ast.Store())),
+                  "must have Load context")
+
+    def test_augassign(self):
+        aug = ast.AugAssign(ast.Name("x", ast.Load()), ast.Add(),
+                            ast.Name("y", ast.Load()))
+        self.stmt(aug, "must have Store context")
+        aug = ast.AugAssign(ast.Name("x", ast.Store()), ast.Add(),
+                            ast.Name("y", ast.Store()))
+        self.stmt(aug, "must have Load context")
+
+    def test_for(self):
+        x = ast.Name("x", ast.Store())
+        y = ast.Name("y", ast.Load())
+        p = ast.Pass()
+        self.stmt(ast.For(x, y, [], []), "empty body on For")
+        self.stmt(ast.For(ast.Name("x", ast.Load()), y, [p], []),
+                  "must have Store context")
+        self.stmt(ast.For(x, ast.Name("y", ast.Store()), [p], []),
+                  "must have Load context")
+        e = ast.Expr(ast.Name("x", ast.Store()))
+        self.stmt(ast.For(x, y, [e], []), "must have Load context")
+        self.stmt(ast.For(x, y, [p], [e]), "must have Load context")
+
+    def test_while(self):
+        self.stmt(ast.While(ast.Num(3), [], []), "empty body on While")
+        self.stmt(ast.While(ast.Name("x", ast.Store()), [ast.Pass()], []),
+                  "must have Load context")
+        self.stmt(ast.While(ast.Num(3), [ast.Pass()],
+                             [ast.Expr(ast.Name("x", ast.Store()))]),
+                             "must have Load context")
+
+    def test_if(self):
+        self.stmt(ast.If(ast.Num(3), [], []), "empty body on If")
+        i = ast.If(ast.Name("x", ast.Store()), [ast.Pass()], [])
+        self.stmt(i, "must have Load context")
+        i = ast.If(ast.Num(3), [ast.Expr(ast.Name("x", ast.Store()))], [])
+        self.stmt(i, "must have Load context")
+        i = ast.If(ast.Num(3), [ast.Pass()],
+                   [ast.Expr(ast.Name("x", ast.Store()))])
+        self.stmt(i, "must have Load context")
+
+    def test_with(self):
+        p = ast.Pass()
+        self.stmt(ast.With([], [p]), "empty items on With")
+        i = ast.withitem(ast.Num(3), None)
+        self.stmt(ast.With([i], []), "empty body on With")
+        i = ast.withitem(ast.Name("x", ast.Store()), None)
+        self.stmt(ast.With([i], [p]), "must have Load context")
+        i = ast.withitem(ast.Num(3), ast.Name("x", ast.Load()))
+        self.stmt(ast.With([i], [p]), "must have Store context")
+
+    def test_raise(self):
+        r = ast.Raise(None, ast.Num(3))
+        self.stmt(r, "Raise with cause but no exception")
+        r = ast.Raise(ast.Name("x", ast.Store()), None)
+        self.stmt(r, "must have Load context")
+        r = ast.Raise(ast.Num(4), ast.Name("x", ast.Store()))
+        self.stmt(r, "must have Load context")
+
+    def test_try(self):
+        p = ast.Pass()
+        t = ast.Try([], [], [], [p])
+        self.stmt(t, "empty body on Try")
+        t = ast.Try([ast.Expr(ast.Name("x", ast.Store()))], [], [], [p])
+        self.stmt(t, "must have Load context")
+        t = ast.Try([p], [], [], [])
+        self.stmt(t, "Try has neither except handlers nor finalbody")
+        t = ast.Try([p], [], [p], [p])
+        self.stmt(t, "Try has orelse but no except handlers")
+        t = ast.Try([p], [ast.ExceptHandler(None, "x", [])], [], [])
+        self.stmt(t, "empty body on ExceptHandler")
+        e = [ast.ExceptHandler(ast.Name("x", ast.Store()), "y", [p])]
+        self.stmt(ast.Try([p], e, [], []), "must have Load context")
+        e = [ast.ExceptHandler(None, "x", [p])]
+        t = ast.Try([p], e, [ast.Expr(ast.Name("x", ast.Store()))], [p])
+        self.stmt(t, "must have Load context")
+        t = ast.Try([p], e, [p], [ast.Expr(ast.Name("x", ast.Store()))])
+        self.stmt(t, "must have Load context")
+
+    def test_assert(self):
+        self.stmt(ast.Assert(ast.Name("x", ast.Store()), None),
+                  "must have Load context")
+        assrt = ast.Assert(ast.Name("x", ast.Load()),
+                           ast.Name("y", ast.Store()))
+        self.stmt(assrt, "must have Load context")
+
+    def test_import(self):
+        self.stmt(ast.Import([]), "empty names on Import")
+
+    def test_importfrom(self):
+        imp = ast.ImportFrom(None, [ast.alias("x", None)], -42)
+        self.stmt(imp, "level less than -1")
+        self.stmt(ast.ImportFrom(None, [], 0), "empty names on ImportFrom")
+
+    def test_global(self):
+        self.stmt(ast.Global([]), "empty names on Global")
+
+    def test_nonlocal(self):
+        self.stmt(ast.Nonlocal([]), "empty names on Nonlocal")
+
+    def test_expr(self):
+        e = ast.Expr(ast.Name("x", ast.Store()))
+        self.stmt(e, "must have Load context")
+
+    def test_boolop(self):
+        b = ast.BoolOp(ast.And(), [])
+        self.expr(b, "less than 2 values")
+        b = ast.BoolOp(ast.And(), [ast.Num(3)])
+        self.expr(b, "less than 2 values")
+        b = ast.BoolOp(ast.And(), [ast.Num(4), None])
+        self.expr(b, "None disallowed")
+        b = ast.BoolOp(ast.And(), [ast.Num(4), ast.Name("x", ast.Store())])
+        self.expr(b, "must have Load context")
+
+    def test_unaryop(self):
+        u = ast.UnaryOp(ast.Not(), ast.Name("x", ast.Store()))
+        self.expr(u, "must have Load context")
+
+    def test_lambda(self):
+        a = ast.arguments([], None, None, [], None, None, [], [])
+        self.expr(ast.Lambda(a, ast.Name("x", ast.Store())),
+                  "must have Load context")
+        def fac(args):
+            return ast.Lambda(args, ast.Name("x", ast.Load()))
+        self._check_arguments(fac, self.expr)
+
+    def test_ifexp(self):
+        l = ast.Name("x", ast.Load())
+        s = ast.Name("y", ast.Store())
+        for args in (s, l, l), (l, s, l), (l, l, s):
+                  self.expr(ast.IfExp(*args), "must have Load context")
+
+    def test_dict(self):
+        d = ast.Dict([], [ast.Name("x", ast.Load())])
+        self.expr(d, "same number of keys as values")
+        d = ast.Dict([None], [ast.Name("x", ast.Load())])
+        self.expr(d, "None disallowed")
+        d = ast.Dict([ast.Name("x", ast.Load())], [None])
+        self.expr(d, "None disallowed")
+
+    def test_set(self):
+        self.expr(ast.Set([None]), "None disallowed")
+        s = ast.Set([ast.Name("x", ast.Store())])
+        self.expr(s, "must have Load context")
+
+    def _check_comprehension(self, fac):
+        self.expr(fac([]), "comprehension with no generators")
+        g = ast.comprehension(ast.Name("x", ast.Load()),
+                              ast.Name("x", ast.Load()), [])
+        self.expr(fac([g]), "must have Store context")
+        g = ast.comprehension(ast.Name("x", ast.Store()),
+                              ast.Name("x", ast.Store()), [])
+        self.expr(fac([g]), "must have Load context")
+        x = ast.Name("x", ast.Store())
+        y = ast.Name("y", ast.Load())
+        g = ast.comprehension(x, y, [None])
+        self.expr(fac([g]), "None disallowed")
+        g = ast.comprehension(x, y, [ast.Name("x", ast.Store())])
+        self.expr(fac([g]), "must have Load context")
+
+    def _simple_comp(self, fac):
+        g = ast.comprehension(ast.Name("x", ast.Store()),
+                              ast.Name("x", ast.Load()), [])
+        self.expr(fac(ast.Name("x", ast.Store()), [g]),
+                  "must have Load context")
+        def wrap(gens):
+            return fac(ast.Name("x", ast.Store()), gens)
+        self._check_comprehension(wrap)
+
+    def test_listcomp(self):
+        self._simple_comp(ast.ListComp)
+
+    def test_setcomp(self):
+        self._simple_comp(ast.SetComp)
+
+    def test_generatorexp(self):
+        self._simple_comp(ast.GeneratorExp)
+
+    def test_dictcomp(self):
+        g = ast.comprehension(ast.Name("y", ast.Store()),
+                              ast.Name("p", ast.Load()), [])
+        c = ast.DictComp(ast.Name("x", ast.Store()),
+                         ast.Name("y", ast.Load()), [g])
+        self.expr(c, "must have Load context")
+        c = ast.DictComp(ast.Name("x", ast.Load()),
+                         ast.Name("y", ast.Store()), [g])
+        self.expr(c, "must have Load context")
+        def factory(comps):
+            k = ast.Name("x", ast.Load())
+            v = ast.Name("y", ast.Load())
+            return ast.DictComp(k, v, comps)
+        self._check_comprehension(factory)
+
+    def test_yield(self):
+        self.expr(ast.Yield(ast.Name("x", ast.Store())), "must have Load")
+
+    def test_compare(self):
+        left = ast.Name("x", ast.Load())
+        comp = ast.Compare(left, [ast.In()], [])
+        self.expr(comp, "no comparators")
+        comp = ast.Compare(left, [ast.In()], [ast.Num(4), ast.Num(5)])
+        self.expr(comp, "different number of comparators and operands")
+        comp = ast.Compare(ast.Num("blah"), [ast.In()], [left])
+        self.expr(comp, "non-numeric", exc=TypeError)
+        comp = ast.Compare(left, [ast.In()], [ast.Num("blah")])
+        self.expr(comp, "non-numeric", exc=TypeError)
+
+    def test_call(self):
+        func = ast.Name("x", ast.Load())
+        args = [ast.Name("y", ast.Load())]
+        keywords = [ast.keyword("w", ast.Name("z", ast.Load()))]
+        stararg = ast.Name("p", ast.Load())
+        kwarg = ast.Name("q", ast.Load())
+        call = ast.Call(ast.Name("x", ast.Store()), args, keywords, stararg,
+                        kwarg)
+        self.expr(call, "must have Load context")
+        call = ast.Call(func, [None], keywords, stararg, kwarg)
+        self.expr(call, "None disallowed")
+        bad_keywords = [ast.keyword("w", ast.Name("z", ast.Store()))]
+        call = ast.Call(func, args, bad_keywords, stararg, kwarg)
+        self.expr(call, "must have Load context")
+        call = ast.Call(func, args, keywords, ast.Name("z", ast.Store()), kwarg)
+        self.expr(call, "must have Load context")
+        call = ast.Call(func, args, keywords, stararg,
+                        ast.Name("w", ast.Store()))
+        self.expr(call, "must have Load context")
+
+    def test_num(self):
+        class subint(int):
+            pass
+        class subfloat(float):
+            pass
+        class subcomplex(complex):
+            pass
+        for obj in "0", "hello", subint(), subfloat(), subcomplex():
+            self.expr(ast.Num(obj), "non-numeric", exc=TypeError)
+
+    def test_attribute(self):
+        attr = ast.Attribute(ast.Name("x", ast.Store()), "y", ast.Load())
+        self.expr(attr, "must have Load context")
+
+    def test_subscript(self):
+        sub = ast.Subscript(ast.Name("x", ast.Store()), ast.Index(ast.Num(3)),
+                            ast.Load())
+        self.expr(sub, "must have Load context")
+        x = ast.Name("x", ast.Load())
+        sub = ast.Subscript(x, ast.Index(ast.Name("y", ast.Store())),
+                            ast.Load())
+        self.expr(sub, "must have Load context")
+        s = ast.Name("x", ast.Store())
+        for args in (s, None, None), (None, s, None), (None, None, s):
+            sl = ast.Slice(*args)
+            self.expr(ast.Subscript(x, sl, ast.Load()),
+                      "must have Load context")
+        sl = ast.ExtSlice([])
+        self.expr(ast.Subscript(x, sl, ast.Load()), "empty dims on ExtSlice")
+        sl = ast.ExtSlice([ast.Index(s)])
+        self.expr(ast.Subscript(x, sl, ast.Load()), "must have Load context")
+
+    def test_starred(self):
+        left = ast.List([ast.Starred(ast.Name("x", ast.Load()), ast.Store())],
+                        ast.Store())
+        assign = ast.Assign([left], ast.Num(4))
+        self.stmt(assign, "must have Store context")
+
+    def _sequence(self, fac):
+        self.expr(fac([None], ast.Load()), "None disallowed")
+        self.expr(fac([ast.Name("x", ast.Store())], ast.Load()),
+                  "must have Load context")
+
+    def test_list(self):
+        self._sequence(ast.List)
+
+    def test_tuple(self):
+        self._sequence(ast.Tuple)
+
+    def test_stdlib_validates(self):
+        stdlib = os.path.dirname(ast.__file__)
+        tests = [fn for fn in os.listdir(stdlib) if fn.endswith(".py")]
+        tests.extend(["test/test_grammar.py", "test/test_unpack_ex.py"])
+        for module in tests:
+            fn = os.path.join(stdlib, module)
+            with open(fn, "r", encoding="utf-8") as fp:
+                source = fp.read()
+            mod = ast.parse(source)
+            compile(mod, fn, "exec")
+
+
 def test_main():
-    support.run_unittest(AST_Tests, ASTHelpers_Test)
+    support.run_unittest(AST_Tests, ASTHelpers_Test, ASTValidatorTests)
 
 def main():
     if __name__ != '__main__':
diff --git a/Misc/NEWS b/Misc/NEWS
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -10,6 +10,8 @@
 Core and Builtins
 -----------------
 
+- Issue #12575: Validate user-generated AST before it is compiled.
+
 - Make type(None), type(Ellipsis), and type(NotImplemented) callable. They
   return the respective singleton instances.
 
diff --git a/Python/ast.c b/Python/ast.c
--- a/Python/ast.c
+++ b/Python/ast.c
@@ -1,19 +1,497 @@
 /*
  * This file includes functions to transform a concrete syntax tree (CST) to
- * an abstract syntax tree (AST).  The main function is PyAST_FromNode().
+ * an abstract syntax tree (AST). The main function is PyAST_FromNode().
  *
  */
 #include "Python.h"
 #include "Python-ast.h"
-#include "grammar.h"
 #include "node.h"
 #include "ast.h"
 #include "token.h"
+
+#include <assert.h>
+
+static int validate_stmts(asdl_seq *);
+static int validate_exprs(asdl_seq *, expr_context_ty, int);
+static int validate_nonempty_seq(asdl_seq *, const char *, const char *);
+static int validate_stmt(stmt_ty);
+static int validate_expr(expr_ty, expr_context_ty);
+
+static int
+validate_comprehension(asdl_seq *gens)
+{
+    int i;
+    if (!asdl_seq_LEN(gens)) {
+        PyErr_SetString(PyExc_ValueError, "comprehension with no generators");
+        return 0;
+    }
+    for (i = 0; i < asdl_seq_LEN(gens); i++) {
+        comprehension_ty comp = asdl_seq_GET(gens, i);
+        if (!validate_expr(comp->target, Store) ||
+            !validate_expr(comp->iter, Load) ||
+            !validate_exprs(comp->ifs, Load, 0))
+            return 0;
+    }
+    return 1;
+}
+
+static int
+validate_slice(slice_ty slice)
+{
+    switch (slice->kind) {
+    case Slice_kind:
+        return (!slice->v.Slice.lower || validate_expr(slice->v.Slice.lower, Load)) &&
+            (!slice->v.Slice.upper || validate_expr(slice->v.Slice.upper, Load)) &&
+            (!slice->v.Slice.step || validate_expr(slice->v.Slice.step, Load));
+    case ExtSlice_kind: {
+        int i;
+        if (!validate_nonempty_seq(slice->v.ExtSlice.dims, "dims", "ExtSlice"))
+            return 0;
+        for (i = 0; i < asdl_seq_LEN(slice->v.ExtSlice.dims); i++)
+            if (!validate_slice(asdl_seq_GET(slice->v.ExtSlice.dims, i)))
+                return 0;
+        return 1;
+    }
+    case Index_kind:
+        return validate_expr(slice->v.Index.value, Load);
+    default:
+        PyErr_SetString(PyExc_SystemError, "unknown slice node");
+        return 0;
+    }
+}
+
+static int
+validate_keywords(asdl_seq *keywords)
+{
+    int i;
+    for (i = 0; i < asdl_seq_LEN(keywords); i++)
+        if (!validate_expr(((keyword_ty)asdl_seq_GET(keywords, i))->value, Load))
+            return 0;
+    return 1;
+}
+
+static int
+validate_args(asdl_seq *args)
+{
+    int i;
+    for (i = 0; i < asdl_seq_LEN(args); i++) {
+        arg_ty arg = asdl_seq_GET(args, i);
+        if (arg->annotation && !validate_expr(arg->annotation, Load))
+            return 0;
+    }
+    return 1;
+}
+
+static const char *
+expr_context_name(expr_context_ty ctx)
+{
+    switch (ctx) {
+    case Load:
+        return "Load";
+    case Store:
+        return "Store";
+    case Del:
+        return "Del";
+    case AugLoad:
+        return "AugLoad";
+    case AugStore:
+        return "AugStore";
+    case Param:
+        return "Param";
+    default:
+        assert(0);
+        return "(unknown)";
+    }
+}
+
+static int
+validate_arguments(arguments_ty args)
+{
+    if (!validate_args(args->args))
+        return 0;
+    if (args->varargannotation) {
+        if (!args->vararg) {
+            PyErr_SetString(PyExc_ValueError, "varargannotation but no vararg on arguments");
+            return 0;
+        }
+        if (!validate_expr(args->varargannotation, Load))
+            return 0;
+    }
+    if (!validate_args(args->kwonlyargs))
+        return 0;
+    if (args->kwargannotation) {
+        if (!args->kwarg) {
+            PyErr_SetString(PyExc_ValueError, "kwargannotation but no kwarg on arguments");
+            return 0;
+        }
+        if (!validate_expr(args->kwargannotation, Load))
+            return 0;
+    }
+    if (asdl_seq_LEN(args->defaults) > asdl_seq_LEN(args->args)) {
+        PyErr_SetString(PyExc_ValueError, "more positional defaults than args on arguments");
+        return 0;
+    }
+    if (asdl_seq_LEN(args->kw_defaults) != asdl_seq_LEN(args->kwonlyargs)) {
+        PyErr_SetString(PyExc_ValueError, "length of kwonlyargs is not the same as "
+                        "kw_defaults on arguments");
+        return 0;
+    }
+    return validate_exprs(args->defaults, Load, 0) && validate_exprs(args->kw_defaults, Load, 1);
+}
+
+static int
+validate_expr(expr_ty exp, expr_context_ty ctx)
+{
+    int check_ctx = 1;
+    expr_context_ty actual_ctx;
+
+    /* First check expression context. */
+    switch (exp->kind) {
+    case Attribute_kind:
+        actual_ctx = exp->v.Attribute.ctx;
+        break;
+    case Subscript_kind:
+        actual_ctx = exp->v.Subscript.ctx;
+        break;
+    case Starred_kind:
+        actual_ctx = exp->v.Starred.ctx;
+        break;
+    case Name_kind:
+        actual_ctx = exp->v.Name.ctx;
+        break;
+    case List_kind:
+        actual_ctx = exp->v.List.ctx;
+        break;
+    case Tuple_kind:
+        actual_ctx = exp->v.Tuple.ctx;
+        break;
+    default:
+        if (ctx != Load) {
+            PyErr_Format(PyExc_ValueError, "expression which can't be "
+                         "assigned to in %s context", expr_context_name(ctx));
+            return 0;
+        }
+        check_ctx = 0;
+    }
+    if (check_ctx && actual_ctx != ctx) {
+        PyErr_Format(PyExc_ValueError, "expression must have %s context but has %s instead",
+                     expr_context_name(ctx), expr_context_name(actual_ctx));
+        return 0;
+    }
+
+    /* Now validate expression. */
+    switch (exp->kind) {
+    case BoolOp_kind:
+        if (asdl_seq_LEN(exp->v.BoolOp.values) < 2) {
+            PyErr_SetString(PyExc_ValueError, "BoolOp with less than 2 values");
+            return 0;
+        }
+        return validate_exprs(exp->v.BoolOp.values, Load, 0);
+    case BinOp_kind:
+        return validate_expr(exp->v.BinOp.left, Load) &&
+            validate_expr(exp->v.BinOp.right, Load);
+    case UnaryOp_kind:
+        return validate_expr(exp->v.UnaryOp.operand, Load);
+    case Lambda_kind:
+        return validate_arguments(exp->v.Lambda.args) &&
+            validate_expr(exp->v.Lambda.body, Load);
+    case IfExp_kind:
+        return validate_expr(exp->v.IfExp.test, Load) &&
+            validate_expr(exp->v.IfExp.body, Load) &&
+            validate_expr(exp->v.IfExp.orelse, Load);
+    case Dict_kind:
+        if (asdl_seq_LEN(exp->v.Dict.keys) != asdl_seq_LEN(exp->v.Dict.values)) {
+            PyErr_SetString(PyExc_ValueError,
+                            "Dict doesn't have the same number of keys as values");
+            return 0;
+        }
+        return validate_exprs(exp->v.Dict.keys, Load, 0) &&
+            validate_exprs(exp->v.Dict.values, Load, 0);
+    case Set_kind:
+        return validate_exprs(exp->v.Set.elts, Load, 0);
+#define COMP(NAME) \
+        case NAME ## _kind: \
+            return validate_comprehension(exp->v.NAME.generators) && \
+                validate_expr(exp->v.NAME.elt, Load);
+    COMP(ListComp)
+    COMP(SetComp)
+    COMP(GeneratorExp)
+#undef COMP
+    case DictComp_kind:
+        return validate_comprehension(exp->v.DictComp.generators) &&
+            validate_expr(exp->v.DictComp.key, Load) &&
+            validate_expr(exp->v.DictComp.value, Load);
+    case Yield_kind:
+        return !exp->v.Yield.value || validate_expr(exp->v.Yield.value, Load);
+    case Compare_kind:
+        if (!asdl_seq_LEN(exp->v.Compare.comparators)) {
+            PyErr_SetString(PyExc_ValueError, "Compare with no comparators");
+            return 0;
+        }
+        if (asdl_seq_LEN(exp->v.Compare.comparators) !=
+            asdl_seq_LEN(exp->v.Compare.ops)) {
+            PyErr_SetString(PyExc_ValueError, "Compare has a different number "
+                            "of comparators and operands");
+            return 0;
+        }
+        return validate_exprs(exp->v.Compare.comparators, Load, 0) &&
+            validate_expr(exp->v.Compare.left, Load);
+    case Call_kind:
+        return validate_expr(exp->v.Call.func, Load) &&
+            validate_exprs(exp->v.Call.args, Load, 0) &&
+            validate_keywords(exp->v.Call.keywords) &&
+            (!exp->v.Call.starargs || validate_expr(exp->v.Call.starargs, Load)) &&
+            (!exp->v.Call.kwargs || validate_expr(exp->v.Call.kwargs, Load));
+    case Num_kind: {
+        PyObject *n = exp->v.Num.n;
+        if (!PyLong_CheckExact(n) && !PyFloat_CheckExact(n) &&
+            !PyComplex_CheckExact(n)) {
+            PyErr_SetString(PyExc_TypeError, "non-numeric type in Num");
+            return 0;
+        }
+        return 1;
+    }
+    case Str_kind: {
+        PyObject *s = exp->v.Str.s;
+        if (!PyUnicode_CheckExact(s)) {
+            PyErr_SetString(PyExc_TypeError, "non-string type in Str");
+            return 0;
+        }
+        return 1;
+    }
+    case Bytes_kind: {
+        PyObject *b = exp->v.Bytes.s;
+        if (!PyBytes_CheckExact(b)) {
+            PyErr_SetString(PyExc_TypeError, "non-bytes type in Bytes");
+            return 0;
+        }
+        return 1;
+    }
+    case Attribute_kind:
+        return validate_expr(exp->v.Attribute.value, Load);
+    case Subscript_kind:
+        return validate_slice(exp->v.Subscript.slice) &&
+            validate_expr(exp->v.Subscript.value, Load);
+    case Starred_kind:
+        return validate_expr(exp->v.Starred.value, ctx);
+    case List_kind:
+        return validate_exprs(exp->v.List.elts, ctx, 0);
+    case Tuple_kind:
+        return validate_exprs(exp->v.Tuple.elts, ctx, 0);
+    /* These last cases don't have any checking. */
+    case Name_kind:
+    case Ellipsis_kind:
+        return 1;
+    default:
+        PyErr_SetString(PyExc_SystemError, "unexpected expression");
+        return 0;
+    }
+}
+
+static int
+validate_nonempty_seq(asdl_seq *seq, const char *what, const char *owner)
+{
+    if (asdl_seq_LEN(seq))
+        return 1;
+    PyErr_Format(PyExc_ValueError, "empty %s on %s", what, owner);
+    return 0;
+}
+
+static int
+validate_assignlist(asdl_seq *targets, expr_context_ty ctx)
+{
+    return validate_nonempty_seq(targets, "targets", ctx == Del ? "Delete" : "Assign") &&
+        validate_exprs(targets, ctx, 0);
+}
+
+static int
+validate_body(asdl_seq *body, const char *owner)
+{
+    return validate_nonempty_seq(body, "body", owner) && validate_stmts(body);
+}
+
+static int
+validate_stmt(stmt_ty stmt)
+{
+    int i;
+    switch (stmt->kind) {
+    case FunctionDef_kind:
+        return validate_body(stmt->v.FunctionDef.body, "FunctionDef") &&
+            validate_arguments(stmt->v.FunctionDef.args) &&
+            validate_exprs(stmt->v.FunctionDef.decorator_list, Load, 0) &&
+            (!stmt->v.FunctionDef.returns ||
+             validate_expr(stmt->v.FunctionDef.returns, Load));
+    case ClassDef_kind:
+        return validate_body(stmt->v.ClassDef.body, "ClassDef") &&
+            validate_exprs(stmt->v.ClassDef.bases, Load, 0) &&
+            validate_keywords(stmt->v.ClassDef.keywords) &&
+            validate_exprs(stmt->v.ClassDef.decorator_list, Load, 0) &&
+            (!stmt->v.ClassDef.starargs || validate_expr(stmt->v.ClassDef.starargs, Load)) &&
+            (!stmt->v.ClassDef.kwargs || validate_expr(stmt->v.ClassDef.kwargs, Load));
+    case Return_kind:
+        return !stmt->v.Return.value || validate_expr(stmt->v.Return.value, Load);
+    case Delete_kind:
+        return validate_assignlist(stmt->v.Delete.targets, Del);
+    case Assign_kind:
+        return validate_assignlist(stmt->v.Assign.targets, Store) &&
+            validate_expr(stmt->v.Assign.value, Load);
+    case AugAssign_kind:
+        return validate_expr(stmt->v.AugAssign.target, Store) &&
+            validate_expr(stmt->v.AugAssign.value, Load);
+    case For_kind:
+        return validate_expr(stmt->v.For.target, Store) &&
+            validate_expr(stmt->v.For.iter, Load) &&
+            validate_body(stmt->v.For.body, "For") &&
+            validate_stmts(stmt->v.For.orelse);
+    case While_kind:
+        return validate_expr(stmt->v.While.test, Load) &&
+            validate_body(stmt->v.While.body, "While") &&
+            validate_stmts(stmt->v.While.orelse);
+    case If_kind:
+        return validate_expr(stmt->v.If.test, Load) &&
+            validate_body(stmt->v.If.body, "If") &&
+            validate_stmts(stmt->v.If.orelse);
+    case With_kind:
+        if (!validate_nonempty_seq(stmt->v.With.items, "items", "With"))
+            return 0;
+        for (i = 0; i < asdl_seq_LEN(stmt->v.With.items); i++) {
+            withitem_ty item = asdl_seq_GET(stmt->v.With.items, i);
+            if (!validate_expr(item->context_expr, Load) ||
+                (item->optional_vars && !validate_expr(item->optional_vars, Store)))
+                return 0;
+        }
+        return validate_body(stmt->v.With.body, "With");
+    case Raise_kind:
+        if (stmt->v.Raise.exc) {
+            return validate_expr(stmt->v.Raise.exc, Load) &&
+                (!stmt->v.Raise.cause || validate_expr(stmt->v.Raise.cause, Load));
+        }
+        if (stmt->v.Raise.cause) {
+            PyErr_SetString(PyExc_ValueError, "Raise with cause but no exception");
+            return 0;
+        }
+        return 1;
+    case Try_kind:
+        if (!validate_body(stmt->v.Try.body, "Try"))
+            return 0;
+        if (!asdl_seq_LEN(stmt->v.Try.handlers) &&
+            !asdl_seq_LEN(stmt->v.Try.finalbody)) {
+            PyErr_SetString(PyExc_ValueError, "Try has neither except handlers nor finalbody");
+            return 0;
+        }
+        if (!asdl_seq_LEN(stmt->v.Try.handlers) &&
+            asdl_seq_LEN(stmt->v.Try.orelse)) {
+            PyErr_SetString(PyExc_ValueError, "Try has orelse but no except handlers");
+            return 0;
+        }
+        for (i = 0; i < asdl_seq_LEN(stmt->v.Try.handlers); i++) {
+            excepthandler_ty handler = asdl_seq_GET(stmt->v.Try.handlers, i);
+            if ((handler->v.ExceptHandler.type &&
+                 !validate_expr(handler->v.ExceptHandler.type, Load)) ||
+                !validate_body(handler->v.ExceptHandler.body, "ExceptHandler"))
+                return 0;
+        }
+        return (!asdl_seq_LEN(stmt->v.Try.finalbody) ||
+                validate_stmts(stmt->v.Try.finalbody)) &&
+            (!asdl_seq_LEN(stmt->v.Try.orelse) ||
+             validate_stmts(stmt->v.Try.orelse));
+    case Assert_kind:
+        return validate_expr(stmt->v.Assert.test, Load) &&
+            (!stmt->v.Assert.msg || validate_expr(stmt->v.Assert.msg, Load));
+    case Import_kind:
+        return validate_nonempty_seq(stmt->v.Import.names, "names", "Import");
+    case ImportFrom_kind:
+        if (stmt->v.ImportFrom.level < -1) {
+            PyErr_SetString(PyExc_ValueError, "ImportFrom level less than -1");
+            return 0;
+        }
+        return validate_nonempty_seq(stmt->v.ImportFrom.names, "names", "ImportFrom");
+    case Global_kind:
+        return validate_nonempty_seq(stmt->v.Global.names, "names", "Global");
+    case Nonlocal_kind:
+        return validate_nonempty_seq(stmt->v.Nonlocal.names, "names", "Nonlocal");
+    case Expr_kind:
+        return validate_expr(stmt->v.Expr.value, Load);
+    case Pass_kind:
+    case Break_kind:
+    case Continue_kind:
+        return 1;
+    default:
+        PyErr_SetString(PyExc_SystemError, "unexpected statement");
+        return 0;
+    }
+}
+
+static int
+validate_stmts(asdl_seq *seq)
+{
+    int i;
+    for (i = 0; i < asdl_seq_LEN(seq); i++) {
+        stmt_ty stmt = asdl_seq_GET(seq, i);
+        if (stmt) {
+            if (!validate_stmt(stmt))
+                return 0;
+        }
+        else {
+            PyErr_SetString(PyExc_ValueError,
+                            "None disallowed in statement list");
+            return 0;
+        }
+    }
+    return 1;
+}
+
+static int
+validate_exprs(asdl_seq *exprs, expr_context_ty ctx, int null_ok)
+{
+    int i;
+    for (i = 0; i < asdl_seq_LEN(exprs); i++) {
+        expr_ty expr = asdl_seq_GET(exprs, i);
+        if (expr) {
+            if (!validate_expr(expr, ctx))
+                return 0;
+        }
+        else if (!null_ok) {
+            PyErr_SetString(PyExc_ValueError,
+                            "None disallowed in expression list");
+            return 0;
+        }
+            
+    }
+    return 1;
+}
+
+int
+PyAST_Validate(mod_ty mod)
+{
+    int res = 0;
+
+    switch (mod->kind) {
+    case Module_kind:
+        res = validate_stmts(mod->v.Module.body);
+        break;
+    case Interactive_kind:
+        res = validate_stmts(mod->v.Interactive.body);
+        break;
+    case Expression_kind:
+        res = validate_expr(mod->v.Expression.body, Load);
+        break;
+    case Suite_kind:
+        PyErr_SetString(PyExc_ValueError, "Suite is not valid in the CPython compiler");
+        break;
+    default:
+        PyErr_SetString(PyExc_SystemError, "impossible module node");
+        res = 0;
+        break;
+    }
+    return res;
+}
+
+/* This is down here, so defines like "test" don't intefere with access AST above. */
+#include "grammar.h"
 #include "parsetok.h"
 #include "graminit.h"
 
-#include <assert.h>
-
 /* Data structure used internally */
 struct compiling {
     char *c_encoding; /* source encoding */
diff --git a/Python/bltinmodule.c b/Python/bltinmodule.c
--- a/Python/bltinmodule.c
+++ b/Python/bltinmodule.c
@@ -604,6 +604,10 @@
                 PyArena_Free(arena);
                 goto error;
             }
+            if (!PyAST_Validate(mod)) {
+                PyArena_Free(arena);
+                goto error;
+            }
             result = (PyObject*)PyAST_CompileEx(mod, filename,
                                                 &cf, optimize, arena);
             PyArena_Free(arena);

-- 
Repository URL: http://hg.python.org/cpython


More information about the Python-checkins mailing list