[pypy-svn] r75261 - in pypy/branch/multijit-3/pypy/jit/codewriter: . test

arigo at codespeak.net arigo at codespeak.net
Fri Jun 11 11:52:44 CEST 2010


Author: arigo
Date: Fri Jun 11 11:52:42 2010
New Revision: 75261

Modified:
   pypy/branch/multijit-3/pypy/jit/codewriter/call.py
   pypy/branch/multijit-3/pypy/jit/codewriter/codewriter.py
   pypy/branch/multijit-3/pypy/jit/codewriter/jtransform.py
   pypy/branch/multijit-3/pypy/jit/codewriter/test/test_call.py
   pypy/branch/multijit-3/pypy/jit/codewriter/test/test_codewriter.py
   pypy/branch/multijit-3/pypy/jit/codewriter/test/test_flatten.py
Log:
Fix codewriter to do (in a single pass) the transformation
for multiple JitDrivers.


Modified: pypy/branch/multijit-3/pypy/jit/codewriter/call.py
==============================================================================
--- pypy/branch/multijit-3/pypy/jit/codewriter/call.py	(original)
+++ pypy/branch/multijit-3/pypy/jit/codewriter/call.py	Fri Jun 11 11:52:42 2010
@@ -16,21 +16,22 @@
 
 class CallControl(object):
     virtualref_info = None     # optionally set from outside
-    virtualizable_info = None  # optionally set from outside
-    portal_runner_ptr = None   # optionally set from outside
 
-    def __init__(self, cpu=None, portal_graph=None):
+    def __init__(self, cpu=None, jitdrivers_sd=[]):
+        assert isinstance(jitdrivers_sd, list)   # debugging
         self.cpu = cpu
-        self.portal_graph = portal_graph
+        self.jitdrivers_sd = jitdrivers_sd
         self.jitcodes = {}             # map {graph: jitcode}
         self.unfinished_graphs = []    # list of graphs with pending jitcodes
-        self.jitdriver = None
         if hasattr(cpu, 'rtyper'):     # for tests
             self.rtyper = cpu.rtyper
             translator = self.rtyper.annotator.translator
             self.raise_analyzer = RaiseAnalyzer(translator)
             self.readwrite_analyzer = ReadWriteAnalyzer(translator)
             self.virtualizable_analyzer = VirtualizableAnalyzer(translator)
+        #
+        for index, jd in enumerate(jitdrivers_sd):
+            jd.index = index
 
     def find_all_graphs(self, policy):
         try:
@@ -41,8 +42,8 @@
         def is_candidate(graph):
             return policy.look_inside_graph(graph)
 
