[pypy-svn] r32843 - in pypy/dist/pypy/jit/timeshifter: . test

pedronis at codespeak.net pedronis at codespeak.net
Tue Oct 3 15:42:08 CEST 2006


Author: pedronis
Date: Tue Oct  3 15:42:04 2006
New Revision: 32843

Modified:
   pypy/dist/pypy/jit/timeshifter/rtimeshift.py
   pypy/dist/pypy/jit/timeshifter/rtyper.py
   pypy/dist/pypy/jit/timeshifter/test/test_promotion.py
   pypy/dist/pypy/jit/timeshifter/transform.py
Log:
(arigo, arre, pedronis)

restore support for yellow calls, with support for promotions too with a test.



Modified: pypy/dist/pypy/jit/timeshifter/rtimeshift.py
==============================================================================
--- pypy/dist/pypy/jit/timeshifter/rtimeshift.py	(original)
+++ pypy/dist/pypy/jit/timeshifter/rtimeshift.py	Tue Oct  3 15:42:04 2006
@@ -249,9 +249,9 @@
         node = jitstate.promotion_path
         while not node.cut_limit:
             node = node.next
-        dispatch_queue = jitstate.frame.dispatch_queue
-        count = dispatch_queue.mergecounter + 1
-        dispatch_queue.mergecounter = count
+        dispatchqueue = jitstate.frame.dispatchqueue
+        count = dispatchqueue.mergecounter + 1
+        dispatchqueue.mergecounter = count
         node = PromotionPathMergesToSee(node, count)
         jitstate.promotion_path = node
     else:
@@ -287,23 +287,41 @@
 def collect_split(jitstate_chain, resumepoint, *greens_gv):
     greens_gv = list(greens_gv)
     pending = jitstate_chain
+    resuming = jitstate_chain.resuming
+    if resuming is not None and resuming.mergesleft == 0:
+        node = resuming.path.pop()
+        assert isinstance(node, PromotionPathCollectSplit)
+        for i in range(node.n):
+            pending = pending.next
+        pending.greens.extend(greens_gv)
+        pending.next = None
+        return pending
+
+    n = 0
     while True:
         jitstate = pending
         pending = pending.next
         jitstate.greens.extend(greens_gv)   # item 0 is the return value
         jitstate.resumepoint = resumepoint
+        if resuming is None:
+            node = jitstate.promotion_path
+            jitstate.promotion_path = PromotionPathCollectSplit(node, n)
+            n += 1
         if pending is None:
             break
-    dispatch_queue = jitstate_chain.frame.dispatch_queue
-    jitstate.next = dispatch_queue.split_chain
-    dispatch_queue.split_chain = jitstate_chain.next
+
+    dispatchqueue = jitstate_chain.frame.dispatchqueue
+    jitstate.next = dispatchqueue.split_chain
+    dispatchqueue.split_chain = jitstate_chain.next
+    jitstate_chain.next = None
+    return jitstate_chain
     # XXX obscurity++ above
 
 def dispatch_next(oldjitstate):
-    dispatch_queue = oldjitstate.frame.dispatch_queue
-    if dispatch_queue.split_chain is not None:
-        jitstate = dispatch_queue.split_chain
-        dispatch_queue.split_chain = jitstate.next
+    dispatchqueue = oldjitstate.frame.dispatchqueue
+    if dispatchqueue.split_chain is not None:
+        jitstate = dispatchqueue.split_chain
+        dispatchqueue.split_chain = jitstate.next
         enter_block(jitstate)
         return jitstate
     else:
@@ -344,9 +362,9 @@
 
 def save_return(jitstate):
     # add 'jitstate' to the chain of return-jitstates
