[pypy-svn] r21263 - in pypy/dist/pypy: jit jit/test translator/c

arigo at codespeak.net arigo at codespeak.net
Sun Dec 18 11:38:12 CET 2005


Author: arigo
Date: Sun Dec 18 11:38:09 2005
New Revision: 21263

Modified:
   pypy/dist/pypy/jit/llabstractinterp.py
   pypy/dist/pypy/jit/test/test_jit_tl.py
   pypy/dist/pypy/jit/test/test_llabstractinterp.py
   pypy/dist/pypy/jit/test/test_tl.py
   pypy/dist/pypy/jit/tl.py
   pypy/dist/pypy/translator/c/funcgen.py
Log:
Start of the hint()-driven approach.  So far, it works on the small example
(and only there :-).  The idea is to have a flag 'fixed' on LLRuntimeValues
that can be set to True a posteriori, to mean that the constant contained in
this LLRuntimeValue should be passed over to the next block as a constant --
instead of the default behavior which is to turn it into a variable.  The
'hint' operation sets some 'fixed' flags to True and raises a
RestartCompleting to restart the whole process.  An 'origin' attribute on
LLRuntimeValues tracks the multiple possible histories of a variable back to
the point(s) where it was constant.

I removed the code for constant propagation with automatic generalization
because the new a-posteriori-constantification logic is confusing enough to
get right without having to additionally worry about not breaking the
const_propagate policy.  I guess we'll have to re-insert this code when things
start to work again.

Added the hint() in the TL source code and implemented the 'hint' operation in
GenC.  The test_jit_tl is disabled because it explodes.



Modified: pypy/dist/pypy/jit/llabstractinterp.py
==============================================================================
--- pypy/dist/pypy/jit/llabstractinterp.py	(original)
+++ pypy/dist/pypy/jit/llabstractinterp.py	Sun Dec 18 11:38:09 2005
@@ -57,6 +57,8 @@
 
 
 class LLRuntimeValue(LLAbstractValue):
+    origin = None
+    fixed = False
 
     def __init__(self, orig_v):
         if isinstance(orig_v, Variable):
@@ -81,11 +83,10 @@
         return self.copy_v
 
     def getruntimevars(self, memo):
-        if (isinstance(self.copy_v, Variable) or
-            self not in memo.propagate_as_constants):
-            return [self.copy_v]
+        if memo.get_fixed_flags:
+            return [self.fixed]
         else:
-            return []   # we propagate this constant as a constant
+            return [self.copy_v]
 
     def maybe_get_constant(self):
         if isinstance(self.copy_v, Constant):
@@ -95,22 +96,17 @@
 
     def with_fresh_variables(self, memo):
         # don't use memo.seen here: shared variables must become distinct
-        if (isinstance(self.copy_v, Variable) or
-            self not in memo.propagate_as_constants):
-            return LLRuntimeValue(self.getconcretetype())
-        else:
-            return self   # we are allowed to propagate this constant
+        if memo.key is not None:
+            c = memo.key.next()
+            if c is not None:
+                return LLRuntimeValue(c)
+        result = LLRuntimeValue(self.getconcretetype())
+        result.origin = [self]
+        return result
 
     def match(self, other, memo):
-        if not isinstance(other, LLRuntimeValue):
-            return False
-        if isinstance(self.copy_v, Variable):
-            return True
-        if self.copy_v == other.copy_v:
-            memo.propagate_as_constants[other] = True   # exact match
-        else:
-            memo.exact_match = False
-        return True
+        memo.dependencies.append((self, other, self.fixed))
+        return isinstance(other, LLRuntimeValue)
 
 ll_no_return_value = LLRuntimeValue(const(None, lltype.Void))
 
@@ -334,7 +330,7 @@
         self.a_back = a_back
         self.args_a = args_a
         self.origblock = origblock
-        self.copyblock = None
+        self.copyblocks = {}
         assert len(args_a) == len(self.getlivevars())
 
     def key(self):