-        assert self.portal_graph is not None
-        todo = [self.portal_graph]
+        assert len(self.jitdrivers_sd) > 0
+        todo = [jd.portal_graph for jd in self.jitdrivers_sd]
         if hasattr(self, 'rtyper'):
             for oopspec_name, ll_args, ll_res in support.inline_calls_to:
                 c_func, _ = support.builtin_func_for_spec(self.rtyper,
@@ -122,7 +123,7 @@
     def guess_call_kind(self, op, is_candidate=None):
         if op.opname == 'direct_call':
             funcptr = op.args[0].value
-            if funcptr is self.portal_runner_ptr:
+            if self.jitdriver_sd_from_portal_runner_ptr(funcptr) is not None:
                 return 'recursive'
             funcobj = get_funcobj(funcptr)
             if getattr(funcobj, 'graph', None) is None:
@@ -143,6 +144,10 @@
         # used only after find_all_graphs()
         return graph in self.candidate_graphs
 
+    def grab_initial_jitcodes(self):
+        for jd in self.jitdrivers_sd:
+            jd.mainjitcode = self.get_jitcode(jd.portal_graph)
+
     def enum_pending_graphs(self):
         while self.unfinished_graphs:
             graph = self.unfinished_graphs.pop()
@@ -241,12 +246,26 @@
         return (effectinfo is None or
                 effectinfo.extraeffect >= EffectInfo.EF_CAN_RAISE)
 
-    def found_jitdriver(self, jitdriver):
-        if self.jitdriver is None:
-            self.jitdriver = jitdriver
-        else:
-            assert self.jitdriver is jitdriver
+    def jitdriver_sd_from_portal_graph(self, graph):
+        for jd in self.jitdrivers_sd:
+            if jd.portal_graph is graph:
+                return jd
+        return None
 
-    def getjitdriver(self):
-        assert self.jitdriver is not None, "order dependency issue?"
-        return self.jitdriver
+    def jitdriver_sd_from_portal_runner_ptr(self, funcptr):
+        for jd in self.jitdrivers_sd:
+            if funcptr is jd.portal_runner_ptr:
+                return jd
+        return None
+
+    def get_vinfo(self, VTYPEPTR):
+        seen = set()
+        for jd in self.jitdrivers_sd:
+            if jd.virtualizable_info is not None:
+                if jd.virtualizable_info.is_vtypeptr(VTYPEPTR):
+                    seen.add(jd.virtualizable_info)
+        if seen:
+            assert len(seen) == 1
+            return seen.pop()
+        else:
+            return None

Modified: pypy/branch/multijit-3/pypy/jit/codewriter/codewriter.py
==============================================================================
--- pypy/branch/multijit-3/pypy/jit/codewriter/codewriter.py	(original)
+++ pypy/branch/multijit-3/pypy/jit/codewriter/codewriter.py	Fri Jun 11 11:52:42 2010
@@ -14,29 +14,30 @@
 class CodeWriter(object):
     callcontrol = None    # for tests
 
-    def __init__(self, cpu=None, maingraph=None):
+    def __init__(self, cpu=None, jitdrivers_sd=[]):
         self.cpu = cpu
         self.assembler = Assembler()
-        self.portal_graph = maingraph
-        self.callcontrol = CallControl(cpu, maingraph)
+        self.callcontrol = CallControl(cpu, jitdrivers_sd)
+        self._seen_files = set()
 
     def transform_func_to_jitcode(self, func, values, type_system='lltype'):
         """For testing."""
         rtyper = support.annotate(func, values, type_system=type_system)
         graph = rtyper.annotator.translator.graphs[0]
         jitcode = JitCode("test")
-        self.transform_graph_to_jitcode(graph, jitcode, True, True)
+        self.transform_graph_to_jitcode(graph, jitcode, True)
         return jitcode
 
-    def transform_graph_to_jitcode(self, graph, jitcode, portal, verbose):
+    def transform_graph_to_jitcode(self, graph, jitcode, verbose):
         """Transform a graph into a JitCode containing the same bytecode
         in a different format.
         """
+        portal_jd = self.callcontrol.jitdriver_sd_from_portal_graph(graph)
         graph = copygraph(graph, shallowvars=True)
         #
         # step 1: mangle the graph so that it contains the final instructions
         # that we want in the JitCode, but still as a control flow graph
-        transform_graph(graph, self.cpu, self.callcontrol, portal)
+        transform_graph(graph, self.cpu, self.callcontrol)
         #
         # step 2: perform register allocation on it
         regallocs = {}
@@ -59,16 +60,14 @@
         self.assembler.assemble(ssarepr, jitcode)
         #
         # print the resulting assembler
-        self.print_ssa_repr(ssarepr, portal, verbose)
+        self.print_ssa_repr(ssarepr, portal_jd, verbose)
 
     def make_jitcodes(self, verbose=False):
         log.info("making JitCodes...")
-        maingraph = self.portal_graph
-        self.mainjitcode = self.callcontrol.get_jitcode(maingraph)
+        self.callcontrol.grab_initial_jitcodes()
         count = 0
         for graph, jitcode in self.callcontrol.enum_pending_graphs():
-            self.transform_graph_to_jitcode(graph, jitcode,
-                                            graph is maingraph, verbose)
+            self.transform_graph_to_jitcode(graph, jitcode, verbose)
             count += 1
             if not count % 500:
                 log.info("Produced %d jitcodes" % count)
@@ -76,33 +75,35 @@
         log.info("there are %d JitCode instances." % count)
 
     def setup_vrefinfo(self, vrefinfo):
+        # must be called at most once
+        assert self.callcontrol.virtualref_info is None
         self.callcontrol.virtualref_info = vrefinfo
 
-    def setup_virtualizable_info(self, vinfo):
-        self.callcontrol.virtualizable_info = vinfo
-
-    def setup_portal_runner_ptr(self, portal_runner_ptr):
-        self.callcontrol.portal_runner_ptr = portal_runner_ptr
+    def setup_jitdriver(self, jitdriver_sd):
+        # Must be called once per jitdriver.  Usually jitdriver_sd is an
+        # instance of pypy.jit.metainterp.jitdriver.JitDriverStaticData.
+        self.callcontrol.jitdrivers_sd.append(jitdriver_sd)
 
     def find_all_graphs(self, policy):
         return self.callcontrol.find_all_graphs(policy)
 
-    def print_ssa_repr(self, ssarepr, portal, verbose):
+    def print_ssa_repr(self, ssarepr, portal_jitdriver, verbose):
         if verbose:
             print '%s:' % (ssarepr.name,)
             print format_assembler(ssarepr)
         else:
             dir = udir.ensure("jitcodes", dir=1)
-            if portal:
-                name = "00_portal_runner"
+            if portal_jitdriver:
+                name = "%02d_portal_runner" % (portal_jitdriver.index,)
             elif ssarepr.name and ssarepr.name != '?':
                 name = ssarepr.name
             else:
                 name = 'unnamed' % id(ssarepr)
             i = 1
             extra = ''
-            while dir.join(name+extra).check(exists=1):
+            while name+extra in self._seen_files:
                 i += 1
                 extra = '.%d' % i
+            self._seen_files.add(name+extra)
             dir.join(name+extra).write(format_assembler(ssarepr))
             log.dot()

Modified: pypy/branch/multijit-3/pypy/jit/codewriter/jtransform.py
==============================================================================
--- pypy/branch/multijit-3/pypy/jit/codewriter/jtransform.py	(original)
+++ pypy/branch/multijit-3/pypy/jit/codewriter/jtransform.py	Fri Jun 11 11:52:42 2010
@@ -13,12 +13,12 @@
 from pypy.translator.simplify import get_funcobj
 
 
-def transform_graph(graph, cpu=None, callcontrol=None, portal=True):
+def transform_graph(graph, cpu=None, callcontrol=None):
     """Transform a control flow graph to make it suitable for
     being flattened in a JitCode.
     """
     t = Transformer(cpu, callcontrol)
-    t.transform(graph, portal)
+    t.transform(graph)
 
 
 class Transformer(object):
@@ -27,9 +27,8 @@
         self.cpu = cpu
         self.callcontrol = callcontrol
 
-    def transform(self, graph, portal):
+    def transform(self, graph):
         self.graph = graph
-        self.portal = portal
         for block in list(graph.iterblocks()):
             self.optimize_block(block)
 
@@ -317,10 +316,12 @@
         return op1
 
     def handle_recursive_call(self, op):
-        ops = self.promote_greens(op.args[1:])
-        targetgraph = self.callcontrol.portal_graph
-        num_green_args = len(self.callcontrol.getjitdriver().greens)
-        args = (self.make_three_lists(op.args[1:1+num_green_args]) +
+        jitdriver_sd = self.callcontrol.jitdriver_sd_from_portal_runner_ptr(
+            op.args[0])
+        ops = self.promote_greens(op.args[1:], jitdriver_sd.jitdriver)
+        num_green_args = len(jitdriver_sd.jitdriver.greens)
+        args = ([Constant(jitdriver_sd.index, lltype.Signed)] +
+                self.make_three_lists(op.args[1:1+num_green_args]) +
                 self.make_three_lists(op.args[1+num_green_args:]))
         kind = getkind(op.result.concretetype)[0]
         op0 = SpaceOperation('recursive_call_%s' % kind, args, op.result)
@@ -475,14 +476,14 @@
         # check for virtualizable
         try:
             if self.is_virtualizable_getset(op):
-                descr = self.get_virtualizable_field_descr(op.args[1].value)
+                descr = self.get_virtualizable_field_descr(op)
                 kind = getkind(RESULT)[0]
                 return [SpaceOperation('-live-', [], None),
                         SpaceOperation('getfield_vable_%s' % kind,
                                        [v_inst, descr], op.result)]
-        except VirtualizableArrayField:
+        except VirtualizableArrayField, e:
             # xxx hack hack hack
-            vinfo = self.callcontrol.virtualizable_info
+            vinfo = e.args[1]
             arrayindex = vinfo.array_field_counter[op.args[1].value]
             arrayfielddescr = vinfo.array_field_descrs[arrayindex]
             arraydescr = vinfo.array_descrs[arrayindex]
@@ -519,7 +520,7 @@
             return
         # check for virtualizable
         if self.is_virtualizable_getset(op):
-            descr = self.get_virtualizable_field_descr(op.args[1].value)
+            descr = self.get_virtualizable_field_descr(op)
             kind = getkind(RESULT)[0]
             return [SpaceOperation('-live-', [], None),
                     SpaceOperation('setfield_vable_%s' % kind,
@@ -536,21 +537,23 @@
         return (op.args[1].value == 'typeptr' and
                 op.args[0].concretetype.TO._hints.get('typeptr'))
 
+    def get_vinfo(self, v_virtualizable):
+        if self.callcontrol is None:      # for tests
+            return None
+        return self.callcontrol.get_vinfo(v_virtualizable.concretetype)
+
     def is_virtualizable_getset(self, op):
         # every access of an object of exactly the type VTYPEPTR is
         # likely to be a virtualizable access, but we still have to
         # check it in pyjitpl.py.
-        try:
-            vinfo = self.callcontrol.virtualizable_info
-        except AttributeError:
-            return False
-        if vinfo is None or not vinfo.is_vtypeptr(op.args[0].concretetype):
+        vinfo = self.get_vinfo(op.args[0])
+        if vinfo is None:
             return False
         res = False
         if op.args[1].value in vinfo.static_field_to_extra_box:
             res = True
         if op.args[1].value in vinfo.array_fields:
-            res = VirtualizableArrayField(self.graph)
+            res = VirtualizableArrayField(self.graph, vinfo)
 
         if res:
             flags = self.vable_flags[op.args[0]]
@@ -560,8 +563,9 @@
             raise res
         return res
 
-    def get_virtualizable_field_descr(self, fieldname):
-        vinfo = self.callcontrol.virtualizable_info
+    def get_virtualizable_field_descr(self, op):
+        fieldname = op.args[1].value
+        vinfo = self.get_vinfo(op.args[0])
         index = vinfo.static_field_to_extra_box[fieldname]
         return vinfo.static_field_descrs[index]
 
@@ -751,9 +755,10 @@
             return Constant(value, lltype.Bool)
         return op
 
-    def promote_greens(self, args):
+    def promote_greens(self, args, jitdriver):
         ops = []
-        num_green_args = len(self.callcontrol.getjitdriver().greens)
+        num_green_args = len(jitdriver.greens)
+        assert len(args) == num_green_args + len(jitdriver.reds)
         for v in args[:num_green_args]:
             if isinstance(v, Variable) and v.concretetype is not lltype.Void:
                 kind = getkind(v.concretetype)
@@ -763,20 +768,19 @@
         return ops
 
     def rewrite_op_jit_marker(self, op):
-        self.callcontrol.found_jitdriver(op.args[1].value)
         key = op.args[0].value
-        return getattr(self, 'handle_jit_marker__%s' % key)(op)
+        jitdriver = op.args[1].value
+        return getattr(self, 'handle_jit_marker__%s' % key)(op, jitdriver)
 
-    def handle_jit_marker__jit_merge_point(self, op):
-        assert self.portal, "jit_merge_point in non-main graph!"
-        ops = self.promote_greens(op.args[2:])
-        num_green_args = len(self.callcontrol.getjitdriver().greens)
+    def handle_jit_marker__jit_merge_point(self, op, jitdriver):
+        ops = self.promote_greens(op.args[2:], jitdriver)
+        num_green_args = len(jitdriver.greens)
         args = (self.make_three_lists(op.args[2:2+num_green_args]) +
                 self.make_three_lists(op.args[2+num_green_args:]))
         op1 = SpaceOperation('jit_merge_point', args, None)
         return ops + [op1]
 
-    def handle_jit_marker__can_enter_jit(self, op):
+    def handle_jit_marker__can_enter_jit(self, op, jitdriver):
         return SpaceOperation('can_enter_jit', [], None)
 
     def rewrite_op_debug_assert(self, op):
@@ -975,9 +979,8 @@
 
     def rewrite_op_jit_force_virtualizable(self, op):
         # this one is for virtualizables
-        vinfo = self.callcontrol.virtualizable_info
+        vinfo = self.get_vinfo(op.args[0])
         assert vinfo is not None
-        assert vinfo.is_vtypeptr(op.args[0].concretetype)
         self.vable_flags[op.args[0]] = op.args[2].value
         return []
 

Modified: pypy/branch/multijit-3/pypy/jit/codewriter/test/test_call.py
==============================================================================
--- pypy/branch/multijit-3/pypy/jit/codewriter/test/test_call.py	(original)
+++ pypy/branch/multijit-3/pypy/jit/codewriter/test/test_call.py	Fri Jun 11 11:52:42 2010
@@ -52,13 +52,19 @@
 
 # ____________________________________________________________
 
+class FakeJitDriverSD:
+    def __init__(self, portal_graph):
+        self.portal_graph = portal_graph
+        self.portal_runner_ptr = "???"
+
 def test_find_all_graphs():
     def g(x):
         return x + 2
     def f(x):
         return g(x) + 1
     rtyper = support.annotate(f, [7])
-    cc = CallControl(portal_graph=rtyper.annotator.translator.graphs[0])
+    jitdriver_sd = FakeJitDriverSD(rtyper.annotator.translator.graphs[0])
+    cc = CallControl(jitdrivers_sd=[jitdriver_sd])
     res = cc.find_all_graphs(FakePolicy())
     funcs = set([graph.func for graph in res])
     assert funcs == set([f, g])
@@ -69,7 +75,8 @@
     def f(x):
         return g(x) + 1
     rtyper = support.annotate(f, [7])
-    cc = CallControl(portal_graph=rtyper.annotator.translator.graphs[0])
+    jitdriver_sd = FakeJitDriverSD(rtyper.annotator.translator.graphs[0])
+    cc = CallControl(jitdrivers_sd=[jitdriver_sd])
     class CustomFakePolicy:
         def look_inside_graph(self, graph):
             assert graph.name == 'g'
@@ -83,10 +90,11 @@
 def test_guess_call_kind_and_calls_from_graphs():
     class portal_runner_obj:
         graph = object()
+    class FakeJitDriverSD:
+        portal_runner_ptr = portal_runner_obj
     g = object()
     g1 = object()
-    cc = CallControl()
-    cc.portal_runner_ptr = portal_runner_obj
+    cc = CallControl(jitdrivers_sd=[FakeJitDriverSD()])
     cc.candidate_graphs = [g, g1]
 
     op = SpaceOperation('direct_call', [Constant(portal_runner_obj)],

Modified: pypy/branch/multijit-3/pypy/jit/codewriter/test/test_codewriter.py
==============================================================================
--- pypy/branch/multijit-3/pypy/jit/codewriter/test/test_codewriter.py	(original)
+++ pypy/branch/multijit-3/pypy/jit/codewriter/test/test_codewriter.py	Fri Jun 11 11:52:42 2010
@@ -35,6 +35,12 @@
     def look_inside_graph(self, graph):
         return graph.name != 'dont_look'
 
+class FakeJitDriverSD:
+    def __init__(self, portal_graph):
+        self.portal_graph = portal_graph
+        self.portal_runner_ptr = "???"
+        self.virtualizable_info = None
+
 
 def test_loop():
     def f(a, b):
@@ -70,11 +76,11 @@
     def fff(a, b):
         return ggg(b) - ggg(a)
     rtyper = support.annotate(fff, [35, 42])
-    maingraph = rtyper.annotator.translator.graphs[0]
-    cw = CodeWriter(FakeCPU(rtyper), maingraph)
+    jitdriver_sd = FakeJitDriverSD(rtyper.annotator.translator.graphs[0])
+    cw = CodeWriter(FakeCPU(rtyper), [jitdriver_sd])
     cw.find_all_graphs(FakePolicy())
     cw.make_jitcodes(verbose=True)
-    jitcode = cw.mainjitcode
+    jitcode = jitdriver_sd.mainjitcode
     print jitcode.dump()
     [jitcode2] = cw.assembler.descrs
     print jitcode2.dump()
@@ -117,7 +123,7 @@
         return x().id + y().id + dont_look(n)
     rtyper = support.annotate(f, [35])
     maingraph = rtyper.annotator.translator.graphs[0]
-    cw = CodeWriter(FakeCPU(rtyper), maingraph)
+    cw = CodeWriter(FakeCPU(rtyper), [FakeJitDriverSD(maingraph)])
     cw.find_all_graphs(FakePolicy())
     cw.make_jitcodes(verbose=True)
     #
@@ -144,10 +150,10 @@
     def f(n):
         return abs(n)
     rtyper = support.annotate(f, [35])
-    maingraph = rtyper.annotator.translator.graphs[0]
-    cw = CodeWriter(FakeCPU(rtyper), maingraph)
+    jitdriver_sd = FakeJitDriverSD(rtyper.annotator.translator.graphs[0])
+    cw = CodeWriter(FakeCPU(rtyper), [jitdriver_sd])
     cw.find_all_graphs(FakePolicy())
     cw.make_jitcodes(verbose=True)
     #
-    s = cw.mainjitcode.dump()
+    s = jitdriver_sd.mainjitcode.dump()
     assert "inline_call_ir_i <JitCode '_ll_1_int_abs__Signed'>" in s

Modified: pypy/branch/multijit-3/pypy/jit/codewriter/test/test_flatten.py
==============================================================================
--- pypy/branch/multijit-3/pypy/jit/codewriter/test/test_flatten.py	(original)
+++ pypy/branch/multijit-3/pypy/jit/codewriter/test/test_flatten.py	Fri Jun 11 11:52:42 2010
@@ -68,11 +68,8 @@
         return FakeDescr()
     def calldescr_canraise(self, calldescr):
         return calldescr is not self._descr_cannot_raise
-    def found_jitdriver(self, jitdriver):
-        assert isinstance(jitdriver, JitDriver)
-        self.jitdriver = jitdriver
-    def getjitdriver(self):
-        return self.jitdriver
+    def get_vinfo(self, VTYPEPTR):
+        return None
 
 class FakeCallControlWithVRefInfo:
     class virtualref_info:



More information about the Pypy-commit mailing list