-    dispatch_queue = jitstate.frame.dispatch_queue
-    jitstate.next = dispatch_queue.return_chain
-    dispatch_queue.return_chain = jitstate
+    dispatchqueue = jitstate.frame.dispatchqueue
+    jitstate.next = dispatchqueue.return_chain
+    dispatchqueue.return_chain = jitstate
 
 ##def ll_gvar_from_redbox(jitstate, redbox):
 ##    return redbox.getgenvar(jitstate.curbuilder)
@@ -391,6 +409,17 @@
         else:
             self.mergesleft = MC_IGNORE_UNTIL_RETURN
 
+    def leave_call(self, dispatchqueue):
+        parent_mergesleft = dispatchqueue.mergecounter
+        if parent_mergesleft == 0:
+            node = self.path.pop()
+            assert isinstance(node, PromotionPathBackFromReturn)
+            self.merges_to_see()
+        elif parent_mergesleft == MC_CALL_NOT_TAKEN:
+            self.mergesleft = 0
+        else:
+            self.mergesleft = parent_mergesleft
+
 
 class PromotionPoint(object):
     def __init__(self, flexswitch, switchblock, promotion_path):
@@ -446,6 +475,12 @@
 class PromotionPathNo(PromotionPathSplit):
     answer = False
 
+class PromotionPathCollectSplit(PromotionPathNode):
+
+    def __init__(self, next, n):
+        self.next = next
+        self.n = n
+
 class PromotionPathCallNotTaken(PromotionPathNode):
     pass
 
@@ -548,7 +583,7 @@
             # clear the complete state of dispatch queues
             f = jitstate.frame
             while f is not None:
-                f.dispatch_queue.clear()
+                f.dispatchqueue.clear()
                 f = f.backframe
 
             if len(resuming.path) == 0:
@@ -628,7 +663,7 @@
             backframe = self.fz_backframe.unfreeze(incomingvarboxes, memo)
         else:
             backframe = None
-        vframe = VirtualFrame(backframe, BaseDispatchQueue())
+        vframe = VirtualFrame(backframe, None) # dispatch queue to be patched
         vframe.local_boxes = local_boxes
         return vframe
 
@@ -663,9 +698,9 @@
 
 class VirtualFrame(object):
 
-    def __init__(self, backframe, dispatch_queue):
+    def __init__(self, backframe, dispatchqueue):
         self.backframe = backframe
-        self.dispatch_queue = dispatch_queue
+        self.dispatchqueue = dispatchqueue
         #self.local_boxes = ... set by callers
 
     def enter_block(self, incoming, memo):
@@ -687,7 +722,7 @@
             newbackframe = None
         else:
             newbackframe = self.backframe.copy(memo)
-        result = VirtualFrame(newbackframe, self.dispatch_queue)
+        result = VirtualFrame(newbackframe, self.dispatchqueue)
         result.local_boxes = [box.copy(memo) for box in self.local_boxes]
         return result
 
@@ -724,9 +759,9 @@
                                   newgreens,
                                   self.resuming)
         # add the later_jitstate to the chain of pending-for-dispatch_next()
-        dispatch_queue = self.frame.dispatch_queue
-        later_jitstate.next = dispatch_queue.split_chain
-        dispatch_queue.split_chain = later_jitstate
+        dispatchqueue = self.frame.dispatchqueue
+        later_jitstate.next = dispatchqueue.split_chain
+        dispatchqueue.split_chain = later_jitstate
         return later_jitstate
 
     def enter_block(self, incoming, memo):
@@ -776,7 +811,8 @@
     pass
 
 def merge_returning_jitstates(jitstate):
-    return_chain = jitstate.frame.dispatch_queue.return_chain
+    dispatchqueue = jitstate.frame.dispatchqueue
+    return_chain = dispatchqueue.return_chain
     return_cache = {}
     still_pending = None
     while return_chain is not None:
@@ -797,6 +833,11 @@
         res = retrieve_jitstate_for_merge(return_cache, jitstate, (),
                                           return_marker)
         assert res is True   # finished