@@ -388,18 +384,6 @@
         else:
             return True
 
-    def resolveblock(self, newblock):
-        #print "RESOLVING BLOCK", newblock
-        if self.copyblock is not None:
-            # uncommon case: must patch the existing Block
-            assert len(self.copyblock.inputargs) == len(newblock.inputargs)
-            self.copyblock.inputargs  = newblock.inputargs
-            self.copyblock.operations = newblock.operations
-            self.copyblock.exitswitch = newblock.exitswitch
-            self.copyblock.recloseblock(*newblock.exits)
-        else:
-            self.copyblock = newblock
-
     def getbindings(self):
         return dict(zip(self.getlivevars(), self.args_a))
 
@@ -407,7 +391,6 @@
 class LLBlockState(LLState):
     """Entry state of a block, as a combination of LLAbstractValues
     for its input arguments."""
-    propagate_as_constants = {}
 
     def localkey(self):
         return (self.origblock,)
@@ -434,13 +417,13 @@
 # ____________________________________________________________
 
 class Policy(object):
-    def __init__(self, inlining=False, const_propagate=False,
-                       concrete_propagate=True):
+    def __init__(self, inlining=False,
+                       concrete_propagate=True, concrete_args=True):
         self.inlining = inlining
-        self.const_propagate = const_propagate
         self.concrete_propagate = concrete_propagate
+        self.concrete_args = concrete_args
 
-best_policy = Policy(inlining=True, const_propagate=True)
+best_policy = Policy(inlining=True, concrete_args=False)
 
 
 class LLAbstractInterp(object):
@@ -460,7 +443,10 @@
         args_a = []
         for i, v in enumerate(origgraph.getargs()):
             if i in arghints:
-                a = LLConcreteValue(arghints[i])
+                if self.policy.concrete_args:
+                    a = LLConcreteValue(arghints[i])
+                else:
+                    a = LLRuntimeValue(const(arghints[i]))
             else:
                 a = LLRuntimeValue(orig_v=v)
             args_a.append(a)
@@ -484,8 +470,7 @@
     def schedule(self, inputstate):
         #print "SCHEDULE", args_a, origblock
         state = self.schedule_getstate(inputstate)
-        memo = VarMemo(state.propagate_as_constants)
-        args_v = inputstate.getruntimevars(memo)
+        args_v = inputstate.getruntimevars(VarMemo())
         newlink = Link(args_v, None)
         self.pendingstates[newlink] = state
         return newlink
@@ -498,24 +483,41 @@
             memo = MatchMemo()
             if state.match(inputstate, memo):
                 # already matched
-                if memo.exact_match:
-                    return state    # exact match
-                if not self.policy.const_propagate:
-                    return state    # all constants will be generalized anyway
-                # partial match: in the old state, some constants need to
-                # be turned into variables.
-                inputstate.propagate_as_constants = memo.propagate_as_constants
-                # The generalized state replaces the existing one.
-                pendingstates[i] = inputstate
-                state.generalized_by = inputstate
-                return inputstate
+                must_restart = False
+                for statevar, inputvar, fixed in memo.dependencies:
+                    if fixed:
+                        must_restart |= self.hint_needs_constant(inputvar)
+                if must_restart:
+                    raise RestartCompleting
+                for statevar, inputvar, fixed in memo.dependencies:
+                    if statevar.origin is None:
+                        statevar.origin = []
+                    statevar.origin.append(inputvar)
+                return state
         else:
             # cache and return this new state
-            if self.policy.const_propagate:
-                inputstate.propagate_as_constants = ALL
             pendingstates.append(inputstate)
             return inputstate
 
