[pypy-svn] r35052 - in pypy/dist/pypy/jit/timeshifter: . test

ac at codespeak.net ac at codespeak.net
Mon Nov 27 20:21:10 CET 2006


Author: ac
Date: Mon Nov 27 20:21:09 2006
New Revision: 35052

Modified:
   pypy/dist/pypy/jit/timeshifter/hrtyper.py
   pypy/dist/pypy/jit/timeshifter/test/test_portal.py
Log:
(pedronis, arre)
Add support for recursive calls to the portal.



Modified: pypy/dist/pypy/jit/timeshifter/hrtyper.py
==============================================================================
--- pypy/dist/pypy/jit/timeshifter/hrtyper.py	(original)
+++ pypy/dist/pypy/jit/timeshifter/hrtyper.py	Mon Nov 27 20:21:09 2006
@@ -153,11 +153,12 @@
         bk = self.annotator.bookkeeper
         bk.compute_after_normalization()
         entrygraph = self.annotator.translator.graphs[0]
+        self.origportalgraph = origportalgraph
         if origportalgraph:
-            portalgraph = bk.get_graph_by_key(origportalgraph, None)
-            leaveportalgraph = portalgraph
+            self.portalgraph = bk.get_graph_by_key(origportalgraph, None)
+            leaveportalgraph = self.portalgraph
         else:
-            portalgraph = None
+            self.portalgraph = None
             # in the case of tests not specifying a portal
             # we still need to force merges when entry
             # returns
@@ -182,10 +183,26 @@
             self.timeshift_graph(graph)
 
         if origportalgraph:
-            self.rewire_portal(origportalgraph, portalgraph)
+            self.rewire_portal()
+
+    # remember a shared pointer for the portal graph,
+    # so that it can be later patched by rewire_portal.
+    # this pointer is going to be used by the resuming logic
+    # and portal (re)entry.
+    def naked_tsfnptr(self, tsgraph):
+        if tsgraph is self.portalgraph:
+            try:
+                return self.portal_tsfnptr
+            except AttributeError:
+                self.portal_tsfnptr = self.gettscallable(tsgraph)
+                return self.portal_tsfnptr
+        return self.gettscallable(tsgraph)
         
-    def rewire_portal(self, origportalgraph, portalgraph):
+    def rewire_portal(self):
+        origportalgraph = self.origportalgraph
+        portalgraph = self.portalgraph
         annhelper = self.annhelper
+        rgenop = self.RGenOp()
 
         argcolors = []
         portal_args_s = []
@@ -200,12 +217,20 @@
                 portal_args_s.append(self.s_RedBox)
             argcolors.append(color)
 
-        portal_fnptr = self.rtyper.type_system.getcallable(portalgraph)
+        tsportalgraph = portalgraph
+        # patch the shared portal pointer
+        portalgraph = flowmodel.copygraph(tsportalgraph, shallow=True)
+        portal_fnptr = self.naked_tsfnptr(self.portalgraph)
+        portal_fnptr._obj.graph = portalgraph
+        
         portal_fn = PseudoHighLevelCallable(
             portal_fnptr,
             [self.s_JITState] + portal_args_s,
             self.s_JITState)
         FUNC = self.get_residual_functype(portalgraph)
+        RESTYPE = FUNC.RESULT
+        reskind = rgenop.kindToken(RESTYPE)
+        boxbuilder = rvalue.ll_redboxbuilder(RESTYPE)
         argcolors = unrolling_iterable(argcolors)
         fresh_jitstate = self.ll_fresh_jitstate
         finish_jitstate = self.ll_finish_jitstate
@@ -215,7 +240,6 @@
                 self.cache = {}
 
         state = PortalState()
-        rgenop = self.RGenOp()
 
         # debug helper
         def readportal(*args):
@@ -230,12 +254,15 @@
                 i = i + 1
             cache = state.cache
             try:
-                return cache[key]
+                gv_generated = cache[key]
             except KeyError:
                 return lltype.nullptr(FUNC)
-
+            fn = gv_generated.revealconst(lltype.Ptr(FUNC))
+            return fn
+            
         def readallportals():
-            return state.cache.values()
+            return [gv_gen.revealconst(lltype.Ptr(FUNC))
+                    for gv_gen in state.cache.values()]
         
         def portalentry(*args):
             i = 0
@@ -252,12 +279,13 @@
                 i = i + 1
             cache = state.cache
             try:
-                fn = cache[key]
+                gv_generated = cache[key]
             except KeyError:
                 portal_ts_args = ()
                 sigtoken = rgenop.sigToken(FUNC)
                 builder, gv_generated, inputargs_gv = rgenop.newgraph(sigtoken,
                                                              "generated")
+                cache[key] = gv_generated
                 i = 0
                 for color in argcolors:
                     if color == "green":
@@ -281,11 +309,11 @@
                     finish_jitstate(top_jitstate, sigtoken)
 
                 builder.end()
-                fn = gv_generated.revealconst(lltype.Ptr(FUNC))
                 builder.show_incremental_progress()
-                cache[key] = fn
+            fn = gv_generated.revealconst(lltype.Ptr(FUNC))
             return fn(*residualargs)
 
