[pypy-svn] r75262 - in pypy/branch/multijit-3/pypy/jit: codewriter codewriter/test metainterp metainterp/test

arigo at codespeak.net arigo at codespeak.net
Fri Jun 11 12:52:15 CEST 2010


Author: arigo
Date: Fri Jun 11 12:52:13 2010
New Revision: 75262

Added:
   pypy/branch/multijit-3/pypy/jit/metainterp/jitdriver.py   (contents, props changed)
Modified:
   pypy/branch/multijit-3/pypy/jit/codewriter/call.py
   pypy/branch/multijit-3/pypy/jit/codewriter/codewriter.py
   pypy/branch/multijit-3/pypy/jit/codewriter/jitcode.py
   pypy/branch/multijit-3/pypy/jit/codewriter/jtransform.py
   pypy/branch/multijit-3/pypy/jit/codewriter/test/test_flatten.py
   pypy/branch/multijit-3/pypy/jit/metainterp/blackhole.py
   pypy/branch/multijit-3/pypy/jit/metainterp/pyjitpl.py
   pypy/branch/multijit-3/pypy/jit/metainterp/test/test_basic.py
Log:
Start porting metainterp/pyjitpl.


Modified: pypy/branch/multijit-3/pypy/jit/codewriter/call.py
==============================================================================
--- pypy/branch/multijit-3/pypy/jit/codewriter/call.py	(original)
+++ pypy/branch/multijit-3/pypy/jit/codewriter/call.py	Fri Jun 11 12:52:13 2010
@@ -147,6 +147,7 @@
     def grab_initial_jitcodes(self):
         for jd in self.jitdrivers_sd:
             jd.mainjitcode = self.get_jitcode(jd.portal_graph)
+            jd.mainjitcode.is_portal = True
 
     def enum_pending_graphs(self):
         while self.unfinished_graphs:

Modified: pypy/branch/multijit-3/pypy/jit/codewriter/codewriter.py
==============================================================================
--- pypy/branch/multijit-3/pypy/jit/codewriter/codewriter.py	(original)
+++ pypy/branch/multijit-3/pypy/jit/codewriter/codewriter.py	Fri Jun 11 12:52:13 2010
@@ -37,7 +37,7 @@
         #
         # step 1: mangle the graph so that it contains the final instructions
         # that we want in the JitCode, but still as a control flow graph
-        transform_graph(graph, self.cpu, self.callcontrol)
+        transform_graph(graph, self.cpu, self.callcontrol, portal_jd)
         #
         # step 2: perform register allocation on it
         regallocs = {}

Modified: pypy/branch/multijit-3/pypy/jit/codewriter/jitcode.py
==============================================================================
--- pypy/branch/multijit-3/pypy/jit/codewriter/jitcode.py	(original)
+++ pypy/branch/multijit-3/pypy/jit/codewriter/jitcode.py	Fri Jun 11 12:52:13 2010
@@ -13,6 +13,7 @@
         self.name = name
         self.fnaddr = fnaddr
         self.calldescr = calldescr
+        self.is_portal = False
         self._called_from = called_from   # debugging
         self._ssarepr     = None          # debugging
 
@@ -24,11 +25,11 @@
         self.constants_i = constants_i or self._empty_i
         self.constants_r = constants_r or self._empty_r
         self.constants_f = constants_f or self._empty_f
-        # encode the three num_regs into a single integer
+        # encode the three num_regs into a single char each
         assert num_regs_i < 256 and num_regs_r < 256 and num_regs_f < 256
-        self.num_regs_encoded = ((num_regs_i << 16) |
-                                 (num_regs_f << 8) |
-                                 (num_regs_r << 0))
+        self.c_num_regs_i = chr(num_regs_i)
+        self.c_num_regs_r = chr(num_regs_r)
+        self.c_num_regs_f = chr(num_regs_f)
         self.liveness = make_liveness_cache(liveness)
         self._startpoints = startpoints   # debugging
         self._alllabels = alllabels       # debugging
@@ -37,13 +38,13 @@
         return heaptracker.adr2int(self.fnaddr)
 
     def num_regs_i(self):
-        return self.num_regs_encoded >> 16
-
-    def num_regs_f(self):
-        return (self.num_regs_encoded >> 8) & 0xFF
+        return ord(self.c_num_regs_i)
 
     def num_regs_r(self):