+
+    resuming = most_general_jitstate.resuming
+    if resuming is not None:
+        resuming.leave_call(dispatchqueue)
+        
     return most_general_jitstate
 
 def leave_graph_red(jitstate):
@@ -815,10 +856,8 @@
 def leave_frame(jitstate):
     myframe = jitstate.frame
     backframe = myframe.backframe
-    jitstate.frame = backframe
-    mydispatchqueue = myframe.dispatch_queue
-    resuming = jitstate.resuming
-    if resuming is None:
+    jitstate.frame = backframe    
+    if jitstate.resuming is None:
         #debug_view(jitstate)
         node = jitstate.promotion_path
         while not node.cut_limit:
@@ -829,42 +868,17 @@
             node = PromotionPathBackFromReturn(node)
             node = PromotionPathMergesToSee(node, 0)
         jitstate.promotion_path = node
-    else:
-        parent_mergesleft = mydispatchqueue.mergecounter
-        if parent_mergesleft == 0:
-            node = resuming.path.pop()
-            assert isinstance(node, PromotionPathBackFromReturn)
-            resuming.merges_to_see()
-        elif parent_mergesleft == MC_CALL_NOT_TAKEN:
-            resuming.mergesleft = 0
-        else:
-            resuming.mergesleft = parent_mergesleft
+
 
 def leave_graph_yellow(jitstate):
-    mydispatchqueue = jitstate.frame.dispatch_queue
+    mydispatchqueue = jitstate.frame.dispatchqueue
     return_chain = mydispatchqueue.return_chain
     jitstate = return_chain
-    resuming = mydispatchqueue.parent_resuming
-    if resuming is None:
-        n = 0
-        parent_promotion_path = mydispatchqueue.parent_promotion_path
-        while jitstate is not None:
-            assert jitstate.resuming is None
-            node = PromotionPathNoWithArg(parent_promotion_path, n)
-            jitstate.promotion_path = node
-            n += 1
-            jitstate.frame = jitstate.frame.backframe
-            jitstate = jitstate.next
-        return return_chain    # a jitstate, which is the head of the chain
-    else:
-        node = resuming.path.pop()
-        assert isinstance(node, PromotionPathNoWithArg)
-        n = node.arg
-        for i in range(n):
-            assert jitstate.resuming is None
-            jitstate = jitstate.next
-        jitstate.resuming = resuming
-        jitstate.promotion_path = None
-        jitstate.frame = jitstate.frame.backframe
-        jitstate.next = None
-        return jitstate
+    resuming = jitstate.resuming
+    if resuming is not None:
+        resuming.leave_call(mydispatchqueue)
+    while jitstate is not None:
+        leave_frame(jitstate)
+        jitstate = jitstate.next
+    return return_chain    # a jitstate, which is the head of the chain
+

Modified: pypy/dist/pypy/jit/timeshifter/rtyper.py
==============================================================================
--- pypy/dist/pypy/jit/timeshifter/rtyper.py	(original)
+++ pypy/dist/pypy/jit/timeshifter/rtyper.py	Tue Oct  3 15:42:04 2006
@@ -239,11 +239,8 @@
         try:
             return self.dispatchsubclasses[mergepointfamily]
         except KeyError:
-            if mergepointfamily.is_global:
-                subclass = rtimeshift.BaseDispatchQueue
-            else:
-                attrnames = mergepointfamily.getattrnames()
-                subclass = rtimeshift.build_dispatch_subclass(attrnames)
+            attrnames = mergepointfamily.getlocalattrnames()
+            subclass = rtimeshift.build_dispatch_subclass(attrnames)
             self.dispatchsubclasses[mergepointfamily] = subclass
             return subclass
 
@@ -729,16 +726,17 @@
         args_s += [self.s_ConstOrVar] * len(greens_v)
         args_v = [v_jitstate, c_resumepoint]
         args_v += greens_v