+    def hint_needs_constant(self, a):
+        if a.maybe_get_constant() is not None:
+            return False
+        fix_me = [a]
+        while fix_me:
+            a = fix_me.pop()
+            if not a.origin:
+                raise Exception("hint() failed: cannot trace the variable %r "
+                                "back to a link where it was a constant" % (a,))
+            for a_origin in a.origin:
+                # 'a_origin' is a LLRuntimeValue attached to a saved state
+                assert isinstance(a_origin, LLRuntimeValue)
+                if not a_origin.fixed:
+                    print 'fixing:', a_origin
+                    a_origin.fixed = True
+                    if a_origin.maybe_get_constant() is None:
+                        fix_me.append(a_origin)
+        return True
+
 
 class GraphState(object):
     """Entry state of a graph."""
@@ -534,6 +536,9 @@
                                 self.copygraph.exceptblock.inputargs[1])]:
             if hasattr(orig_v, 'concretetype'):
                 copy_v.concretetype = orig_v.concretetype
+        # The 'args' attribute is needed by process_constant_input(),
+        # which looks for it on either a GraphState or a Link
+        self.args = inputstate.getruntimevars(VarMemo())
         self.a_return = None
         self.state = "before"
 
@@ -546,6 +551,15 @@
         if self.state == "after":
             return
         self.state = "during"
+        while True:
+            try:
+                self.try_to_complete()
+                break
+            except RestartCompleting:
+                print '--- restarting ---'
+                continue
+
+    def try_to_complete(self):
         graph = self.copygraph
         interp = self.interp
         pending = [self]
@@ -555,10 +569,22 @@
         while pending:
             next = pending.pop()
             state = interp.pendingstates[next]
-            if state.copyblock is None:
-                self.flowin(state)
-            next.settarget(state.copyblock)
-            for link in state.copyblock.exits:
+            fixed_flags = state.getruntimevars(VarMemo(get_fixed_flags=True))
+            key = []
+            for fixed, c in zip(fixed_flags, next.args):
+                if fixed:
+                    assert isinstance(c, Constant), (
+                        "unexpected Variable %r reaching a fixed input arg" %
+                        (c,))
+                    key.append(c)
+                else:
+                    key.append(None)
+            key = tuple(key)
+            if key not in state.copyblocks:
+                self.flowin(state, key)
+            block = state.copyblocks[key]
+            next.settarget(block)
+            for link in block.exits:
                 if link.target is None or link.target.operations != ():
                     if link not in seen:
                         seen[link] = True
@@ -575,8 +601,7 @@
                     else:
                         raise Exception("uh?")
 
-        if interp.policy.const_propagate:
-            self.compactify(seen)
+        remove_constant_inputargs(graph)
 
         # the graph should be complete now; sanity-check
         try:
@@ -592,33 +617,28 @@
         join_blocks(graph)
         self.state = "after"
 
-    def compactify(self, links):
-        # remove the parts of the graph that use constants that were later
-        # generalized
-        interp = self.interp
-        for link in links:
-            oldstate = interp.pendingstates[link]
-            if oldstate.generalized_by is not None:
-                newstate = oldstate.generalized_by
-                while newstate.generalized_by:
-                    newstate = newstate.generalized_by
-                # Patch oldstate.block to point to the new state,
-                # as in the flow object space
-                builder = BlockBuilder(self, oldstate)
-                memo = VarMemo(newstate.propagate_as_constants)
-                args_v = builder.runningstate.getruntimevars(memo)
-                oldlink = Link(args_v, newstate.copyblock)
-                oldblock = builder.buildblock(None, [oldlink])
-                oldstate.resolveblock(oldblock)
-
-    def flowin(self, state):
+    def flowin(self, state, key):
         # flow in the block
         assert isinstance(state, LLBlockState)
         origblock = state.origblock
         origposition = 0
-        builder = BlockBuilder(self.interp, state)
+        builder = BlockBuilder(self.interp, state, key)
         newexitswitch = None