-        return self.num_regs_encoded & 0xFF
+        return ord(self.c_num_regs_r)
+
+    def num_regs_f(self):
+        return ord(self.c_num_regs_f)
 
     def has_liveness_info(self, pc):
         return pc in self.liveness

Modified: pypy/branch/multijit-3/pypy/jit/codewriter/jtransform.py
==============================================================================
--- pypy/branch/multijit-3/pypy/jit/codewriter/jtransform.py	(original)
+++ pypy/branch/multijit-3/pypy/jit/codewriter/jtransform.py	Fri Jun 11 12:52:13 2010
@@ -13,19 +13,20 @@
 from pypy.translator.simplify import get_funcobj
 
 
-def transform_graph(graph, cpu=None, callcontrol=None):
+def transform_graph(graph, cpu=None, callcontrol=None, portal_jd=None):
     """Transform a control flow graph to make it suitable for
     being flattened in a JitCode.
     """
-    t = Transformer(cpu, callcontrol)
+    t = Transformer(cpu, callcontrol, portal_jd)
     t.transform(graph)
 
 
 class Transformer(object):
 
-    def __init__(self, cpu=None, callcontrol=None):
+    def __init__(self, cpu=None, callcontrol=None, portal_jd=None):
         self.cpu = cpu
         self.callcontrol = callcontrol
+        self.portal_jd = portal_jd   # non-None only for the portal graph(s)
 
     def transform(self, graph):
         self.graph = graph
@@ -773,9 +774,14 @@
         return getattr(self, 'handle_jit_marker__%s' % key)(op, jitdriver)
 
     def handle_jit_marker__jit_merge_point(self, op, jitdriver):
+        assert self.portal_jd is not None, (
+            "'jit_merge_point' in non-portal graph!")
+        assert jitdriver is self.portal_jd.jitdriver, (
+            "general mix-up of jitdrivers?")
         ops = self.promote_greens(op.args[2:], jitdriver)
         num_green_args = len(jitdriver.greens)
-        args = (self.make_three_lists(op.args[2:2+num_green_args]) +
+        args = ([Constant(self.portal_jd.index, lltype.Signed)] +
+                self.make_three_lists(op.args[2:2+num_green_args]) +
                 self.make_three_lists(op.args[2+num_green_args:]))
         op1 = SpaceOperation('jit_merge_point', args, None)
         return ops + [op1]

Modified: pypy/branch/multijit-3/pypy/jit/codewriter/test/test_flatten.py
==============================================================================
--- pypy/branch/multijit-3/pypy/jit/codewriter/test/test_flatten.py	(original)
+++ pypy/branch/multijit-3/pypy/jit/codewriter/test/test_flatten.py	Fri Jun 11 12:52:13 2010
@@ -115,13 +115,13 @@
         return self.rtyper.annotator.translator.graphs
 
     def encoding_test(self, func, args, expected,
-                      transform=False, liveness=False, cc=None):
+                      transform=False, liveness=False, cc=None, jd=None):
         graphs = self.make_graphs(func, args)
         #graphs[0].show()
         if transform:
             from pypy.jit.codewriter.jtransform import transform_graph
             cc = cc or FakeCallControl()
-            transform_graph(graphs[0], FakeCPU(self.rtyper), cc)
+            transform_graph(graphs[0], FakeCPU(self.rtyper), cc, jd)
         ssarepr = flatten_graph(graphs[0], fake_regallocs(),
                                 _include_all_exc_links=not transform)
         if liveness:
@@ -581,13 +581,16 @@
         def f(x, y):
             myjitdriver.jit_merge_point(x=x, y=y)
             myjitdriver.can_enter_jit(x=y, y=x)
