Bytecode optimisation

Corran Webster cwebster at math.tamu.edu
Mon May 17 19:13:16 EDT 1999


I've been mucking around with the bytecodehacks modules of Michael Hudson 
this past weekend, and have used it to hack up an ad hoc bytecode 
optimiser.  I'm posting this to stimulate a bit of interest from others, 
since I have zero experience in building optimising compilers - my knowledge 
is all gathered from reading the dragon book.

This is proof-of-concept stuff at present, I intend a complete 
re-write, and this code almost certainly is incorrect - it has not been 
seriously tested.  Use it at your own risk.

However, it can:
  * precalculate constants: 8 + 4 gets replaced by 12.
  * remove redundant code: strip SET_LINENO bytecodes and unreachable code.
  * simplify jumps-to-jumps and jumps-to-breaks.

In particular it will simplify the bytecode generated by the traditional 
Python idiom of the form:

while 1:
  line = file.readline()
  if not line:
    break
  frob(line)

to look more like the pseudo-Python:

do:
  line = file.readline()
while line:
  frob(line)

ie. the loop test gets moved into the middle of the loop, which considerably
reduces the loop bytecode.

At present the gains as compared to hand-optimised Python are minimal - a 
few percent at most; but in tight inner loops where the loop code itself is 
a significant proportion of running time, the gains can be substantial.

The morals of this hack are:

  * Michael's bytecode hacks _enormously_ simplify this sort of thing 
    because they track jump destinations, and allow easy editing of code
    objects.

  * There's considerable room for peephole type optimization.

  * Optimisation is probably only worthwhile when speed is essential.

  * More complex optimisation is certainly possible, although a lot more 
    work would be required.  I've made a first stab at block-level analysis
    and it seems feasible.  In particular it would be nice to detect and 
    shift loop invariant code.

  * If type information could be guessed, then more might be able to be 
    squeezed.  At the moment I can't simplify something like 0 * x, 
    because x could be anything, and could call an arbitrary function to 
    perform the operation.  If it was known that x was an integer, then we 
    could safely replace it with 0.  Assert statements could be used for
    this sort of thing.

  * If you really need to make your code run faster, you're still better 
    off squeezing the Python source; if that doesn't work, then you 
    probably want to re-write in C.  However, there is hope for us lazy 
    coders in the future.

I must warn again that this code is ugly-mindbending-probably-wrong-pre-
-alpha-use-it-at-your-own-risk-has-to-be-rewritten-evil-nasty stuff.

But-it-is-fun-to-play-with-ly yours,
Corran

