[pypy-svn] r75271 - in pypy/branch/multijit-3/pypy/jit/metainterp: . test

arigo at codespeak.net arigo at codespeak.net
Fri Jun 11 15:56:06 CEST 2010


Author: arigo
Date: Fri Jun 11 15:56:03 2010
New Revision: 75271

Modified:
   pypy/branch/multijit-3/pypy/jit/metainterp/blackhole.py
   pypy/branch/multijit-3/pypy/jit/metainterp/compile.py
   pypy/branch/multijit-3/pypy/jit/metainterp/history.py
   pypy/branch/multijit-3/pypy/jit/metainterp/jitdriver.py
   pypy/branch/multijit-3/pypy/jit/metainterp/pyjitpl.py
   pypy/branch/multijit-3/pypy/jit/metainterp/resume.py
   pypy/branch/multijit-3/pypy/jit/metainterp/test/test_basic.py
   pypy/branch/multijit-3/pypy/jit/metainterp/warmspot.py
   pypy/branch/multijit-3/pypy/jit/metainterp/warmstate.py
Log:
Tests start passing again.


Modified: pypy/branch/multijit-3/pypy/jit/metainterp/blackhole.py
==============================================================================
--- pypy/branch/multijit-3/pypy/jit/metainterp/blackhole.py	(original)
+++ pypy/branch/multijit-3/pypy/jit/metainterp/blackhole.py	Fri Jun 11 15:56:03 2010
@@ -315,27 +315,12 @@
     def get_tmpreg_f(self):
         return self.tmpreg_f
 
-    def final_result_i(self):
-        assert self._return_type == 'i'
-        return self.get_tmpreg_i()
-
-    def final_result_r(self):
-        assert self._return_type == 'r'
-        return self.get_tmpreg_r()
-
-    def final_result_f(self):
-        assert self._return_type == 'f'
-        return self.get_tmpreg_f()
-
-    def final_result_v(self):
-        assert self._return_type == 'v'
-
     def _final_result_anytype(self):
         "NOT_RPYTHON"
-        if self._return_type == 'i': return self.final_result_i()
-        if self._return_type == 'r': return self.final_result_r()
-        if self._return_type == 'f': return self.final_result_f()
-        if self._return_type == 'v': return self.final_result_v()
+        if self._return_type == 'i': return self.get_tmpreg_i()
+        if self._return_type == 'r': return self.get_tmpreg_r()
+        if self._return_type == 'f': return self.get_tmpreg_v()
+        if self._return_type == 'v': return None
         raise ValueError(self._return_type)
 
     def cleanup_registers(self):
@@ -790,6 +775,7 @@
             # call the interpreter main loop from here, and just return its
             # result.
             sd = self.builder.metainterp_sd
+            xxxxxxxxxxxx
             if sd.result_type == 'void':
                 self.bhimpl_recursive_call_v(jdindex, *args)
                 self.bhimpl_void_return()
@@ -1176,11 +1162,11 @@
             self._done_with_this_frame()
         kind = self._return_type
         if kind == 'i':
-            caller._setup_return_value_i(self.final_result_i())
+            caller._setup_return_value_i(self.get_tmpreg_i())
         elif kind == 'r':
-            caller._setup_return_value_r(self.final_result_r())
+            caller._setup_return_value_r(self.get_tmpreg_r())
         elif kind == 'f':
-            caller._setup_return_value_f(self.final_result_f())
+            caller._setup_return_value_f(self.get_tmpreg_f())
         else:
             assert kind == 'v'
         return lltype.nullptr(rclass.OBJECTPTR.TO)
@@ -1246,15 +1232,15 @@
         # rare case: we only get there if the blackhole interps all returned
         # normally (in general we get a ContinueRunningNormally exception).
         sd = self.builder.metainterp_sd
-        if sd.result_type == 'void':
-            self.final_result_v()
+        kind = self._return_type
+        if kind == 'v':
             raise sd.DoneWithThisFrameVoid()
-        elif sd.result_type == 'int':
-            raise sd.DoneWithThisFrameInt(self.final_result_i())
-        elif sd.result_type == 'ref':
-            raise sd.DoneWithThisFrameRef(self.cpu, self.final_result_r())
-        elif sd.result_type == 'float':
-            raise sd.DoneWithThisFrameFloat(self.final_result_f())
+        elif kind == 'i':
+            raise sd.DoneWithThisFrameInt(self.get_tmpreg_i())
+        elif kind == 'r':
+            raise sd.DoneWithThisFrameRef(self.cpu, self.get_tmpreg_r())
+        elif kind == 'f':
+            raise sd.DoneWithThisFrameFloat(self.get_tmpreg_f())
         else:
             assert False
 
@@ -1288,12 +1274,14 @@
             blackholeinterp.builder.release_interp(blackholeinterp)
         blackholeinterp = blackholeinterp.nextblackholeinterp
 
