[pypy-svn] r51211 - in pypy/branch/jit-refactoring/pypy/jit: rainbow rainbow/test timeshifter

cfbolz at codespeak.net cfbolz at codespeak.net
Sat Feb 2 16:29:42 CET 2008


Author: cfbolz
Date: Sat Feb  2 16:29:41 2008
New Revision: 51211

Modified:
   pypy/branch/jit-refactoring/pypy/jit/rainbow/bytecode.py
   pypy/branch/jit-refactoring/pypy/jit/rainbow/test/test_interpreter.py
   pypy/branch/jit-refactoring/pypy/jit/rainbow/test/test_serializegraph.py
   pypy/branch/jit-refactoring/pypy/jit/timeshifter/greenkey.py
   pypy/branch/jit-refactoring/pypy/jit/timeshifter/hrtyper.py
   pypy/branch/jit-refactoring/pypy/jit/timeshifter/rtimeshift.py
Log:
Implement (local) merge points in the rainbow interpreter.  This breaks many
(most?) existing jit tests because some details in rtimeshift.py changed.


Modified: pypy/branch/jit-refactoring/pypy/jit/rainbow/bytecode.py
==============================================================================
--- pypy/branch/jit-refactoring/pypy/jit/rainbow/bytecode.py	(original)
+++ pypy/branch/jit-refactoring/pypy/jit/rainbow/bytecode.py	Sat Feb  2 16:29:41 2008
@@ -1,9 +1,10 @@
 from pypy.rlib.rarithmetic import intmask
+from pypy.rlib.unroll import unrolling_iterable
 from pypy.objspace.flow import model as flowmodel
 from pypy.rpython.lltypesystem import lltype
 from pypy.jit.hintannotator.model import originalconcretetype
 from pypy.jit.timeshifter import rtimeshift, rvalue
-from pypy.rlib.unroll import unrolling_iterable
+from pypy.jit.timeshifter.greenkey import KeyDesc, empty_key, GreenKey
 
 class JitCode(object):
     """
@@ -18,11 +19,14 @@
     green consts are negative indexes
     """
 
-    def __init__(self, code, constants, typekinds, redboxclasses):
+    def __init__(self, code, constants, typekinds, redboxclasses, keydescs,
+                 num_mergepoints):
         self.code = code
         self.constants = constants
         self.typekinds = typekinds
         self.redboxclasses = redboxclasses
+        self.keydescs = keydescs
+        self.num_mergepoints = num_mergepoints
 
     def _freeze_(self):
         return True
@@ -46,8 +50,7 @@
 
     def run(self, jitstate, bytecode, greenargs, redargs):
         self.jitstate = jitstate
-        self.queue = rtimeshift.ensure_queue(jitstate,
-                                             rtimeshift.BaseDispatchQueue)
+        self.queue = rtimeshift.DispatchQueue(bytecode.num_mergepoints)
         rtimeshift.enter_frame(self.jitstate, self.queue)
         self.frame = self.jitstate.frame
         self.frame.pc = 0
@@ -66,6 +69,17 @@
             else:
                 assert result is None
 
+    def dispatch(self):
+        newjitstate = rtimeshift.dispatch_next(self.queue)
+        resumepoint = rtimeshift.getresumepoint(newjitstate)
+        self.newjitstate(newjitstate)
+        if resumepoint == -1:
+            # XXX what about green returns?
+            newjitstate = rtimeshift.leave_graph_red(self.queue, is_portal=True)
+            self.newjitstate(newjitstate)
+            return STOP
+        else:
+            self.frame.pc = resumepoint
 
     # operation helper functions
 
@@ -138,21 +152,10 @@
 
     def opimpl_red_return(self):
         rtimeshift.save_return(self.jitstate)
-        newjitstate = rtimeshift.dispatch_next(self.queue)
-        resumepoint = rtimeshift.getresumepoint(newjitstate)
-        if resumepoint == -1:
-            # XXX for now
-            newjitstate = rtimeshift.leave_graph_red(self.queue, is_portal=True)
-            self.newjitstate(newjitstate)
-            return STOP
-        self.newjitstate(newjitstate)
-        self.frame.pc = resumepoint
+        return self.dispatch()
 
     def opimpl_green_return(self):
