what is the ordering of an AST tree?

Robert Brewer fumanchu at amor.org
Fri Feb 27 22:47:44 CET 2004


Stephen Emslie wrote:
> When one creates an AST tree using the compiler module, what 
> order is the tree created in? In other words, are the
> ordering of the nodes in the tree the same as the
> flow of the program?

In my experience (_not_ analysis ;)..yes--it's in the same order. Here's
the deriver module I ended up with last week, and some unittests which
demonstrate that, for simple expressions, at least, order is preserved.

----derive.py----

import compiler
import operator

class Expression(unicode):
    """A Python expression."""
    
    rootnode = None
    
    def __new__(cls, arg=u''):
        return unicode.__new__(cls, arg)
    
    def __init__(self, arg=u''):
        self.rootnode = compiler.parse(arg).node
    
    def evaluate(self, processor):
        return processor.evaluate(self)


class Eval(object):
    """Filters Python objects according to an Expression.
    
    Objects of Name 'x' in the Expression will be replaced with the
    specified object.
    """
    def evaluate(self, expr, obj):
        if expr:
            return [eval(expr, globals(), {'x': obj})]
        else:
            return [True]


class Processor(object):
    """A base class for preorder processing of compiler.AST nodes."""
    
    def process(self, node):
        name = node.__class__.__name__
        # Let attribute errors propagate outward.
        visitFunc = getattr(self, 'visit' + name)
        return visitFunc(node)
    
    def _walk(self, node, childattr=None):
        if childattr is None:
            nodes = node.getChildNodes()
        else:
            nodes = getattr(node, childattr)
        return [self.process(child) for child in nodes]
    
    def evaluate(self, expr):
        if not isinstance(expr, Expression):
            expr = Expression(expr)
        return self.process(expr.rootnode)


class Filter(Processor):
    """Filters Python objects according to compiler.AST nodes."""
    
    comparisons = {'<': operator.lt,
                   '<=': operator.le,
                   '==': operator.eq,
                   '!=': operator.ne,
                   '>': operator.gt,
                   '>=': operator.ge,
                   'in': lambda x, y: x in y,
                   }
    
    def __init__(self, namespace):
        self.namespace = namespace
    
    def visitAnd(self, node):
        for atom in self._walk(node):
            if not atom:
                return False
        return True
    
    def visitOr(self, node):
        for atom in self._walk(node):
            if atom:
                return True
        return False
    
    def visitNot(self, node):
        return (not self.process(node.expr))
    
    def visitUnarySub(self, node):
        return -(self.process(node.expr))
    
    def visitCompare(self, node):
        op, operand = node.ops[0]
        opFunc = self.comparisons[op]
        a = self.process(node.expr)
        b = self.process(operand)
        return opFunc(a, b)
    
    def visitGetattr(self, node):
        obj = self.process(node.expr)
        return getattr(obj, node.attrname)
    
    def visitName(self, node):
        return self.namespace[node.name]
    
    def visitConst(self, node):
        return node.value
    
    def visitTuple(self, node):
        return tuple(self._walk(node))
    
    def visitList(self, node):
        return self._walk(node)
    
    def visitCallFunc(self, node):
        args = ()
        if node.args:
            args = tuple(self._walk(node, 'args'))
        if node.star_args:
            args += tuple(self._walk(node, 'star_args'))
        
        kwargs = {}
        if node.dstar_args:
            for child in self._walk(node, 'dstar_args'):
                kwargs.update(child)
        
        callFunc = self.process(node.node)
        return callFunc(*args, **kwargs)
    
    def visitKeyword(self, node):
        return {node.name: self.process(node.expr)}
    
    def visitDiscard(self, node):
        return self.process(node.expr)
    
    def visitStmt(self, node):
        if node.nodes:
            return self._walk(node, 'nodes')
        else:
            return [True]


class Deriver(Processor):
    """Derives Python code from compiler.AST nodes."""
    
    def visitAnd(self, node):
        return ' and '.join(["(" + x + ")" for x in self._walk(node)])
    
    def visitOr(self, node):
        return ' or '.join(["(" + x + ")" for x in self._walk(node)])
    
    def visitNot(self, node):
        return "not (" + self.process(node.expr) + ")"
    
    def visitUnarySub(self, node):
        return "-" + self.process(node.expr)
    
    def visitCompare(self, node):
        op, operand = node.ops[0]
        return (self.process(node.expr)
                + " " + op + " "
                + self.process(operand))
    
    def visitGetattr(self, node):
        return self.process(node.expr) + "." + node.attrname
    
    def visitName(self, node):
        return node.name
    
    def visitConst(self, node):
        return repr(node.value)
    
    def visitTuple(self, node):
        return "(" + ", ".join(self._walk(node)) + ")"
    
    def visitList(self, node):
        return "[" + ", ".join(self._walk(node)) + "]"
    
    def visitCallFunc(self, node):
        callFunc = self.process(node.node)
        
        atoms = []
        if node.args:
            atoms.extend(self._walk(node, 'args'))
        if node.star_args:
            atoms.extend(self._walk(node, 'star_args'))
        if node.dstar_args:
            atoms.extend(self._walk(node, 'dstar_args'))
        return callFunc + "(" + ", ".join(atoms) + ")"
    
    def visitKeyword(self, node):
        return node.name + "=" + self.process(node.expr)
    
    def visitDiscard(self, node):
        return self.process(node.expr)
    
    def visitStmt(self, node):
        return "; ".join(self._walk(node, 'nodes'))


----derive_test.py----

import unittest
from test import test_support
import derive

class Dummy(object):
    pass

class ExprTests(unittest.TestCase):
    
    def test_Filter(self):
        x = Dummy()
        x.a = 3
        x.b = 5
        e = derive.Expression('x.a == 3 and (x.b > 1 or x.b < -1)')
        self.assertEqual(derive.Filter(locals()).evaluate(e), [True])
        
        x.b = 0
        self.assertEqual(derive.Filter(locals()).evaluate(e), [False])
        
        x.Name = 'a fishy consignment'
        x.Content = 'sea-bass'
        e = derive.Expression("'fish' in x.Name or 'slap' in x.Content")
        self.assertEqual(derive.Filter(locals()).evaluate(e), [True])
    
    def test_Deriver(self):
        e = derive.Expression("(x.a == 3 and (x.b > 1 or x.b < -10))")
        d = derive.Deriver().evaluate(e)
        self.assertEqual(d, u'(x.a == 3) and ((x.b > 1) or (x.b <
-10))')
        
        # Try an alternate version of the same expression.
        e = derive.Expression("(x.a == 3) and ((x.b > 1) or (x.b <
-10))")
        d = derive.Deriver().evaluate(e)
        self.assertEqual(d, u'(x.a == 3) and ((x.b > 1) or (x.b <
-10))')
        
        # Test the empty Expression.
        e = derive.Expression()
        self.assertEqual(derive.Deriver().evaluate(e), u'')


if __name__ == "__main__":
    test_support.run_unittest(ExprTests)




Robert Brewer
MIS
Amor Ministries
fumanchu at amor.org




More information about the Python-list mailing list