[pypy-svn] r40770 - in pypy/dist/pypy/jit/timeshifter: . test

arigo at codespeak.net arigo at codespeak.net
Mon Mar 19 14:45:37 CET 2007


Author: arigo
Date: Mon Mar 19 14:45:31 2007
New Revision: 40770

Modified:
   pypy/dist/pypy/jit/timeshifter/test/test_timeshift.py
   pypy/dist/pypy/jit/timeshifter/transform.py
Log:
Support for red switches, handled by turning them into a binary search tree.


Modified: pypy/dist/pypy/jit/timeshifter/test/test_timeshift.py
==============================================================================
--- pypy/dist/pypy/jit/timeshifter/test/test_timeshift.py	(original)
+++ pypy/dist/pypy/jit/timeshifter/test/test_timeshift.py	Mon Mar 19 14:45:31 2007
@@ -1592,3 +1592,48 @@
 
         res = self.timeshift(f, [42], policy=P_NOVIRTUAL)
         assert res == -7
+
+    def test_switch(self):
+        def g(n):
+            if n == 0:
+                return 12
+            elif n == 1:
+                return 34
+            elif n == 3:
+                return 56
+            elif n == 7:
+                return 78
+            else:
+                return 90
+        def f(n, m):
+            x = g(n)   # gives a red switch
+            y = g(hint(m, concrete=True))   # gives a green switch
+            return x - y
+
+        res = self.timeshift(f, [7, 2], backendoptimize=True)
+        assert res == 78 - 90
+        res = self.timeshift(f, [8, 1], backendoptimize=True)
+        assert res == 90 - 34
+
+    def test_switch_char(self):
+        def g(n):
+            n = chr(n)
+            if n == '\x00':
+                return 12
+            elif n == '\x01':
+                return 34
+            elif n == '\x02':
+                return 56
+            elif n == '\x03':
+                return 78
+            else:
+                return 90
+        def f(n, m):
+            x = g(n)   # gives a red switch
+            y = g(hint(m, concrete=True))   # gives a green switch
+            return x - y
+
+        res = self.timeshift(f, [3, 0], backendoptimize=True)
+        assert res == 78 - 12
+        res = self.timeshift(f, [2, 4], backendoptimize=True)
+        assert res == 56 - 90

Modified: pypy/dist/pypy/jit/timeshifter/transform.py
==============================================================================
--- pypy/dist/pypy/jit/timeshifter/transform.py	(original)
+++ pypy/dist/pypy/jit/timeshifter/transform.py	Mon Mar 19 14:45:31 2007
@@ -1,3 +1,4 @@
+import sys
 from pypy.objspace.flow.model import Variable, Constant, Block, Link
 from pypy.objspace.flow.model import SpaceOperation, mkentrymap
 from pypy.annotation        import model as annmodel
@@ -267,12 +268,19 @@
 
     def insert_splits(self):
         hannotator = self.hannotator