+        class FakeJitDriverSD:
+            jitdriver = myjitdriver
+            index = 27
         self.encoding_test(f, [4, 5], """
             -live- %i0, %i1
             int_guard_value %i0
-            jit_merge_point I[%i0], R[], F[], I[%i1], R[], F[]
+            jit_merge_point $27, I[%i0], R[], F[], I[%i1], R[], F[]
             can_enter_jit
             void_return
-        """, transform=True, liveness=True)
+        """, transform=True, liveness=True, jd=FakeJitDriverSD())
 
     def test_keepalive(self):
         S = lltype.GcStruct('S')

Modified: pypy/branch/multijit-3/pypy/jit/metainterp/blackhole.py
==============================================================================
--- pypy/branch/multijit-3/pypy/jit/metainterp/blackhole.py	(original)
+++ pypy/branch/multijit-3/pypy/jit/metainterp/blackhole.py	Fri Jun 11 12:52:13 2010
@@ -775,8 +775,8 @@
     def bhimpl_can_enter_jit():
         pass
 
-    @arguments("self", "I", "R", "F", "I", "R", "F")
-    def bhimpl_jit_merge_point(self, *args):
+    @arguments("self", "i", "I", "R", "F", "I", "R", "F")
+    def bhimpl_jit_merge_point(self, jdindex, *args):
         if self.nextblackholeinterp is None:    # we are the last level
             CRN = self.builder.metainterp_sd.ContinueRunningNormally
             raise CRN(*args)
@@ -791,54 +791,55 @@
             # result.
             sd = self.builder.metainterp_sd
             if sd.result_type == 'void':
-                self.bhimpl_recursive_call_v(*args)
+                self.bhimpl_recursive_call_v(jdindex, *args)
                 self.bhimpl_void_return()
             elif sd.result_type == 'int':
-                x = self.bhimpl_recursive_call_i(*args)
+                x = self.bhimpl_recursive_call_i(jdindex, *args)
                 self.bhimpl_int_return(x)
             elif sd.result_type == 'ref':
-                x = self.bhimpl_recursive_call_r(*args)
+                x = self.bhimpl_recursive_call_r(jdindex, *args)
                 self.bhimpl_ref_return(x)
             elif sd.result_type == 'float':
-                x = self.bhimpl_recursive_call_f(*args)
+                x = self.bhimpl_recursive_call_f(jdindex, *args)
                 self.bhimpl_float_return(x)
             assert False
 
-    def get_portal_runner(self):
+    def get_portal_runner(self, jdindex):
         metainterp_sd = self.builder.metainterp_sd
-        fnptr = llmemory.cast_ptr_to_adr(metainterp_sd._portal_runner_ptr)
+        jitdriver_sd = metainterp_sd.jitdrivers_sd[jdindex]
+        fnptr = llmemory.cast_ptr_to_adr(jitdriver_sd.portal_runner_ptr)
         fnptr = heaptracker.adr2int(fnptr)
-        calldescr = metainterp_sd.portal_code.calldescr
+        calldescr = jitdriver_sd.mainjitcode.calldescr
         return fnptr, calldescr
 
-    @arguments("self", "I", "R", "F", "I", "R", "F", returns="i")
-    def bhimpl_recursive_call_i(self, greens_i, greens_r, greens_f,
-                                      reds_i,   reds_r,   reds_f):
-        fnptr, calldescr = self.get_portal_runner()
+    @arguments("self", "i", "I", "R", "F", "I", "R", "F", returns="i")
+    def bhimpl_recursive_call_i(self, jdindex, greens_i, greens_r, greens_f,
+                                               reds_i,   reds_r,   reds_f):
+        fnptr, calldescr = self.get_portal_runner(jdindex)
         return self.cpu.bh_call_i(fnptr, calldescr,
                                   greens_i + reds_i,
                                   greens_r + reds_r,
                                   greens_f + reds_f)
-    @arguments("self", "I", "R", "F", "I", "R", "F", returns="r")
-    def bhimpl_recursive_call_r(self, greens_i, greens_r, greens_f,
-                                      reds_i,   reds_r,   reds_f):
-        fnptr, calldescr = self.get_portal_runner()
+    @arguments("self", "i", "I", "R", "F", "I", "R", "F", returns="r")
+    def bhimpl_recursive_call_r(self, jdindex, greens_i, greens_r, greens_f,
+                                               reds_i,   reds_r,   reds_f):
+        fnptr, calldescr = self.get_portal_runner(jdindex)
         return self.cpu.bh_call_r(fnptr, calldescr,
                                   greens_i + reds_i,
                                   greens_r + reds_r,
                                   greens_f + reds_f)
-    @arguments("self", "I", "R", "F", "I", "R", "F", returns="f")
-    def bhimpl_recursive_call_f(self, greens_i, greens_r, greens_f,
-                                      reds_i,   reds_r,   reds_f):
-        fnptr, calldescr = self.get_portal_runner()
+    @arguments("self", "i", "I", "R", "F", "I", "R", "F", returns="f")
+    def bhimpl_recursive_call_f(self, jdindex, greens_i, greens_r, greens_f,
+                                               reds_i,   reds_r,   reds_f):
+        fnptr, calldescr = self.get_portal_runner(jdindex)
         return self.cpu.bh_call_f(fnptr, calldescr,
                                   greens_i + reds_i,
                                   greens_r + reds_r,
                                   greens_f + reds_f)
-    @arguments("self", "I", "R", "F", "I", "R", "F")
-    def bhimpl_recursive_call_v(self, greens_i, greens_r, greens_f,
-                                      reds_i,   reds_r,   reds_f):
-        fnptr, calldescr = self.get_portal_runner()
+    @arguments("self", "i", "I", "R", "F", "I", "R", "F")
+    def bhimpl_recursive_call_v(self, jdindex, greens_i, greens_r, greens_f,
+                                               reds_i,   reds_r,   reds_f):
+        fnptr, calldescr = self.get_portal_runner(jdindex)
         return self.cpu.bh_call_v(fnptr, calldescr,
                                   greens_i + reds_i,
                                   greens_r + reds_r,

Added: pypy/branch/multijit-3/pypy/jit/metainterp/jitdriver.py
==============================================================================
--- (empty file)
+++ pypy/branch/multijit-3/pypy/jit/metainterp/jitdriver.py	Fri Jun 11 12:52:13 2010
@@ -0,0 +1,17 @@
+
+
+class JitDriverStaticData:
+    """There is one instance of this class per JitDriver used in the program.
+    """
+    # This is just a container with the following attributes (... set by):
+    #    self.jitdriver         ... pypy.jit.metainterp.warmspot
+    #    self.portal_graph      ... pypy.jit.metainterp.warmspot
+    #    self.portal_runner_ptr ... pypy.jit.metainterp.warmspot
+    #    self.num_green_args    ... pypy.jit.metainterp.warmspot
+    #    self.result_type       ... pypy.jit.metainterp.warmspot
+    #    self.virtualizable_info... pypy.jit.metainterp.warmspot
+    #    self.index             ... pypy.jit.codewriter.call
+    #    self.mainjitcode       ... pypy.jit.codewriter.call
+
+    def _freeze_(self):
+        return True

Modified: pypy/branch/multijit-3/pypy/jit/metainterp/pyjitpl.py
==============================================================================
--- pypy/branch/multijit-3/pypy/jit/metainterp/pyjitpl.py	(original)
+++ pypy/branch/multijit-3/pypy/jit/metainterp/pyjitpl.py	Fri Jun 11 12:52:13 2010
@@ -56,6 +56,7 @@
         self.parent_resumedata_snapshot = None
         self.parent_resumedata_frame_info_list = None
 
+    @specialize.arg(3)
     def copy_constants(self, registers, constants, ConstClass):
         """Copy jitcode.constants[0] to registers[255],
                 jitcode.constants[1] to registers[254],