+
         args_s = [annmodel.lltype_to_annotation(v.concretetype) for
                   v in origportalgraph.getargs()]
         s_result = annmodel.lltype_to_annotation(
@@ -297,16 +325,87 @@
         self.readportalgraph = annhelper.getgraph(readportal, args_s,
                                    s_funcptr)
 
-        s_funcptrlist = annmodel.SomeList(listdef.ListDef(None, s_funcptr))
+        s_funcptrlist = annmodel.SomeList(listdef.ListDef(None, s_funcptr,
+                                                          resized=True))
         self.readallportalsgraph = annhelper.getgraph(readallportals, [],
                                                       s_funcptrlist)
 
+        TYPES = [v.concretetype for v in origportalgraph.getargs()]
+        argcolorandtypes = unrolling_iterable(zip(argcolors,
+                                                  TYPES))
+
+        def portalreentry(jitstate, *args):
+            i = 0
+            key = ()
+            curbuilder = jitstate.curbuilder
+            args_gv = []
+            for color in argcolors:
+                if color == "green":
+                    x = args[i]
+                    if isinstance(lltype.typeOf(x), lltype.Ptr): 
+                        x = llmemory.cast_ptr_to_adr(x)
+                    key = key + (x,)
+                else:
+                    box = args[i]
+                    args_gv.append(box.getgenvar(curbuilder))
+                i = i + 1
+            sigtoken = rgenop.sigToken(FUNC)
+            cache = state.cache
+            try:
+                gv_generated = cache[key]
+            except KeyError:
+                portal_ts_args = ()
+                builder, gv_generated, inputargs_gv = rgenop.newgraph(sigtoken,
+                                                                "generated")
+                cache[key] = gv_generated
+                i = 0
+                for color, T in argcolorandtypes:
+                    if color == "green":
+                        llvalue = args[0]
+                        args = args[1:]
+                        portal_ts_args += (llvalue,)
+                    else:
+                        args = args[1:]
+                        kind = rgenop.kindToken(T)
+                        boxcls = rvalue.ll_redboxcls(T)
+                        gv_arg = inputargs_gv[i]
+                        box = boxcls(kind, gv_arg)
+                        i += 1
+                        portal_ts_args += (box,)
+
+                top_jitstate = fresh_jitstate(builder)
+                top_jitstate = portal_fn(top_jitstate, *portal_ts_args)
+                if top_jitstate is not None:
+                    finish_jitstate(top_jitstate, sigtoken)
+
+                builder.end()
+                builder.show_incremental_progress()
+
+ 
+            gv_res = curbuilder.genop_call(sigtoken, gv_generated, args_gv)
+            if RESTYPE == lltype.Void:
+                retbox = None
+            else:
+                retbox = boxbuilder(reskind, gv_res)
+                
+            jitstate.returnbox = retbox
+            assert jitstate.next is None
+            return jitstate
+
+        portalreentrygraph = annhelper.getgraph(portalreentry,
+                [self.s_JITState] + portal_args_s, self.s_JITState)
+        portalreentrygraph.tag = "portal_reentry"
+
         annhelper.finish()
 
         origportalgraph.startblock = portalentrygraph.startblock
         origportalgraph.returnblock = portalentrygraph.returnblock
         origportalgraph.exceptblock = portalentrygraph.exceptblock
-        # name, func?
+
+        tsportalgraph.startblock = portalreentrygraph.startblock
+        tsportalgraph.returnblock = portalreentrygraph.returnblock
+        tsportalgraph.exceptblock = portalreentrygraph.exceptblock
+        
 
     def transform_graph(self, graph, is_portal=False):
         # prepare the graphs by inserting all bookkeeping/dispatching logic
@@ -972,7 +1071,7 @@
         attrname = hop.args_v[1].value
         N = mpfamily.resumepoint_after_mergepoint[attrname]
         tsgraph = mpfamily.tsgraph
-        ts_fnptr = self.gettscallable(tsgraph)
+        ts_fnptr = self.naked_tsfnptr(tsgraph)
         TS_FUNC = lltype.typeOf(ts_fnptr)
         dummy_args = [ARG._defl() for ARG in TS_FUNC.TO.ARGS[1:]]
         dummy_args = tuple(dummy_args)

Modified: pypy/dist/pypy/jit/timeshifter/test/test_portal.py
==============================================================================
--- pypy/dist/pypy/jit/timeshifter/test/test_portal.py	(original)
+++ pypy/dist/pypy/jit/timeshifter/test/test_portal.py	Mon Nov 27 20:21:09 2006
@@ -309,3 +309,47 @@
         res = self.timeshift_from_portal(ll_function, ll_function, [0], policy=P_NOVIRTUAL)
         assert res == ord('2')
         self.check_insns(indirect_call=0, malloc=0)
+
+    def test_simple_recursive_portal_call(self):
+
+        def main(code, x):
+            return evaluate(code, x)
+
+        def evaluate(y, x):
+            hint(y, concrete=True)
+            if y <= 0:
+                return x
+            z = 1 + evaluate(y - 1, x)
+            return z
+
+        res = self.timeshift_from_portal(main, evaluate, [3, 2])
+        assert res == 5
+
+        res = self.timeshift_from_portal(main, evaluate, [3, 5])
+        assert res == 8
+
+        res = self.timeshift_from_portal(main, evaluate, [4, 7])
+        assert res == 11
+    
+
+    def test_simple_recursive_portal_call2(self):
+
+        def main(code, x):
+            return evaluate(code, x)
+
+        def evaluate(y, x):
+            hint(y, concrete=True)
+            if x <= 0:
+                return y
+            z = evaluate(y, x - 1) + 1
+            return z
+
+        res = self.timeshift_from_portal(main, evaluate, [3, 2])
+        assert res == 5
+
+        res = self.timeshift_from_portal(main, evaluate, [3, 5])
+        assert res == 8
+
+        res = self.timeshift_from_portal(main, evaluate, [4, 7])
+        assert res == 11
+    



More information about the Pypy-commit mailing list