-def resume_in_blackhole(metainterp_sd, resumedescr, all_virtuals=None):
+def resume_in_blackhole(metainterp_sd, jitdriver_sd, resumedescr,
+                        all_virtuals=None):
     from pypy.jit.metainterp.resume import blackhole_from_resumedata
     debug_start('jit-blackhole')
     metainterp_sd.profiler.start_blackhole()
     blackholeinterp = blackhole_from_resumedata(
         metainterp_sd.blackholeinterpbuilder,
+        jitdriver_sd,
         resumedescr,
         all_virtuals)
     current_exc = blackholeinterp._prepare_resume_from_failure(

Modified: pypy/branch/multijit-3/pypy/jit/metainterp/compile.py
==============================================================================
--- pypy/branch/multijit-3/pypy/jit/metainterp/compile.py	(original)
+++ pypy/branch/multijit-3/pypy/jit/metainterp/compile.py	Fri Jun 11 15:56:03 2010
@@ -63,11 +63,12 @@
     # make a copy, because optimize_loop can mutate the ops and descrs
     loop.operations = [op.clone() for op in ops]
     metainterp_sd = metainterp.staticdata
+    jitdriver_sd = metainterp.jitdriver_sd
     loop_token = make_loop_token(len(loop.inputargs))
     loop.token = loop_token
     loop.operations[-1].descr = loop_token     # patch the target of the JUMP
     try:
-        old_loop_token = metainterp_sd.state.optimize_loop(
+        old_loop_token = jitdriver_sd._state.optimize_loop(
             metainterp_sd, old_loop_tokens, loop)
     except InvalidLoop:
         return None
@@ -141,32 +142,32 @@
     pass
 
 class DoneWithThisFrameDescrVoid(_DoneWithThisFrameDescr):
-    def handle_fail(self, metainterp_sd):
-        assert metainterp_sd.result_type == 'void'
+    def handle_fail(self, metainterp_sd, jitdriver_sd):
+        assert jitdriver_sd.result_type == history.VOID
         raise metainterp_sd.DoneWithThisFrameVoid()
 
 class DoneWithThisFrameDescrInt(_DoneWithThisFrameDescr):
-    def handle_fail(self, metainterp_sd):
-        assert metainterp_sd.result_type == 'int'
+    def handle_fail(self, metainterp_sd, jitdriver_sd):
+        assert jitdriver_sd.result_type == history.INT
         result = metainterp_sd.cpu.get_latest_value_int(0)
         raise metainterp_sd.DoneWithThisFrameInt(result)
 
 class DoneWithThisFrameDescrRef(_DoneWithThisFrameDescr):
-    def handle_fail(self, metainterp_sd):
-        assert metainterp_sd.result_type == 'ref'
+    def handle_fail(self, metainterp_sd, jitdriver_sd):
+        assert jitdriver_sd.result_type == history.REF
         cpu = metainterp_sd.cpu
         result = cpu.get_latest_value_ref(0)
         cpu.clear_latest_values(1)
         raise metainterp_sd.DoneWithThisFrameRef(cpu, result)
 
 class DoneWithThisFrameDescrFloat(_DoneWithThisFrameDescr):
-    def handle_fail(self, metainterp_sd):
-        assert metainterp_sd.result_type == 'float'
+    def handle_fail(self, metainterp_sd, jitdriver_sd):
+        assert jitdriver_sd.result_type == history.FLOAT
         result = metainterp_sd.cpu.get_latest_value_float(0)
         raise metainterp_sd.DoneWithThisFrameFloat(result)
 
 class ExitFrameWithExceptionDescrRef(_DoneWithThisFrameDescr):
-    def handle_fail(self, metainterp_sd):
+    def handle_fail(self, metainterp_sd, jitdriver_sd):
         cpu = metainterp_sd.cpu
         value = cpu.get_latest_value_ref(0)
         cpu.clear_latest_values(1)
@@ -258,22 +259,27 @@
             # a negative value
             self._counter = cnt | i
 
-    def handle_fail(self, metainterp_sd):
-        if self.must_compile(metainterp_sd):
-            return self._trace_and_compile_from_bridge(metainterp_sd)
+    def handle_fail(self, metainterp_sd, jitdriver_sd):
+        if self.must_compile(metainterp_sd, jitdriver_sd):
+            return self._trace_and_compile_from_bridge(metainterp_sd,
+                                                       jitdriver_sd)
         else:
             from pypy.jit.metainterp.blackhole import resume_in_blackhole
-            resume_in_blackhole(metainterp_sd, self)
+            resume_in_blackhole(metainterp_sd, jitdriver_sd, self)
             assert 0, "unreachable"
 
-    def _trace_and_compile_from_bridge(self, metainterp_sd):
+    def _trace_and_compile_from_bridge(self, metainterp_sd, jitdriver_sd):
+        # 'jitdriver_sd' corresponds to the outermost one, i.e. the one
+        # of the jit_merge_point where we started the loop, even if the
+        # loop itself may contain temporarily recursion into other
+        # jitdrivers.
         from pypy.jit.metainterp.pyjitpl import MetaInterp
-        metainterp = MetaInterp(metainterp_sd)
+        metainterp = MetaInterp(metainterp_sd, jitdriver_sd)
         return metainterp.handle_guard_failure(self)
     _trace_and_compile_from_bridge._dont_inline_ = True
 
-    def must_compile(self, metainterp_sd):
-        trace_eagerness = metainterp_sd.state.trace_eagerness
+    def must_compile(self, metainterp_sd, jitdriver_sd):
+        trace_eagerness = jitdriver_sd._state.trace_eagerness
         if self._counter >= 0:
             self._counter += 1
             return self._counter >= trace_eagerness
@@ -333,7 +339,7 @@
 
 class ResumeGuardForcedDescr(ResumeGuardDescr):
 
-    def handle_fail(self, metainterp_sd):
+    def handle_fail(self, metainterp_sd, jitdriver_sd):
         # Failures of a GUARD_NOT_FORCED are never compiled, but
         # always just blackholed.  First fish for the data saved when
         # the virtualrefs and virtualizable have been forced by
@@ -343,7 +349,7 @@
         all_virtuals = self.fetch_data(token)
         if all_virtuals is None:
             all_virtuals = []
-        resume_in_blackhole(metainterp_sd, self, all_virtuals)
+        resume_in_blackhole(metainterp_sd, jitdriver_sd, self, all_virtuals)
         assert 0, "unreachable"
 
     @staticmethod
@@ -464,6 +470,7 @@
         # a loop at all but ends in a jump to the target loop.  It starts
         # with completely unoptimized arguments, as in the interpreter.
         metainterp_sd = metainterp.staticdata
+        jitdriver_sd = metainterp.jitdriver_sd
         metainterp.history.inputargs = self.redkey
         new_loop_token = make_loop_token(len(self.redkey))
         new_loop.greenkey = self.original_greenkey
@@ -471,12 +478,11 @@
         new_loop.token = new_loop_token
         send_loop_to_backend(metainterp_sd, new_loop, "entry bridge")
         # send the new_loop to warmspot.py, to be called directly the next time
-        metainterp_sd.state.attach_unoptimized_bridge_from_interp(
+        jitdriver_sd._state.attach_unoptimized_bridge_from_interp(
             self.original_greenkey,
             new_loop_token)
         # store the new loop in compiled_merge_points too
-        glob = metainterp_sd.globaldata
-        old_loop_tokens = glob.get_compiled_merge_points(
+        old_loop_tokens = metainterp.get_compiled_merge_points(
             self.original_greenkey)
         # it always goes at the end of the list, as it is the most
         # general loop token
@@ -500,8 +506,9 @@
     # clone ops, as optimize_bridge can mutate the ops
     new_loop.operations = [op.clone() for op in metainterp.history.operations]
     metainterp_sd = metainterp.staticdata
+    jitdriver_sd = metainterp.jitdriver_sd
     try:
-        target_loop_token = metainterp_sd.state.optimize_bridge(metainterp_sd,
+        target_loop_token = jitdriver_sd._state.optimize_bridge(metainterp_sd,
                                                                 old_loop_tokens,
                                                                 new_loop)
     except InvalidLoop:

Modified: pypy/branch/multijit-3/pypy/jit/metainterp/history.py
==============================================================================
--- pypy/branch/multijit-3/pypy/jit/metainterp/history.py	(original)
+++ pypy/branch/multijit-3/pypy/jit/metainterp/history.py	Fri Jun 11 15:56:03 2010
@@ -174,7 +174,7 @@
 class AbstractFailDescr(AbstractDescr):
     index = -1
 
-    def handle_fail(self, metainterp_sd):
+    def handle_fail(self, metainterp_sd, jitdriver_sd):
         raise NotImplementedError
     def compile_and_attach(self, metainterp, new_loop):
         raise NotImplementedError

Modified: pypy/branch/multijit-3/pypy/jit/metainterp/jitdriver.py
==============================================================================
--- pypy/branch/multijit-3/pypy/jit/metainterp/jitdriver.py	(original)
+++ pypy/branch/multijit-3/pypy/jit/metainterp/jitdriver.py	Fri Jun 11 15:56:03 2010
@@ -13,5 +13,7 @@
     #    self.index             ... pypy.jit.codewriter.call
     #    self.mainjitcode       ... pypy.jit.codewriter.call
 
+    # warmspot sets extra attributes starting with '_' for its own use.
+
     def _freeze_(self):
         return True

Modified: pypy/branch/multijit-3/pypy/jit/metainterp/pyjitpl.py
==============================================================================
--- pypy/branch/multijit-3/pypy/jit/metainterp/pyjitpl.py	(original)
+++ pypy/branch/multijit-3/pypy/jit/metainterp/pyjitpl.py	Fri Jun 11 15:56:03 2010
@@ -557,6 +557,7 @@
 
     def _get_arrayitem_vable_index(self, pc, arrayfielddescr, indexbox):
         indexbox = self.implement_guard_value(pc, indexbox)
+        xxxxxxx
         vinfo = self.metainterp.jitdriver_sd.virtualizable_info
         virtualizable_box = self.metainterp.virtualizable_boxes[-1]
         virtualizable = vinfo.unwrap_virtualizable_box(virtualizable_box)
@@ -606,6 +607,7 @@
             arraybox = self.metainterp.execute_and_record(rop.GETFIELD_GC,
                                                           fdescr, box)
             return self.execute_with_descr(rop.ARRAYLEN_GC, adescr, arraybox)
+        xxxxxxx
         vinfo = self.metainterp.jitdriver_sd.virtualizable_info
         virtualizable_box = self.metainterp.virtualizable_boxes[-1]
         virtualizable = vinfo.unwrap_virtualizable_box(virtualizable_box)
@@ -655,8 +657,8 @@
     opimpl_residual_call_irf_f = _opimpl_residual_call3
     opimpl_residual_call_irf_v = _opimpl_residual_call3
 
-    @arguments("boxes3", "boxes3")
-    def _opimpl_recursive_call(self, greenboxes, redboxes):
+    @arguments("int", "boxes3", "boxes3")
+    def _opimpl_recursive_call(self, jdindex, greenboxes, redboxes):
         allboxes = greenboxes + redboxes
         metainterp_sd = self.metainterp.staticdata
         xxxx
@@ -766,17 +768,18 @@
             raise CannotInlineCanEnterJit()
         self.metainterp.seen_can_enter_jit = True
 
-    def verify_green_args(self, varargs):
-        num_green_args = self.metainterp.staticdata.num_green_args
+    def verify_green_args(self, jdindex, varargs):
+        jitdriver = self.metainterp.staticdata.jitdrivers_sd[jdindex]
+        num_green_args = jitdriver.num_green_args
         assert len(varargs) == num_green_args
         for i in range(num_green_args):
             assert isinstance(varargs[i], Const)
 
-    @arguments("orgpc", "boxes3", "boxes3")
-    def opimpl_jit_merge_point(self, orgpc, greenboxes, redboxes):
-        self.verify_green_args(greenboxes)
+    @arguments("orgpc", "int", "boxes3", "boxes3")
+    def opimpl_jit_merge_point(self, orgpc, jdindex, greenboxes, redboxes):
+        self.verify_green_args(jdindex, greenboxes)
         # xxx we may disable the following line in some context later
-        self.debug_merge_point(greenboxes)
+        self.debug_merge_point(jdindex, greenboxes)
         if self.metainterp.seen_can_enter_jit:
             self.metainterp.seen_can_enter_jit = False
             # Assert that it's impossible to arrive here with in_recursion
@@ -784,6 +787,7 @@
             # to True by opimpl_can_enter_jit, which should be executed
             # just before opimpl_jit_merge_point (no recursion inbetween).
             assert not self.metainterp.in_recursion
+            assert jdindex == self.metainterp.jitdriver_sd.index
             # Set self.pc to point to jit_merge_point instead of just after:
             # if reached_can_enter_jit() raises SwitchToBlackhole, then the
             # pc is still at the jit_merge_point, which is a point that is
@@ -793,10 +797,10 @@
             self.metainterp.reached_can_enter_jit(greenboxes, redboxes)
             self.pc = saved_pc
 
-    def debug_merge_point(self, greenkey):
+    def debug_merge_point(self, jdindex, greenkey):
         # debugging: produce a DEBUG_MERGE_POINT operation
-        sd = self.metainterp.staticdata
-        loc = sd.state.get_location_str(greenkey)
+        jitdriver = self.metainterp.staticdata.jitdrivers_sd[jdindex]
+        loc = jitdriver._state.get_location_str(greenkey)
         debug_print(loc)
         constloc = self.metainterp.cpu.ts.conststr(loc)
         self.metainterp.history.record(rop.DEBUG_MERGE_POINT,
@@ -1131,6 +1135,11 @@
         self._addr2name_keys = [key for key, value in list_of_addr2name]
         self._addr2name_values = [value for key, value in list_of_addr2name]
 
+    def setup_jitdrivers_sd(self, optimizer):
+        if optimizer is not None:
+            for jd in self.jitdrivers_sd:
+                jd._state.set_param_optimizer(optimizer)
+
     def finish_setup(self, codewriter, optimizer=None):
         from pypy.jit.metainterp.blackhole import BlackholeInterpBuilder
         self.blackholeinterpbuilder = BlackholeInterpBuilder(codewriter, self)
@@ -1143,14 +1152,8 @@
         #
         self.jitdrivers_sd = codewriter.callcontrol.jitdrivers_sd
         self.virtualref_info = codewriter.callcontrol.virtualref_info
+        self.setup_jitdrivers_sd(optimizer)
         #
-        warmrunnerdesc = self.warmrunnerdesc
-        if warmrunnerdesc is not None:
-            XXX
-            self.num_green_args = warmrunnerdesc.num_green_args
-            self.state = warmrunnerdesc.state
-            if optimizer is not None:
-                self.state.set_param_optimizer(optimizer)
         self.globaldata = MetaInterpGlobalData(self)
 
     def _setup_once(self):
@@ -1229,34 +1232,19 @@
         self.loopnumbering = 0
         self.resume_virtuals = {}
         self.resume_virtuals_not_translated = []
-        #
-        state = None     # XXX staticdata.state
-        if state is not None:
-            self.jit_cell_at_key = state.jit_cell_at_key
-        else:
-            # for tests only; not RPython
-            class JitCell:
-                compiled_merge_points = None
-            _jitcell_dict = {}
-            def jit_cell_at_key(greenkey):
-                greenkey = tuple(greenkey)
-                return _jitcell_dict.setdefault(greenkey, JitCell())
-            self.jit_cell_at_key = jit_cell_at_key
-
-    def get_compiled_merge_points(self, greenkey):
-        cell = self.jit_cell_at_key(greenkey)
-        if cell.compiled_merge_points is None:
-            cell.compiled_merge_points = []
-        return cell.compiled_merge_points
 
 # ____________________________________________________________
 
 class MetaInterp(object):
     in_recursion = 0
 
-    def __init__(self, staticdata):
+    def __init__(self, staticdata, jitdriver_sd):
         self.staticdata = staticdata
         self.cpu = staticdata.cpu
+        self.jitdriver_sd = jitdriver_sd
+        # Note: self.jitdriver_sd is the JitDriverStaticData that corresponds
+        # to the current loop -- the outermost one.  Be careful, because
+        # during recursion we can also see other jitdrivers.
         self.portal_trace_positions = []
         self.free_frames_list = []
         self.last_exc_value_box = None
@@ -1363,7 +1351,6 @@
             raise AssertionError
 
     def create_empty_history(self):
-        warmrunnerstate = self.staticdata.state
         self.history = history.History()
         self.staticdata.stats.set_history(self.history)
 
@@ -1473,12 +1460,11 @@
         self.resumekey.reset_counter_from_failure()
 
     def blackhole_if_trace_too_long(self):
-        warmrunnerstate = self.staticdata.state
+        warmrunnerstate = self.jitdriver_sd._state
         if len(self.history.operations) > warmrunnerstate.trace_limit:
             greenkey_of_huge_function = self.find_biggest_function()
             self.portal_trace_positions = None
             if greenkey_of_huge_function is not None:
-                warmrunnerstate = self.staticdata.state
                 warmrunnerstate.disable_noninlinable_function(
                     greenkey_of_huge_function)
             raise SwitchToBlackhole(ABORT_TOO_LONG)
@@ -1507,13 +1493,16 @@
 
     @specialize.arg(1)
     def compile_and_run_once(self, jitdriver_sd, *args):
+        # NB. we pass explicity 'jitdriver_sd' around here, even though it
+        # is also available as 'self.jitdriver_sd', because we need to
+        # specialize this function and a few other ones for the '*args'.
         debug_start('jit-tracing')
         self.staticdata._setup_once()
         self.staticdata.profiler.start_tracing()
+        assert jitdriver_sd is self.jitdriver_sd
         self.create_empty_history()
         try:
             original_boxes = self.initialize_original_boxes(jitdriver_sd,*args)
-            self.jitdriver_sd = jitdriver_sd
             return self._compile_and_run_once(original_boxes)
         finally:
             self.staticdata.profiler.end_tracing()
@@ -1611,11 +1600,12 @@
         # Search in current_merge_points for original_boxes with compatible
         # green keys, representing the beginning of the same loop as the one
         # we end now. 
-       
+
+        num_green_args = self.jitdriver_sd.num_green_args
         for j in range(len(self.current_merge_points)-1, -1, -1):
             original_boxes, start = self.current_merge_points[j]
             assert len(original_boxes) == len(live_arg_boxes) or start < 0
-            for i in range(self.staticdata.num_green_args):
+            for i in range(num_green_args):
                 box1 = original_boxes[i]
                 box2 = live_arg_boxes[i]
                 assert isinstance(box1, Const)
@@ -1638,7 +1628,7 @@
 
     def designate_target_loop(self, gmp):
         loop_token = gmp.target_loop_token
-        num_green_args = self.staticdata.num_green_args
+        num_green_args = self.jitdriver_sd.num_green_args
         residual_args = self.get_residual_args(loop_token.specnodes,
                                                gmp.argboxes[num_green_args:])
         history.set_future_values(self.cpu, residual_args)
@@ -1679,12 +1669,17 @@
             from pypy.jit.metainterp.resoperation import opname
             raise NotImplementedError(opname[opnum])
 
+    def get_compiled_merge_points(self, greenkey):
+        cell = self.jitdriver_sd._state.jit_cell_at_key(greenkey)
+        if cell.compiled_merge_points is None:
+            cell.compiled_merge_points = []
+        return cell.compiled_merge_points
+
     def compile(self, original_boxes, live_arg_boxes, start):
-        num_green_args = self.staticdata.num_green_args
+        num_green_args = self.jitdriver_sd.num_green_args
         self.history.inputargs = original_boxes[num_green_args:]
         greenkey = original_boxes[:num_green_args]
-        glob = self.staticdata.globaldata
-        old_loop_tokens = glob.get_compiled_merge_points(greenkey)
+        old_loop_tokens = self.get_compiled_merge_points(greenkey)
         self.history.record(rop.JUMP, live_arg_boxes[num_green_args:], None)
         loop_token = compile.compile_new_loop(self, old_loop_tokens,
                                               greenkey, start)
@@ -1693,10 +1688,9 @@
         self.history.operations.pop()     # remove the JUMP
 
     def compile_bridge(self, live_arg_boxes):
-        num_green_args = self.staticdata.num_green_args
+        num_green_args = self.jitdriver_sd.num_green_args
         greenkey = live_arg_boxes[:num_green_args]
-        glob = self.staticdata.globaldata
-        old_loop_tokens = glob.get_compiled_merge_points(greenkey)
+        old_loop_tokens = self.get_compiled_merge_points(greenkey)
         if len(old_loop_tokens) == 0:
             return
         self.history.record(rop.JUMP, live_arg_boxes[num_green_args:], None)
@@ -1755,9 +1749,6 @@
 
     @specialize.arg(1)
     def initialize_original_boxes(self, jitdriver_sd, *args):
-        # NB. we pass explicity 'jitdriver_sd' around here, even though it
-        # might also available as 'self.jitdriver_sd', because we need to
-        # specialize these functions for the particular *args.
         original_boxes = []
         self._fill_original_boxes(jitdriver_sd, original_boxes,
                                   jitdriver_sd.num_green_args, *args)
@@ -2164,6 +2155,16 @@
                 position = position3 + 1 + length3
             elif argtype == "orgpc":
                 value = orgpc
+            elif argtype == "int":
+                argcode = argcodes[next_argcode]
+                next_argcode = next_argcode + 1
+                if argcode == 'i':
+                    value = ord(code[position])
+                elif argcode == 'c':
+                    value = signedord(code[position])
+                else:
+                    raise AssertionError("bad argcode")
+                position += 1
             else:
                 raise AssertionError("bad argtype: %r" % (argtype,))
             args += (value,)

Modified: pypy/branch/multijit-3/pypy/jit/metainterp/resume.py
==============================================================================
--- pypy/branch/multijit-3/pypy/jit/metainterp/resume.py	(original)
+++ pypy/branch/multijit-3/pypy/jit/metainterp/resume.py	Fri Jun 11 15:56:03 2010
@@ -709,11 +709,11 @@
 
 # ---------- when resuming for blackholing, get direct values ----------
 
-def blackhole_from_resumedata(blackholeinterpbuilder, storage,
+def blackhole_from_resumedata(blackholeinterpbuilder, jitdriver_sd, storage,
                               all_virtuals=None):
     resumereader = ResumeDataDirectReader(blackholeinterpbuilder.cpu, storage,
                                           all_virtuals)
-    vinfo = blackholeinterpbuilder.metainterp_sd.virtualizable_info
+    vinfo = jitdriver_sd.virtualizable_info
     vrefinfo = blackholeinterpbuilder.metainterp_sd.virtualref_info
     resumereader.consume_vref_and_vable(vrefinfo, vinfo)
     #

Modified: pypy/branch/multijit-3/pypy/jit/metainterp/test/test_basic.py
==============================================================================
--- pypy/branch/multijit-3/pypy/jit/metainterp/test/test_basic.py	(original)
+++ pypy/branch/multijit-3/pypy/jit/metainterp/test/test_basic.py	Fri Jun 11 15:56:03 2010
@@ -16,10 +16,18 @@
     from pypy.jit.codewriter import support, codewriter
     from pypy.jit.metainterp import simple_optimize
 
+    class FakeJitCell:
+        compiled_merge_points = None
+
     class FakeWarmRunnerState:
         def attach_unoptimized_bridge_from_interp(self, greenkey, newloop):
             pass
 
+        def jit_cell_at_key(self, greenkey):
+            assert greenkey == []
+            return self._cell
+        _cell = FakeJitCell()
+
         # pick the optimizer this way
         optimize_loop = staticmethod(simple_optimize.optimize_loop)
         optimize_bridge = staticmethod(simple_optimize.optimize_bridge)
@@ -45,7 +53,7 @@
     testself.cw = cw
     cw.find_all_graphs(JitPolicy())
     #
-    testself.warmrunnerstate = FakeWarmRunnerState()
+    testself.warmrunnerstate = FakeJitDriverSD._state = FakeWarmRunnerState()
     testself.warmrunnerstate.cpu = cpu
     if hasattr(testself, 'finish_setup_for_interp_operations'):
         testself.finish_setup_for_interp_operations()
@@ -88,15 +96,13 @@
     cw = testself.cw
     opt = history.Options(listops=True)
     metainterp_sd = pyjitpl.MetaInterpStaticData(cw.cpu, opt)
-    metainterp_sd.finish_setup(cw, optimizer="bogus")
-    metainterp_sd.state = testself.warmrunnerstate
-    metainterp_sd.state.cpu = metainterp_sd.cpu
-    metainterp = pyjitpl.MetaInterp(metainterp_sd)
+    metainterp_sd.finish_setup(cw)
+    [jitdriver_sd] = metainterp_sd.jitdrivers_sd
+    metainterp = pyjitpl.MetaInterp(metainterp_sd, jitdriver_sd)
     metainterp_sd.DoneWithThisFrameInt = DoneWithThisFrame
     metainterp_sd.DoneWithThisFrameRef = DoneWithThisFrameRef
     metainterp_sd.DoneWithThisFrameFloat = DoneWithThisFrame
     testself.metainterp = metainterp
-    [jitdriver_sd] = cw.callcontrol.jitdrivers_sd
     try:
         metainterp.compile_and_run_once(jitdriver_sd, *args)
     except DoneWithThisFrame, e:
@@ -802,8 +808,9 @@
         translator.config.translation.gc = "boehm"
         warmrunnerdesc = WarmRunnerDesc(translator,
                                         CPUClass=self.CPUClass)
-        warmrunnerdesc.state.set_param_threshold(3)          # for tests
-        warmrunnerdesc.state.set_param_trace_eagerness(0)    # for tests
+        state = warmrunnerdesc.jitdrivers_sd[0]._state
+        state.set_param_threshold(3)          # for tests
+        state.set_param_trace_eagerness(0)    # for tests
         warmrunnerdesc.finish()
         for n, k in [(20, 0), (20, 1)]:
             interp.eval_graph(graph, [n, k])

Modified: pypy/branch/multijit-3/pypy/jit/metainterp/warmspot.py
==============================================================================
--- pypy/branch/multijit-3/pypy/jit/metainterp/warmspot.py	(original)
+++ pypy/branch/multijit-3/pypy/jit/metainterp/warmspot.py	Fri Jun 11 15:56:03 2010
@@ -21,6 +21,7 @@
 from pypy.jit.metainterp.typesystem import LLTypeHelper, OOTypeHelper
 from pypy.jit.metainterp.jitprof import Profiler, EmptyProfiler
 from pypy.jit.metainterp.jitexc import JitException
+from pypy.jit.metainterp.jitdriver import JitDriverStaticData
 from pypy.jit.codewriter import support, codewriter
 from pypy.jit.codewriter.policy import JitPolicy
 from pypy.rlib.jit import DEBUG_STEPS, DEBUG_DETAILED, DEBUG_OFF, DEBUG_PROFILE
@@ -69,11 +70,12 @@
     translator.config.translation.gc = "boehm"
     translator.config.translation.list_comprehension_operations = True
     warmrunnerdesc = WarmRunnerDesc(translator, backendopt=backendopt, **kwds)
-    warmrunnerdesc.state.set_param_threshold(3)          # for tests
-    warmrunnerdesc.state.set_param_trace_eagerness(2)    # for tests
-    warmrunnerdesc.state.set_param_trace_limit(trace_limit)
-    warmrunnerdesc.state.set_param_inlining(inline)
-    warmrunnerdesc.state.set_param_debug(debug_level)
+    for jd in warmrunnerdesc.jitdrivers_sd:
+        jd._state.set_param_threshold(3)          # for tests
+        jd._state.set_param_trace_eagerness(2)    # for tests
+        jd._state.set_param_trace_limit(trace_limit)
+        jd._state.set_param_inlining(inline)
+        jd._state.set_param_debug(debug_level)
     warmrunnerdesc.finish()
     res = interp.eval_graph(graph, args)
     if not kwds.get('translate_support_code', False):
@@ -110,12 +112,11 @@
         raise Exception("no can_enter_jit found!")
     return results
 
-def find_jit_merge_point(graphs):
+def find_jit_merge_points(graphs):
     results = _find_jit_marker(graphs, 'jit_merge_point')
-    if len(results) != 1:
-        raise Exception("found %d jit_merge_points, need exactly one!" %
-                        (len(results),))
-    return results[0]
+    if not results:
+        raise Exception("no jit_merge_point found!")
+    return results
 
 def find_set_param(graphs):
     return _find_jit_marker(graphs, 'set_param')
@@ -146,8 +147,8 @@
         pyjitpl._warmrunnerdesc = self   # this is a global for debugging only!
         self.set_translator(translator)
         self.build_cpu(CPUClass, **kwds)
-        self.find_portal()
-        self.codewriter = codewriter.CodeWriter(self.cpu, self.portal_graph)
+        self.find_portals()
+        self.codewriter = codewriter.CodeWriter(self.cpu, self.jitdrivers_sd)
         if policy is None:
             policy = JitPolicy()
         policy.set_supports_floats(self.cpu.supports_floats)
@@ -158,35 +159,31 @@
             self.prejit_optimizations(policy, graphs)
 
         self.build_meta_interp(ProfilerClass)
-        self.make_args_specification()
+        self.make_args_specifications()
         #
         from pypy.jit.metainterp.virtualref import VirtualRefInfo
         vrefinfo = VirtualRefInfo(self)
         self.codewriter.setup_vrefinfo(vrefinfo)
-        if self.jitdriver.virtualizables:
-            from pypy.jit.metainterp.virtualizable import VirtualizableInfo
-            self.virtualizable_info = VirtualizableInfo(self)
-            self.codewriter.setup_virtualizable_info(self.virtualizable_info)
-        else:
-            self.virtualizable_info = None
         #
+        self.make_virtualizable_infos()
         self.make_exception_classes()
         self.make_driverhook_graphs()
-        self.make_enter_function()
-        self.rewrite_jit_merge_point(policy)
+        self.make_enter_functions()
+        self.rewrite_jit_merge_points(policy)
 
         verbose = not self.cpu.translate_support_code
         self.codewriter.make_jitcodes(verbose=verbose)
-        self.rewrite_can_enter_jit()
+        self.rewrite_can_enter_jits()
         self.rewrite_set_param()
         self.rewrite_force_virtual(vrefinfo)
         self.add_finish()
         self.metainterp_sd.finish_setup(self.codewriter, optimizer=optimizer)
 
     def finish(self):
-        vinfo = self.virtualizable_info
-        if vinfo is not None:
-            vinfo.finish()
+        vinfos = set([jd.virtualizable_info for jd in self.jitdrivers_sd])
+        for vinfo in vinfos:
+            if vinfo is not None:
+                vinfo.finish()
         if self.cpu.translate_support_code:
             self.annhelper.finish()
 
@@ -198,18 +195,27 @@
         self.rtyper = translator.rtyper
         self.gcdescr = gc.get_description(translator.config)
 
-    def find_portal(self):
+    def find_portals(self):
+        self.jitdrivers_sd = []
         graphs = self.translator.graphs
-        self.jit_merge_point_pos = find_jit_merge_point(graphs)
-        graph, block, pos = self.jit_merge_point_pos
+        for jit_merge_point_pos in find_jit_merge_points(graphs):
+            self.split_graph_and_record_jitdriver(*jit_merge_point_pos)
+        #
+        assert (len(set([jd.jitdriver for jd in self.jitdrivers_sd])) ==
+                len(self.jitdrivers_sd)), \
+                "there are multiple jit_merge_points with the same jitdriver"
+
+    def split_graph_and_record_jitdriver(self, graph, block, pos):
+        jd = JitDriverStaticData()
+        jd._jit_merge_point_pos = (graph, block, pos)
         op = block.operations[pos]
         args = op.args[2:]
         s_binding = self.translator.annotator.binding
-        self.portal_args_s = [s_binding(v) for v in args]
+        jd._portal_args_s = [s_binding(v) for v in args]
         graph = copygraph(graph)
         graph.startblock.isstartblock = False
-        graph.startblock = support.split_before_jit_merge_point(
-            *find_jit_merge_point([graph]))
+        [jmpp] = find_jit_merge_points([graph])
+        graph.startblock = support.split_before_jit_merge_point(*jmpp)
         graph.startblock.isstartblock = True
         # a crash in the following checkgraph() means that you forgot
         # to list some variable in greens=[] or reds=[] in JitDriver.
@@ -218,12 +224,16 @@
             assert isinstance(v, Variable)
         assert len(dict.fromkeys(graph.getargs())) == len(graph.getargs())
         self.translator.graphs.append(graph)
-        self.portal_graph = graph
+        jd.portal_graph = graph
         # it's a bit unbelievable to have a portal without func
         assert hasattr(graph, "func")
         graph.func._dont_inline_ = True
         graph.func._jit_unroll_safe_ = True
-        self.jitdriver = block.operations[pos].args[1].value
+        jd.jitdriver = block.operations[pos].args[1].value
+        jd.portal_runner_ptr = "<not set so far>"
+        jd.result_type = history.getkind(jd.portal_graph.getreturnvar()
+                                         .concretetype)[0]
+        self.jitdrivers_sd.append(jd)
 
     def check_access_directly_sanity(self, graphs):
         from pypy.translator.backendopt.inline import collect_called_graphs
@@ -268,6 +278,17 @@
                                                   ProfilerClass=ProfilerClass,
                                                   warmrunnerdesc=self)
 
+    def make_virtualizable_infos(self):
+        for jd in self.jitdrivers_sd:
+            if jd.jitdriver.virtualizables:
+                XXX
+                from pypy.jit.metainterp.virtualizable import VirtualizableInfo
+                vinfo = VirtualizableInfo(self)
+                YYY  # share!
+            else:
+                vinfo = None
+            jd.virtualizable_info = vinfo
+
     def make_exception_classes(self):
 
         class DoneWithThisFrameVoid(JitException):
@@ -317,6 +338,8 @@
                     self.green_int, self.green_ref, self.green_float,
                     self.red_int, self.red_ref, self.red_float)
 
+        # XXX there is no point any more to not just have the exceptions
+        # as globals
         self.DoneWithThisFrameVoid = DoneWithThisFrameVoid
         self.DoneWithThisFrameInt = DoneWithThisFrameInt
         self.DoneWithThisFrameRef = DoneWithThisFrameRef
@@ -330,11 +353,15 @@
         self.metainterp_sd.ExitFrameWithExceptionRef = ExitFrameWithExceptionRef
         self.metainterp_sd.ContinueRunningNormally = ContinueRunningNormally
 
-    def make_enter_function(self):
+    def make_enter_functions(self):
+        for jd in self.jitdrivers_sd:
+            self.make_enter_function(jd)
+
+    def make_enter_function(self, jd):
         from pypy.jit.metainterp.warmstate import WarmEnterState
-        state = WarmEnterState(self)
+        state = WarmEnterState(self, jd)
         maybe_compile_and_run = state.make_entry_point()
-        self.state = state
+        jd._state = state
 
         def crash_in_jit(e):
             if not we_are_translated():
@@ -359,15 +386,16 @@
             def maybe_enter_jit(*args):
                 maybe_compile_and_run(*args)
             maybe_enter_jit._always_inline_ = True
-        self.maybe_enter_jit_fn = maybe_enter_jit
+        jd._maybe_enter_jit_fn = maybe_enter_jit
 
-        can_inline = self.state.can_inline_greenargs
+        can_inline = state.can_inline_greenargs
+        num_green_args = jd.num_green_args
         def maybe_enter_from_start(*args):
-            if can_inline is not None and not can_inline(*args[:self.num_green_args]):
+            if can_inline is not None and not can_inline(*args[:num_green_args]):
                 maybe_compile_and_run(*args)
         maybe_enter_from_start._always_inline_ = True
-        self.maybe_enter_from_start_fn = maybe_enter_from_start
-        
+        jd._maybe_enter_from_start_fn = maybe_enter_from_start
+
     def make_driverhook_graphs(self):
         from pypy.rlib.jit import BaseJitCell
         bk = self.rtyper.annotator.bookkeeper
@@ -378,22 +406,23 @@
         s_Str = annmodel.SomeString()
         #
         annhelper = MixLevelHelperAnnotator(self.translator.rtyper)
-        self.set_jitcell_at_ptr = self._make_hook_graph(
-            annhelper, self.jitdriver.set_jitcell_at, annmodel.s_None,
-            s_BaseJitCell_not_None)
-        self.get_jitcell_at_ptr = self._make_hook_graph(
-            annhelper, self.jitdriver.get_jitcell_at, s_BaseJitCell_or_None)
-        self.can_inline_ptr = self._make_hook_graph(
-            annhelper, self.jitdriver.can_inline, annmodel.s_Bool)
-        self.get_printable_location_ptr = self._make_hook_graph(
-            annhelper, self.jitdriver.get_printable_location, s_Str)
-        self.confirm_enter_jit_ptr = self._make_hook_graph(
-            annhelper, self.jitdriver.confirm_enter_jit, annmodel.s_Bool,
-            onlygreens=False)
+        for jd in self.jitdrivers_sd:
+            jd._set_jitcell_at_ptr = self._make_hook_graph(jd,
+                annhelper, jd.jitdriver.set_jitcell_at, annmodel.s_None,
+                s_BaseJitCell_not_None)
+            jd._get_jitcell_at_ptr = self._make_hook_graph(jd,
+                annhelper, jd.jitdriver.get_jitcell_at, s_BaseJitCell_or_None)
+            jd._can_inline_ptr = self._make_hook_graph(jd,
+                annhelper, jd.jitdriver.can_inline, annmodel.s_Bool)
+            jd._get_printable_location_ptr = self._make_hook_graph(jd,
+                annhelper, jd.jitdriver.get_printable_location, s_Str)
+            jd._confirm_enter_jit_ptr = self._make_hook_graph(jd,
+                annhelper, jd.jitdriver.confirm_enter_jit, annmodel.s_Bool,
+                onlygreens=False)
         annhelper.finish()
 
-    def _make_hook_graph(self, annhelper, func, s_result, s_first_arg=None,
-                         onlygreens=True):
+    def _make_hook_graph(self, jitdriver_sd, annhelper, func,
+                         s_result, s_first_arg=None, onlygreens=True):
         if func is None:
             return None
         #
@@ -401,38 +430,57 @@
         if s_first_arg is not None:
             extra_args_s.append(s_first_arg)
         #
-        args_s = self.portal_args_s
+        args_s = jitdriver_sd._portal_args_s
         if onlygreens:
-            args_s = args_s[:len(self.green_args_spec)]
+            args_s = args_s[:len(jitdriver_sd._green_args_spec)]
         graph = annhelper.getgraph(func, extra_args_s + args_s, s_result)
         funcptr = annhelper.graph2delayed(graph)
         return funcptr
 
-    def make_args_specification(self):
-        graph, block, index = self.jit_merge_point_pos
+    def make_args_specifications(self):
+        for jd in self.jitdrivers_sd:
+            self.make_args_specification(jd)
+
+    def make_args_specification(self, jd):
+        graph, block, index = jd._jit_merge_point_pos
         op = block.operations[index]
         greens_v, reds_v = support.decode_hp_hint_args(op)
         ALLARGS = [v.concretetype for v in (greens_v + reds_v)]
-        self.green_args_spec = [v.concretetype for v in greens_v]
-        self.red_args_types = [history.getkind(v.concretetype) for v in reds_v]
-        self.num_green_args = len(self.green_args_spec)
+        jd._green_args_spec = [v.concretetype for v in greens_v]
+        jd._red_args_types = [history.getkind(v.concretetype) for v in reds_v]
+        jd.num_green_args = len(jd._green_args_spec)
         RESTYPE = graph.getreturnvar().concretetype
-        (self.JIT_ENTER_FUNCTYPE,
-         self.PTR_JIT_ENTER_FUNCTYPE) = self.cpu.ts.get_FuncType(ALLARGS, lltype.Void)
-        (self.PORTAL_FUNCTYPE,
-         self.PTR_PORTAL_FUNCTYPE) = self.cpu.ts.get_FuncType(ALLARGS, RESTYPE)
-        (_, self.PTR_ASSEMBLER_HELPER_FUNCTYPE) = self.cpu.ts.get_FuncType(
+        (jd._JIT_ENTER_FUNCTYPE,
+         jd._PTR_JIT_ENTER_FUNCTYPE) = self.cpu.ts.get_FuncType(ALLARGS, lltype.Void)
+        (jd._PORTAL_FUNCTYPE,
+         jd._PTR_PORTAL_FUNCTYPE) = self.cpu.ts.get_FuncType(ALLARGS, RESTYPE)
+        (_, jd._PTR_ASSEMBLER_HELPER_FUNCTYPE) = self.cpu.ts.get_FuncType(
             [lltype.Signed, llmemory.GCREF], RESTYPE)
 
-    def rewrite_can_enter_jit(self):
-        FUNC = self.JIT_ENTER_FUNCTYPE
-        FUNCPTR = self.PTR_JIT_ENTER_FUNCTYPE
-        jit_enter_fnptr = self.helper_func(FUNCPTR, self.maybe_enter_jit_fn)
+    def rewrite_can_enter_jits(self):
+        can_enter_jits = find_can_enter_jit(self.translator.graphs)
+        sublists = {}
+        for jd in self.jitdrivers_sd:
+            sublists[jd.jitdriver] = []
+        for graph, block, index in can_enter_jits:
+            op = block.operations[index]
+            jitdriver = op.args[1].value
+            assert jitdriver in sublists, \
+                   "can_enter_jit with no matching jit_merge_point"
+            sublists[jitdriver].append((graph, block, index))
+        for jd in self.jitdrivers_sd:
+            sublist = sublists[jd.jitdriver]
+            assert len(sublist) > 0, \
+                   "found no can_enter_jit for %r" % (jd.jitdriver,)
+            self.rewrite_can_enter_jit(jd, sublist)
+
+    def rewrite_can_enter_jit(self, jd, can_enter_jits):
+        FUNC = jd._JIT_ENTER_FUNCTYPE
+        FUNCPTR = jd._PTR_JIT_ENTER_FUNCTYPE
+        jit_enter_fnptr = self.helper_func(FUNCPTR, jd._maybe_enter_jit_fn)
 
-        graphs = self.translator.graphs
-        can_enter_jits = find_can_enter_jit(graphs)
         for graph, block, index in can_enter_jits:
-            if graph is self.jit_merge_point_pos[0]:
+            if graph is jd._jit_merge_point_pos[0]:
                 continue
 
             op = block.operations[index]
@@ -455,7 +503,11 @@
         graph = self.annhelper.getgraph(func, args_s, s_result)
         return self.annhelper.graph2delayed(graph, FUNC)
 
-    def rewrite_jit_merge_point(self, policy):
+    def rewrite_jit_merge_points(self, policy):
+        for jd in self.jitdrivers_sd:
+            self.rewrite_jit_merge_point(jd, policy)
+
+    def rewrite_jit_merge_point(self, jd, policy):
         #
         # Mutate the original portal graph from this:
         #
@@ -486,9 +538,9 @@
         #           while 1:
         #               more stuff
         #
-        origportalgraph = self.jit_merge_point_pos[0]
-        portalgraph = self.portal_graph
-        PORTALFUNC = self.PORTAL_FUNCTYPE
+        origportalgraph = jd._jit_merge_point_pos[0]
+        portalgraph = jd.portal_graph
+        PORTALFUNC = jd._PORTAL_FUNCTYPE
 
         # ____________________________________________________________
         # Prepare the portal_runner() helper
@@ -496,12 +548,12 @@
         from pypy.jit.metainterp.warmstate import specialize_value
         portal_ptr = self.cpu.ts.functionptr(PORTALFUNC, 'portal',
                                          graph = portalgraph)
-        self.portal_ptr = portal_ptr
+        jd._portal_ptr = portal_ptr
         #
         portalfunc_ARGS = []
         nums = {}
         for i, ARG in enumerate(PORTALFUNC.ARGS):
-            if i < len(self.jitdriver.greens):
+            if i < len(jd.jitdriver.greens):
                 color = 'green'
             else:
                 color = 'red'
@@ -519,7 +571,7 @@
         def ll_portal_runner(*args):
             while 1:
                 try:
-                    self.maybe_enter_from_start_fn(*args)
+                    jd._maybe_enter_from_start_fn(*args)
                     return support.maybe_on_top_of_llinterp(rtyper,
                                                       portal_ptr)(*args)
                 except self.ContinueRunningNormally, e:
@@ -548,16 +600,15 @@
                         value = cast_base_ptr_to_instance(Exception, value)
                         raise Exception, value
 
-        self.ll_portal_runner = ll_portal_runner # for debugging
-        self.portal_runner_ptr = self.helper_func(self.PTR_PORTAL_FUNCTYPE,
-                                                  ll_portal_runner)
+        jd._ll_portal_runner = ll_portal_runner # for debugging
+        jd.portal_runner_ptr = self.helper_func(jd._PTR_PORTAL_FUNCTYPE,
+                                                ll_portal_runner)
         self.cpu.portal_calldescr = self.cpu.calldescrof(
-            self.PTR_PORTAL_FUNCTYPE.TO,
-            self.PTR_PORTAL_FUNCTYPE.TO.ARGS,
-            self.PTR_PORTAL_FUNCTYPE.TO.RESULT)
-        self.codewriter.setup_portal_runner_ptr(self.portal_runner_ptr)
+            jd._PTR_PORTAL_FUNCTYPE.TO,
+            jd._PTR_PORTAL_FUNCTYPE.TO.ARGS,
+            jd._PTR_PORTAL_FUNCTYPE.TO.RESULT)
 
-        vinfo = self.virtualizable_info
+        vinfo = jd.virtualizable_info
 
         def assembler_call_helper(failindex, virtualizableref):
             fail_descr = self.cpu.get_fail_descr_from_number(failindex)
@@ -567,6 +618,7 @@
                         virtualizable = lltype.cast_opaque_ptr(
                             vinfo.VTYPEPTR, virtualizableref)
                         vinfo.reset_vable_token(virtualizable)
+                    XXX   # careful here, we must pass the correct jitdriver_sd
                     loop_token = fail_descr.handle_fail(self.metainterp_sd)
                     fail_descr = self.cpu.execute_token(loop_token)
                 except self.ContinueRunningNormally, e:
@@ -596,12 +648,14 @@
                         value = cast_base_ptr_to_instance(Exception, value)
                         raise Exception, value
 
-        self.assembler_call_helper = assembler_call_helper # for debugging
+        jd._assembler_call_helper = assembler_call_helper # for debugging
+        # XXX rewrite me, ugly sticking does not work any more
         self.cpu.assembler_helper_ptr = self.helper_func(
-            self.PTR_ASSEMBLER_HELPER_FUNCTYPE,
+            jd._PTR_ASSEMBLER_HELPER_FUNCTYPE,
             assembler_call_helper)
         # XXX a bit ugly sticking
         if vinfo is not None:
+            XXX     # rewrite me, ugly sticking does not work any more
             self.cpu.index_of_virtualizable = (vinfo.index_of_virtualizable -
                                                self.num_green_args)
         else:
@@ -610,12 +664,12 @@
         # ____________________________________________________________
         # Now mutate origportalgraph to end with a call to portal_runner_ptr
         #
-        _, origblock, origindex = self.jit_merge_point_pos
+        _, origblock, origindex = jd._jit_merge_point_pos
         op = origblock.operations[origindex]
         assert op.opname == 'jit_marker'
         assert op.args[0].value == 'jit_merge_point'
         greens_v, reds_v = support.decode_hp_hint_args(op)
-        vlist = [Constant(self.portal_runner_ptr, self.PTR_PORTAL_FUNCTYPE)]
+        vlist = [Constant(jd.portal_runner_ptr, jd._PTR_PORTAL_FUNCTYPE)]
         vlist += greens_v
         vlist += reds_v
         v_result = Variable()
@@ -642,8 +696,8 @@
         graphs = self.translator.graphs
         _, PTR_SET_PARAM_FUNCTYPE = self.cpu.ts.get_FuncType([lltype.Signed],
                                                              lltype.Void)
-        def make_closure(fullfuncname):
-            state = self.state
+        def make_closure(jd, fullfuncname):
+            state = jd._state
             def closure(i):
                 getattr(state, fullfuncname)(i)
             funcptr = self.helper_func(PTR_SET_PARAM_FUNCTYPE, closure)
@@ -651,12 +705,17 @@
         #
         for graph, block, i in find_set_param(graphs):
             op = block.operations[i]
-            assert op.args[1].value == self.jitdriver
+            for jd in self.jitdrivers_sd:
+                if jd.jitdriver is op.args[1].value:
+                    break
+            else:
+                assert 0, "jitdriver of set_param() not found"
             funcname = op.args[2].value
-            if funcname not in closures:
-                closures[funcname] = make_closure('set_param_' + funcname)
+            key = jd, funcname
+            if key not in closures:
+                closures[key] = make_closure(jd, 'set_param_' + funcname)
             op.opname = 'direct_call'
-            op.args[:3] = [closures[funcname]]
+            op.args[:3] = [closures[key]]
 
     def rewrite_force_virtual(self, vrefinfo):
         if self.cpu.ts.name != 'lltype':

Modified: pypy/branch/multijit-3/pypy/jit/metainterp/warmstate.py
==============================================================================
--- pypy/branch/multijit-3/pypy/jit/metainterp/warmstate.py	(original)
+++ pypy/branch/multijit-3/pypy/jit/metainterp/warmstate.py	Fri Jun 11 15:56:03 2010
@@ -1,7 +1,7 @@
 import sys
 from pypy.rpython.lltypesystem import lltype, llmemory, rstr
 from pypy.rpython.ootypesystem import ootype
-from pypy.rpython.annlowlevel import hlstr, cast_base_ptr_to_instance
+from pypy.rpython.annlowlevel import hlstr, llstr, cast_base_ptr_to_instance
 from pypy.rpython.annlowlevel import cast_object_to_ptr
 from pypy.rlib.objectmodel import specialize, we_are_translated, r_dict
 from pypy.rlib.rarithmetic import intmask
@@ -120,6 +120,16 @@
     else:
         assert False
 
+class JitCell(BaseJitCell):
+    # the counter can mean the following things:
+    #     counter >=  0: not yet traced, wait till threshold is reached
+    #     counter == -1: there is an entry bridge for this cell
+    #     counter == -2: tracing is currently going on for this cell
+    counter = 0
+    compiled_merge_points = None
+    dont_trace_here = False
+    entry_loop_token = None
+
 # ____________________________________________________________
 
 
@@ -127,9 +137,10 @@
     THRESHOLD_LIMIT = sys.maxint // 2
     default_jitcell_dict = None
 
-    def __init__(self, warmrunnerdesc):
+    def __init__(self, warmrunnerdesc, jitdriver_sd):
         "NOT_RPYTHON"
         self.warmrunnerdesc = warmrunnerdesc
+        self.jitdriver_sd = jitdriver_sd
         try:
             self.profiler = warmrunnerdesc.metainterp_sd.profiler
         except AttributeError:       # for tests
@@ -195,8 +206,9 @@
             return self.maybe_compile_and_run
 
         metainterp_sd = self.warmrunnerdesc.metainterp_sd
-        vinfo = self.warmrunnerdesc.virtualizable_info
-        num_green_args = self.warmrunnerdesc.num_green_args
+        jitdriver_sd = self.jitdriver_sd
+        vinfo = jitdriver_sd.virtualizable_info
+        num_green_args = jitdriver_sd.num_green_args
         get_jitcell = self.make_jitcell_getter()
         set_future_values = self.make_set_future_values()
         self.make_jitdriver_callbacks()
@@ -206,7 +218,6 @@
             """Entry point to the JIT.  Called at the point with the
             can_enter_jit() hint.
             """
-            globaldata = metainterp_sd.globaldata
             if NonConstant(False):
                 # make sure we always see the saner optimizer from an
                 # annotation point of view, otherwise we get lots of
@@ -234,11 +245,12 @@
                     return
                 # bound reached; start tracing
                 from pypy.jit.metainterp.pyjitpl import MetaInterp
-                metainterp = MetaInterp(metainterp_sd)
+                metainterp = MetaInterp(metainterp_sd, jitdriver_sd)
                 # set counter to -2, to mean "tracing in effect"
                 cell.counter = -2
                 try:
-                    loop_token = metainterp.compile_and_run_once(*args)
+                    loop_token = metainterp.compile_and_run_once(jitdriver_sd,
+                                                                 *args)
                 finally:
                     if cell.counter == -2:
                         cell.counter = 0
@@ -264,7 +276,8 @@
                 metainterp_sd.profiler.end_running()
                 if vinfo is not None:
                     vinfo.reset_vable_token(virtualizable)
-                loop_token = fail_descr.handle_fail(metainterp_sd)
+                loop_token = fail_descr.handle_fail(metainterp_sd,
+                                                    jitdriver_sd)
        
         maybe_compile_and_run._dont_inline_ = True
         self.maybe_compile_and_run = maybe_compile_and_run
@@ -277,8 +290,8 @@
         if hasattr(self, 'unwrap_greenkey'):
             return self.unwrap_greenkey
         #
-        warmrunnerdesc = self.warmrunnerdesc
-        green_args_spec = unrolling_iterable(warmrunnerdesc.green_args_spec)
+        jitdriver_sd = self.jitdriver_sd
+        green_args_spec = unrolling_iterable(jitdriver_sd._green_args_spec)
         #
         def unwrap_greenkey(greenkey):
             greenargs = ()
@@ -302,20 +315,10 @@
         if hasattr(self, 'jit_getter'):
             return self.jit_getter
         #
-        class JitCell(BaseJitCell):
-            # the counter can mean the following things:
-            #     counter >=  0: not yet traced, wait till threshold is reached
-            #     counter == -1: there is an entry bridge for this cell
-            #     counter == -2: tracing is currently going on for this cell
-            counter = 0
-            compiled_merge_points = None
-            dont_trace_here = False
-            entry_loop_token = None
-        #
-        if self.warmrunnerdesc.get_jitcell_at_ptr is None:
-            jit_getter = self._make_jitcell_getter_default(JitCell)
+        if self.jitdriver_sd._get_jitcell_at_ptr is None:
+            jit_getter = self._make_jitcell_getter_default()
         else:
-            jit_getter = self._make_jitcell_getter_custom(JitCell)
+            jit_getter = self._make_jitcell_getter_custom()
         #
         unwrap_greenkey = self.make_unwrap_greenkey()
         #
@@ -327,10 +330,10 @@
         #
         return jit_getter
 
-    def _make_jitcell_getter_default(self, JitCell):
+    def _make_jitcell_getter_default(self):
         "NOT_RPYTHON"
-        warmrunnerdesc = self.warmrunnerdesc
-        green_args_spec = unrolling_iterable(warmrunnerdesc.green_args_spec)
+        jitdriver_sd = self.jitdriver_sd
+        green_args_spec = unrolling_iterable(jitdriver_sd._green_args_spec)
         #
         def comparekey(greenargs1, greenargs2):
             i = 0
@@ -361,11 +364,11 @@
             return cell
         return get_jitcell
 
-    def _make_jitcell_getter_custom(self, JitCell):
+    def _make_jitcell_getter_custom(self):
         "NOT_RPYTHON"
         rtyper = self.warmrunnerdesc.rtyper
-        get_jitcell_at_ptr = self.warmrunnerdesc.get_jitcell_at_ptr
-        set_jitcell_at_ptr = self.warmrunnerdesc.set_jitcell_at_ptr
+        get_jitcell_at_ptr = self.jitdriver_sd._get_jitcell_at_ptr
+        set_jitcell_at_ptr = self.jitdriver_sd._set_jitcell_at_ptr
         lltohlhack = {}
         #
         def get_jitcell(*greenargs):
@@ -415,9 +418,10 @@
             return self.set_future_values
 
         warmrunnerdesc = self.warmrunnerdesc
+        jitdriver_sd   = self.jitdriver_sd
         cpu = warmrunnerdesc.cpu
-        vinfo = warmrunnerdesc.virtualizable_info
-        red_args_types = unrolling_iterable(warmrunnerdesc.red_args_types)
+        vinfo = jitdriver_sd.virtualizable_info
+        red_args_types = unrolling_iterable(jitdriver_sd._red_args_types)
         #
         def set_future_values(*redargs):
             i = 0
@@ -428,8 +432,8 @@
                 set_future_values_from_vinfo(*redargs)
         #
         if vinfo is not None:
-            i0 = len(warmrunnerdesc.red_args_types)
-            num_green_args = warmrunnerdesc.num_green_args
+            i0 = len(jitdriver_sd._red_args_types)
+            num_green_args = jitdriver_sd.num_green_args
             vable_static_fields = unrolling_iterable(
                 zip(vinfo.static_extra_types, vinfo.static_fields))
             vable_array_fields = unrolling_iterable(
@@ -464,7 +468,7 @@
         if hasattr(self, 'get_location_str'):
             return
         #
-        can_inline_ptr = self.warmrunnerdesc.can_inline_ptr
+        can_inline_ptr = self.jitdriver_sd._can_inline_ptr
         unwrap_greenkey = self.make_unwrap_greenkey()
         if can_inline_ptr is None:
             def can_inline_callable(*greenargs):
@@ -497,10 +501,15 @@
         self.get_assembler_token = get_assembler_token
         
         #
-        get_location_ptr = self.warmrunnerdesc.get_printable_location_ptr
+        get_location_ptr = self.jitdriver_sd._get_printable_location_ptr
         if get_location_ptr is None:
+            missing = '(no jitdriver.get_printable_location!)'
+            missingll = llstr(missing)
             def get_location_str(greenkey):
-                return '(no jitdriver.get_printable_location!)'
+                if we_are_translated():
+                    return missingll
+                else:
+                    return missing
         else:
             rtyper = self.warmrunnerdesc.rtyper
             unwrap_greenkey = self.make_unwrap_greenkey()
@@ -514,7 +523,7 @@
                 return res
         self.get_location_str = get_location_str
         #
-        confirm_enter_jit_ptr = self.warmrunnerdesc.confirm_enter_jit_ptr
+        confirm_enter_jit_ptr = self.jitdriver_sd._confirm_enter_jit_ptr
         if confirm_enter_jit_ptr is None:
             def confirm_enter_jit(*args):
                 return True



More information about the Pypy-commit mailing list