@@ -68,7 +69,6 @@
             assert j >= 0
             registers[j] = ConstClass(constants[i])
             i -= 1
-    copy_constants._annspecialcase_ = 'specialize:arg(3)'
 
     def cleanup_registers(self):
         # To avoid keeping references alive, this cleans up the registers_r.
@@ -80,6 +80,7 @@
     # ------------------------------
     # Decoding of the JitCode
 
+    @specialize.arg(4)
     def prepare_list_of_boxes(self, outvalue, startindex, position, argcode):
         assert argcode in 'IRF'
         code = self.bytecode
@@ -92,7 +93,6 @@
             elif argcode == 'F': reg = self.registers_f[index]
             else: raise AssertionError(argcode)
             outvalue[startindex+i] = reg
-    prepare_list_of_boxes._annspecialcase_ = 'specialize:arg(4)'
 
     def get_current_position_info(self):
         return self.jitcode.get_live_vars_info(self.pc)
@@ -518,15 +518,15 @@
         return not isstandard
 
     def _get_virtualizable_field_index(self, fielddescr):
-        vinfo = self.metainterp.staticdata.virtualizable_info
+        vinfo = self.metainterp.jitdrivers_sd.virtualizable_info
         return vinfo.static_field_by_descrs[fielddescr]
 
     def _get_virtualizable_array_field_descr(self, index):
-        vinfo = self.metainterp.staticdata.virtualizable_info
+        vinfo = self.metainterp.jitdrivers_sd.virtualizable_info
         return vinfo.array_field_descrs[index]
 
     def _get_virtualizable_array_descr(self, index):
-        vinfo = self.metainterp.staticdata.virtualizable_info
+        vinfo = self.metainterp.jitdrivers_sd.virtualizable_info
         return vinfo.array_descrs[index]
 
     @arguments("orgpc", "box", "descr")