+        # debugging print
+        arglist = []
+        for v1, v2, k in zip(state.getruntimevars(VarMemo()),
+                             builder.runningstate.getruntimevars(VarMemo()),
+                             key):
+            if k is None:
+                assert isinstance(v2, Variable)
+            else:
+                assert v2 == k
+            arglist.append('%s => %s' % (v1, v2))
         print
+        print '--> %s [%s]' % (origblock, ', '.join(arglist))
+        for op in origblock.operations:
+            print '\t\t', op
+        # end of debugging print
         try:
             if origblock.operations == ():
                 if state.a_back is None:
@@ -691,17 +711,16 @@
                 newlinks.append(newlink)
 
         newblock = builder.buildblock(newexitswitch, newlinks)
-        state.resolveblock(newblock)
+        state.copyblocks[key] = newblock
 
 
 class BlockBuilder(object):
 
-    def __init__(self, interp, initialstate):
+    def __init__(self, interp, initialstate, key):
         self.interp = interp
-        memo = VarMemo(initialstate.propagate_as_constants)
+        memo = VarMemo(iter(key))
         self.runningstate = initialstate.with_fresh_variables(memo)
-        memo = VarMemo(initialstate.propagate_as_constants)
-        self.newinputargs = self.runningstate.getruntimevars(memo)
+        self.newinputargs = self.runningstate.getruntimevars(VarMemo())
         # {Variables-of-origblock: a_value}
         self.bindings = self.runningstate.getbindings()
         self.residual_operations = []
@@ -743,9 +762,9 @@
             concretevalues.append(v.value)
             any_concrete = any_concrete or isinstance(a, LLConcreteValue)
         # can constant-fold
-        print 'fold:', constant_op, concretevalues
+        print 'fold:', constant_op.__name__, concretevalues
         concreteresult = constant_op(*concretevalues)
-        if any_concrete and self.policy.concrete_propagate:
+        if any_concrete and self.interp.policy.concrete_propagate:
             return LLConcreteValue(concreteresult)
         else:
             return LLRuntimeValue(const(concreteresult))
@@ -769,9 +788,21 @@
             if a_result is not None:
                 return a_result
         a_result = LLRuntimeValue(op.result)
+        if constant_op:
+            self.record_origin(a_result, args_a)
         self.residual(op.opname, args_a, a_result)
         return a_result
 
+    def record_origin(self, a_result, args_a):
+        origin = []
+        for a in args_a:
+            if a.maybe_get_constant() is not None:
+                continue
+            if not isinstance(a, LLRuntimeValue) or a.origin is None:
+                return
+            origin.extend(a.origin)
+        a_result.origin = origin
+
     # ____________________________________________________________
     # Operation handlers
 
@@ -814,6 +845,9 @@
     def op_int_ne(self, op, a1, a2):
         return self.residualize(op, [a1, a2], operator.ne)
 
+    op_char_eq = op_int_eq
+    op_char_ne = op_int_ne
+
     def op_cast_char_to_int(self, op, a):
         return self.residualize(op, [a], ord)
 
@@ -823,6 +857,23 @@
     def op_same_as(self, op, a):
         return a
 
+    def op_hint(self, op, a, a_hints):
+        c_hints = a_hints.maybe_get_constant()
+        assert c_hints is not None, "hint dict not constant"
+        hints = c_hints.value
+        if hints.get('concrete'):
+            # turn this 'a' into a concrete value
+            c = a.forcevarorconst(self)
+            if isinstance(c, Constant):
+                a = LLConcreteValue(c.value)
+            else:
+                # Oups! it's not a constant.  Try to trace it back to a
+                # constant that was turned into a variable by a link.
+                restart = self.interp.hint_needs_constant(a)
+                assert restart
+                raise RestartCompleting
+        return a
+
     def op_direct_call(self, op, *args_a):
         a_result = self.handle_call(op, *args_a)
         if a_result is None:
@@ -983,22 +1034,20 @@
     def __init__(self, link):
         self.link = link
 
+class RestartCompleting(Exception):
+    pass
+
 class MatchMemo(object):
     def __init__(self):
