[pypy-svn] r75300 - in pypy/branch/multijit-4/pypy/jit: codewriter codewriter/test metainterp metainterp/test

arigo at codespeak.net arigo at codespeak.net
Sat Jun 12 09:32:19 CEST 2010


Author: arigo
Date: Sat Jun 12 09:32:17 2010
New Revision: 75300

Added:
   pypy/branch/multijit-4/pypy/jit/metainterp/jitdriver.py
      - copied unchanged from r75299, pypy/branch/multijit-3/pypy/jit/metainterp/jitdriver.py
Modified:
   pypy/branch/multijit-4/pypy/jit/codewriter/call.py
   pypy/branch/multijit-4/pypy/jit/codewriter/codewriter.py
   pypy/branch/multijit-4/pypy/jit/codewriter/jitcode.py
   pypy/branch/multijit-4/pypy/jit/codewriter/jtransform.py
   pypy/branch/multijit-4/pypy/jit/codewriter/test/test_call.py
   pypy/branch/multijit-4/pypy/jit/codewriter/test/test_codewriter.py
   pypy/branch/multijit-4/pypy/jit/codewriter/test/test_flatten.py
   pypy/branch/multijit-4/pypy/jit/metainterp/blackhole.py
   pypy/branch/multijit-4/pypy/jit/metainterp/compile.py
   pypy/branch/multijit-4/pypy/jit/metainterp/history.py
   pypy/branch/multijit-4/pypy/jit/metainterp/pyjitpl.py
   pypy/branch/multijit-4/pypy/jit/metainterp/resume.py
   pypy/branch/multijit-4/pypy/jit/metainterp/test/test_basic.py
   pypy/branch/multijit-4/pypy/jit/metainterp/warmspot.py
   pypy/branch/multijit-4/pypy/jit/metainterp/warmstate.py
Log:
Merge branch/multijit-3.


Modified: pypy/branch/multijit-4/pypy/jit/codewriter/call.py
==============================================================================
--- pypy/branch/multijit-4/pypy/jit/codewriter/call.py	(original)
+++ pypy/branch/multijit-4/pypy/jit/codewriter/call.py	Sat Jun 12 09:32:17 2010
@@ -16,21 +16,22 @@
 
 class CallControl(object):
     virtualref_info = None     # optionally set from outside
-    virtualizable_info = None  # optionally set from outside
-    portal_runner_ptr = None   # optionally set from outside
 
-    def __init__(self, cpu=None, portal_graph=None):
+    def __init__(self, cpu=None, jitdrivers_sd=[]):
+        assert isinstance(jitdrivers_sd, list)   # debugging
         self.cpu = cpu
-        self.portal_graph = portal_graph
+        self.jitdrivers_sd = jitdrivers_sd
         self.jitcodes = {}             # map {graph: jitcode}
         self.unfinished_graphs = []    # list of graphs with pending jitcodes
-        self.jitdriver = None
         if hasattr(cpu, 'rtyper'):     # for tests
             self.rtyper = cpu.rtyper
             translator = self.rtyper.annotator.translator
             self.raise_analyzer = RaiseAnalyzer(translator)
             self.readwrite_analyzer = ReadWriteAnalyzer(translator)
             self.virtualizable_analyzer = VirtualizableAnalyzer(translator)
+        #
+        for index, jd in enumerate(jitdrivers_sd):
+            jd.index = index
 
     def find_all_graphs(self, policy):
         try:
@@ -41,8 +42,8 @@
         def is_candidate(graph):
             return policy.look_inside_graph(graph)
 