---- sample use --
Python 1.5.2b1 (#7, Dec 25 1998, 08:55:40)  [GCC 2.7.2.2+myc1] on netbsd1
Copyright 1991-1995 Stichting Mathematisch Centrum, Amsterdam
>>> from opt import optimize
>>> import dis
>>> def f(file):
...   while 1:
...     line = file.readline()
...     if not line:
...       break
...     frob(line)
... 
>>> f2 = optimize(f)
>>> dis.dis(f)
          0 SET_LINENO          1

          3 SET_LINENO          2
          6 SETUP_LOOP         62 (to 71)

    >>    9 SET_LINENO          2
         12 LOAD_CONST          1 (1)
         15 JUMP_IF_FALSE      51 (to 69)
         18 POP_TOP        

         19 SET_LINENO          3
         22 LOAD_FAST           0 (file)
         25 LOAD_ATTR           1 (readline)
         28 CALL_FUNCTION       0
         31 STORE_FAST          1 (line)

         34 SET_LINENO          4
         37 LOAD_FAST           1 (line)
         40 UNARY_NOT      
         41 JUMP_IF_FALSE       8 (to 52)
         44 POP_TOP        

         45 SET_LINENO          5
         48 BREAK_LOOP     
         49 JUMP_FORWARD        1 (to 53)
    >>   52 POP_TOP        

    >>   53 SET_LINENO          6
         56 LOAD_GLOBAL         3 (frob)
         59 LOAD_FAST           1 (line)
         62 CALL_FUNCTION       1
         65 POP_TOP        
         66 JUMP_ABSOLUTE       9
    >>   69 POP_TOP        
         70 POP_BLOCK      
    >>   71 LOAD_CONST          0 (None)
         74 RETURN_VALUE   
>>> dis.dis(f2)
          0 SETUP_LOOP         35 (to 38)
    >>    3 LOAD_FAST           0 (file)
          6 LOAD_ATTR           1 (readline)
          9 CALL_FUNCTION       0
         12 STORE_FAST          1 (line)
         15 LOAD_FAST           1 (line)
         18 UNARY_NOT      
         19 JUMP_IF_TRUE       14 (to 36)
         22 POP_TOP        
         23 LOAD_GLOBAL         3 (frob)
         26 LOAD_FAST           1 (line)
         29 CALL_FUNCTION       1
         32 POP_TOP        
         33 JUMP_ABSOLUTE       3
    >>   36 POP_TOP        
         37 POP_BLOCK      
    >>   38 LOAD_CONST          0 (None)
         41 RETURN_VALUE   
>>> def frob(x):    
...   pass
... 
>>> import time      
>>> file = open("test")
>>> s = time.time(); f(file); print time.time() - s
3.18224000931
>>> file.close()
>>> file = open("test")
>>> s = time.time(); f2(file); print time.time() - s
3.02694892883
>>> file.close()

---- file opt.py ----
import bytecodehacks.code_editor
import bytecodehacks.ops
import operator

ce = bytecodehacks.code_editor
ops = bytecodehacks.ops


CONDJUMP = [ ops.JUMP_IF_TRUE, ops.JUMP_IF_FALSE ]
UNCONDJUMP = [ ops.JUMP_FORWARD, ops.JUMP_ABSOLUTE ]
UNCOND = UNCONDJUMP + [ ops.BREAK_LOOP, ops.STOP_CODE, ops.RETURN_VALUE, \
    ops.RAISE_VARARGS ]
PYBLOCK = [ ops.SETUP_LOOP, ops.SETUP_EXCEPT, ops.SETUP_FINALLY ]
PYENDBLOCK = [ ops.POP_BLOCK ]

binaryops = {
    'BINARY_ADD': operator.add,
    'BINARY_SUBTRACT': operator.sub,
    'BINARY_MULTIPLY': operator.mul,
    'BINARY_DIVIDE': operator.div,
    'BINARY_MODULO': operator.mod,
    'BINARY_POWER': pow,
    'BINARY_LSHIFT': operator.lshift,
    'BINARY_RSHIFT': operator.rshift,
    'BINARY_AND': operator.and_,
    'BINARY_OR': operator.or_,
    'BINARY_XOR': operator.xor
  }

unaryops = {
    'UNARY_POS': operator.pos,
    'UNARY_NEG': operator.neg,
    'UNARY_NOT': operator.not_
  }

def optimize(func):
  """Optimize a function."""
  f = ce.Function(func)

  # perform some simplifications - no attempt at completeness
  calculateConstants(f.func_code)
  strip_setlineno(f.func_code)
  simplifyjumps(f.func_code)
  removeconstjump(f.func_code)
  simplifyjumps(f.func_code)

  return f.make_function()

def calculateConstants(co):
  """Precalculate results of operations involving constants."""

  cs = co.co_code
  cc = co.co_consts

  stack = []
  i = 0
  while i < len(cs):
    op = cs[i]
    if repr(op) in binaryops.keys():
      if map(lambda x: x.opc, stack[-2:]) == ['d', 'd']:
        arg1 = cc[stack[-2].arg]
        arg2 = cc[stack[-1].arg]
        result = binaryops[repr(op)](arg1,arg2)

        if result in cc:
          arg = cc.index(result)
        else:
          arg = len(cc)
          cc.append(result)
        cs.remove(stack[-2])
        cs.remove(stack[-1])
        i = i - 2
        cs[i] = ops.LOAD_CONST(arg)

        stack.pop()
        stack.pop()
        stack.append(cs[i])
      else:
        op.execute(stack)
    elif repr(op) in unaryops.keys():
      if stack[-1].__class__ == ops.LOAD_CONST:
        arg1 = cc[stack[-1].arg]
        result = unaryops[repr(op)](arg1)

        if result in cc:
          arg = cc.index(result)
        else:
          arg = len(cc)
          cc.append(result)
        cs.remove(stack[-1])
        i = i - 1
        cs[i] = ops.LOAD_CONST(arg)

        stack.pop()
        stack.append(cs[i])
      else:
        op.execute(stack)
    else:
      # this is almost certainly wrong
      try:
        op.execute(stack)
      except: pass
    i = i + 1

def strip_setlineno(co):
  """Take in an EditableCode object and strip the SET_LINENO bytecodes"""
  
  i = 0
  while i < len(co.co_code):
    op = co.co_code[i]
    if op.__class__ == ops.SET_LINENO:
      co.co_code.remove(op)
    else:
      i = i+1

def simplifyjumps(co):
  cs = co.co_code
  
  i = 0
  pyblockstack = [None]
  loopstack = [None]
  trystack = [None]
  firstlook = 1

  while i < len(cs):
    op = cs[i]

    # new pyblock?
    if firstlook:
      if op.__class__ in PYBLOCK:
        pyblockstack.append(op)
        if op.__class__ == ops.SETUP_LOOP:
          loopstack.append(op.label.op)
        else:
          trystack.append(op.label.op)
      # end of pyblock?
      elif op.__class__ == ops.POP_BLOCK:
        op2 = pyblockstack.pop()
        if op2.__class__ == ops.SETUP_LOOP:
          loopstack.pop()
        else:
          trystack.pop()

    # Is the code inaccessible
    if i >= 1:
      if cs[i-1].__class__ in UNCOND and not (cs.find_labels(i) or \
          op.__class__ in PYENDBLOCK):
        cs.remove(op)
        if op.is_jump():
          cs.labels.remove(op.label)
        firstlook = 1
        continue

      # are we jumping from the statement before?
      if cs[i-1].__class__ in UNCONDJUMP:
        if cs[i-1].label.op == op:
          cs.labels.remove(cs[i-1].label)
          cs.remove(cs[i-1])
          firstlook = 1
          continue

      # break before end of loop?
      elif cs[i-1].__class__ == ops.BREAK_LOOP:
        if op.__class__ == ops.POP_BLOCK:
          cs.remove(cs[i-1])
          firstlook = 1
          continue

    # Do we have an unconditional jump to an unconditional jump?
    if op.__class__ in UNCONDJUMP:
      if op.label.op.__class__ in UNCONDJUMP:
        refop = op.label.op
        if op.__class__ == ops.JUMP_FORWARD:
          newop = ops.JUMP_ABSOLUTE()
          newop.label = ce.Label()
          newop.label.op = refop.label.op
          cs.labels.append(newop.label)
          cs.labels.remove(op.label)
          cs[i] = newop
        else: 
          op.label.op = refop.label.op
        firstlook = 0
        continue

    # Do we have a conditional jump to a break?
    if op.__class__ in CONDJUMP and loopstack[-1]:
      destindex = cs.index(op.label.op)
      preendindex = cs.index(loopstack[-1])-2
      if cs[i+2].__class__ == ops.BREAK_LOOP and cs[preendindex].__class__ \
          == ops.POP_TOP:
        if op.__class__ == ops.JUMP_IF_FALSE:
          newop = ops.JUMP_IF_TRUE()
        else:
          newop = ops.JUMP_IF_FALSE()
        label = ce.Label()
        newop.label = label
        label.op = cs[preendindex]
        cs.labels.append(label)
        cs.labels.remove(op.label)
        cs[i] = newop
        cs.remove(cs[i+1])
        cs.remove(cs[i+1])
        cs.remove(cs[i+1])
        firstlook = 0
        continue
      elif cs[destindex+1].__class__ == ops.BREAK_LOOP and \
          cs[preendindex].__class__ == ops.POP_TOP:
        op.label.op = cs[preendindex]
        cs.remove(cs[destindex])
        cs.remove(cs[destindex])
        cs.remove(cs[destindex])
        firstlook = 0
        continue

    firstlook = 1
    i = i+1

def removeconstjump(co):
  cs = co.co_code
  cc = co.co_consts

  i = 0
  while i < len(cs):
    op = cs[i]
    if op.__class__ in CONDJUMP and cs[i-1].__class__ == ops.LOAD_CONST:
      if (op.__class__ == ops.JUMP_IF_FALSE and cc[cs[i-1].arg]) or \
          (op.__class__ == ops.JUMP_IF_TRUE and not cc[cs[i-1].arg]):
        cs.remove(cs[i-1])
        cs.remove(cs[i-1])
        cs.remove(cs[i-1])
        cs.labels.remove(op.label)
        i = i-2
      else:
        cs.remove(cs[i-1])
        cs.remove(cs[i])
        newop = ops.JUMP_FORWARD()
        newop.label = ce.Label()
        newop.label.op = cs[cs.index(op.label.op)+1]
        cs[i-1] = newop
        cs.labels.remove(op.label)
        i = i-1
    i = i+1






More information about the Python-list mailing list