@@ -557,7 +557,7 @@
 
     def _get_arrayitem_vable_index(self, pc, arrayfielddescr, indexbox):
         indexbox = self.implement_guard_value(pc, indexbox)
-        vinfo = self.metainterp.staticdata.virtualizable_info
+        vinfo = self.metainterp.jitdriver_sd.virtualizable_info
         virtualizable_box = self.metainterp.virtualizable_boxes[-1]
         virtualizable = vinfo.unwrap_virtualizable_box(virtualizable_box)
         arrayindex = vinfo.array_field_by_descrs[arrayfielddescr]
@@ -606,7 +606,7 @@
             arraybox = self.metainterp.execute_and_record(rop.GETFIELD_GC,
                                                           fdescr, box)
             return self.execute_with_descr(rop.ARRAYLEN_GC, adescr, arraybox)
-        vinfo = self.metainterp.staticdata.virtualizable_info
+        vinfo = self.metainterp.jitdriver_sd.virtualizable_info
         virtualizable_box = self.metainterp.virtualizable_boxes[-1]
         virtualizable = vinfo.unwrap_virtualizable_box(virtualizable_box)
         arrayindex = vinfo.array_field_by_descrs[fdescr]
@@ -945,7 +945,7 @@
         guard_op = metainterp.history.record(opnum, moreargs, None,
                                              descr=resumedescr)       
         virtualizable_boxes = None
-        if metainterp.staticdata.virtualizable_info is not None:
+        if metainterp.jitdriver_sd.virtualizable_info is not None:
             virtualizable_boxes = metainterp.virtualizable_boxes
         saved_pc = self.pc
         if resumepc >= 0:
@@ -1140,22 +1140,16 @@
         self.setup_indirectcalltargets(asm.indirectcalltargets)
         self.setup_list_of_addr2name(asm.list_of_addr2name)
         #
-        self.portal_code = codewriter.mainjitcode
-        self._portal_runner_ptr = codewriter.callcontrol.portal_runner_ptr
+        self.jitdrivers_sd = codewriter.callcontrol.jitdrivers_sd
         self.virtualref_info = codewriter.callcontrol.virtualref_info
-        self.virtualizable_info = codewriter.callcontrol.virtualizable_info
-        RESULT = codewriter.portal_graph.getreturnvar().concretetype
-        self.result_type = history.getkind(RESULT)
         #
         warmrunnerdesc = self.warmrunnerdesc
         if warmrunnerdesc is not None:
+            XXX
             self.num_green_args = warmrunnerdesc.num_green_args
             self.state = warmrunnerdesc.state
             if optimizer is not None:
                 self.state.set_param_optimizer(optimizer)
-        else:
-            self.num_green_args = 0
-            self.state = None
         self.globaldata = MetaInterpGlobalData(self)
 
     def _setup_once(self):
@@ -1176,8 +1170,7 @@
                 # Build the dictionary at run-time.  This is needed
                 # because the keys are function/class addresses, so they
                 # can change from run to run.
-                k = llmemory.cast_ptr_to_adr(self._portal_runner_ptr)
-                d = {k: 'recursive call'}
+                d = {}
                 keys = self._addr2name_keys
                 values = self._addr2name_values
                 for i in range(len(keys)):
@@ -1236,7 +1229,7 @@
         self.resume_virtuals = {}
         self.resume_virtuals_not_translated = []
         #
-        state = staticdata.state
+        state = None     # XXX staticdata.state
         if state is not None:
             self.jit_cell_at_key = state.jit_cell_at_key
         else:
@@ -1269,13 +1262,12 @@
 
     def perform_call(self, jitcode, boxes, greenkey=None):
         # causes the metainterp to enter the given subfunction
-        # with a special case for recursive portal calls
         f = self.newframe(jitcode, greenkey)
         f.setup_call(boxes)
         raise ChangeFrame
 
     def newframe(self, jitcode, greenkey=None):