-        self.exact_match = True
-        self.propagate_as_constants = {}
+        self.dependencies = []
         self.self_alias = {}
         self.other_alias = {}
 
 class VarMemo(object):
-    def __init__(self, propagate_as_constants={}):
+    def __init__(self, key=None, get_fixed_flags=False):
         self.seen = {}
-        self.propagate_as_constants = propagate_as_constants
-
-class ALL(object):
-    def __contains__(self, other):
-        return True
-ALL = ALL()
+        self.key = key
+        self.get_fixed_flags = get_fixed_flags
 
 
 def live_variables(block, position):
@@ -1019,3 +1068,22 @@
         if op.result in used:
             result.append(op.result)
     return result
+
+def remove_constant_inputargs(graph):
+    # for simplicity, the logic in GraphState produces graphs that can
+    # pass constants from one block to the next explicitly, via a
+    # link.args -> block.inputargs.  Remove them now.
+    for link in graph.iterlinks():
+        i = 0
+        for v in link.target.inputargs:
+            if isinstance(v, Constant):
+                del link.args[i]
+            else:
+                i += 1
+    for block in graph.iterblocks():
+        i = 0
+        for v in block.inputargs[:]:
+            if isinstance(v, Constant):
+                del block.inputargs[i]
+            else:
+                i += 1

Modified: pypy/dist/pypy/jit/test/test_jit_tl.py
==============================================================================
--- pypy/dist/pypy/jit/test/test_jit_tl.py	(original)
+++ pypy/dist/pypy/jit/test/test_jit_tl.py	Sun Dec 18 11:38:09 2005
@@ -8,7 +8,7 @@
 from pypy.rpython.llinterp import LLInterpreter
 #from pypy.translator.backendopt import inline
 
-#py.test.skip("in-progress")
+py.test.skip("in-progress")
 
 def setup_module(mod):
     t = TranslationContext()
@@ -32,7 +32,7 @@
 
     assert result1 == result2
 
-    #interp.graphs[0].show()
+    interp.graphs[0].show()
 
 
 def run_jit(code):
@@ -90,4 +90,29 @@
             PUSH 5
             ADD
             RETURN
-   ''')
+    ''')
+
+def test_factorial():
+    run_jit('''
+            PUSH 1   #  accumulator
+            PUSH 7   #  N
+
+        start:
+            PICK 0
+            PUSH 1
+            LE
+            BR_COND exit
+
+            SWAP
+            PICK 1
+            MUL
+            SWAP
+            PUSH 1
+            SUB
+            PUSH 1
+            BR_COND start
+
+        exit:
+            POP
+            RETURN
+    ''')

Modified: pypy/dist/pypy/jit/test/test_llabstractinterp.py
==============================================================================
--- pypy/dist/pypy/jit/test/test_llabstractinterp.py	(original)
+++ pypy/dist/pypy/jit/test/test_llabstractinterp.py	Sun Dec 18 11:38:09 2005
@@ -62,8 +62,7 @@
     return graph2, insns
 
 P_INLINE = Policy(inlining=True)
-P_CONST_INLINE = Policy(inlining=True, const_propagate=True)
-P_HINT_DRIVEN = Policy(inlining=True, concrete_propagate=False)
+P_HINT_DRIVEN = Policy(inlining=True, concrete_args=False)
 
 
 def test_simple():
@@ -316,13 +315,13 @@
     graph2, insns = abstrinterp(ll1, [3, 4, 5], [1, 2], policy=P_INLINE)
     assert insns == {'int_add': 1}
 
-def test_const_propagate():
-    def ll_add(x, y):
-        return x + y
-    def ll1(x):
-        return ll_add(x, 42)
-    graph2, insns = abstrinterp(ll1, [3], [0], policy=P_CONST_INLINE)
-    assert insns == {}
+##def test_const_propagate():
+##    def ll_add(x, y):
+##        return x + y
+##    def ll1(x):
+##        return ll_add(x, 42)
+##    graph2, insns = abstrinterp(ll1, [3], [0], policy=P_CONST_INLINE)
+##    assert insns == {}
 
 def test_dont_unroll_loop():
     def ll_factorial(n):