-        hop.llops.genmixlevelhelpercall(rtimeshift.collect_split,
-                                        args_s, args_v,
-                                        annmodel.s_None)
+        v_newjs = hop.llops.genmixlevelhelpercall(rtimeshift.collect_split,
+                                                  args_s, args_v,
+                                                  self.s_JITState)
+        hop.llops.setjitstate(v_newjs)
 
     def translate_op_merge_point(self, hop, global_resumer=None):
         mpfamily = hop.args_v[0].value
         attrname = hop.args_v[1].value
         DispatchQueueSubclass = self.get_dispatch_subclass(mpfamily)
 
-        if mpfamily.is_global:
+        if global_resumer is not None:
             states_dic = {}
             def merge_point(jitstate, *key):
                 return rtimeshift.retrieve_jitstate_for_merge(states_dic,
@@ -746,9 +744,9 @@
                                                               global_resumer)
         else:
             def merge_point(jitstate, *key):
-                dispatch_queue = jitstate.frame.dispatch_queue
-                assert isinstance(dispatch_queue, DispatchQueueSubclass)
-                states_dic = getattr(dispatch_queue, attrname)
+                dispatchqueue = jitstate.frame.dispatchqueue
+                assert isinstance(dispatchqueue, DispatchQueueSubclass)
+                states_dic = getattr(dispatchqueue, attrname)
                 return rtimeshift.retrieve_jitstate_for_merge(states_dic,
                                                               jitstate, key,
                                                               global_resumer)
@@ -789,7 +787,10 @@
         s_res = self.s_JITState
         tsfn = annlowlevel.PseudoHighLevelCallable(ts_fnptr, args_s, s_res)
 
+        DispatchQueueSubclass = self.get_dispatch_subclass(mpfamily)
+
         def call_for_global_resuming(jitstate):
+            jitstate.frame.dispatchqueue = DispatchQueueSubclass()
             jitstate.resumepoint = N
             try:
                 finaljitstate = tsfn(jitstate, *dummy_args)

Modified: pypy/dist/pypy/jit/timeshifter/test/test_promotion.py
==============================================================================
--- pypy/dist/pypy/jit/timeshifter/test/test_promotion.py	(original)
+++ pypy/dist/pypy/jit/timeshifter/test/test_promotion.py	Tue Oct  3 15:42:04 2006
@@ -81,6 +81,28 @@
         assert res == 4*17 + 10
         self.check_insns(int_mul=0, int_add=1)
 
+    def test_promote_after_yellow_call(self):
+        S = lltype.GcStruct('S', ('x', lltype.Signed))
+        def ll_two(k, s):
+            if k > 5:
+                s.x = 20*k
+                return 7
+            else:
+                s.x = 10*k
+                return 9
+            
+        def ll_function(n):
+            s = lltype.malloc(S)
+            c = ll_two(n, s)
+            k = hint(s.x, promote=True)
+            k += c
+            return hint(k, variable=True)
+        ll_function._global_merge_points_ = True
+
+        res = self.timeshift(ll_function, [4], [], policy=P_NOVIRTUAL)
+        assert res == 49
+        self.check_insns(int_add=0)
+
     def test_promote_inside_call(self):
         def ll_two(n):
             k = hint(n, promote=True)

Modified: pypy/dist/pypy/jit/timeshifter/transform.py
==============================================================================
--- pypy/dist/pypy/jit/timeshifter/transform.py	(original)
+++ pypy/dist/pypy/jit/timeshifter/transform.py	Tue Oct  3 15:42:04 2006
@@ -13,17 +13,25 @@
 
 
 class MergePointFamily(object):
-    def __init__(self, tsgraph, is_global=False):
+    def __init__(self, tsgraph):
         self.tsgraph = tsgraph
-        self.is_global = is_global
         self.count = 0
         self.resumepoint_after_mergepoint = {}
-    def add(self):
+        self.localmergepoints = []
+        
+    def add(self, kind):
         result = self.count
         self.count += 1