-        rtimeshift.save_return(self.jitstate)
-        newstate = rtimeshift.leave_graph_yellow(self.queue)
-        self.jitstate = newstate
-        return STOP
+        XXX
 
     def opimpl_make_new_redvars(self):
         # an opcode with a variable number of args
@@ -167,11 +170,28 @@
         # an opcode with a variable number of args
         # num_args arg_old_1 arg_new_1 ...
         num = self.load_2byte()
+        if num == 0 and len(self.frame.local_green) == 0:
+            # fast (very common) case
+            return
         newgreens = []
         for i in range(num):
             newgreens.append(self.get_greenarg())
         self.frame.local_green = newgreens
 
+    def opimpl_merge(self):
+        mergepointnum = self.load_2byte()
+        keydescnum = self.load_2byte()
+        if keydescnum == -1:
+            key = empty_key
+        else:
+            keydesc = self.frame.bytecode.keydescs[keydescnum]
+            key = GreenKey(self.frame.local_green[:keydesc.nb_vals], keydesc)
+        states_dic = self.queue.local_caches[mergepointnum]
+        done = rtimeshift.retrieve_jitstate_for_merge(states_dic, self.jitstate,
+                                                      key, None)
+        if done:
+            return self.dispatch()
+
     # construction-time interface
 
     def _add_implemented_opcodes(self):
@@ -219,7 +239,8 @@
         self.opcode_implementations.append(implementation)
         self.opcode_descs.append(opdesc)
         return index
-            
+
+
 
 
 class BytecodeWriter(object):
@@ -237,6 +258,7 @@
         self.constants = []
         self.typekinds = []
         self.redboxclasses = []
+        self.keydescs = []
         # mapping constant -> index in constants
         self.const_positions = {}
         # mapping blocks to True
@@ -249,6 +271,10 @@
         self.free_green = {}
         # mapping TYPE to index
         self.type_positions = {}
+        # mapping tuple of green TYPES to index
+        self.keydesc_positions = {}
+
+        self.num_mergepoints = 0
 
         self.graph = graph
         self.entrymap = flowmodel.mkentrymap(graph)
@@ -257,7 +283,9 @@
         return JitCode(assemble(self.interpreter, *self.assembler),
                        self.constants,
                        self.typekinds,
-                       self.redboxclasses)
+                       self.redboxclasses,
+                       self.keydescs,
+                       self.num_mergepoints)
 
     def make_bytecode_block(self, block, insert_goto=False):
         if block in self.seen_blocks:
@@ -271,16 +299,16 @@
         self.free_green[block] = 0
         self.free_red[block] = 0
         self.current_block = block
+
         self.emit(label(block))
         reds, greens = self.sort_by_color(block.inputargs)
         for arg in reds:
             self.register_redvar(arg)
         for arg in greens:
             self.register_greenvar(arg)
-        #self.insert_merges(block)
+        self.insert_merges(block)
         for op in block.operations:
             self.serialize_op(op)
-        #self.insert_splits(block)
         self.insert_exits(block)
         self.current_block = oldblock
 
@@ -288,6 +316,7 @@
         if block.exits == ():
             returnvar, = block.inputargs
             color = self.varcolor(returnvar)
+            assert color == "red" # XXX green return values not supported yet
             index = self.serialize_oparg(color, returnvar)
             self.emit("%s_return" % color)
             self.emit(index)
@@ -312,6 +341,30 @@
         else:
             XXX
 
+    def insert_merges(self, block):
+        if block is self.graph.returnblock:
+            return
+        if len(self.entrymap[block]) <= 1:
+            return
+        num = self.num_mergepoints
+        self.num_mergepoints += 1
+        # make keydesc
+        key = ()
+        for arg in self.sort_by_color(block.inputargs)[1]:
+            TYPE = arg.concretetype
+            key += (TYPE, )
+        if not key:
+            keyindex = -1 # use prebuilt empty_key
+        elif key not in self.keydesc_positions:
+            keyindex = len(self.keydesc_positions)
+            self.keydesc_positions[key] = keyindex
+            self.keydescs.append(KeyDesc(self.RGenOp, *key))
+        else:
+            keyindex = self.keydesc_positions[key]
+        self.emit("merge")
+        self.emit(num)
+        self.emit(keyindex)
+
     def insert_renaming(self, args):
         reds, greens = self.sort_by_color(args)
         for color, args in [("red", reds), ("green", greens)]:

Modified: pypy/branch/jit-refactoring/pypy/jit/rainbow/test/test_interpreter.py
==============================================================================
--- pypy/branch/jit-refactoring/pypy/jit/rainbow/test/test_interpreter.py	(original)
+++ pypy/branch/jit-refactoring/pypy/jit/rainbow/test/test_interpreter.py	Sat Feb  2 16:29:41 2008
@@ -181,7 +181,17 @@
                 return x
             return y
         res = self.interpret(f, [1, 2])
+        assert res == 1
 
+    def test_merge(self):
+        def f(x, y, z):
+            if x:
+                a = y - z
+            else:
+                a = y + z
+            return 1 + a
+        res = self.interpret(f, [1, 2, 3])
+        assert res == 0
 
 class TestLLType(AbstractInterpretationTest):
     type_system = "lltype"

Modified: pypy/branch/jit-refactoring/pypy/jit/rainbow/test/test_serializegraph.py
==============================================================================
--- pypy/branch/jit-refactoring/pypy/jit/rainbow/test/test_serializegraph.py	(original)
+++ pypy/branch/jit-refactoring/pypy/jit/rainbow/test/test_serializegraph.py	Sat Feb  2 16:29:41 2008
@@ -135,6 +135,40 @@
         assert len(jitcode.constants) == 0
         assert len(jitcode.typekinds) == 0
 
+    def test_merge(self):
+        def f(x, y, z):
+            if x:
+                a = y - z
+            else:
+                a = y + z
+            return 1 + a
+        writer, jitcode = self.serialize(f, [int, int, int])
+        expected = assemble(writer.interpreter,
+                            "red_int_is_true", 0,
+                            "red_goto_iftrue", 3, tlabel("add"),
+                            "make_new_redvars", 2, 1, 2,
+                            "make_new_greenvars", 0,
+                            "red_int_add", 0, 1,
+                            "make_new_redvars", 1, 2,
+                            "make_new_greenvars", 0,
+                            label("after"),
+                            "merge", 0, -1,
+                            "make_redbox", -1, 0,
+                            "red_int_add", 1, 0,
+                            "make_new_redvars", 1, 2,
+                            "make_new_greenvars", 0,
+                            "red_return", 0,
+                            label("add"),
+                            "make_new_redvars", 2, 1, 2,
+                            "make_new_greenvars", 0,
+                            "red_int_sub", 0, 1,
+                            "make_new_redvars", 1, 2,
+                            "make_new_greenvars", 0,
+                            "goto", tlabel("after"),
+                            )
+        assert jitcode.code == expected
+        assert len(jitcode.constants) == 1
+        assert len(jitcode.typekinds) == 1
 
 class TestLLType(AbstractSerializationTest):
     type_system = "lltype"

Modified: pypy/branch/jit-refactoring/pypy/jit/timeshifter/greenkey.py
==============================================================================
--- pypy/branch/jit-refactoring/pypy/jit/timeshifter/greenkey.py	(original)
+++ pypy/branch/jit-refactoring/pypy/jit/timeshifter/greenkey.py	Sat Feb  2 16:29:41 2008
@@ -10,13 +10,16 @@
     def __init__(self, RGenOp=None, *TYPES):
         self.RGenOp = RGenOp
         self.TYPES = TYPES
-        TARGETTYPES = []
+        self.nb_vals = len(TYPES)
+        if not TYPES:
+            assert RGenOp is None
 
         if RGenOp is None:
             assert len(TYPES) == 0
             self.hash = lambda self: 0
             self.compare = lambda self, other: True
 
+        TARGETTYPES = []
         for TYPE in TYPES:
             # XXX more cases?
             TARGET = lltype.Signed
@@ -53,6 +56,7 @@
 
 class GreenKey(object):
     def __init__(self, values, desc):
+        assert len(values) == desc.nb_vals
         self.desc = desc
         self.values = values
 