-        assert self.portal_graph is not None
-        todo = [self.portal_graph]
+        assert len(self.jitdrivers_sd) > 0
+        todo = [jd.portal_graph for jd in self.jitdrivers_sd]
         if hasattr(self, 'rtyper'):
             for oopspec_name, ll_args, ll_res in support.inline_calls_to:
                 c_func, _ = support.builtin_func_for_spec(self.rtyper,
@@ -122,7 +123,7 @@
     def guess_call_kind(self, op, is_candidate=None):
         if op.opname == 'direct_call':
             funcptr = op.args[0].value
-            if funcptr is self.portal_runner_ptr:
+            if self.jitdriver_sd_from_portal_runner_ptr(funcptr) is not None:
                 return 'recursive'
             funcobj = get_funcobj(funcptr)
             if getattr(funcobj, 'graph', None) is None:
@@ -143,6 +144,11 @@
         # used only after find_all_graphs()
         return graph in self.candidate_graphs
 
+    def grab_initial_jitcodes(self):
+        for jd in self.jitdrivers_sd:
+            jd.mainjitcode = self.get_jitcode(jd.portal_graph)
+            jd.mainjitcode.is_portal = True
+
     def enum_pending_graphs(self):
         while self.unfinished_graphs:
             graph = self.unfinished_graphs.pop()
@@ -241,12 +247,26 @@
         return (effectinfo is None or
                 effectinfo.extraeffect >= EffectInfo.EF_CAN_RAISE)
 
-    def found_jitdriver(self, jitdriver):
-        if self.jitdriver is None:
-            self.jitdriver = jitdriver
-        else:
-            assert self.jitdriver is jitdriver
+    def jitdriver_sd_from_portal_graph(self, graph):
+        for jd in self.jitdrivers_sd:
+            if jd.portal_graph is graph:
+                return jd
+        return None
 
-    def getjitdriver(self):
-        assert self.jitdriver is not None, "order dependency issue?"
-        return self.jitdriver
+    def jitdriver_sd_from_portal_runner_ptr(self, funcptr):
+        for jd in self.jitdrivers_sd:
+            if funcptr is jd.portal_runner_ptr:
+                return jd
+        return None
+
+    def get_vinfo(self, VTYPEPTR):
+        seen = set()
+        for jd in self.jitdrivers_sd:
+            if jd.virtualizable_info is not None:
+                if jd.virtualizable_info.is_vtypeptr(VTYPEPTR):
+                    seen.add(jd.virtualizable_info)
+        if seen:
+            assert len(seen) == 1
+            return seen.pop()
+        else:
+            return None

Modified: pypy/branch/multijit-4/pypy/jit/codewriter/codewriter.py
==============================================================================
--- pypy/branch/multijit-4/pypy/jit/codewriter/codewriter.py	(original)
+++ pypy/branch/multijit-4/pypy/jit/codewriter/codewriter.py	Sat Jun 12 09:32:17 2010
@@ -14,29 +14,30 @@
 class CodeWriter(object):
     callcontrol = None    # for tests
 
-    def __init__(self, cpu=None, maingraph=None):
+    def __init__(self, cpu=None, jitdrivers_sd=[]):
         self.cpu = cpu
         self.assembler = Assembler()
-        self.portal_graph = maingraph
-        self.callcontrol = CallControl(cpu, maingraph)
+        self.callcontrol = CallControl(cpu, jitdrivers_sd)
+        self._seen_files = set()
 
     def transform_func_to_jitcode(self, func, values, type_system='lltype'):
         """For testing."""
         rtyper = support.annotate(func, values, type_system=type_system)
         graph = rtyper.annotator.translator.graphs[0]
         jitcode = JitCode("test")
-        self.transform_graph_to_jitcode(graph, jitcode, True, True)
+        self.transform_graph_to_jitcode(graph, jitcode, True)
         return jitcode
 
-    def transform_graph_to_jitcode(self, graph, jitcode, portal, verbose):
+    def transform_graph_to_jitcode(self, graph, jitcode, verbose):
         """Transform a graph into a JitCode containing the same bytecode
         in a different format.
         """
+        portal_jd = self.callcontrol.jitdriver_sd_from_portal_graph(graph)
         graph = copygraph(graph, shallowvars=True)
         #
         # step 1: mangle the graph so that it contains the final instructions
         # that we want in the JitCode, but still as a control flow graph
-        transform_graph(graph, self.cpu, self.callcontrol, portal)
+        transform_graph(graph, self.cpu, self.callcontrol, portal_jd)
         #
         # step 2: perform register allocation on it
         regallocs = {}
@@ -59,16 +60,14 @@
         self.assembler.assemble(ssarepr, jitcode)
         #
         # print the resulting assembler
-        self.print_ssa_repr(ssarepr, portal, verbose)
+        self.print_ssa_repr(ssarepr, portal_jd, verbose)
 
     def make_jitcodes(self, verbose=False):
         log.info("making JitCodes...")
-        maingraph = self.portal_graph
-        self.mainjitcode = self.callcontrol.get_jitcode(maingraph)
+        self.callcontrol.grab_initial_jitcodes()
         count = 0
         for graph, jitcode in self.callcontrol.enum_pending_graphs():
-            self.transform_graph_to_jitcode(graph, jitcode,
-                                            graph is maingraph, verbose)
+            self.transform_graph_to_jitcode(graph, jitcode, verbose)
             count += 1
             if not count % 500:
                 log.info("Produced %d jitcodes" % count)
@@ -76,33 +75,35 @@
         log.info("there are %d JitCode instances." % count)
 
     def setup_vrefinfo(self, vrefinfo):
+        # must be called at most once
+        assert self.callcontrol.virtualref_info is None
         self.callcontrol.virtualref_info = vrefinfo
 
-    def setup_virtualizable_info(self, vinfo):
-        self.callcontrol.virtualizable_info = vinfo
-
-    def setup_portal_runner_ptr(self, portal_runner_ptr):
-        self.callcontrol.portal_runner_ptr = portal_runner_ptr
+    def setup_jitdriver(self, jitdriver_sd):
+        # Must be called once per jitdriver.  Usually jitdriver_sd is an
+        # instance of pypy.jit.metainterp.jitdriver.JitDriverStaticData.
+        self.callcontrol.jitdrivers_sd.append(jitdriver_sd)
 
     def find_all_graphs(self, policy):
         return self.callcontrol.find_all_graphs(policy)
 
-    def print_ssa_repr(self, ssarepr, portal, verbose):
+    def print_ssa_repr(self, ssarepr, portal_jitdriver, verbose):
         if verbose:
             print '%s:' % (ssarepr.name,)
             print format_assembler(ssarepr)
         else:
             dir = udir.ensure("jitcodes", dir=1)
-            if portal:
-                name = "00_portal_runner"
+            if portal_jitdriver:
+                name = "%02d_portal_runner" % (portal_jitdriver.index,)
             elif ssarepr.name and ssarepr.name != '?':
                 name = ssarepr.name
             else:
                 name = 'unnamed' % id(ssarepr)
             i = 1
             extra = ''
-            while dir.join(name+extra).check(exists=1):
+            while name+extra in self._seen_files:
                 i += 1
                 extra = '.%d' % i
+            self._seen_files.add(name+extra)
             dir.join(name+extra).write(format_assembler(ssarepr))
             log.dot()

Modified: pypy/branch/multijit-4/pypy/jit/codewriter/jitcode.py
==============================================================================
--- pypy/branch/multijit-4/pypy/jit/codewriter/jitcode.py	(original)
+++ pypy/branch/multijit-4/pypy/jit/codewriter/jitcode.py	Sat Jun 12 09:32:17 2010
@@ -13,6 +13,7 @@
         self.name = name
         self.fnaddr = fnaddr
         self.calldescr = calldescr
+        self.is_portal = False
         self._called_from = called_from   # debugging
         self._ssarepr     = None          # debugging
 
@@ -24,11 +25,11 @@
         self.constants_i = constants_i or self._empty_i
         self.constants_r = constants_r or self._empty_r
         self.constants_f = constants_f or self._empty_f
-        # encode the three num_regs into a single integer
+        # encode the three num_regs into a single char each
         assert num_regs_i < 256 and num_regs_r < 256 and num_regs_f < 256
-        self.num_regs_encoded = ((num_regs_i << 16) |
-                                 (num_regs_f << 8) |
-                                 (num_regs_r << 0))
+        self.c_num_regs_i = chr(num_regs_i)
+        self.c_num_regs_r = chr(num_regs_r)
+        self.c_num_regs_f = chr(num_regs_f)
         self.liveness = make_liveness_cache(liveness)
         self._startpoints = startpoints   # debugging
         self._alllabels = alllabels       # debugging
@@ -37,13 +38,13 @@
         return heaptracker.adr2int(self.fnaddr)
 
     def num_regs_i(self):
-        return self.num_regs_encoded >> 16
-
-    def num_regs_f(self):
-        return (self.num_regs_encoded >> 8) & 0xFF
+        return ord(self.c_num_regs_i)
 
     def num_regs_r(self):
-        return self.num_regs_encoded & 0xFF
+        return ord(self.c_num_regs_r)
+
+    def num_regs_f(self):
+        return ord(self.c_num_regs_f)
 
     def has_liveness_info(self, pc):
         return pc in self.liveness

Modified: pypy/branch/multijit-4/pypy/jit/codewriter/jtransform.py
==============================================================================
--- pypy/branch/multijit-4/pypy/jit/codewriter/jtransform.py	(original)
+++ pypy/branch/multijit-4/pypy/jit/codewriter/jtransform.py	Sat Jun 12 09:32:17 2010
@@ -13,23 +13,23 @@
 from pypy.translator.simplify import get_funcobj
 
 
-def transform_graph(graph, cpu=None, callcontrol=None, portal=True):
+def transform_graph(graph, cpu=None, callcontrol=None, portal_jd=None):
     """Transform a control flow graph to make it suitable for
     being flattened in a JitCode.
     """
-    t = Transformer(cpu, callcontrol)
-    t.transform(graph, portal)
+    t = Transformer(cpu, callcontrol, portal_jd)
+    t.transform(graph)
 
 
 class Transformer(object):
 
-    def __init__(self, cpu=None, callcontrol=None):
+    def __init__(self, cpu=None, callcontrol=None, portal_jd=None):
         self.cpu = cpu
         self.callcontrol = callcontrol
+        self.portal_jd = portal_jd   # non-None only for the portal graph(s)
 
-    def transform(self, graph, portal):
+    def transform(self, graph):
         self.graph = graph
-        self.portal = portal
         for block in list(graph.iterblocks()):
             self.optimize_block(block)
 
@@ -325,10 +325,12 @@
         return op1
 
     def handle_recursive_call(self, op):
-        ops = self.promote_greens(op.args[1:])
-        targetgraph = self.callcontrol.portal_graph
-        num_green_args = len(self.callcontrol.getjitdriver().greens)
-        args = (self.make_three_lists(op.args[1:1+num_green_args]) +
+        jitdriver_sd = self.callcontrol.jitdriver_sd_from_portal_runner_ptr(
+            op.args[0])
+        ops = self.promote_greens(op.args[1:], jitdriver_sd.jitdriver)
+        num_green_args = len(jitdriver_sd.jitdriver.greens)
+        args = ([Constant(jitdriver_sd.index, lltype.Signed)] +
+                self.make_three_lists(op.args[1:1+num_green_args]) +
                 self.make_three_lists(op.args[1+num_green_args:]))
         kind = getkind(op.result.concretetype)[0]
         op0 = SpaceOperation('recursive_call_%s' % kind, args, op.result)
@@ -483,14 +485,14 @@
         # check for virtualizable
         try:
             if self.is_virtualizable_getset(op):
-                descr = self.get_virtualizable_field_descr(op.args[1].value)
+                descr = self.get_virtualizable_field_descr(op)
                 kind = getkind(RESULT)[0]
                 return [SpaceOperation('-live-', [], None),
                         SpaceOperation('getfield_vable_%s' % kind,
                                        [v_inst, descr], op.result)]
-        except VirtualizableArrayField:
+        except VirtualizableArrayField, e:
             # xxx hack hack hack
-            vinfo = self.callcontrol.virtualizable_info
+            vinfo = e.args[1]
             arrayindex = vinfo.array_field_counter[op.args[1].value]
             arrayfielddescr = vinfo.array_field_descrs[arrayindex]
             arraydescr = vinfo.array_descrs[arrayindex]
@@ -527,7 +529,7 @@
             return
         # check for virtualizable
         if self.is_virtualizable_getset(op):
-            descr = self.get_virtualizable_field_descr(op.args[1].value)
+            descr = self.get_virtualizable_field_descr(op)
             kind = getkind(RESULT)[0]
             return [SpaceOperation('-live-', [], None),
                     SpaceOperation('setfield_vable_%s' % kind,
@@ -544,21 +546,23 @@
         return (op.args[1].value == 'typeptr' and
                 op.args[0].concretetype.TO._hints.get('typeptr'))
 
+    def get_vinfo(self, v_virtualizable):
+        if self.callcontrol is None:      # for tests
+            return None
+        return self.callcontrol.get_vinfo(v_virtualizable.concretetype)
+
     def is_virtualizable_getset(self, op):
         # every access of an object of exactly the type VTYPEPTR is
         # likely to be a virtualizable access, but we still have to
         # check it in pyjitpl.py.
-        try:
-            vinfo = self.callcontrol.virtualizable_info
-        except AttributeError:
-            return False
-        if vinfo is None or not vinfo.is_vtypeptr(op.args[0].concretetype):
+        vinfo = self.get_vinfo(op.args[0])
+        if vinfo is None:
             return False
         res = False
         if op.args[1].value in vinfo.static_field_to_extra_box:
             res = True
         if op.args[1].value in vinfo.array_fields:
-            res = VirtualizableArrayField(self.graph)
+            res = VirtualizableArrayField(self.graph, vinfo)
 
         if res:
             flags = self.vable_flags[op.args[0]]
@@ -568,8 +572,9 @@
             raise res
         return res
 
-    def get_virtualizable_field_descr(self, fieldname):
-        vinfo = self.callcontrol.virtualizable_info
+    def get_virtualizable_field_descr(self, op):
+        fieldname = op.args[1].value
+        vinfo = self.get_vinfo(op.args[0])
         index = vinfo.static_field_to_extra_box[fieldname]
         return vinfo.static_field_descrs[index]
 
@@ -750,9 +755,10 @@
             return Constant(value, lltype.Bool)
         return op
 
-    def promote_greens(self, args):
+    def promote_greens(self, args, jitdriver):
         ops = []
-        num_green_args = len(self.callcontrol.getjitdriver().greens)
+        num_green_args = len(jitdriver.greens)
+        assert len(args) == num_green_args + len(jitdriver.reds)
         for v in args[:num_green_args]:
             if isinstance(v, Variable) and v.concretetype is not lltype.Void:
                 kind = getkind(v.concretetype)
@@ -762,20 +768,24 @@
         return ops
 
     def rewrite_op_jit_marker(self, op):
-        self.callcontrol.found_jitdriver(op.args[1].value)
         key = op.args[0].value
-        return getattr(self, 'handle_jit_marker__%s' % key)(op)
+        jitdriver = op.args[1].value
+        return getattr(self, 'handle_jit_marker__%s' % key)(op, jitdriver)
 
-    def handle_jit_marker__jit_merge_point(self, op):
-        assert self.portal, "jit_merge_point in non-main graph!"
-        ops = self.promote_greens(op.args[2:])
-        num_green_args = len(self.callcontrol.getjitdriver().greens)
-        args = (self.make_three_lists(op.args[2:2+num_green_args]) +
+    def handle_jit_marker__jit_merge_point(self, op, jitdriver):
+        assert self.portal_jd is not None, (
+            "'jit_merge_point' in non-portal graph!")
+        assert jitdriver is self.portal_jd.jitdriver, (
+            "general mix-up of jitdrivers?")
+        ops = self.promote_greens(op.args[2:], jitdriver)
+        num_green_args = len(jitdriver.greens)
+        args = ([Constant(self.portal_jd.index, lltype.Signed)] +
+                self.make_three_lists(op.args[2:2+num_green_args]) +
                 self.make_three_lists(op.args[2+num_green_args:]))
         op1 = SpaceOperation('jit_merge_point', args, None)
         return ops + [op1]
 
-    def handle_jit_marker__can_enter_jit(self, op):
+    def handle_jit_marker__can_enter_jit(self, op, jitdriver):
         return SpaceOperation('can_enter_jit', [], None)
 
     def rewrite_op_debug_assert(self, op):
@@ -974,9 +984,8 @@
 
     def rewrite_op_jit_force_virtualizable(self, op):
         # this one is for virtualizables
-        vinfo = self.callcontrol.virtualizable_info
+        vinfo = self.get_vinfo(op.args[0])
         assert vinfo is not None
-        assert vinfo.is_vtypeptr(op.args[0].concretetype)
         self.vable_flags[op.args[0]] = op.args[2].value
         return []
 

Modified: pypy/branch/multijit-4/pypy/jit/codewriter/test/test_call.py
==============================================================================
--- pypy/branch/multijit-4/pypy/jit/codewriter/test/test_call.py	(original)
+++ pypy/branch/multijit-4/pypy/jit/codewriter/test/test_call.py	Sat Jun 12 09:32:17 2010
@@ -52,13 +52,19 @@
 
 # ____________________________________________________________
 
+class FakeJitDriverSD:
+    def __init__(self, portal_graph):
+        self.portal_graph = portal_graph
+        self.portal_runner_ptr = "???"
+
 def test_find_all_graphs():
     def g(x):
         return x + 2
     def f(x):
         return g(x) + 1
     rtyper = support.annotate(f, [7])
-    cc = CallControl(portal_graph=rtyper.annotator.translator.graphs[0])
+    jitdriver_sd = FakeJitDriverSD(rtyper.annotator.translator.graphs[0])
+    cc = CallControl(jitdrivers_sd=[jitdriver_sd])
     res = cc.find_all_graphs(FakePolicy())
     funcs = set([graph.func for graph in res])
     assert funcs == set([f, g])
@@ -69,7 +75,8 @@
     def f(x):
         return g(x) + 1
     rtyper = support.annotate(f, [7])
-    cc = CallControl(portal_graph=rtyper.annotator.translator.graphs[0])
+    jitdriver_sd = FakeJitDriverSD(rtyper.annotator.translator.graphs[0])
+    cc = CallControl(jitdrivers_sd=[jitdriver_sd])
     class CustomFakePolicy:
         def look_inside_graph(self, graph):
             assert graph.name == 'g'
@@ -83,10 +90,11 @@
 def test_guess_call_kind_and_calls_from_graphs():
     class portal_runner_obj:
         graph = object()
+    class FakeJitDriverSD:
+        portal_runner_ptr = portal_runner_obj
     g = object()
     g1 = object()
-    cc = CallControl()
-    cc.portal_runner_ptr = portal_runner_obj
+    cc = CallControl(jitdrivers_sd=[FakeJitDriverSD()])
     cc.candidate_graphs = [g, g1]
 
     op = SpaceOperation('direct_call', [Constant(portal_runner_obj)],

Modified: pypy/branch/multijit-4/pypy/jit/codewriter/test/test_codewriter.py
==============================================================================
--- pypy/branch/multijit-4/pypy/jit/codewriter/test/test_codewriter.py	(original)
+++ pypy/branch/multijit-4/pypy/jit/codewriter/test/test_codewriter.py	Sat Jun 12 09:32:17 2010
@@ -35,6 +35,12 @@
     def look_inside_graph(self, graph):
         return graph.name != 'dont_look'
 
+class FakeJitDriverSD:
+    def __init__(self, portal_graph):
+        self.portal_graph = portal_graph
+        self.portal_runner_ptr = "???"
+        self.virtualizable_info = None
+
 
 def test_loop():
     def f(a, b):
@@ -70,11 +76,11 @@
     def fff(a, b):
         return ggg(b) - ggg(a)
     rtyper = support.annotate(fff, [35, 42])
-    maingraph = rtyper.annotator.translator.graphs[0]
-    cw = CodeWriter(FakeCPU(rtyper), maingraph)
+    jitdriver_sd = FakeJitDriverSD(rtyper.annotator.translator.graphs[0])
+    cw = CodeWriter(FakeCPU(rtyper), [jitdriver_sd])
     cw.find_all_graphs(FakePolicy())
     cw.make_jitcodes(verbose=True)
-    jitcode = cw.mainjitcode
+    jitcode = jitdriver_sd.mainjitcode
     print jitcode.dump()
     [jitcode2] = cw.assembler.descrs
     print jitcode2.dump()
@@ -117,7 +123,7 @@
         return x().id + y().id + dont_look(n)
     rtyper = support.annotate(f, [35])
     maingraph = rtyper.annotator.translator.graphs[0]
-    cw = CodeWriter(FakeCPU(rtyper), maingraph)
+    cw = CodeWriter(FakeCPU(rtyper), [FakeJitDriverSD(maingraph)])
     cw.find_all_graphs(FakePolicy())
     cw.make_jitcodes(verbose=True)
     #
@@ -144,10 +150,10 @@
     def f(n):
         return abs(n)
     rtyper = support.annotate(f, [35])
-    maingraph = rtyper.annotator.translator.graphs[0]
-    cw = CodeWriter(FakeCPU(rtyper), maingraph)
+    jitdriver_sd = FakeJitDriverSD(rtyper.annotator.translator.graphs[0])
+    cw = CodeWriter(FakeCPU(rtyper), [jitdriver_sd])
     cw.find_all_graphs(FakePolicy())
     cw.make_jitcodes(verbose=True)
     #
-    s = cw.mainjitcode.dump()
+    s = jitdriver_sd.mainjitcode.dump()
     assert "inline_call_ir_i <JitCode '_ll_1_int_abs__Signed'>" in s

Modified: pypy/branch/multijit-4/pypy/jit/codewriter/test/test_flatten.py
==============================================================================
--- pypy/branch/multijit-4/pypy/jit/codewriter/test/test_flatten.py	(original)
+++ pypy/branch/multijit-4/pypy/jit/codewriter/test/test_flatten.py	Sat Jun 12 09:32:17 2010
@@ -68,11 +68,8 @@
         return FakeDescr()
     def calldescr_canraise(self, calldescr):
         return calldescr is not self._descr_cannot_raise
-    def found_jitdriver(self, jitdriver):
-        assert isinstance(jitdriver, JitDriver)
-        self.jitdriver = jitdriver
-    def getjitdriver(self):
-        return self.jitdriver
+    def get_vinfo(self, VTYPEPTR):
+        return None
 
 class FakeCallControlWithVRefInfo:
     class virtualref_info:
@@ -118,13 +115,13 @@
         return self.rtyper.annotator.translator.graphs
 
     def encoding_test(self, func, args, expected,
-                      transform=False, liveness=False, cc=None):
+                      transform=False, liveness=False, cc=None, jd=None):
         graphs = self.make_graphs(func, args)
         #graphs[0].show()
         if transform:
             from pypy.jit.codewriter.jtransform import transform_graph
             cc = cc or FakeCallControl()
-            transform_graph(graphs[0], FakeCPU(self.rtyper), cc)
+            transform_graph(graphs[0], FakeCPU(self.rtyper), cc, jd)
         ssarepr = flatten_graph(graphs[0], fake_regallocs(),
                                 _include_all_exc_links=not transform)
         if liveness:
@@ -584,13 +581,16 @@
         def f(x, y):
             myjitdriver.jit_merge_point(x=x, y=y)
             myjitdriver.can_enter_jit(x=y, y=x)
+        class FakeJitDriverSD:
+            jitdriver = myjitdriver
+            index = 27
         self.encoding_test(f, [4, 5], """
             -live- %i0, %i1
             int_guard_value %i0
-            jit_merge_point I[%i0], R[], F[], I[%i1], R[], F[]
+            jit_merge_point $27, I[%i0], R[], F[], I[%i1], R[], F[]
             can_enter_jit
             void_return
-        """, transform=True, liveness=True)
+        """, transform=True, liveness=True, jd=FakeJitDriverSD())
 
     def test_keepalive(self):
         S = lltype.GcStruct('S')

Modified: pypy/branch/multijit-4/pypy/jit/metainterp/blackhole.py
==============================================================================
--- pypy/branch/multijit-4/pypy/jit/metainterp/blackhole.py	(original)
+++ pypy/branch/multijit-4/pypy/jit/metainterp/blackhole.py	Sat Jun 12 09:32:17 2010
@@ -315,27 +315,12 @@
     def get_tmpreg_f(self):
         return self.tmpreg_f
 
-    def final_result_i(self):
-        assert self._return_type == 'i'
-        return self.get_tmpreg_i()
-
-    def final_result_r(self):
-        assert self._return_type == 'r'
-        return self.get_tmpreg_r()
-
-    def final_result_f(self):
-        assert self._return_type == 'f'
-        return self.get_tmpreg_f()
-
-    def final_result_v(self):
-        assert self._return_type == 'v'
-
     def _final_result_anytype(self):
         "NOT_RPYTHON"
-        if self._return_type == 'i': return self.final_result_i()
-        if self._return_type == 'r': return self.final_result_r()
-        if self._return_type == 'f': return self.final_result_f()
-        if self._return_type == 'v': return self.final_result_v()
+        if self._return_type == 'i': return self.get_tmpreg_i()
+        if self._return_type == 'r': return self.get_tmpreg_r()
+        if self._return_type == 'f': return self.get_tmpreg_v()
+        if self._return_type == 'v': return None
         raise ValueError(self._return_type)
 
     def cleanup_registers(self):
@@ -775,8 +760,8 @@
     def bhimpl_can_enter_jit():
         pass
 
-    @arguments("self", "I", "R", "F", "I", "R", "F")
-    def bhimpl_jit_merge_point(self, *args):
+    @arguments("self", "i", "I", "R", "F", "I", "R", "F")
+    def bhimpl_jit_merge_point(self, jdindex, *args):
         if self.nextblackholeinterp is None:    # we are the last level
             CRN = self.builder.metainterp_sd.ContinueRunningNormally
             raise CRN(*args)
@@ -790,55 +775,57 @@
             # call the interpreter main loop from here, and just return its
             # result.
             sd = self.builder.metainterp_sd
+            xxxxxxxxxxxx
             if sd.result_type == 'void':
-                self.bhimpl_recursive_call_v(*args)
+                self.bhimpl_recursive_call_v(jdindex, *args)
                 self.bhimpl_void_return()
             elif sd.result_type == 'int':
-                x = self.bhimpl_recursive_call_i(*args)
+                x = self.bhimpl_recursive_call_i(jdindex, *args)
                 self.bhimpl_int_return(x)
             elif sd.result_type == 'ref':
-                x = self.bhimpl_recursive_call_r(*args)
+                x = self.bhimpl_recursive_call_r(jdindex, *args)
                 self.bhimpl_ref_return(x)
             elif sd.result_type == 'float':
-                x = self.bhimpl_recursive_call_f(*args)
+                x = self.bhimpl_recursive_call_f(jdindex, *args)
                 self.bhimpl_float_return(x)
             assert False
 
-    def get_portal_runner(self):
+    def get_portal_runner(self, jdindex):
         metainterp_sd = self.builder.metainterp_sd
-        fnptr = llmemory.cast_ptr_to_adr(metainterp_sd._portal_runner_ptr)
+        jitdriver_sd = metainterp_sd.jitdrivers_sd[jdindex]
+        fnptr = llmemory.cast_ptr_to_adr(jitdriver_sd.portal_runner_ptr)
         fnptr = heaptracker.adr2int(fnptr)
-        calldescr = metainterp_sd.portal_code.calldescr
+        calldescr = jitdriver_sd.mainjitcode.calldescr
         return fnptr, calldescr
 
-    @arguments("self", "I", "R", "F", "I", "R", "F", returns="i")
-    def bhimpl_recursive_call_i(self, greens_i, greens_r, greens_f,
-                                      reds_i,   reds_r,   reds_f):
-        fnptr, calldescr = self.get_portal_runner()
+    @arguments("self", "i", "I", "R", "F", "I", "R", "F", returns="i")
+    def bhimpl_recursive_call_i(self, jdindex, greens_i, greens_r, greens_f,
+                                               reds_i,   reds_r,   reds_f):
+        fnptr, calldescr = self.get_portal_runner(jdindex)
         return self.cpu.bh_call_i(fnptr, calldescr,
                                   greens_i + reds_i,
                                   greens_r + reds_r,
                                   greens_f + reds_f)
-    @arguments("self", "I", "R", "F", "I", "R", "F", returns="r")
-    def bhimpl_recursive_call_r(self, greens_i, greens_r, greens_f,
-                                      reds_i,   reds_r,   reds_f):
-        fnptr, calldescr = self.get_portal_runner()
+    @arguments("self", "i", "I", "R", "F", "I", "R", "F", returns="r")
+    def bhimpl_recursive_call_r(self, jdindex, greens_i, greens_r, greens_f,
+                                               reds_i,   reds_r,   reds_f):
+        fnptr, calldescr = self.get_portal_runner(jdindex)
         return self.cpu.bh_call_r(fnptr, calldescr,
                                   greens_i + reds_i,
                                   greens_r + reds_r,
                                   greens_f + reds_f)
-    @arguments("self", "I", "R", "F", "I", "R", "F", returns="f")
-    def bhimpl_recursive_call_f(self, greens_i, greens_r, greens_f,
-                                      reds_i,   reds_r,   reds_f):
-        fnptr, calldescr = self.get_portal_runner()
+    @arguments("self", "i", "I", "R", "F", "I", "R", "F", returns="f")
+    def bhimpl_recursive_call_f(self, jdindex, greens_i, greens_r, greens_f,
+                                               reds_i,   reds_r,   reds_f):
+        fnptr, calldescr = self.get_portal_runner(jdindex)
         return self.cpu.bh_call_f(fnptr, calldescr,
                                   greens_i + reds_i,
                                   greens_r + reds_r,
                                   greens_f + reds_f)
-    @arguments("self", "I", "R", "F", "I", "R", "F")
-    def bhimpl_recursive_call_v(self, greens_i, greens_r, greens_f,
-                                      reds_i,   reds_r,   reds_f):
-        fnptr, calldescr = self.get_portal_runner()
+    @arguments("self", "i", "I", "R", "F", "I", "R", "F")
+    def bhimpl_recursive_call_v(self, jdindex, greens_i, greens_r, greens_f,
+                                               reds_i,   reds_r,   reds_f):
+        fnptr, calldescr = self.get_portal_runner(jdindex)
         return self.cpu.bh_call_v(fnptr, calldescr,
                                   greens_i + reds_i,
                                   greens_r + reds_r,
@@ -1175,11 +1162,11 @@
             self._done_with_this_frame()
         kind = self._return_type
         if kind == 'i':
-            caller._setup_return_value_i(self.final_result_i())
+            caller._setup_return_value_i(self.get_tmpreg_i())
         elif kind == 'r':
-            caller._setup_return_value_r(self.final_result_r())
+            caller._setup_return_value_r(self.get_tmpreg_r())
         elif kind == 'f':
-            caller._setup_return_value_f(self.final_result_f())
+            caller._setup_return_value_f(self.get_tmpreg_f())
         else:
             assert kind == 'v'
         return lltype.nullptr(rclass.OBJECTPTR.TO)
@@ -1245,15 +1232,15 @@
         # rare case: we only get there if the blackhole interps all returned
         # normally (in general we get a ContinueRunningNormally exception).
         sd = self.builder.metainterp_sd
-        if sd.result_type == 'void':
-            self.final_result_v()
+        kind = self._return_type
+        if kind == 'v':
             raise sd.DoneWithThisFrameVoid()
-        elif sd.result_type == 'int':
-            raise sd.DoneWithThisFrameInt(self.final_result_i())
-        elif sd.result_type == 'ref':
-            raise sd.DoneWithThisFrameRef(self.cpu, self.final_result_r())
-        elif sd.result_type == 'float':
-            raise sd.DoneWithThisFrameFloat(self.final_result_f())
+        elif kind == 'i':
+            raise sd.DoneWithThisFrameInt(self.get_tmpreg_i())
+        elif kind == 'r':
+            raise sd.DoneWithThisFrameRef(self.cpu, self.get_tmpreg_r())
+        elif kind == 'f':
+            raise sd.DoneWithThisFrameFloat(self.get_tmpreg_f())
         else:
             assert False
 
@@ -1287,12 +1274,14 @@
             blackholeinterp.builder.release_interp(blackholeinterp)
         blackholeinterp = blackholeinterp.nextblackholeinterp
 
-def resume_in_blackhole(metainterp_sd, resumedescr, all_virtuals=None):
+def resume_in_blackhole(metainterp_sd, jitdriver_sd, resumedescr,
+                        all_virtuals=None):
     from pypy.jit.metainterp.resume import blackhole_from_resumedata
     debug_start('jit-blackhole')
     metainterp_sd.profiler.start_blackhole()
     blackholeinterp = blackhole_from_resumedata(
         metainterp_sd.blackholeinterpbuilder,
+        jitdriver_sd,
         resumedescr,
         all_virtuals)
     current_exc = blackholeinterp._prepare_resume_from_failure(

Modified: pypy/branch/multijit-4/pypy/jit/metainterp/compile.py
==============================================================================
--- pypy/branch/multijit-4/pypy/jit/metainterp/compile.py	(original)
+++ pypy/branch/multijit-4/pypy/jit/metainterp/compile.py	Sat Jun 12 09:32:17 2010
@@ -63,11 +63,12 @@
     # make a copy, because optimize_loop can mutate the ops and descrs
     loop.operations = [op.clone() for op in ops]
     metainterp_sd = metainterp.staticdata
+    jitdriver_sd = metainterp.jitdriver_sd
     loop_token = make_loop_token(len(loop.inputargs))
     loop.token = loop_token
     loop.operations[-1].descr = loop_token     # patch the target of the JUMP
     try:
-        old_loop_token = metainterp_sd.state.optimize_loop(
+        old_loop_token = jitdriver_sd._state.optimize_loop(
             metainterp_sd, old_loop_tokens, loop)
     except InvalidLoop:
         return None
@@ -141,32 +142,32 @@
     pass
 
 class DoneWithThisFrameDescrVoid(_DoneWithThisFrameDescr):
-    def handle_fail(self, metainterp_sd):
-        assert metainterp_sd.result_type == 'void'
+    def handle_fail(self, metainterp_sd, jitdriver_sd):
+        assert jitdriver_sd.result_type == history.VOID
         raise metainterp_sd.DoneWithThisFrameVoid()
 
 class DoneWithThisFrameDescrInt(_DoneWithThisFrameDescr):
-    def handle_fail(self, metainterp_sd):
-        assert metainterp_sd.result_type == 'int'
+    def handle_fail(self, metainterp_sd, jitdriver_sd):
+        assert jitdriver_sd.result_type == history.INT
         result = metainterp_sd.cpu.get_latest_value_int(0)
         raise metainterp_sd.DoneWithThisFrameInt(result)
 
 class DoneWithThisFrameDescrRef(_DoneWithThisFrameDescr):
-    def handle_fail(self, metainterp_sd):
-        assert metainterp_sd.result_type == 'ref'
+    def handle_fail(self, metainterp_sd, jitdriver_sd):
+        assert jitdriver_sd.result_type == history.REF
         cpu = metainterp_sd.cpu
         result = cpu.get_latest_value_ref(0)
         cpu.clear_latest_values(1)
         raise metainterp_sd.DoneWithThisFrameRef(cpu, result)
 
 class DoneWithThisFrameDescrFloat(_DoneWithThisFrameDescr):
-    def handle_fail(self, metainterp_sd):
-        assert metainterp_sd.result_type == 'float'
+    def handle_fail(self, metainterp_sd, jitdriver_sd):
+        assert jitdriver_sd.result_type == history.FLOAT
         result = metainterp_sd.cpu.get_latest_value_float(0)
         raise metainterp_sd.DoneWithThisFrameFloat(result)
 
 class ExitFrameWithExceptionDescrRef(_DoneWithThisFrameDescr):
-    def handle_fail(self, metainterp_sd):
+    def handle_fail(self, metainterp_sd, jitdriver_sd):
         cpu = metainterp_sd.cpu
         value = cpu.get_latest_value_ref(0)
         cpu.clear_latest_values(1)
@@ -258,22 +259,27 @@
             # a negative value
             self._counter = cnt | i
 
-    def handle_fail(self, metainterp_sd):
-        if self.must_compile(metainterp_sd):
-            return self._trace_and_compile_from_bridge(metainterp_sd)
+    def handle_fail(self, metainterp_sd, jitdriver_sd):
+        if self.must_compile(metainterp_sd, jitdriver_sd):
+            return self._trace_and_compile_from_bridge(metainterp_sd,
+                                                       jitdriver_sd)
         else:
             from pypy.jit.metainterp.blackhole import resume_in_blackhole
-            resume_in_blackhole(metainterp_sd, self)
+            resume_in_blackhole(metainterp_sd, jitdriver_sd, self)
             assert 0, "unreachable"
 
-    def _trace_and_compile_from_bridge(self, metainterp_sd):
+    def _trace_and_compile_from_bridge(self, metainterp_sd, jitdriver_sd):
+        # 'jitdriver_sd' corresponds to the outermost one, i.e. the one
+        # of the jit_merge_point where we started the loop, even if the
+        # loop itself may contain temporarily recursion into other
+        # jitdrivers.
         from pypy.jit.metainterp.pyjitpl import MetaInterp
-        metainterp = MetaInterp(metainterp_sd)
+        metainterp = MetaInterp(metainterp_sd, jitdriver_sd)
         return metainterp.handle_guard_failure(self)
     _trace_and_compile_from_bridge._dont_inline_ = True
 
-    def must_compile(self, metainterp_sd):
-        trace_eagerness = metainterp_sd.state.trace_eagerness
+    def must_compile(self, metainterp_sd, jitdriver_sd):
+        trace_eagerness = jitdriver_sd._state.trace_eagerness
         if self._counter >= 0:
             self._counter += 1
             return self._counter >= trace_eagerness
@@ -333,7 +339,7 @@
 
 class ResumeGuardForcedDescr(ResumeGuardDescr):
 
-    def handle_fail(self, metainterp_sd):
+    def handle_fail(self, metainterp_sd, jitdriver_sd):
         # Failures of a GUARD_NOT_FORCED are never compiled, but
         # always just blackholed.  First fish for the data saved when
         # the virtualrefs and virtualizable have been forced by
@@ -343,7 +349,7 @@
         all_virtuals = self.fetch_data(token)
         if all_virtuals is None:
             all_virtuals = []
-        resume_in_blackhole(metainterp_sd, self, all_virtuals)
+        resume_in_blackhole(metainterp_sd, jitdriver_sd, self, all_virtuals)
         assert 0, "unreachable"
 
     @staticmethod
@@ -464,6 +470,7 @@
         # a loop at all but ends in a jump to the target loop.  It starts
         # with completely unoptimized arguments, as in the interpreter.
         metainterp_sd = metainterp.staticdata
+        jitdriver_sd = metainterp.jitdriver_sd
         metainterp.history.inputargs = self.redkey
         new_loop_token = make_loop_token(len(self.redkey))
         new_loop.greenkey = self.original_greenkey
@@ -471,12 +478,11 @@
         new_loop.token = new_loop_token
         send_loop_to_backend(metainterp_sd, new_loop, "entry bridge")
         # send the new_loop to warmspot.py, to be called directly the next time
-        metainterp_sd.state.attach_unoptimized_bridge_from_interp(
+        jitdriver_sd._state.attach_unoptimized_bridge_from_interp(
             self.original_greenkey,
             new_loop_token)
         # store the new loop in compiled_merge_points too
-        glob = metainterp_sd.globaldata
-        old_loop_tokens = glob.get_compiled_merge_points(
+        old_loop_tokens = metainterp.get_compiled_merge_points(
             self.original_greenkey)
         # it always goes at the end of the list, as it is the most
         # general loop token
@@ -500,8 +506,9 @@
     # clone ops, as optimize_bridge can mutate the ops
     new_loop.operations = [op.clone() for op in metainterp.history.operations]
     metainterp_sd = metainterp.staticdata
+    jitdriver_sd = metainterp.jitdriver_sd
     try:
-        target_loop_token = metainterp_sd.state.optimize_bridge(metainterp_sd,
+        target_loop_token = jitdriver_sd._state.optimize_bridge(metainterp_sd,
                                                                 old_loop_tokens,
                                                                 new_loop)
     except InvalidLoop:

Modified: pypy/branch/multijit-4/pypy/jit/metainterp/history.py
==============================================================================
--- pypy/branch/multijit-4/pypy/jit/metainterp/history.py	(original)
+++ pypy/branch/multijit-4/pypy/jit/metainterp/history.py	Sat Jun 12 09:32:17 2010
@@ -174,7 +174,7 @@
 class AbstractFailDescr(AbstractDescr):
     index = -1
 
-    def handle_fail(self, metainterp_sd):
+    def handle_fail(self, metainterp_sd, jitdriver_sd):
         raise NotImplementedError
     def compile_and_attach(self, metainterp, new_loop):
         raise NotImplementedError

Modified: pypy/branch/multijit-4/pypy/jit/metainterp/pyjitpl.py
==============================================================================
--- pypy/branch/multijit-4/pypy/jit/metainterp/pyjitpl.py	(original)
+++ pypy/branch/multijit-4/pypy/jit/metainterp/pyjitpl.py	Sat Jun 12 09:32:17 2010
@@ -56,6 +56,7 @@
         self.parent_resumedata_snapshot = None
         self.parent_resumedata_frame_info_list = None
 
+    @specialize.arg(3)
     def copy_constants(self, registers, constants, ConstClass):
         """Copy jitcode.constants[0] to registers[255],
                 jitcode.constants[1] to registers[254],
@@ -68,7 +69,6 @@
             assert j >= 0
             registers[j] = ConstClass(constants[i])
             i -= 1
-    copy_constants._annspecialcase_ = 'specialize:arg(3)'
 
     def cleanup_registers(self):
         # To avoid keeping references alive, this cleans up the registers_r.
@@ -80,6 +80,7 @@
     # ------------------------------
     # Decoding of the JitCode
 
+    @specialize.arg(4)
     def prepare_list_of_boxes(self, outvalue, startindex, position, argcode):
         assert argcode in 'IRF'
         code = self.bytecode
@@ -92,7 +93,6 @@
             elif argcode == 'F': reg = self.registers_f[index]
             else: raise AssertionError(argcode)
             outvalue[startindex+i] = reg
-    prepare_list_of_boxes._annspecialcase_ = 'specialize:arg(4)'
 
     def get_current_position_info(self):
         return self.jitcode.get_live_vars_info(self.pc)
@@ -518,15 +518,15 @@
         return not isstandard
 
     def _get_virtualizable_field_index(self, fielddescr):
-        vinfo = self.metainterp.staticdata.virtualizable_info
+        vinfo = self.metainterp.jitdrivers_sd.virtualizable_info
         return vinfo.static_field_by_descrs[fielddescr]
 
     def _get_virtualizable_array_field_descr(self, index):
-        vinfo = self.metainterp.staticdata.virtualizable_info
+        vinfo = self.metainterp.jitdrivers_sd.virtualizable_info
         return vinfo.array_field_descrs[index]
 
     def _get_virtualizable_array_descr(self, index):
-        vinfo = self.metainterp.staticdata.virtualizable_info
+        vinfo = self.metainterp.jitdrivers_sd.virtualizable_info
         return vinfo.array_descrs[index]
 
     @arguments("orgpc", "box", "descr")
@@ -557,7 +557,8 @@
 
     def _get_arrayitem_vable_index(self, pc, arrayfielddescr, indexbox):
         indexbox = self.implement_guard_value(pc, indexbox)
-        vinfo = self.metainterp.staticdata.virtualizable_info
+        xxxxxxx
+        vinfo = self.metainterp.jitdriver_sd.virtualizable_info
         virtualizable_box = self.metainterp.virtualizable_boxes[-1]
         virtualizable = vinfo.unwrap_virtualizable_box(virtualizable_box)
         arrayindex = vinfo.array_field_by_descrs[arrayfielddescr]
@@ -606,7 +607,8 @@
             arraybox = self.metainterp.execute_and_record(rop.GETFIELD_GC,
                                                           fdescr, box)
             return self.execute_with_descr(rop.ARRAYLEN_GC, adescr, arraybox)
-        vinfo = self.metainterp.staticdata.virtualizable_info
+        xxxxxxx
+        vinfo = self.metainterp.jitdriver_sd.virtualizable_info
         virtualizable_box = self.metainterp.virtualizable_boxes[-1]
         virtualizable = vinfo.unwrap_virtualizable_box(virtualizable_box)
         arrayindex = vinfo.array_field_by_descrs[fdescr]
@@ -655,10 +657,11 @@
     opimpl_residual_call_irf_f = _opimpl_residual_call3
     opimpl_residual_call_irf_v = _opimpl_residual_call3
 
-    @arguments("boxes3", "boxes3")
-    def _opimpl_recursive_call(self, greenboxes, redboxes):
+    @arguments("int", "boxes3", "boxes3")
+    def _opimpl_recursive_call(self, jdindex, greenboxes, redboxes):
         allboxes = greenboxes + redboxes
         metainterp_sd = self.metainterp.staticdata
+        xxxx
         portal_code = metainterp_sd.portal_code
         warmrunnerstate = metainterp_sd.state
         token = None
@@ -765,17 +768,18 @@
             raise CannotInlineCanEnterJit()
         self.metainterp.seen_can_enter_jit = True
 
-    def verify_green_args(self, varargs):
-        num_green_args = self.metainterp.staticdata.num_green_args
+    def verify_green_args(self, jdindex, varargs):
+        jitdriver = self.metainterp.staticdata.jitdrivers_sd[jdindex]
+        num_green_args = jitdriver.num_green_args
         assert len(varargs) == num_green_args
         for i in range(num_green_args):
             assert isinstance(varargs[i], Const)
 
-    @arguments("orgpc", "boxes3", "boxes3")
-    def opimpl_jit_merge_point(self, orgpc, greenboxes, redboxes):
-        self.verify_green_args(greenboxes)
+    @arguments("orgpc", "int", "boxes3", "boxes3")
+    def opimpl_jit_merge_point(self, orgpc, jdindex, greenboxes, redboxes):
+        self.verify_green_args(jdindex, greenboxes)
         # xxx we may disable the following line in some context later
-        self.debug_merge_point(greenboxes)
+        self.debug_merge_point(jdindex, greenboxes)
         if self.metainterp.seen_can_enter_jit:
             self.metainterp.seen_can_enter_jit = False
             # Assert that it's impossible to arrive here with in_recursion
@@ -783,6 +787,7 @@
             # to True by opimpl_can_enter_jit, which should be executed
             # just before opimpl_jit_merge_point (no recursion inbetween).
             assert not self.metainterp.in_recursion
+            assert jdindex == self.metainterp.jitdriver_sd.index
             # Set self.pc to point to jit_merge_point instead of just after:
             # if reached_can_enter_jit() raises SwitchToBlackhole, then the
             # pc is still at the jit_merge_point, which is a point that is
@@ -792,10 +797,10 @@
             self.metainterp.reached_can_enter_jit(greenboxes, redboxes)
             self.pc = saved_pc
 
-    def debug_merge_point(self, greenkey):
+    def debug_merge_point(self, jdindex, greenkey):
         # debugging: produce a DEBUG_MERGE_POINT operation
-        sd = self.metainterp.staticdata
-        loc = sd.state.get_location_str(greenkey)
+        jitdriver = self.metainterp.staticdata.jitdrivers_sd[jdindex]
+        loc = jitdriver._state.get_location_str(greenkey)
         debug_print(loc)
         constloc = self.metainterp.cpu.ts.conststr(loc)
         self.metainterp.history.record(rop.DEBUG_MERGE_POINT,
@@ -945,7 +950,7 @@
         guard_op = metainterp.history.record(opnum, moreargs, None,
                                              descr=resumedescr)       
         virtualizable_boxes = None
-        if metainterp.staticdata.virtualizable_info is not None:
+        if metainterp.jitdriver_sd.virtualizable_info is not None:
             virtualizable_boxes = metainterp.virtualizable_boxes
         saved_pc = self.pc
         if resumepc >= 0:
@@ -1127,6 +1132,11 @@
         self._addr2name_keys = [key for key, value in list_of_addr2name]
         self._addr2name_values = [value for key, value in list_of_addr2name]
 
+    def setup_jitdrivers_sd(self, optimizer):
+        if optimizer is not None:
+            for jd in self.jitdrivers_sd:
+                jd._state.set_param_optimizer(optimizer)
+
     def finish_setup(self, codewriter, optimizer=None):
         from pypy.jit.metainterp.blackhole import BlackholeInterpBuilder
         self.blackholeinterpbuilder = BlackholeInterpBuilder(codewriter, self)
@@ -1137,28 +1147,21 @@
         self.setup_indirectcalltargets(asm.indirectcalltargets)
         self.setup_list_of_addr2name(asm.list_of_addr2name)
         #
-        self.portal_code = codewriter.mainjitcode
-        self._portal_runner_ptr = codewriter.callcontrol.portal_runner_ptr
+        self.jitdrivers_sd = codewriter.callcontrol.jitdrivers_sd
         self.virtualref_info = codewriter.callcontrol.virtualref_info
-        self.virtualizable_info = codewriter.callcontrol.virtualizable_info
-        RESULT = codewriter.portal_graph.getreturnvar().concretetype
-        self.result_type = history.getkind(RESULT)
+        self.setup_jitdrivers_sd(optimizer)
         #
         # store this information for fastpath of call_assembler
-        name = self.result_type
-        tokens = getattr(self, 'loop_tokens_done_with_this_frame_%s' % name)
-        num = self.cpu.get_fail_descr_number(tokens[0].finishdescr)
-        setattr(self.cpu, 'done_with_this_frame_%s_v' % name, num)
+        # (only the paths that can actually be taken)
+        for jd in self.jitdrivers_sd:
+            name = {history.INT: 'int',
+                    history.REF: 'ref',
+                    history.FLOAT: 'float',
+                    history.VOID: 'void'}[jd.result_type]
+            tokens = getattr(self, 'loop_tokens_done_with_this_frame_%s' % name)
+            num = self.cpu.get_fail_descr_number(tokens[0].finishdescr)
+            setattr(self.cpu, 'done_with_this_frame_%s_v' % name, num)
         #
-        warmrunnerdesc = self.warmrunnerdesc
-        if warmrunnerdesc is not None:
-            self.num_green_args = warmrunnerdesc.num_green_args
-            self.state = warmrunnerdesc.state
-            if optimizer is not None:
-                self.state.set_param_optimizer(optimizer)
-        else:
-            self.num_green_args = 0
-            self.state = None
         self.globaldata = MetaInterpGlobalData(self)
 
     def _setup_once(self):
@@ -1179,8 +1182,7 @@
                 # Build the dictionary at run-time.  This is needed
                 # because the keys are function/class addresses, so they
                 # can change from run to run.
-                k = llmemory.cast_ptr_to_adr(self._portal_runner_ptr)
-                d = {k: 'recursive call'}
+                d = {}
                 keys = self._addr2name_keys
                 values = self._addr2name_values
                 for i in range(len(keys)):
@@ -1238,47 +1240,31 @@
         self.loopnumbering = 0
         self.resume_virtuals = {}
         self.resume_virtuals_not_translated = []
-        #
-        state = staticdata.state
-        if state is not None:
-            self.jit_cell_at_key = state.jit_cell_at_key
-        else:
-            # for tests only; not RPython
-            class JitCell:
-                compiled_merge_points = None
-            _jitcell_dict = {}
-            def jit_cell_at_key(greenkey):
-                greenkey = tuple(greenkey)
-                return _jitcell_dict.setdefault(greenkey, JitCell())
-            self.jit_cell_at_key = jit_cell_at_key
-
-    def get_compiled_merge_points(self, greenkey):
-        cell = self.jit_cell_at_key(greenkey)
-        if cell.compiled_merge_points is None:
-            cell.compiled_merge_points = []
-        return cell.compiled_merge_points
 
 # ____________________________________________________________
 
 class MetaInterp(object):
     in_recursion = 0
 
-    def __init__(self, staticdata):
+    def __init__(self, staticdata, jitdriver_sd):
         self.staticdata = staticdata
         self.cpu = staticdata.cpu
+        self.jitdriver_sd = jitdriver_sd
+        # Note: self.jitdriver_sd is the JitDriverStaticData that corresponds
+        # to the current loop -- the outermost one.  Be careful, because
+        # during recursion we can also see other jitdrivers.
         self.portal_trace_positions = []
         self.free_frames_list = []
         self.last_exc_value_box = None
 
     def perform_call(self, jitcode, boxes, greenkey=None):
         # causes the metainterp to enter the given subfunction
-        # with a special case for recursive portal calls
         f = self.newframe(jitcode, greenkey)
         f.setup_call(boxes)
         raise ChangeFrame
 
     def newframe(self, jitcode, greenkey=None):
-        if jitcode is self.staticdata.portal_code:
+        if jitcode.is_portal:
             self.in_recursion += 1
         if greenkey is not None:
             self.portal_trace_positions.append(
@@ -1293,7 +1279,7 @@
 
     def popframe(self):
         frame = self.framestack.pop()
-        if frame.jitcode is self.staticdata.portal_code:
+        if frame.jitcode.is_portal:
             self.in_recursion -= 1
         if frame.greenkey is not None:
             self.portal_trace_positions.append(
@@ -1317,14 +1303,15 @@
             except SwitchToBlackhole, stb:
                 self.aborted_tracing(stb.reason)
             sd = self.staticdata
-            if sd.result_type == 'void':
+            result_type = self.jitdriver_sd.result_type
+            if result_type == history.VOID:
                 assert resultbox is None
                 raise sd.DoneWithThisFrameVoid()
-            elif sd.result_type == 'int':
+            elif result_type == history.INT:
                 raise sd.DoneWithThisFrameInt(resultbox.getint())
-            elif sd.result_type == 'ref':
+            elif result_type == history.REF:
                 raise sd.DoneWithThisFrameRef(self.cpu, resultbox.getref_base())
-            elif sd.result_type == 'float':
+            elif result_type == history.FLOAT:
                 raise sd.DoneWithThisFrameFloat(resultbox.getfloat())
             else:
                 assert False
@@ -1354,14 +1341,17 @@
         in_recursion = -1
         for frame in self.framestack:
             jitcode = frame.jitcode
-            if jitcode is self.staticdata.portal_code:
+            assert jitcode.is_portal == len([
+                jd for jd in self.staticdata.jitdrivers_sd
+                   if jd.mainjitcode is jitcode])
+            if jitcode.is_portal:
                 in_recursion += 1
         if in_recursion != self.in_recursion:
             print "in_recursion problem!!!"
             print in_recursion, self.in_recursion
             for frame in self.framestack:
                 jitcode = frame.jitcode
-                if jitcode is self.staticdata.portal_code:
+                if jitcode.is_portal:
                     print "P",
                 else:
                     print " ",
@@ -1369,7 +1359,6 @@
             raise AssertionError
 
     def create_empty_history(self):
-        warmrunnerstate = self.staticdata.state
         self.history = history.History()
         self.staticdata.stats.set_history(self.history)
 
@@ -1479,12 +1468,11 @@
         self.resumekey.reset_counter_from_failure()
 
     def blackhole_if_trace_too_long(self):
-        warmrunnerstate = self.staticdata.state
+        warmrunnerstate = self.jitdriver_sd._state
         if len(self.history.operations) > warmrunnerstate.trace_limit:
             greenkey_of_huge_function = self.find_biggest_function()
             self.portal_trace_positions = None
             if greenkey_of_huge_function is not None:
-                warmrunnerstate = self.staticdata.state
                 warmrunnerstate.disable_noninlinable_function(
                     greenkey_of_huge_function)
             raise SwitchToBlackhole(ABORT_TOO_LONG)
@@ -1511,21 +1499,27 @@
                     self.staticdata.log(sys.exc_info()[0].__name__)
                 raise
 
-    def compile_and_run_once(self, *args):
+    @specialize.arg(1)
+    def compile_and_run_once(self, jitdriver_sd, *args):
+        # NB. we pass explicity 'jitdriver_sd' around here, even though it
+        # is also available as 'self.jitdriver_sd', because we need to
+        # specialize this function and a few other ones for the '*args'.
         debug_start('jit-tracing')
         self.staticdata._setup_once()
         self.staticdata.profiler.start_tracing()
+        assert jitdriver_sd is self.jitdriver_sd
         self.create_empty_history()
         try:
-            return self._compile_and_run_once(*args)
+            original_boxes = self.initialize_original_boxes(jitdriver_sd,*args)
+            return self._compile_and_run_once(original_boxes)
         finally:
             self.staticdata.profiler.end_tracing()
             debug_stop('jit-tracing')
 
-    def _compile_and_run_once(self, *args):
-        original_boxes = self.initialize_state_from_start(*args)
+    def _compile_and_run_once(self, original_boxes):
+        self.initialize_state_from_start(original_boxes)
         self.current_merge_points = [(original_boxes, 0)]
-        num_green_args = self.staticdata.num_green_args
+        num_green_args = self.jitdriver_sd.num_green_args
         original_greenkey = original_boxes[:num_green_args]
         redkey = original_boxes[num_green_args:]
         self.resumekey = compile.ResumeFromInterpDescr(original_greenkey,
@@ -1592,7 +1586,7 @@
         self.remove_consts_and_duplicates(redboxes, len(redboxes),
                                           duplicates)
         live_arg_boxes = greenboxes + redboxes
-        if self.staticdata.virtualizable_info is not None:
+        if self.jitdriver_sd.virtualizable_info is not None:
             # we use pop() to remove the last item, which is the virtualizable
             # itself
             self.remove_consts_and_duplicates(self.virtualizable_boxes,
@@ -1614,11 +1608,12 @@
         # Search in current_merge_points for original_boxes with compatible
         # green keys, representing the beginning of the same loop as the one
         # we end now. 
-       
+
+        num_green_args = self.jitdriver_sd.num_green_args
         for j in range(len(self.current_merge_points)-1, -1, -1):
             original_boxes, start = self.current_merge_points[j]
             assert len(original_boxes) == len(live_arg_boxes) or start < 0
-            for i in range(self.staticdata.num_green_args):
+            for i in range(num_green_args):
                 box1 = original_boxes[i]
                 box2 = live_arg_boxes[i]
                 assert isinstance(box1, Const)
@@ -1641,7 +1636,7 @@
 
     def designate_target_loop(self, gmp):
         loop_token = gmp.target_loop_token
-        num_green_args = self.staticdata.num_green_args
+        num_green_args = self.jitdriver_sd.num_green_args
         residual_args = self.get_residual_args(loop_token.specnodes,
                                                gmp.argboxes[num_green_args:])
         history.set_future_values(self.cpu, residual_args)
@@ -1682,12 +1677,17 @@
             from pypy.jit.metainterp.resoperation import opname
             raise NotImplementedError(opname[opnum])
 
+    def get_compiled_merge_points(self, greenkey):
+        cell = self.jitdriver_sd._state.jit_cell_at_key(greenkey)
+        if cell.compiled_merge_points is None:
+            cell.compiled_merge_points = []
+        return cell.compiled_merge_points
+
     def compile(self, original_boxes, live_arg_boxes, start):
-        num_green_args = self.staticdata.num_green_args
+        num_green_args = self.jitdriver_sd.num_green_args
         self.history.inputargs = original_boxes[num_green_args:]
         greenkey = original_boxes[:num_green_args]
-        glob = self.staticdata.globaldata
-        old_loop_tokens = glob.get_compiled_merge_points(greenkey)
+        old_loop_tokens = self.get_compiled_merge_points(greenkey)
         self.history.record(rop.JUMP, live_arg_boxes[num_green_args:], None)
         loop_token = compile.compile_new_loop(self, old_loop_tokens,
                                               greenkey, start)
@@ -1696,10 +1696,9 @@
         self.history.operations.pop()     # remove the JUMP
 
     def compile_bridge(self, live_arg_boxes):
-        num_green_args = self.staticdata.num_green_args
+        num_green_args = self.jitdriver_sd.num_green_args
         greenkey = live_arg_boxes[:num_green_args]
-        glob = self.staticdata.globaldata
-        old_loop_tokens = glob.get_compiled_merge_points(greenkey)
+        old_loop_tokens = self.get_compiled_merge_points(greenkey)
         if len(old_loop_tokens) == 0:
             return
         self.history.record(rop.JUMP, live_arg_boxes[num_green_args:], None)
@@ -1713,17 +1712,18 @@
         self.gen_store_back_in_virtualizable()
         # temporarily put a JUMP to a pseudo-loop
         sd = self.staticdata
-        if sd.result_type == 'void':
+        result_type = self.jitdriver_sd.result_type
+        if result_type == history.VOID:
             assert exitbox is None
             exits = []
             loop_tokens = sd.loop_tokens_done_with_this_frame_void
-        elif sd.result_type == 'int':
+        elif result_type == history.INT:
             exits = [exitbox]
             loop_tokens = sd.loop_tokens_done_with_this_frame_int
-        elif sd.result_type == 'ref':
+        elif result_type == history.REF:
             exits = [exitbox]
             loop_tokens = sd.loop_tokens_done_with_this_frame_ref
-        elif sd.result_type == 'float':
+        elif result_type == history.FLOAT:
             exits = [exitbox]
             loop_tokens = sd.loop_tokens_done_with_this_frame_float
         else:
@@ -1755,26 +1755,32 @@
             specnode.extract_runtime_data(self.cpu, args[i], expanded_args)
         return expanded_args
 
-    def _initialize_from_start(self, original_boxes, num_green_args, *args):
+    @specialize.arg(1)
+    def initialize_original_boxes(self, jitdriver_sd, *args):
+        original_boxes = []
+        self._fill_original_boxes(jitdriver_sd, original_boxes,
+                                  jitdriver_sd.num_green_args, *args)
+        return original_boxes
+
+    @specialize.arg(1)
+    def _fill_original_boxes(self, jitdriver_sd, original_boxes,
+                             num_green_args, *args):
         if args:
             from pypy.jit.metainterp.warmstate import wrap
             box = wrap(self.cpu, args[0], num_green_args > 0)
             original_boxes.append(box)
-            self._initialize_from_start(original_boxes, num_green_args-1,
-                                        *args[1:])
+            self._fill_original_boxes(jitdriver_sd, original_boxes,
+                                      num_green_args-1, *args[1:])
 
-    def initialize_state_from_start(self, *args):
-        self.in_recursion = -1 # always one portal around
-        num_green_args = self.staticdata.num_green_args
-        original_boxes = []
-        self._initialize_from_start(original_boxes, num_green_args, *args)
+    def initialize_state_from_start(self, original_boxes):
         # ----- make a new frame -----
+        self.in_recursion = -1 # always one portal around
         self.framestack = []
-        f = self.newframe(self.staticdata.portal_code)
+        f = self.newframe(self.jitdriver_sd.mainjitcode)
         f.setup_call(original_boxes)
+        assert self.in_recursion == 0
         self.virtualref_boxes = []
         self.initialize_virtualizable(original_boxes)
-        return original_boxes
 
     def initialize_state_from_guard_failure(self, resumedescr):
         # guard failure: rebuild a complete MIFrame stack
@@ -1784,7 +1790,7 @@
         self.history.inputargs = [box for box in inputargs_and_holes if box]
 
     def initialize_virtualizable(self, original_boxes):
-        vinfo = self.staticdata.virtualizable_info
+        vinfo = self.jitdriver_sd.virtualizable_info
         if vinfo is not None:
             virtualizable_box = original_boxes[vinfo.index_of_virtualizable]
             virtualizable = vinfo.unwrap_virtualizable_box(virtualizable_box)
@@ -1797,7 +1803,7 @@
             self.initialize_virtualizable_enter()
 
     def initialize_virtualizable_enter(self):
-        vinfo = self.staticdata.virtualizable_info
+        vinfo = self.jitdriver_sd.virtualizable_info
         virtualizable_box = self.virtualizable_boxes[-1]
         virtualizable = vinfo.unwrap_virtualizable_box(virtualizable_box)
         vinfo.clear_vable_token(virtualizable)
@@ -1811,7 +1817,7 @@
             # the FORCE_TOKEN is already set at runtime in each vref when
             # it is created, by optimizeopt.py.
         #
-        vinfo = self.staticdata.virtualizable_info
+        vinfo = self.jitdriver_sd.virtualizable_info
         if vinfo is not None:
             virtualizable_box = self.virtualizable_boxes[-1]
             virtualizable = vinfo.unwrap_virtualizable_box(virtualizable_box)
@@ -1837,7 +1843,7 @@
                 self.stop_tracking_virtualref(i)
 
     def vable_after_residual_call(self):
-        vinfo = self.staticdata.virtualizable_info
+        vinfo = self.jitdriver_sd.virtualizable_info
         if vinfo is not None:
             virtualizable_box = self.virtualizable_boxes[-1]
             virtualizable = vinfo.unwrap_virtualizable_box(virtualizable_box)
@@ -1888,7 +1894,7 @@
         assert self.last_exc_value_box is None
 
     def rebuild_state_after_failure(self, resumedescr):
-        vinfo = self.staticdata.virtualizable_info
+        vinfo = self.jitdriver_sd.virtualizable_info
         self.framestack = []
         boxlists = resume.rebuild_from_resumedata(self, resumedescr, vinfo)
         inputargs_and_holes, virtualizable_boxes, virtualref_boxes = boxlists
@@ -1924,13 +1930,13 @@
 
     def check_synchronized_virtualizable(self):
         if not we_are_translated():
-            vinfo = self.staticdata.virtualizable_info
+            vinfo = self.jitdriver_sd.virtualizable_info
             virtualizable_box = self.virtualizable_boxes[-1]
             virtualizable = vinfo.unwrap_virtualizable_box(virtualizable_box)
             vinfo.check_boxes(virtualizable, self.virtualizable_boxes)
 
     def synchronize_virtualizable(self):
-        vinfo = self.staticdata.virtualizable_info
+        vinfo = self.jitdriver_sd.virtualizable_info
         virtualizable_box = self.virtualizable_boxes[-1]
         virtualizable = vinfo.unwrap_virtualizable_box(virtualizable_box)
         vinfo.write_boxes(virtualizable, self.virtualizable_boxes)
@@ -1939,7 +1945,7 @@
         # Force a reload of the virtualizable fields into the local
         # boxes (called only in escaping cases).  Only call this function
         # just before SwitchToBlackhole.
-        vinfo = self.staticdata.virtualizable_info
+        vinfo = self.jitdriver_sd.virtualizable_info
         if vinfo is not None:
             virtualizable_box = self.virtualizable_boxes[-1]
             virtualizable = vinfo.unwrap_virtualizable_box(virtualizable_box)
@@ -1948,7 +1954,7 @@
             self.virtualizable_boxes.append(virtualizable_box)
 
     def gen_store_back_in_virtualizable(self):
-        vinfo = self.staticdata.virtualizable_info
+        vinfo = self.jitdriver_sd.virtualizable_info
         if vinfo is not None:
             # xxx only write back the fields really modified
             vbox = self.virtualizable_boxes[-1]
@@ -1970,7 +1976,7 @@
             assert i + 1 == len(self.virtualizable_boxes)
 
     def gen_load_from_other_virtualizable(self, vbox):
-        vinfo = self.staticdata.virtualizable_info
+        vinfo = self.jitdriver_sd.virtualizable_info
         boxes = []
         assert vinfo is not None
         for i in range(vinfo.num_static_extra_boxes):
@@ -1994,7 +2000,7 @@
         for i in range(len(boxes)):
             if boxes[i] is oldbox:
                 boxes[i] = newbox
-        if self.staticdata.virtualizable_info is not None:
+        if self.jitdriver_sd.virtualizable_info is not None:
             boxes = self.virtualizable_boxes
             for i in range(len(boxes)):
                 if boxes[i] is oldbox:
@@ -2050,8 +2056,9 @@
         assert op.opnum == rop.CALL_MAY_FORCE
         num_green_args = self.staticdata.num_green_args
         args = op.args[num_green_args + 1:]
-        if self.staticdata.virtualizable_info is not None:
-            vindex = self.staticdata.virtualizable_info.index_of_virtualizable
+        vinfo = self.jitdriver_sd.virtualizable_info
+        if vinfo is not None:
+            vindex = vinfo.index_of_virtualizable
             vbox = args[vindex - num_green_args]
             args = args + self.gen_load_from_other_virtualizable(vbox)
             # ^^^ and not "+=", which makes 'args' a resizable list
@@ -2156,6 +2163,16 @@
                 position = position3 + 1 + length3
             elif argtype == "orgpc":
                 value = orgpc
+            elif argtype == "int":
+                argcode = argcodes[next_argcode]
+                next_argcode = next_argcode + 1
+                if argcode == 'i':
+                    value = ord(code[position])
+                elif argcode == 'c':
+                    value = signedord(code[position])
+                else:
+                    raise AssertionError("bad argcode")
+                position += 1
             else:
                 raise AssertionError("bad argtype: %r" % (argtype,))
             args += (value,)

Modified: pypy/branch/multijit-4/pypy/jit/metainterp/resume.py
==============================================================================
--- pypy/branch/multijit-4/pypy/jit/metainterp/resume.py	(original)
+++ pypy/branch/multijit-4/pypy/jit/metainterp/resume.py	Sat Jun 12 09:32:17 2010
@@ -709,11 +709,11 @@
 
 # ---------- when resuming for blackholing, get direct values ----------
 
-def blackhole_from_resumedata(blackholeinterpbuilder, storage,
+def blackhole_from_resumedata(blackholeinterpbuilder, jitdriver_sd, storage,
                               all_virtuals=None):
     resumereader = ResumeDataDirectReader(blackholeinterpbuilder.cpu, storage,
                                           all_virtuals)
-    vinfo = blackholeinterpbuilder.metainterp_sd.virtualizable_info
+    vinfo = jitdriver_sd.virtualizable_info
     vrefinfo = blackholeinterpbuilder.metainterp_sd.virtualref_info
     resumereader.consume_vref_and_vable(vrefinfo, vinfo)
     #

Modified: pypy/branch/multijit-4/pypy/jit/metainterp/test/test_basic.py
==============================================================================
--- pypy/branch/multijit-4/pypy/jit/metainterp/test/test_basic.py	(original)
+++ pypy/branch/multijit-4/pypy/jit/metainterp/test/test_basic.py	Sat Jun 12 09:32:17 2010
@@ -16,10 +16,18 @@
     from pypy.jit.codewriter import support, codewriter
     from pypy.jit.metainterp import simple_optimize
 
+    class FakeJitCell:
+        compiled_merge_points = None
+
     class FakeWarmRunnerState:
         def attach_unoptimized_bridge_from_interp(self, greenkey, newloop):
             pass
 
+        def jit_cell_at_key(self, greenkey):
+            assert greenkey == []
+            return self._cell
+        _cell = FakeJitCell()
+
         # pick the optimizer this way
         optimize_loop = staticmethod(simple_optimize.optimize_loop)
         optimize_bridge = staticmethod(simple_optimize.optimize_bridge)
@@ -30,13 +38,22 @@
     func._jit_unroll_safe_ = True
     rtyper = support.annotate(func, values, type_system=type_system)
     graphs = rtyper.annotator.translator.graphs
+    result_kind = history.getkind(graphs[0].getreturnvar().concretetype)[0]
+
+    class FakeJitDriverSD:
+        num_green_args = 0
+        portal_graph = graphs[0]
+        virtualizable_info = None
+        result_type = result_kind
+        portal_runner_ptr = "???"
+
     stats = history.Stats()
     cpu = CPUClass(rtyper, stats, None, False)
-    cw = codewriter.CodeWriter(cpu, graphs[0])
+    cw = codewriter.CodeWriter(cpu, [FakeJitDriverSD()])
     testself.cw = cw
     cw.find_all_graphs(JitPolicy())
     #
-    testself.warmrunnerstate = FakeWarmRunnerState()
+    testself.warmrunnerstate = FakeJitDriverSD._state = FakeWarmRunnerState()
     testself.warmrunnerstate.cpu = cpu
     if hasattr(testself, 'finish_setup_for_interp_operations'):
         testself.finish_setup_for_interp_operations()
@@ -62,7 +79,8 @@
             count_f += 1
         else:
             raise TypeError(T)
-    blackholeinterp.setposition(cw.mainjitcode, 0)
+    [jitdriver_sd] = cw.callcontrol.jitdrivers_sd
+    blackholeinterp.setposition(jitdriver_sd.mainjitcode, 0)
     blackholeinterp.run()
     return blackholeinterp._final_result_anytype()
 
@@ -78,16 +96,15 @@
     cw = testself.cw
     opt = history.Options(listops=True)
     metainterp_sd = pyjitpl.MetaInterpStaticData(cw.cpu, opt)
-    metainterp_sd.finish_setup(cw, optimizer="bogus")
-    metainterp_sd.state = testself.warmrunnerstate
-    metainterp_sd.state.cpu = metainterp_sd.cpu
-    metainterp = pyjitpl.MetaInterp(metainterp_sd)
+    metainterp_sd.finish_setup(cw)
+    [jitdriver_sd] = metainterp_sd.jitdrivers_sd
+    metainterp = pyjitpl.MetaInterp(metainterp_sd, jitdriver_sd)
     metainterp_sd.DoneWithThisFrameInt = DoneWithThisFrame
     metainterp_sd.DoneWithThisFrameRef = DoneWithThisFrameRef
     metainterp_sd.DoneWithThisFrameFloat = DoneWithThisFrame
     testself.metainterp = metainterp
     try:
-        metainterp.compile_and_run_once(*args)
+        metainterp.compile_and_run_once(jitdriver_sd, *args)
     except DoneWithThisFrame, e:
         #if conftest.option.view:
         #    metainterp.stats.view()
@@ -814,8 +831,9 @@
         translator.config.translation.gc = "boehm"
         warmrunnerdesc = WarmRunnerDesc(translator,
                                         CPUClass=self.CPUClass)
-        warmrunnerdesc.state.set_param_threshold(3)          # for tests
-        warmrunnerdesc.state.set_param_trace_eagerness(0)    # for tests
+        state = warmrunnerdesc.jitdrivers_sd[0]._state
+        state.set_param_threshold(3)          # for tests
+        state.set_param_trace_eagerness(0)    # for tests
         warmrunnerdesc.finish()
         for n, k in [(20, 0), (20, 1)]:
             interp.eval_graph(graph, [n, k])

Modified: pypy/branch/multijit-4/pypy/jit/metainterp/warmspot.py
==============================================================================
--- pypy/branch/multijit-4/pypy/jit/metainterp/warmspot.py	(original)
+++ pypy/branch/multijit-4/pypy/jit/metainterp/warmspot.py	Sat Jun 12 09:32:17 2010
@@ -21,6 +21,7 @@
 from pypy.jit.metainterp.typesystem import LLTypeHelper, OOTypeHelper
 from pypy.jit.metainterp.jitprof import Profiler, EmptyProfiler
 from pypy.jit.metainterp.jitexc import JitException
+from pypy.jit.metainterp.jitdriver import JitDriverStaticData
 from pypy.jit.codewriter import support, codewriter
 from pypy.jit.codewriter.policy import JitPolicy
 from pypy.rlib.jit import DEBUG_STEPS, DEBUG_DETAILED, DEBUG_OFF, DEBUG_PROFILE
@@ -69,11 +70,12 @@
     translator.config.translation.gc = "boehm"
     translator.config.translation.list_comprehension_operations = True
     warmrunnerdesc = WarmRunnerDesc(translator, backendopt=backendopt, **kwds)
-    warmrunnerdesc.state.set_param_threshold(3)          # for tests
-    warmrunnerdesc.state.set_param_trace_eagerness(2)    # for tests
-    warmrunnerdesc.state.set_param_trace_limit(trace_limit)
-    warmrunnerdesc.state.set_param_inlining(inline)
-    warmrunnerdesc.state.set_param_debug(debug_level)
+    for jd in warmrunnerdesc.jitdrivers_sd:
+        jd._state.set_param_threshold(3)          # for tests
+        jd._state.set_param_trace_eagerness(2)    # for tests
+        jd._state.set_param_trace_limit(trace_limit)
+        jd._state.set_param_inlining(inline)
+        jd._state.set_param_debug(debug_level)
     warmrunnerdesc.finish()
     res = interp.eval_graph(graph, args)
     if not kwds.get('translate_support_code', False):
@@ -110,12 +112,11 @@
         raise Exception("no can_enter_jit found!")
     return results
 
-def find_jit_merge_point(graphs):
+def find_jit_merge_points(graphs):
     results = _find_jit_marker(graphs, 'jit_merge_point')
-    if len(results) != 1:
-        raise Exception("found %d jit_merge_points, need exactly one!" %
-                        (len(results),))
-    return results[0]
+    if not results:
+        raise Exception("no jit_merge_point found!")
+    return results
 
 def find_set_param(graphs):
     return _find_jit_marker(graphs, 'set_param')
@@ -146,8 +147,8 @@
         pyjitpl._warmrunnerdesc = self   # this is a global for debugging only!
         self.set_translator(translator)
         self.build_cpu(CPUClass, **kwds)
-        self.find_portal()
-        self.codewriter = codewriter.CodeWriter(self.cpu, self.portal_graph)
+        self.find_portals()
+        self.codewriter = codewriter.CodeWriter(self.cpu, self.jitdrivers_sd)
         if policy is None:
             policy = JitPolicy()
         policy.set_supports_floats(self.cpu.supports_floats)
@@ -158,35 +159,31 @@
             self.prejit_optimizations(policy, graphs)
 
         self.build_meta_interp(ProfilerClass)
-        self.make_args_specification()
+        self.make_args_specifications()
         #
         from pypy.jit.metainterp.virtualref import VirtualRefInfo
         vrefinfo = VirtualRefInfo(self)
         self.codewriter.setup_vrefinfo(vrefinfo)
-        if self.jitdriver.virtualizables:
-            from pypy.jit.metainterp.virtualizable import VirtualizableInfo
-            self.virtualizable_info = VirtualizableInfo(self)
-            self.codewriter.setup_virtualizable_info(self.virtualizable_info)
-        else:
-            self.virtualizable_info = None
         #
+        self.make_virtualizable_infos()
         self.make_exception_classes()
         self.make_driverhook_graphs()
-        self.make_enter_function()
-        self.rewrite_jit_merge_point(policy)
+        self.make_enter_functions()
+        self.rewrite_jit_merge_points(policy)
 
         verbose = not self.cpu.translate_support_code
         self.codewriter.make_jitcodes(verbose=verbose)
-        self.rewrite_can_enter_jit()
+        self.rewrite_can_enter_jits()
         self.rewrite_set_param()
         self.rewrite_force_virtual(vrefinfo)
         self.add_finish()
         self.metainterp_sd.finish_setup(self.codewriter, optimizer=optimizer)
 
     def finish(self):
-        vinfo = self.virtualizable_info
-        if vinfo is not None:
-            vinfo.finish()
+        vinfos = set([jd.virtualizable_info for jd in self.jitdrivers_sd])
+        for vinfo in vinfos:
+            if vinfo is not None:
+                vinfo.finish()
         if self.cpu.translate_support_code:
             self.annhelper.finish()
 
@@ -198,18 +195,27 @@
         self.rtyper = translator.rtyper
         self.gcdescr = gc.get_description(translator.config)
 
-    def find_portal(self):
+    def find_portals(self):
+        self.jitdrivers_sd = []
         graphs = self.translator.graphs
-        self.jit_merge_point_pos = find_jit_merge_point(graphs)
-        graph, block, pos = self.jit_merge_point_pos
+        for jit_merge_point_pos in find_jit_merge_points(graphs):
+            self.split_graph_and_record_jitdriver(*jit_merge_point_pos)
+        #
+        assert (len(set([jd.jitdriver for jd in self.jitdrivers_sd])) ==
+                len(self.jitdrivers_sd)), \
+                "there are multiple jit_merge_points with the same jitdriver"
+
+    def split_graph_and_record_jitdriver(self, graph, block, pos):
+        jd = JitDriverStaticData()
+        jd._jit_merge_point_pos = (graph, block, pos)
         op = block.operations[pos]
         args = op.args[2:]
         s_binding = self.translator.annotator.binding
-        self.portal_args_s = [s_binding(v) for v in args]
+        jd._portal_args_s = [s_binding(v) for v in args]
         graph = copygraph(graph)
         graph.startblock.isstartblock = False
-        graph.startblock = support.split_before_jit_merge_point(
-            *find_jit_merge_point([graph]))
+        [jmpp] = find_jit_merge_points([graph])
+        graph.startblock = support.split_before_jit_merge_point(*jmpp)
         graph.startblock.isstartblock = True
         # a crash in the following checkgraph() means that you forgot
         # to list some variable in greens=[] or reds=[] in JitDriver.
@@ -218,12 +224,16 @@
             assert isinstance(v, Variable)
         assert len(dict.fromkeys(graph.getargs())) == len(graph.getargs())
         self.translator.graphs.append(graph)
-        self.portal_graph = graph
+        jd.portal_graph = graph
         # it's a bit unbelievable to have a portal without func
         assert hasattr(graph, "func")
         graph.func._dont_inline_ = True
         graph.func._jit_unroll_safe_ = True
-        self.jitdriver = block.operations[pos].args[1].value
+        jd.jitdriver = block.operations[pos].args[1].value
+        jd.portal_runner_ptr = "<not set so far>"
+        jd.result_type = history.getkind(jd.portal_graph.getreturnvar()
+                                         .concretetype)[0]
+        self.jitdrivers_sd.append(jd)
 
     def check_access_directly_sanity(self, graphs):
         from pypy.translator.backendopt.inline import collect_called_graphs
@@ -268,6 +278,17 @@
                                                   ProfilerClass=ProfilerClass,
                                                   warmrunnerdesc=self)
 
+    def make_virtualizable_infos(self):
+        for jd in self.jitdrivers_sd:
+            if jd.jitdriver.virtualizables:
+                XXX
+                from pypy.jit.metainterp.virtualizable import VirtualizableInfo
+                vinfo = VirtualizableInfo(self)
+                YYY  # share!
+            else:
+                vinfo = None
+            jd.virtualizable_info = vinfo
+
     def make_exception_classes(self):
 
         class DoneWithThisFrameVoid(JitException):
@@ -317,6 +338,8 @@
                     self.green_int, self.green_ref, self.green_float,
                     self.red_int, self.red_ref, self.red_float)
 
+        # XXX there is no point any more to not just have the exceptions
+        # as globals
         self.DoneWithThisFrameVoid = DoneWithThisFrameVoid
         self.DoneWithThisFrameInt = DoneWithThisFrameInt
         self.DoneWithThisFrameRef = DoneWithThisFrameRef
@@ -330,11 +353,15 @@
         self.metainterp_sd.ExitFrameWithExceptionRef = ExitFrameWithExceptionRef
         self.metainterp_sd.ContinueRunningNormally = ContinueRunningNormally
 
-    def make_enter_function(self):
+    def make_enter_functions(self):
+        for jd in self.jitdrivers_sd:
+            self.make_enter_function(jd)
+
+    def make_enter_function(self, jd):
         from pypy.jit.metainterp.warmstate import WarmEnterState
-        state = WarmEnterState(self)
+        state = WarmEnterState(self, jd)
         maybe_compile_and_run = state.make_entry_point()
-        self.state = state
+        jd._state = state
 
         def crash_in_jit(e):
             if not we_are_translated():
@@ -359,15 +386,16 @@
             def maybe_enter_jit(*args):
                 maybe_compile_and_run(*args)
             maybe_enter_jit._always_inline_ = True
-        self.maybe_enter_jit_fn = maybe_enter_jit
+        jd._maybe_enter_jit_fn = maybe_enter_jit
 
-        can_inline = self.state.can_inline_greenargs
+        can_inline = state.can_inline_greenargs
+        num_green_args = jd.num_green_args
         def maybe_enter_from_start(*args):
-            if can_inline is not None and not can_inline(*args[:self.num_green_args]):
+            if can_inline is not None and not can_inline(*args[:num_green_args]):
                 maybe_compile_and_run(*args)
         maybe_enter_from_start._always_inline_ = True
-        self.maybe_enter_from_start_fn = maybe_enter_from_start
-        
+        jd._maybe_enter_from_start_fn = maybe_enter_from_start
+
     def make_driverhook_graphs(self):
         from pypy.rlib.jit import BaseJitCell
         bk = self.rtyper.annotator.bookkeeper
@@ -378,22 +406,23 @@
         s_Str = annmodel.SomeString()
         #
         annhelper = MixLevelHelperAnnotator(self.translator.rtyper)
-        self.set_jitcell_at_ptr = self._make_hook_graph(
-            annhelper, self.jitdriver.set_jitcell_at, annmodel.s_None,
-            s_BaseJitCell_not_None)
-        self.get_jitcell_at_ptr = self._make_hook_graph(
-            annhelper, self.jitdriver.get_jitcell_at, s_BaseJitCell_or_None)
-        self.can_inline_ptr = self._make_hook_graph(
-            annhelper, self.jitdriver.can_inline, annmodel.s_Bool)
-        self.get_printable_location_ptr = self._make_hook_graph(
-            annhelper, self.jitdriver.get_printable_location, s_Str)
-        self.confirm_enter_jit_ptr = self._make_hook_graph(
-            annhelper, self.jitdriver.confirm_enter_jit, annmodel.s_Bool,
-            onlygreens=False)
+        for jd in self.jitdrivers_sd:
+            jd._set_jitcell_at_ptr = self._make_hook_graph(jd,
+                annhelper, jd.jitdriver.set_jitcell_at, annmodel.s_None,
+                s_BaseJitCell_not_None)
+            jd._get_jitcell_at_ptr = self._make_hook_graph(jd,
+                annhelper, jd.jitdriver.get_jitcell_at, s_BaseJitCell_or_None)
+            jd._can_inline_ptr = self._make_hook_graph(jd,
+                annhelper, jd.jitdriver.can_inline, annmodel.s_Bool)
+            jd._get_printable_location_ptr = self._make_hook_graph(jd,
+                annhelper, jd.jitdriver.get_printable_location, s_Str)
+            jd._confirm_enter_jit_ptr = self._make_hook_graph(jd,
+                annhelper, jd.jitdriver.confirm_enter_jit, annmodel.s_Bool,
+                onlygreens=False)
         annhelper.finish()
 
-    def _make_hook_graph(self, annhelper, func, s_result, s_first_arg=None,
-                         onlygreens=True):
+    def _make_hook_graph(self, jitdriver_sd, annhelper, func,
+                         s_result, s_first_arg=None, onlygreens=True):
         if func is None:
             return None
         #
@@ -401,38 +430,57 @@
         if s_first_arg is not None:
             extra_args_s.append(s_first_arg)
         #
-        args_s = self.portal_args_s
+        args_s = jitdriver_sd._portal_args_s
         if onlygreens:
-            args_s = args_s[:len(self.green_args_spec)]
+            args_s = args_s[:len(jitdriver_sd._green_args_spec)]
         graph = annhelper.getgraph(func, extra_args_s + args_s, s_result)
         funcptr = annhelper.graph2delayed(graph)
         return funcptr
 
-    def make_args_specification(self):
-        graph, block, index = self.jit_merge_point_pos
+    def make_args_specifications(self):
+        for jd in self.jitdrivers_sd:
+            self.make_args_specification(jd)
+
+    def make_args_specification(self, jd):
+        graph, block, index = jd._jit_merge_point_pos
         op = block.operations[index]
         greens_v, reds_v = support.decode_hp_hint_args(op)
         ALLARGS = [v.concretetype for v in (greens_v + reds_v)]
-        self.green_args_spec = [v.concretetype for v in greens_v]
-        self.red_args_types = [history.getkind(v.concretetype) for v in reds_v]
-        self.num_green_args = len(self.green_args_spec)
+        jd._green_args_spec = [v.concretetype for v in greens_v]
+        jd._red_args_types = [history.getkind(v.concretetype) for v in reds_v]
+        jd.num_green_args = len(jd._green_args_spec)
         RESTYPE = graph.getreturnvar().concretetype
-        (self.JIT_ENTER_FUNCTYPE,
-         self.PTR_JIT_ENTER_FUNCTYPE) = self.cpu.ts.get_FuncType(ALLARGS, lltype.Void)
-        (self.PORTAL_FUNCTYPE,
-         self.PTR_PORTAL_FUNCTYPE) = self.cpu.ts.get_FuncType(ALLARGS, RESTYPE)
-        (_, self.PTR_ASSEMBLER_HELPER_FUNCTYPE) = self.cpu.ts.get_FuncType(
+        (jd._JIT_ENTER_FUNCTYPE,
+         jd._PTR_JIT_ENTER_FUNCTYPE) = self.cpu.ts.get_FuncType(ALLARGS, lltype.Void)
+        (jd._PORTAL_FUNCTYPE,
+         jd._PTR_PORTAL_FUNCTYPE) = self.cpu.ts.get_FuncType(ALLARGS, RESTYPE)
+        (_, jd._PTR_ASSEMBLER_HELPER_FUNCTYPE) = self.cpu.ts.get_FuncType(
             [lltype.Signed, llmemory.GCREF], RESTYPE)
 
-    def rewrite_can_enter_jit(self):
-        FUNC = self.JIT_ENTER_FUNCTYPE
-        FUNCPTR = self.PTR_JIT_ENTER_FUNCTYPE
-        jit_enter_fnptr = self.helper_func(FUNCPTR, self.maybe_enter_jit_fn)
+    def rewrite_can_enter_jits(self):
+        can_enter_jits = find_can_enter_jit(self.translator.graphs)
+        sublists = {}
+        for jd in self.jitdrivers_sd:
+            sublists[jd.jitdriver] = []
+        for graph, block, index in can_enter_jits:
+            op = block.operations[index]
+            jitdriver = op.args[1].value
+            assert jitdriver in sublists, \
+                   "can_enter_jit with no matching jit_merge_point"
+            sublists[jitdriver].append((graph, block, index))
+        for jd in self.jitdrivers_sd:
+            sublist = sublists[jd.jitdriver]
+            assert len(sublist) > 0, \
+                   "found no can_enter_jit for %r" % (jd.jitdriver,)
+            self.rewrite_can_enter_jit(jd, sublist)
+
+    def rewrite_can_enter_jit(self, jd, can_enter_jits):
+        FUNC = jd._JIT_ENTER_FUNCTYPE
+        FUNCPTR = jd._PTR_JIT_ENTER_FUNCTYPE
+        jit_enter_fnptr = self.helper_func(FUNCPTR, jd._maybe_enter_jit_fn)
 
-        graphs = self.translator.graphs
-        can_enter_jits = find_can_enter_jit(graphs)
         for graph, block, index in can_enter_jits:
-            if graph is self.jit_merge_point_pos[0]:
+            if graph is jd._jit_merge_point_pos[0]:
                 continue
 
             op = block.operations[index]
@@ -455,7 +503,11 @@
         graph = self.annhelper.getgraph(func, args_s, s_result)
         return self.annhelper.graph2delayed(graph, FUNC)
 
-    def rewrite_jit_merge_point(self, policy):
+    def rewrite_jit_merge_points(self, policy):
+        for jd in self.jitdrivers_sd:
+            self.rewrite_jit_merge_point(jd, policy)
+
+    def rewrite_jit_merge_point(self, jd, policy):
         #
         # Mutate the original portal graph from this:
         #
@@ -486,9 +538,9 @@
         #           while 1:
         #               more stuff
         #
-        origportalgraph = self.jit_merge_point_pos[0]
-        portalgraph = self.portal_graph
-        PORTALFUNC = self.PORTAL_FUNCTYPE
+        origportalgraph = jd._jit_merge_point_pos[0]
+        portalgraph = jd.portal_graph
+        PORTALFUNC = jd._PORTAL_FUNCTYPE
 
         # ____________________________________________________________
         # Prepare the portal_runner() helper
@@ -496,12 +548,12 @@
         from pypy.jit.metainterp.warmstate import specialize_value
         portal_ptr = self.cpu.ts.functionptr(PORTALFUNC, 'portal',
                                          graph = portalgraph)
-        self.portal_ptr = portal_ptr
+        jd._portal_ptr = portal_ptr
         #
         portalfunc_ARGS = []
         nums = {}
         for i, ARG in enumerate(PORTALFUNC.ARGS):
-            if i < len(self.jitdriver.greens):
+            if i < len(jd.jitdriver.greens):
                 color = 'green'
             else:
                 color = 'red'
@@ -519,7 +571,7 @@
         def ll_portal_runner(*args):
             while 1:
                 try:
-                    self.maybe_enter_from_start_fn(*args)
+                    jd._maybe_enter_from_start_fn(*args)
                     return support.maybe_on_top_of_llinterp(rtyper,
                                                       portal_ptr)(*args)
                 except self.ContinueRunningNormally, e:
@@ -548,16 +600,15 @@
                         value = cast_base_ptr_to_instance(Exception, value)
                         raise Exception, value
 
-        self.ll_portal_runner = ll_portal_runner # for debugging
-        self.portal_runner_ptr = self.helper_func(self.PTR_PORTAL_FUNCTYPE,
-                                                  ll_portal_runner)
+        jd._ll_portal_runner = ll_portal_runner # for debugging
+        jd.portal_runner_ptr = self.helper_func(jd._PTR_PORTAL_FUNCTYPE,
+                                                ll_portal_runner)
         self.cpu.portal_calldescr = self.cpu.calldescrof(
-            self.PTR_PORTAL_FUNCTYPE.TO,
-            self.PTR_PORTAL_FUNCTYPE.TO.ARGS,
-            self.PTR_PORTAL_FUNCTYPE.TO.RESULT)
-        self.codewriter.setup_portal_runner_ptr(self.portal_runner_ptr)
+            jd._PTR_PORTAL_FUNCTYPE.TO,
+            jd._PTR_PORTAL_FUNCTYPE.TO.ARGS,
+            jd._PTR_PORTAL_FUNCTYPE.TO.RESULT)
 
-        vinfo = self.virtualizable_info
+        vinfo = jd.virtualizable_info
 
         def assembler_call_helper(failindex, virtualizableref):
             fail_descr = self.cpu.get_fail_descr_from_number(failindex)
@@ -567,6 +618,7 @@
                         virtualizable = lltype.cast_opaque_ptr(
                             vinfo.VTYPEPTR, virtualizableref)
                         vinfo.reset_vable_token(virtualizable)
+                    XXX   # careful here, we must pass the correct jitdriver_sd
                     loop_token = fail_descr.handle_fail(self.metainterp_sd)
                     fail_descr = self.cpu.execute_token(loop_token)
                 except self.ContinueRunningNormally, e:
@@ -596,12 +648,14 @@
                         value = cast_base_ptr_to_instance(Exception, value)
                         raise Exception, value
 
-        self.assembler_call_helper = assembler_call_helper # for debugging
+        jd._assembler_call_helper = assembler_call_helper # for debugging
+        # XXX rewrite me, ugly sticking does not work any more
         self.cpu.assembler_helper_ptr = self.helper_func(
-            self.PTR_ASSEMBLER_HELPER_FUNCTYPE,
+            jd._PTR_ASSEMBLER_HELPER_FUNCTYPE,
             assembler_call_helper)
         # XXX a bit ugly sticking
         if vinfo is not None:
+            XXX     # rewrite me, ugly sticking does not work any more
             self.cpu.index_of_virtualizable = (vinfo.index_of_virtualizable -
                                                self.num_green_args)
             self.cpu.vable_token_descr = vinfo.vable_token_descr
@@ -612,12 +666,12 @@
         # ____________________________________________________________
         # Now mutate origportalgraph to end with a call to portal_runner_ptr
         #
-        _, origblock, origindex = self.jit_merge_point_pos
+        _, origblock, origindex = jd._jit_merge_point_pos
         op = origblock.operations[origindex]
         assert op.opname == 'jit_marker'
         assert op.args[0].value == 'jit_merge_point'
         greens_v, reds_v = support.decode_hp_hint_args(op)
-        vlist = [Constant(self.portal_runner_ptr, self.PTR_PORTAL_FUNCTYPE)]
+        vlist = [Constant(jd.portal_runner_ptr, jd._PTR_PORTAL_FUNCTYPE)]
         vlist += greens_v
         vlist += reds_v
         v_result = Variable()
@@ -644,8 +698,8 @@
         graphs = self.translator.graphs
         _, PTR_SET_PARAM_FUNCTYPE = self.cpu.ts.get_FuncType([lltype.Signed],
                                                              lltype.Void)
-        def make_closure(fullfuncname):
-            state = self.state
+        def make_closure(jd, fullfuncname):
+            state = jd._state
             def closure(i):
                 getattr(state, fullfuncname)(i)
             funcptr = self.helper_func(PTR_SET_PARAM_FUNCTYPE, closure)
@@ -653,12 +707,17 @@
         #
         for graph, block, i in find_set_param(graphs):
             op = block.operations[i]
-            assert op.args[1].value == self.jitdriver
+            for jd in self.jitdrivers_sd:
+                if jd.jitdriver is op.args[1].value:
+                    break
+            else:
+                assert 0, "jitdriver of set_param() not found"
             funcname = op.args[2].value
-            if funcname not in closures:
-                closures[funcname] = make_closure('set_param_' + funcname)
+            key = jd, funcname
+            if key not in closures:
+                closures[key] = make_closure(jd, 'set_param_' + funcname)
             op.opname = 'direct_call'
-            op.args[:3] = [closures[funcname]]
+            op.args[:3] = [closures[key]]
 
     def rewrite_force_virtual(self, vrefinfo):
         if self.cpu.ts.name != 'lltype':

Modified: pypy/branch/multijit-4/pypy/jit/metainterp/warmstate.py
==============================================================================
--- pypy/branch/multijit-4/pypy/jit/metainterp/warmstate.py	(original)
+++ pypy/branch/multijit-4/pypy/jit/metainterp/warmstate.py	Sat Jun 12 09:32:17 2010
@@ -1,7 +1,7 @@
 import sys
 from pypy.rpython.lltypesystem import lltype, llmemory, rstr
 from pypy.rpython.ootypesystem import ootype
-from pypy.rpython.annlowlevel import hlstr, cast_base_ptr_to_instance
+from pypy.rpython.annlowlevel import hlstr, llstr, cast_base_ptr_to_instance
 from pypy.rpython.annlowlevel import cast_object_to_ptr
 from pypy.rlib.objectmodel import specialize, we_are_translated, r_dict
 from pypy.rlib.rarithmetic import intmask
@@ -120,6 +120,16 @@
     else:
         assert False
 
+class JitCell(BaseJitCell):
+    # the counter can mean the following things:
+    #     counter >=  0: not yet traced, wait till threshold is reached
+    #     counter == -1: there is an entry bridge for this cell
+    #     counter == -2: tracing is currently going on for this cell
+    counter = 0
+    compiled_merge_points = None
+    dont_trace_here = False
+    entry_loop_token = None
+
 # ____________________________________________________________
 
 
@@ -127,9 +137,10 @@
     THRESHOLD_LIMIT = sys.maxint // 2
     default_jitcell_dict = None
 
-    def __init__(self, warmrunnerdesc):
+    def __init__(self, warmrunnerdesc, jitdriver_sd):
         "NOT_RPYTHON"
         self.warmrunnerdesc = warmrunnerdesc
+        self.jitdriver_sd = jitdriver_sd
         try:
             self.profiler = warmrunnerdesc.metainterp_sd.profiler
         except AttributeError:       # for tests
@@ -195,8 +206,9 @@
             return self.maybe_compile_and_run
 
         metainterp_sd = self.warmrunnerdesc.metainterp_sd
-        vinfo = self.warmrunnerdesc.virtualizable_info
-        num_green_args = self.warmrunnerdesc.num_green_args
+        jitdriver_sd = self.jitdriver_sd
+        vinfo = jitdriver_sd.virtualizable_info
+        num_green_args = jitdriver_sd.num_green_args
         get_jitcell = self.make_jitcell_getter()
         set_future_values = self.make_set_future_values()
         self.make_jitdriver_callbacks()
@@ -206,7 +218,6 @@
             """Entry point to the JIT.  Called at the point with the
             can_enter_jit() hint.
             """
-            globaldata = metainterp_sd.globaldata
             if NonConstant(False):
                 # make sure we always see the saner optimizer from an
                 # annotation point of view, otherwise we get lots of
@@ -234,11 +245,12 @@
                     return
                 # bound reached; start tracing
                 from pypy.jit.metainterp.pyjitpl import MetaInterp
-                metainterp = MetaInterp(metainterp_sd)
+                metainterp = MetaInterp(metainterp_sd, jitdriver_sd)
                 # set counter to -2, to mean "tracing in effect"
                 cell.counter = -2
                 try:
-                    loop_token = metainterp.compile_and_run_once(*args)
+                    loop_token = metainterp.compile_and_run_once(jitdriver_sd,
+                                                                 *args)
                 finally:
                     if cell.counter == -2:
                         cell.counter = 0
@@ -264,7 +276,8 @@
                 metainterp_sd.profiler.end_running()
                 if vinfo is not None:
                     vinfo.reset_vable_token(virtualizable)
-                loop_token = fail_descr.handle_fail(metainterp_sd)
+                loop_token = fail_descr.handle_fail(metainterp_sd,
+                                                    jitdriver_sd)
        
         maybe_compile_and_run._dont_inline_ = True
         self.maybe_compile_and_run = maybe_compile_and_run
@@ -277,8 +290,8 @@
         if hasattr(self, 'unwrap_greenkey'):
             return self.unwrap_greenkey
         #
-        warmrunnerdesc = self.warmrunnerdesc
-        green_args_spec = unrolling_iterable(warmrunnerdesc.green_args_spec)
+        jitdriver_sd = self.jitdriver_sd
+        green_args_spec = unrolling_iterable(jitdriver_sd._green_args_spec)
         #
         def unwrap_greenkey(greenkey):
             greenargs = ()
@@ -302,20 +315,10 @@
         if hasattr(self, 'jit_getter'):
             return self.jit_getter
         #
-        class JitCell(BaseJitCell):
-            # the counter can mean the following things:
-            #     counter >=  0: not yet traced, wait till threshold is reached
-            #     counter == -1: there is an entry bridge for this cell
-            #     counter == -2: tracing is currently going on for this cell
-            counter = 0
-            compiled_merge_points = None
-            dont_trace_here = False
-            entry_loop_token = None
-        #
-        if self.warmrunnerdesc.get_jitcell_at_ptr is None:
-            jit_getter = self._make_jitcell_getter_default(JitCell)
+        if self.jitdriver_sd._get_jitcell_at_ptr is None:
+            jit_getter = self._make_jitcell_getter_default()
         else:
-            jit_getter = self._make_jitcell_getter_custom(JitCell)
+            jit_getter = self._make_jitcell_getter_custom()
         #
         unwrap_greenkey = self.make_unwrap_greenkey()
         #
@@ -327,10 +330,10 @@
         #
         return jit_getter
 
-    def _make_jitcell_getter_default(self, JitCell):
+    def _make_jitcell_getter_default(self):
         "NOT_RPYTHON"
-        warmrunnerdesc = self.warmrunnerdesc
-        green_args_spec = unrolling_iterable(warmrunnerdesc.green_args_spec)
+        jitdriver_sd = self.jitdriver_sd
+        green_args_spec = unrolling_iterable(jitdriver_sd._green_args_spec)
         #
         def comparekey(greenargs1, greenargs2):
             i = 0
@@ -361,11 +364,11 @@
             return cell
         return get_jitcell
 
-    def _make_jitcell_getter_custom(self, JitCell):
+    def _make_jitcell_getter_custom(self):
         "NOT_RPYTHON"
         rtyper = self.warmrunnerdesc.rtyper
-        get_jitcell_at_ptr = self.warmrunnerdesc.get_jitcell_at_ptr
-        set_jitcell_at_ptr = self.warmrunnerdesc.set_jitcell_at_ptr
+        get_jitcell_at_ptr = self.jitdriver_sd._get_jitcell_at_ptr
+        set_jitcell_at_ptr = self.jitdriver_sd._set_jitcell_at_ptr
         lltohlhack = {}
         #
         def get_jitcell(*greenargs):
@@ -415,9 +418,10 @@
             return self.set_future_values
 
         warmrunnerdesc = self.warmrunnerdesc
+        jitdriver_sd   = self.jitdriver_sd
         cpu = warmrunnerdesc.cpu
-        vinfo = warmrunnerdesc.virtualizable_info
-        red_args_types = unrolling_iterable(warmrunnerdesc.red_args_types)
+        vinfo = jitdriver_sd.virtualizable_info
+        red_args_types = unrolling_iterable(jitdriver_sd._red_args_types)
         #
         def set_future_values(*redargs):
             i = 0
@@ -428,8 +432,8 @@
                 set_future_values_from_vinfo(*redargs)
         #
         if vinfo is not None:
-            i0 = len(warmrunnerdesc.red_args_types)
-            num_green_args = warmrunnerdesc.num_green_args
+            i0 = len(jitdriver_sd._red_args_types)
+            num_green_args = jitdriver_sd.num_green_args
             vable_static_fields = unrolling_iterable(
                 zip(vinfo.static_extra_types, vinfo.static_fields))
             vable_array_fields = unrolling_iterable(
@@ -464,7 +468,7 @@
         if hasattr(self, 'get_location_str'):
             return
         #
-        can_inline_ptr = self.warmrunnerdesc.can_inline_ptr
+        can_inline_ptr = self.jitdriver_sd._can_inline_ptr
         unwrap_greenkey = self.make_unwrap_greenkey()
         if can_inline_ptr is None:
             def can_inline_callable(*greenargs):
@@ -497,10 +501,15 @@
         self.get_assembler_token = get_assembler_token
         
         #
-        get_location_ptr = self.warmrunnerdesc.get_printable_location_ptr
+        get_location_ptr = self.jitdriver_sd._get_printable_location_ptr
         if get_location_ptr is None:
+            missing = '(no jitdriver.get_printable_location!)'
+            missingll = llstr(missing)
             def get_location_str(greenkey):
-                return '(no jitdriver.get_printable_location!)'
+                if we_are_translated():
+                    return missingll
+                else:
+                    return missing
         else:
             rtyper = self.warmrunnerdesc.rtyper
             unwrap_greenkey = self.make_unwrap_greenkey()
@@ -514,7 +523,7 @@
                 return res
         self.get_location_str = get_location_str
         #
-        confirm_enter_jit_ptr = self.warmrunnerdesc.confirm_enter_jit_ptr
+        confirm_enter_jit_ptr = self.jitdriver_sd._confirm_enter_jit_ptr
         if confirm_enter_jit_ptr is None:
             def confirm_enter_jit(*args):
                 return True



More information about the Pypy-commit mailing list