-        return 'mp%d' % result
-    def getattrnames(self):
-        return ['mp%d' % i for i in range(self.count)]
+        attrname = 'mp%d' % result
+        if kind == 'local':
+            self.localmergepoints.append(attrname)
+        return attrname
+
+    def getlocalattrnames(self):
+        return self.localmergepoints
+
+    def has_global_mergepoints(self):
+        return bool(self.resumepoint_after_mergepoint)
 
 
 class HintGraphTransformer(object):
@@ -33,11 +41,9 @@
         self.hannotator = hannotator
         self.graph = graph
         self.graphcolor = self.graph_calling_color(graph)
-        self.global_merge_points = self.graph_global_mps(self.graph)
         self.resumepoints = {}
         self.mergepoint_set = {}    # set of blocks
-        self.mergepointfamily = MergePointFamily(graph,
-                                                 self.global_merge_points)
+        self.mergepointfamily = MergePointFamily(graph)
         self.c_mpfamily = inputconst(lltype.Void, self.mergepointfamily)
         self.tsgraphs_seen = []
 
@@ -54,11 +60,15 @@
 
     def compute_merge_points(self):
         entrymap = mkentrymap(self.graph)
+        if self.graph_global_mps(self.graph):
+            kind = 'global'
+        else:
+            kind = 'local'
         for block, links in entrymap.items():
             if len(links) > 1 and block is not self.graph.returnblock:
-                self.mergepoint_set[block] = True
-        if self.global_merge_points:
-            self.mergepoint_set[self.graph.startblock] = True
+                self.mergepoint_set[block] = kind
+        if kind == 'global':
+            self.mergepoint_set[self.graph.startblock] = 'global'
 
     def graph_calling_color(self, tsgraph):
         args_hs, hs_res = self.hannotator.bookkeeper.tsgraphsigs[tsgraph]
@@ -280,17 +290,17 @@
         self.go_to_if(block, self.graph.returnblock, v_finished_flag)
 
     def insert_merge_points(self):
-        for block in self.mergepoint_set:
-            self.insert_merge(block)
+        for block, kind in self.mergepoint_set.items():
+            self.insert_merge(block, kind)
 
-    def insert_merge(self, block):
+    def insert_merge(self, block, kind):
         reds, greens = self.sort_by_color(block.inputargs)
         nextblock = self.naive_split_block(block, 0)
 
         self.genop(block, 'save_locals', reds)
-        mp   = self.mergepointfamily.add()
+        mp   = self.mergepointfamily.add(kind)
         c_mp = inputconst(lltype.Void, mp)
-        if self.global_merge_points:
+        if kind == 'global':
             self.genop(block, 'save_greens', greens)
             prefix = 'global_'
         else:
@@ -313,15 +323,15 @@
         SSA_to_SSI({block    : True,    # reachable from outside
                     nextblock: False}, self.hannotator)
 
-        if self.global_merge_points:
+        if kind == 'global':
             N = self.get_resume_point(nextblock)
             self.mergepointfamily.resumepoint_after_mergepoint[mp] = N
 
     def insert_dispatcher(self):
-        if self.global_merge_points or self.resumepoints:
+        if self.resumepoints:
             block = self.before_return_block()
             self.genop(block, 'dispatch_next', [])
-            if self.global_merge_points:
+            if self.mergepointfamily.has_global_mergepoints():
                 block = self.before_return_block()
                 entryblock = self.before_start_block()
                 v_rp = self.genop(entryblock, 'getresumepoint', [],
@@ -576,8 +586,9 @@
         link.args = []
         link.target = self.get_resume_point_link(nextblock).target
 
-        self.mergepoint_set[nextblock] = True  # to merge some of the possibly
-                                               # many return jitstates
+        # to merge some of the possibly many return jitstates
+        self.mergepoint_set[nextblock] = 'local'  
+        
 
     # __________ hints __________
 



More information about the Pypy-commit mailing list