Modified: pypy/branch/jit-refactoring/pypy/jit/timeshifter/hrtyper.py
==============================================================================
--- pypy/branch/jit-refactoring/pypy/jit/timeshifter/hrtyper.py	(original)
+++ pypy/branch/jit-refactoring/pypy/jit/timeshifter/hrtyper.py	Sat Feb  2 16:29:41 2008
@@ -1269,6 +1269,7 @@
         attrname = hop.args_v[1].value
         DispatchQueueSubclass = self.get_dispatch_subclass(mpfamily)
 
+        py.test.skip("broken due to different key handling")
         if global_resumer is not None:
             states_dic = {}
             def merge_point(jitstate, *key):

Modified: pypy/branch/jit-refactoring/pypy/jit/timeshifter/rtimeshift.py
==============================================================================
--- pypy/branch/jit-refactoring/pypy/jit/timeshifter/rtimeshift.py	(original)
+++ pypy/branch/jit-refactoring/pypy/jit/timeshifter/rtimeshift.py	Sat Feb  2 16:29:41 2008
@@ -3,6 +3,7 @@
 from pypy.rpython.lltypesystem import lltype, lloperation, llmemory
 from pypy.jit.hintannotator.model import originalconcretetype
 from pypy.jit.timeshifter import rvalue, rcontainer, rvirtualizable
+from pypy.jit.timeshifter.greenkey import newgreendict, empty_key
 from pypy.rlib.unroll import unrolling_iterable
 from pypy.rpython.annlowlevel import cachedtype, base_ptr_lltype
 from pypy.rpython.annlowlevel import cast_instance_to_base_ptr
@@ -878,31 +879,26 @@
 
 # ____________________________________________________________
 
-class BaseDispatchQueue(object):
+class DispatchQueue(object):
     resuming = None
 
-    def __init__(self):
+    def __init__(self, num_local_caches=0):
         self.split_chain = None
         self.global_merge_chain = None
         self.return_chain = None
+        self.num_local_caches = num_local_caches
         self.clearlocalcaches()
 
     def clearlocalcaches(self):
         self.mergecounter = 0
+        self.local_caches = [newgreendict()
+                                 for i in range(self.num_local_caches)]
 
     def clear(self):
-        self.__init__()
+        self.__init__(self.num_local_caches)
 
 def build_dispatch_subclass(attrnames):
-    if len(attrnames) == 0:
-        return BaseDispatchQueue
-    attrnames = unrolling_iterable(attrnames)
-    class DispatchQueue(BaseDispatchQueue):
-        def clearlocalcaches(self):
-            BaseDispatchQueue.clearlocalcaches(self)
-            for name in attrnames:
-                setattr(self, name, {})     # the new dicts have various types!
-    return DispatchQueue
+    py.test.skip("no longer exists")
 
 
 class FrozenVirtualFrame(object):
@@ -1252,10 +1248,6 @@
     return jitstate
 
 
-def ensure_queue(jitstate, DispatchQueueClass):
-    return DispatchQueueClass()
-ensure_queue._annspecialcase_ = 'specialize:arg(1)'
-
 def replayable_ensure_queue(jitstate, DispatchQueueClass):
     if jitstate.frame is None:    # common case
         return DispatchQueueClass()
@@ -1291,14 +1283,14 @@
 
 def merge_returning_jitstates(dispatchqueue, force_merge=False):
     return_chain = dispatchqueue.return_chain
-    return_cache = {}
+    return_cache = newgreendict()
     still_pending = None
     opened = None
     while return_chain is not None:
         jitstate = return_chain
         return_chain = return_chain.next
         opened = start_writing(jitstate, opened)
-        res = retrieve_jitstate_for_merge(return_cache, jitstate, (),
+        res = retrieve_jitstate_for_merge(return_cache, jitstate, empty_key,
                                           return_marker,
                                           force_merge=force_merge)
         if res is False:    # not finished
@@ -1311,13 +1303,13 @@
     # more general one.
     return_chain = still_pending
     if return_chain is not None:
-        return_cache = {}
+        return_cache = newgreendict()
         still_pending = None
         while return_chain is not None:
             jitstate = return_chain
             return_chain = return_chain.next
             opened = start_writing(jitstate, opened)
-            res = retrieve_jitstate_for_merge(return_cache, jitstate, (),
+            res = retrieve_jitstate_for_merge(return_cache, jitstate, empty_key,
                                               return_marker,
                                               force_merge=force_merge)
             if res is False:    # not finished



More information about the Pypy-commit mailing list