-        for block in self.graph.iterblocks():
-            if block.exitswitch is not None:
-                assert isinstance(block.exitswitch, Variable)
-                hs_switch = hannotator.binding(block.exitswitch)
-                if not hs_switch.is_green():
-                    self.insert_split_handling(block)
+        retry = True
+        while retry:
+            retry = False
+            for block in list(self.graph.iterblocks()):
+                if block.exitswitch is not None:
+                    assert isinstance(block.exitswitch, Variable)
+                    hs_switch = hannotator.binding(block.exitswitch)
+                    if not hs_switch.is_green():
+                        if block.exitswitch.concretetype is lltype.Bool:
+                            self.insert_split_handling(block)
+                        else:
+                            self.insert_switch_handling(block)
+                            retry = True
 
     def trace_back_bool_var(self, block, v):
         """Return the (opname, arguments) that created the exitswitch of
@@ -326,6 +334,117 @@
                         resulttype = lltype.Bool)
         block.exitswitch = v_flag
 
+    def insert_switch_handling(self, block):
+        v_redswitch = block.exitswitch
+        T = v_redswitch.concretetype
+        range_start = -sys.maxint-1
+        range_stop  = sys.maxint+1
+        if T is not lltype.Signed:
+            if T is lltype.Char:
+                opcast = 'cast_char_to_int'
+                range_start = 0
+                range_stop = 256
+            elif T is lltype.UniChar:
+                opcast = 'cast_unichar_to_int'
+                range_start = 0
+            elif T is lltype.Unsigned:
+                opcast = 'cast_uint_to_int'
+            else:
+                raise AssertionError(T)
+            v_redswitch = self.genop(block, opcast, [v_redswitch],
+                                     resulttype=lltype.Signed, red=True)
+            block.exitswitch = v_redswitch
+        # for now, we always turn the switch back into a chain of tests
+        # that perform a binary search
+        blockset = {block: True}   # reachable from outside
+        cases = {}
+        defaultlink = None
+        for link in block.exits:
+            if link.exitcase == 'default':
+                defaultlink = link
+                blockset[link.target] = False   # not reachable from outside
+            else:
+                assert lltype.typeOf(link.exitcase) == T
+                intval = lltype.cast_primitive(lltype.Signed, link.exitcase)
+                cases[intval] = link
+                link.exitcase = None
+                link.llexitcase = None
+        self.insert_integer_search(block, cases, defaultlink, blockset,
+                                   range_start, range_stop)
+        SSA_to_SSI(blockset, self.hannotator)
+
+    def insert_integer_search(self, block, cases, defaultlink, blockset,
+                              range_start, range_stop):
+        # fix the exit of the 'block' to check for the given remaining
+        # 'cases', knowing that if we get there then the value must
+        # be contained in range(range_start, range_stop).
+        if not cases:
+            assert defaultlink is not None
+            block.exitswitch = None
+            block.recloseblock(Link(defaultlink.args, defaultlink.target))
+        elif len(cases) == 1 and (defaultlink is None or
+                                  range_start == range_stop-1):
+            block.exitswitch = None
+            block.recloseblock(cases.values()[0])
+        else:
+            intvalues = cases.keys()
+            intvalues.sort()
+            if len(intvalues) <= 3:
+                # not much point in being clever with no more than 3 cases
+                intval = intvalues[-1]
+                remainingcases = cases.copy()
+                link = remainingcases.pop(intval)
+                c_intval = inputconst(lltype.Signed, intval)
+                v = self.genop(block, 'int_eq', [block.exitswitch, c_intval],
+                               resulttype=lltype.Bool, red=True)
+                link.exitcase = True
+                link.llexitcase = True
+                falseblock = Block([])
+                falseblock.exitswitch = block.exitswitch
+                blockset[falseblock] = False
+                falselink = Link([], falseblock)
+                falselink.exitcase = False
+                falselink.llexitcase = False
+                block.exitswitch = v
+                block.recloseblock(falselink, link)
+                if defaultlink is None or intval == range_stop-1:
+                    range_stop = intval
+                self.insert_integer_search(falseblock, remainingcases,
+                                           defaultlink, blockset,
+                                           range_start, range_stop)
+            else:
+                intval = intvalues[len(intvalues) // 2]
+                c_intval = inputconst(lltype.Signed, intval)
+                v = self.genop(block, 'int_ge', [block.exitswitch, c_intval],
+                               resulttype=lltype.Bool, red=True)
+                falseblock = Block([])
+                falseblock.exitswitch = block.exitswitch
+                trueblock  = Block([])
+                trueblock.exitswitch = block.exitswitch
+                blockset[falseblock] = False
+                blockset[trueblock]  = False
+                falselink = Link([], falseblock)
+                falselink.exitcase = False
+                falselink.llexitcase = False
+                truelink = Link([], trueblock)
+                truelink.exitcase = True
+                truelink.llexitcase = True
+                block.exitswitch = v
+                block.recloseblock(falselink, truelink)
+                falsecases = {}
+                truecases = {}
+                for intval1, link1 in cases.items():
+                    if intval1 < intval:
+                        falsecases[intval1] = link1
+                    else:
+                        truecases[intval1] = link1
+                self.insert_integer_search(falseblock, falsecases,
+                                           defaultlink, blockset,
+                                           range_start, intval)
+                self.insert_integer_search(trueblock, truecases,
+                                           defaultlink, blockset,
+                                           intval, range_stop)
+
     def get_resume_point_link(self, block):
         try:
             return self.resumepoints[block]



More information about the Pypy-commit mailing list