parallel class structures for AST-based objects

Steve Howell showell30 at yahoo.com
Sat Nov 21 18:23:26 EST 2009


I have been writing some code that parses a mini-language, and I am
running into what I know is a pretty common design pattern problem,
but I am wondering the most Pythonic way to solve it.

Basically, I have a bunch of really simple classes that work together
to define an expression--in my oversimplified example code below,
those are Integer, Sum, and Product.

Then I want to write different modules that work with those expression
objects.  In my example, I have a parallel set of classes that enable
you to evaluate an expression, and another set of objects that enable
you to pretty-print the expression.

The below code works as intended so far (tested under 2.6), but before
I go too much further with this design, I want to get a sanity check
and some ideas on other ways to represent the interrelationships
within the code.  Basically, the issue here is that you have varying
behavior in two dimensions--a node right now is only a Product/Integer/
Sum so far, but I might want to introduce new concepts like
Difference, Quotient, etc.  And right now the only things that you can
do to expressions is eval() them and pprint() them, but you eventually
might want to operate on the expressions in new ways, including fairly
abstract operations that go beyond a simple walking of the tree.

Here is the code:

    #######
    # base classes just represents the expression itself, which
    # get created by a parser or unit tests
    # example usage is
    # expression = Product(Sum(Integer(5),Integer(2)), Integer(6))
    class Integer:
        def __init__(self, val):
            self.val = val

    class BinaryOp:
        def __init__(self, a,b):
            self.a = a
            self.b = b

    class Sum(BinaryOp):
        pass

    class Product(BinaryOp):
        pass

    ########

    class EvalNode:
        def __init__(self, node):
            self.node = node

        def evaluatechild(self, child):
            return EvalNode.factory(child).eval()

        @staticmethod
        def factory(child):
            mapper = {
                'Sum': SumEvalNode,
                'Product': ProductEvalNode,
                'Integer': IntegerEvalNode
                }
            return abstract_factory(child, mapper)

    class SumEvalNode(EvalNode):
        def eval(self):
            a = self.evaluatechild(self.node.a)
            b = self.evaluatechild(self.node.b)
            return a + b

    class ProductEvalNode(EvalNode):
        def eval(self):
            a = self.evaluatechild(self.node.a)
            b = self.evaluatechild(self.node.b)
            return a * b

    class IntegerEvalNode(EvalNode):
        def eval(self): return self.node.val

    #######

    class PrettyPrintNode:
        def __init__(self, node):
            self.node = node

        def pprint_child(self, child):
            return PrettyPrintNode.factory(child).pprint()

        @staticmethod
        def factory(child):
            mapper = {
                'Sum': SumPrettyPrintNode,
                'Product': ProductPrettyPrintNode,
                'Integer': IntegerPrettyPrintNode
                }
            return abstract_factory(child, mapper)

    class SumPrettyPrintNode(PrettyPrintNode):
        def pprint(self):
            a = self.pprint_child(self.node.a)
            b = self.pprint_child(self.node.b)
            return '(the sum of %s and %s)' % (a, b)

    class ProductPrettyPrintNode(PrettyPrintNode):
        def pprint(self):
            a = self.pprint_child(self.node.a)
            b = self.pprint_child(self.node.b)
            return '(the product of %s and %s)' % (a, b)

    class IntegerPrettyPrintNode(PrettyPrintNode):
        def pprint(self): return self.node.val

    ##############
    # Not sure where this method really "wants to be" structurally,
    # or what it should be named, but it reduces some duplication

    def abstract_factory(node, node_class_mapper):
        return node_class_mapper[node.__class__.__name__](node)


    expression = Product(Sum(Integer(5),Integer(2)), Integer(6))

    evaluator = EvalNode.factory(expression)
    print evaluator.eval()

    pprinter = PrettyPrintNode.factory(expression)
    print pprinter.pprint()



More information about the Python-list mailing list