-        if jitcode is self.staticdata.portal_code:
+        if jitcode.is_portal:
             self.in_recursion += 1
         if greenkey is not None:
             self.portal_trace_positions.append(
@@ -1290,7 +1282,7 @@
 
     def popframe(self):
         frame = self.framestack.pop()
-        if frame.jitcode is self.staticdata.portal_code:
+        if frame.jitcode.is_portal:
             self.in_recursion -= 1
         if frame.greenkey is not None:
             self.portal_trace_positions.append(
@@ -1314,14 +1306,15 @@
             except SwitchToBlackhole, stb:
                 self.aborted_tracing(stb.reason)
             sd = self.staticdata
-            if sd.result_type == 'void':
+            result_type = self.jitdriver_sd.result_type
+            if result_type == history.VOID:
                 assert resultbox is None
                 raise sd.DoneWithThisFrameVoid()
-            elif sd.result_type == 'int':
+            elif result_type == history.INT:
                 raise sd.DoneWithThisFrameInt(resultbox.getint())
-            elif sd.result_type == 'ref':
+            elif result_type == history.REF:
                 raise sd.DoneWithThisFrameRef(self.cpu, resultbox.getref_base())
-            elif sd.result_type == 'float':
+            elif result_type == history.FLOAT:
                 raise sd.DoneWithThisFrameFloat(resultbox.getfloat())
             else:
                 assert False
@@ -1508,21 +1501,24 @@
                     self.staticdata.log(sys.exc_info()[0].__name__)
                 raise
 
-    def compile_and_run_once(self, *args):
+    @specialize.arg(1)
+    def compile_and_run_once(self, jitdriver_sd, *args):
         debug_start('jit-tracing')
         self.staticdata._setup_once()
         self.staticdata.profiler.start_tracing()
         self.create_empty_history()
         try:
-            return self._compile_and_run_once(*args)
+            original_boxes = self.initialize_original_boxes(jitdriver_sd,*args)
+            self.jitdriver_sd = jitdriver_sd
+            return self._compile_and_run_once(original_boxes)
         finally:
             self.staticdata.profiler.end_tracing()
             debug_stop('jit-tracing')
 
-    def _compile_and_run_once(self, *args):
-        original_boxes = self.initialize_state_from_start(*args)
+    def _compile_and_run_once(self, original_boxes):
+        self.initialize_state_from_start(original_boxes)
         self.current_merge_points = [(original_boxes, 0)]
-        num_green_args = self.staticdata.num_green_args
+        num_green_args = self.jitdriver_sd.num_green_args
         original_greenkey = original_boxes[:num_green_args]
         redkey = original_boxes[num_green_args:]
         self.resumekey = compile.ResumeFromInterpDescr(original_greenkey,
@@ -1589,7 +1585,7 @@
         self.remove_consts_and_duplicates(redboxes, len(redboxes),
                                           duplicates)
         live_arg_boxes = greenboxes + redboxes
-        if self.staticdata.virtualizable_info is not None:
+        if self.jitdriver_sd.virtualizable_info is not None:
             # we use pop() to remove the last item, which is the virtualizable
             # itself
             self.remove_consts_and_duplicates(self.virtualizable_boxes,
@@ -1710,17 +1706,18 @@
         self.gen_store_back_in_virtualizable()
         # temporarily put a JUMP to a pseudo-loop
         sd = self.staticdata
-        if sd.result_type == 'void':
+        result_type = self.jitdriver_sd.result_type
+        if result_type == history.VOID:
             assert exitbox is None
             exits = []
             loop_tokens = sd.loop_tokens_done_with_this_frame_void
-        elif sd.result_type == 'int':
+        elif result_type == history.INT:
             exits = [exitbox]
             loop_tokens = sd.loop_tokens_done_with_this_frame_int
-        elif sd.result_type == 'ref':
+        elif result_type == history.REF:
             exits = [exitbox]
             loop_tokens = sd.loop_tokens_done_with_this_frame_ref
-        elif sd.result_type == 'float':
+        elif result_type == history.FLOAT:
             exits = [exitbox]
             loop_tokens = sd.loop_tokens_done_with_this_frame_float
         else:
@@ -1752,26 +1749,35 @@
             specnode.extract_runtime_data(self.cpu, args[i], expanded_args)
         return expanded_args
 
-    def _initialize_from_start(self, original_boxes, num_green_args, *args):
+    @specialize.arg(1)
+    def initialize_original_boxes(self, jitdriver_sd, *args):
+        # NB. we pass explicity 'jitdriver_sd' around here, even though it
+        # might also available as 'self.jitdriver_sd', because we need to
+        # specialize these functions for the particular *args.
+        original_boxes = []
+        self._fill_original_boxes(jitdriver_sd, original_boxes,
+                                  jitdriver_sd.num_green_args, *args)
+        return original_boxes
+
+    @specialize.arg(1)
+    def _fill_original_boxes(self, jitdriver_sd, original_boxes,
+                             num_green_args, *args):
         if args:
             from pypy.jit.metainterp.warmstate import wrap
             box = wrap(self.cpu, args[0], num_green_args > 0)
             original_boxes.append(box)
-            self._initialize_from_start(original_boxes, num_green_args-1,
-                                        *args[1:])
+            self._fill_original_boxes(jitdriver_sd, original_boxes,
+                                      num_green_args-1, *args[1:])
 
-    def initialize_state_from_start(self, *args):
-        self.in_recursion = -1 # always one portal around
-        num_green_args = self.staticdata.num_green_args
-        original_boxes = []
-        self._initialize_from_start(original_boxes, num_green_args, *args)
+    def initialize_state_from_start(self, original_boxes):
         # ----- make a new frame -----
+        self.in_recursion = -1 # always one portal around
         self.framestack = []
-        f = self.newframe(self.staticdata.portal_code)
+        f = self.newframe(self.jitdriver_sd.mainjitcode)
         f.setup_call(original_boxes)
+        assert self.in_recursion == 0
         self.virtualref_boxes = []
         self.initialize_virtualizable(original_boxes)
-        return original_boxes
 
     def initialize_state_from_guard_failure(self, resumedescr):
         # guard failure: rebuild a complete MIFrame stack
@@ -1781,7 +1787,7 @@
         self.history.inputargs = [box for box in inputargs_and_holes if box]
 
     def initialize_virtualizable(self, original_boxes):
-        vinfo = self.staticdata.virtualizable_info
+        vinfo = self.jitdriver_sd.virtualizable_info
         if vinfo is not None:
             virtualizable_box = original_boxes[vinfo.index_of_virtualizable]
             virtualizable = vinfo.unwrap_virtualizable_box(virtualizable_box)
@@ -1794,7 +1800,7 @@
             self.initialize_virtualizable_enter()
 
     def initialize_virtualizable_enter(self):
-        vinfo = self.staticdata.virtualizable_info
+        vinfo = self.jitdriver_sd.virtualizable_info
         virtualizable_box = self.virtualizable_boxes[-1]
         virtualizable = vinfo.unwrap_virtualizable_box(virtualizable_box)
         vinfo.clear_vable_token(virtualizable)
@@ -1808,7 +1814,7 @@
             # the FORCE_TOKEN is already set at runtime in each vref when
             # it is created, by optimizeopt.py.
         #
-        vinfo = self.staticdata.virtualizable_info
+        vinfo = self.jitdriver_sd.virtualizable_info
         if vinfo is not None:
             virtualizable_box = self.virtualizable_boxes[-1]
             virtualizable = vinfo.unwrap_virtualizable_box(virtualizable_box)
@@ -1834,7 +1840,7 @@
                 self.stop_tracking_virtualref(i)
 
     def vable_after_residual_call(self):
-        vinfo = self.staticdata.virtualizable_info
+        vinfo = self.jitdriver_sd.virtualizable_info
         if vinfo is not None:
             virtualizable_box = self.virtualizable_boxes[-1]
             virtualizable = vinfo.unwrap_virtualizable_box(virtualizable_box)
@@ -1885,7 +1891,7 @@
         assert self.last_exc_value_box is None
 
     def rebuild_state_after_failure(self, resumedescr):
-        vinfo = self.staticdata.virtualizable_info
+        vinfo = self.jitdriver_sd.virtualizable_info
         self.framestack = []
         boxlists = resume.rebuild_from_resumedata(self, resumedescr, vinfo)
         inputargs_and_holes, virtualizable_boxes, virtualref_boxes = boxlists
@@ -1921,13 +1927,13 @@
 
     def check_synchronized_virtualizable(self):
         if not we_are_translated():
-            vinfo = self.staticdata.virtualizable_info
+            vinfo = self.jitdriver_sd.virtualizable_info
             virtualizable_box = self.virtualizable_boxes[-1]
             virtualizable = vinfo.unwrap_virtualizable_box(virtualizable_box)
             vinfo.check_boxes(virtualizable, self.virtualizable_boxes)
 
     def synchronize_virtualizable(self):
-        vinfo = self.staticdata.virtualizable_info
+        vinfo = self.jitdriver_sd.virtualizable_info
         virtualizable_box = self.virtualizable_boxes[-1]
         virtualizable = vinfo.unwrap_virtualizable_box(virtualizable_box)
         vinfo.write_boxes(virtualizable, self.virtualizable_boxes)
@@ -1936,7 +1942,7 @@
         # Force a reload of the virtualizable fields into the local
         # boxes (called only in escaping cases).  Only call this function
         # just before SwitchToBlackhole.
-        vinfo = self.staticdata.virtualizable_info
+        vinfo = self.jitdriver_sd.virtualizable_info
         if vinfo is not None:
             virtualizable_box = self.virtualizable_boxes[-1]
             virtualizable = vinfo.unwrap_virtualizable_box(virtualizable_box)
@@ -1945,7 +1951,7 @@
             self.virtualizable_boxes.append(virtualizable_box)
 
     def gen_store_back_in_virtualizable(self):
-        vinfo = self.staticdata.virtualizable_info
+        vinfo = self.jitdriver_sd.virtualizable_info
         if vinfo is not None:
             # xxx only write back the fields really modified
             vbox = self.virtualizable_boxes[-1]
@@ -1967,7 +1973,7 @@
             assert i + 1 == len(self.virtualizable_boxes)
 
     def gen_load_from_other_virtualizable(self, vbox):
-        vinfo = self.staticdata.virtualizable_info
+        vinfo = self.jitdriver_sd.virtualizable_info
         boxes = []
         assert vinfo is not None
         for i in range(vinfo.num_static_extra_boxes):
@@ -1991,7 +1997,7 @@
         for i in range(len(boxes)):
             if boxes[i] is oldbox:
                 boxes[i] = newbox
-        if self.staticdata.virtualizable_info is not None:
+        if self.jitdriver_sd.virtualizable_info is not None:
             boxes = self.virtualizable_boxes
             for i in range(len(boxes)):
                 if boxes[i] is oldbox:
@@ -2047,8 +2053,9 @@
         assert op.opnum == rop.CALL_MAY_FORCE
         num_green_args = self.staticdata.num_green_args
         args = op.args[num_green_args + 1:]
-        if self.staticdata.virtualizable_info is not None:
-            vindex = self.staticdata.virtualizable_info.index_of_virtualizable
+        vinfo = self.jitdriver_sd.virtualizable_info
+        if vinfo is not None:
+            vindex = vinfo.index_of_virtualizable
             vbox = args[vindex - num_green_args]
             args = args + self.gen_load_from_other_virtualizable(vbox)
             # ^^^ and not "+=", which makes 'args' a resizable list

Modified: pypy/branch/multijit-3/pypy/jit/metainterp/test/test_basic.py
==============================================================================
--- pypy/branch/multijit-3/pypy/jit/metainterp/test/test_basic.py	(original)
+++ pypy/branch/multijit-3/pypy/jit/metainterp/test/test_basic.py	Fri Jun 11 12:52:13 2010
@@ -30,9 +30,17 @@
     func._jit_unroll_safe_ = True
     rtyper = support.annotate(func, values, type_system=type_system)
     graphs = rtyper.annotator.translator.graphs
+    result_kind = history.getkind(graphs[0].getreturnvar().concretetype)[0]
+
+    class FakeJitDriverSD:
+        num_green_args = 0
+        portal_graph = graphs[0]
+        virtualizable_info = None
+        result_type = result_kind
+
     stats = history.Stats()
     cpu = CPUClass(rtyper, stats, None, False)
-    cw = codewriter.CodeWriter(cpu, graphs[0])
+    cw = codewriter.CodeWriter(cpu, [FakeJitDriverSD()])
     testself.cw = cw
     cw.find_all_graphs(JitPolicy())
     #
@@ -62,7 +70,8 @@
             count_f += 1
         else:
             raise TypeError(T)
-    blackholeinterp.setposition(cw.mainjitcode, 0)
+    [jitdriver_sd] = cw.callcontrol.jitdrivers_sd
+    blackholeinterp.setposition(jitdriver_sd.mainjitcode, 0)
     blackholeinterp.run()
     return blackholeinterp._final_result_anytype()
 
@@ -86,8 +95,9 @@
     metainterp_sd.DoneWithThisFrameRef = DoneWithThisFrameRef
     metainterp_sd.DoneWithThisFrameFloat = DoneWithThisFrame
     testself.metainterp = metainterp
+    [jitdriver_sd] = cw.callcontrol.jitdrivers_sd
     try:
-        metainterp.compile_and_run_once(*args)
+        metainterp.compile_and_run_once(jitdriver_sd, *args)
     except DoneWithThisFrame, e:
         #if conftest.option.view:
         #    metainterp.stats.view()



More information about the Pypy-commit mailing list