parallel class structures for AST-based objects

Richard Thomas chardster at gmail.com
Sat Nov 21 19:33:28 EST 2009


On 22 Nov, 00:07, MRAB <pyt... at mrabarnett.plus.com> wrote:
> Steve Howell wrote:
> > I have been writing some code that parses a mini-language, and I am
> > running into what I know is a pretty common design pattern problem,
> > but I am wondering the most Pythonic way to solve it.
>
> > Basically, I have a bunch of really simple classes that work together
> > to define an expression--in my oversimplified example code below,
> > those are Integer, Sum, and Product.
>
> > Then I want to write different modules that work with those expression
> > objects.  In my example, I have a parallel set of classes that enable
> > you to evaluate an expression, and another set of objects that enable
> > you to pretty-print the expression.
>
> > The below code works as intended so far (tested under 2.6), but before
> > I go too much further with this design, I want to get a sanity check
> > and some ideas on other ways to represent the interrelationships
> > within the code.  Basically, the issue here is that you have varying
> > behavior in two dimensions--a node right now is only a Product/Integer/
> > Sum so far, but I might want to introduce new concepts like
> > Difference, Quotient, etc.  And right now the only things that you can
> > do to expressions is eval() them and pprint() them, but you eventually
> > might want to operate on the expressions in new ways, including fairly
> > abstract operations that go beyond a simple walking of the tree.
>
> > Here is the code:
>
> >     #######
> >     # base classes just represents the expression itself, which
> >     # get created by a parser or unit tests
> >     # example usage is
> >     # expression = Product(Sum(Integer(5),Integer(2)), Integer(6))
> >     class Integer:
> >         def __init__(self, val):
> >             self.val = val
>
> >     class BinaryOp:
> >         def __init__(self, a,b):
> >             self.a = a
> >             self.b = b
>
> >     class Sum(BinaryOp):
> >         pass
>
> >     class Product(BinaryOp):
> >         pass
>
> >     ########
>
> >     class EvalNode:
> >         def __init__(self, node):
> >             self.node = node
>
> >         def evaluatechild(self, child):
> >             return EvalNode.factory(child).eval()
>
> >         @staticmethod
> >         def factory(child):
> >             mapper = {
> >                 'Sum': SumEvalNode,
> >                 'Product': ProductEvalNode,
> >                 'Integer': IntegerEvalNode
> >                 }
> >             return abstract_factory(child, mapper)
>
> >     class SumEvalNode(EvalNode):
> >         def eval(self):
> >             a = self.evaluatechild(self.node.a)
> >             b = self.evaluatechild(self.node.b)
> >             return a + b
>
> >     class ProductEvalNode(EvalNode):
> >         def eval(self):
> >             a = self.evaluatechild(self.node.a)
> >             b = self.evaluatechild(self.node.b)
> >             return a * b
>
> >     class IntegerEvalNode(EvalNode):
> >         def eval(self): return self.node.val
>
> >     #######
>
> >     class PrettyPrintNode:
> >         def __init__(self, node):
> >             self.node = node
>
> >         def pprint_child(self, child):
> >             return PrettyPrintNode.factory(child).pprint()
>
> >         @staticmethod
> >         def factory(child):
> >             mapper = {
> >                 'Sum': SumPrettyPrintNode,
> >                 'Product': ProductPrettyPrintNode,
> >                 'Integer': IntegerPrettyPrintNode
> >                 }
> >             return abstract_factory(child, mapper)
>
> >     class SumPrettyPrintNode(PrettyPrintNode):
> >         def pprint(self):
> >             a = self.pprint_child(self.node.a)
> >             b = self.pprint_child(self.node.b)
> >             return '(the sum of %s and %s)' % (a, b)
>
> >     class ProductPrettyPrintNode(PrettyPrintNode):
> >         def pprint(self):
> >             a = self.pprint_child(self.node.a)
> >             b = self.pprint_child(self.node.b)
> >             return '(the product of %s and %s)' % (a, b)
>
> >     class IntegerPrettyPrintNode(PrettyPrintNode):
> >         def pprint(self): return self.node.val
>
> >     ##############
> >     # Not sure where this method really "wants to be" structurally,
> >     # or what it should be named, but it reduces some duplication
>
> >     def abstract_factory(node, node_class_mapper):
> >         return node_class_mapper[node.__class__.__name__](node)
>
> >     expression = Product(Sum(Integer(5),Integer(2)), Integer(6))
>
> >     evaluator = EvalNode.factory(expression)
> >     print evaluator.eval()
>
> >     pprinter = PrettyPrintNode.factory(expression)
> >     print pprinter.pprint()
>
> I don't see the point of EvalNode and PrettyPrintNode. Why don't you
> just give Integer, Sum and Product 'eval' and 'pprint' methods?

This looks more structurally sound:

class Node(object):
   def eval(self):
      raise NotImplementedError
   def pprint(self):
      raise NotImplementedError

class BinaryOperatorNode(Node):
   operator = None
   def __init__(self, first, second):
      self.first = first
      self.second = second
   def eval(self):
      return self.operator(self.first.eval(), self.second.eval())
   def pprint(self):
      "%s(%s, %s)" % (type(self).__name__, self.first.pprint(),
self.second.pprint())

class Sum(BinaryOperatorNode):
   operator = lambda x, y: x + y

class Product(BinaryOperatorNode):
   operator = lambda x, y: x * y

I don't know what you're doing exactly but if all you need is to be
able to parse and evaluate expressions then you can get very decent
mileage out of overriding operators, to the extent that the whole
thing you are trying to do could be a single class:

class Expression(object):
   def __init__(self, func):
      self.func = func
   def __call__(self, **context):
      while isinstance(self, Expression):
         self = self.func(context)
      return self
   def __add__(self, other):
      return Expression(lambda context: self.func(context) + other)
   def __mul__(self, other):
      return Expression(lambda context: self.func(context) * other)
   def __radd__(self, other):
      return Expression(lambda context: other + self.func(context))
   def __rmul__(self, other):
      return Expression(lambda context: other * self.func(context))
   # ... and so forth ...

def integer(value):
   return Expression(lambda context: value)

def variable(name, default):
   return Expression(lambda context: context.get(name, default))

X = Expression("X", 0)
expr = 2 * X + 1
assert expr(X=3) == 7

But maybe that's not what you need. No need to overengineer if it is
though, keep it simple, simple is better than complex.



More information about the Python-list mailing list