@@ -332,12 +331,12 @@
             i += 1
             result *= i
         return result
-    graph2, insns = abstrinterp(ll_factorial, [7], [], policy=P_CONST_INLINE)
+    graph2, insns = abstrinterp(ll_factorial, [7], [], policy=P_INLINE)
     assert insns == {'int_lt': 1, 'int_add': 1, 'int_mul': 1}
 
-def INPROGRESS_test_hint():
+def test_hint():
     from pypy.rpython.objectmodel import hint
-    A = lltype.GcArray(lltype.Char)
+    A = lltype.GcArray(lltype.Char, hints={'immutable': True})
     def ll_interp(code):
         accum = 0
         pc = 0
@@ -356,6 +355,6 @@
     bytecode[2] = 'A'
     bytecode[3] = 'B'
     bytecode[4] = 'A'
-    graph2, insns = abstrinterp(ll_interp, [bytecode], [],
+    graph2, insns = abstrinterp(ll_interp, [bytecode], [0],
                                 policy=P_HINT_DRIVEN)
     assert insns == {'int_add': 4, 'int_lt': 1}

Modified: pypy/dist/pypy/jit/test/test_tl.py
==============================================================================
--- pypy/dist/pypy/jit/test/test_tl.py	(original)
+++ pypy/dist/pypy/jit/test/test_tl.py	Sun Dec 18 11:38:09 2005
@@ -163,3 +163,30 @@
     assert code == list2bytecode([PUSH,1, CALL,5, PUSH,3, CALL,4, RETURN,
                                   PUSH,2, RETURN,
                                   PUSH,4, PUSH,5, ADD, RETURN])
+
+def test_factorial():
+    code = compile('''
+            PUSH 1   #  accumulator
+            PUSH 7   #  N
+
+        start:
+            PICK 0
+            PUSH 1
+            LE
+            BR_COND exit
+
+            SWAP
+            PICK 1
+            MUL
+            SWAP
+            PUSH 1
+            SUB
+            PUSH 1
+            BR_COND start
+
+        exit:
+            POP
+            RETURN
+    ''')
+    res = interp(code)
+    assert res == 5040

Modified: pypy/dist/pypy/jit/tl.py
==============================================================================
--- pypy/dist/pypy/jit/tl.py	(original)
+++ pypy/dist/pypy/jit/tl.py	Sun Dec 18 11:38:09 2005
@@ -3,6 +3,7 @@
 import py
 from tlopcode import *
 import tlopcode
+from pypy.rpython.objectmodel import hint
 
 def char2int(c):
     t = ord(c)
@@ -19,6 +20,7 @@
 
     while pc < code_len:
         opcode = ord(code[pc])
+        opcode = hint(opcode, concrete=True)
         pc += 1
 
         if opcode == PUSH:

Modified: pypy/dist/pypy/translator/c/funcgen.py
==============================================================================
--- pypy/dist/pypy/translator/c/funcgen.py	(original)
+++ pypy/dist/pypy/translator/c/funcgen.py	Sun Dec 18 11:38:09 2005
@@ -584,6 +584,10 @@
                 result.append(self.pyobj_incref(op.result))
         return '\t'.join(result)
 
+    def OP_HINT(self, op, err):
+        hints = op.args[1].value
+        return '%s\t/* hint: %r */' % (self.OP_SAME_AS(op, err), hints)
+
     def OP_KEEPALIVE(self, op, err): # xxx what should be the sematics consequences of this
         return "/* kept alive: %s */ ;" % self.expr(op.args[0], special_case_void=False)
 



More information about the Pypy-commit mailing list