[pypy-svn] rev 1509 - in pypy/trunk/src/pypy/translator: . test

arigo at codespeak.net arigo at codespeak.net
Wed Oct 1 19:25:17 CEST 2003


Author: arigo
Date: Wed Oct  1 19:25:17 2003
New Revision: 1509

Added:
   pypy/trunk/src/pypy/translator/test/test_typedpyrex.py   (contents, props changed)
Modified:
   pypy/trunk/src/pypy/translator/annotation.py
   pypy/trunk/src/pypy/translator/genpyrex.py
Log:
Emitting type-optimized Pyrex code successfully


Modified: pypy/trunk/src/pypy/translator/annotation.py
==============================================================================
--- pypy/trunk/src/pypy/translator/annotation.py	(original)
+++ pypy/trunk/src/pypy/translator/annotation.py	Wed Oct  1 19:25:17 2003
@@ -39,7 +39,9 @@
             self.annotated[block] = annotations[:]
         else:
             oldannotations = self.annotated[block]
+            #import sys; print >> sys.stderr, block, oldannotations, annotations,
             newannotations = self.unify(oldannotations,annotations)
+            #import sys; print >> sys.stderr, newannotations
             if newannotations == oldannotations:
                 return
             self.annotated[block] = newannotations
@@ -73,32 +75,44 @@
         
         for w_from,w_to in zip(branch.args,branch.target.input_args):
             if isinstance(w_from,Variable):
-                renaming[w_from] = w_to
+                renaming.setdefault(w_from, []).append(w_to)
             else:
                 self.consider_const(w_to,w_from,newannotations)        
 
         def rename(w):
             if isinstance(w,Constant):
-                return w
+                return [w]
             if isinstance(w,GraphGlobalVariable):
-                return w
+                return [w]
             else:
-                return renaming[w]
+                return renaming.get(w, [])
 
-        for ann in self.annotated[curblock]:
-            try:
-                result = rename(ann.result)
-                args = [ rename(arg) for arg in ann.args ]
-            except KeyError:
-                pass
+        def renameall(list_w):
+            if list_w:
+                for w in rename(list_w[0]):
+                    for tail_w in renameall(list_w[1:]):
+                        yield [w] + tail_w
             else:
+                yield []
+
+        for ann in self.annotated[curblock]:
+            # we translate a single SpaceOperation(...) into either
+            # 0 or 1 or multiple ones, by replacing each variable
+            # used in the original operation by (in turn) any of
+            # the variables it can be renamed into
+            for list_w in renameall([ann.result] + ann.args):
+                result = list_w[0]
+                args = list_w[1:]
                 newannotations.append(SpaceOperation(ann.opname,args,result))
 
         self.flowin(branch.target,newannotations)
          
     def flownext_ConditionalBranch(self,branch,curblock):
-        self.flownext(branch.ifbranch,curblock)
-        self.flownext(branch.elsebranch,curblock)
+        # XXX this hack depends on the fact that ConditionalBranches
+        # XXX point to two EggBlocks with *no* renaming necessary
+        curannotations = self.annotated[curblock]
+        self.flowin(branch.ifbranch.target,curannotations)
+        self.flowin(branch.elsebranch.target,curannotations)
 
     def flownext_EndBranch(self,branch,curblock):
         branch = Branch([branch.returnvalue], self.endblock)

Modified: pypy/trunk/src/pypy/translator/genpyrex.py
==============================================================================
--- pypy/trunk/src/pypy/translator/genpyrex.py	(original)
+++ pypy/trunk/src/pypy/translator/genpyrex.py	Wed Oct  1 19:25:17 2003
@@ -6,6 +6,7 @@
 from pypy.tool import test
 from pypy.interpreter.baseobjspace import ObjSpace
 from pypy.translator.flowmodel import *
+from pypy.translator.annotation import Annotator, set_type, get_type
 
 class GenPyrex:
     def __init__(self, functiongraph):
@@ -17,9 +18,19 @@
             oparity[opname] = arity
         self.ops = ops  
         self.oparity = oparity
+        self.annotations = {}
+
+    def annotate(self, input_arg_types):
+        a = Annotator(self.functiongraph)
+        input_ann = []
+        for arg, arg_type in zip(self.functiongraph.get_args(),
+                                 input_arg_types):
+            set_type(arg, arg_type, input_ann)
+        self.annotations = a.build_annotations(input_ann)
 
     def emitcode(self):
         self.blockids = {}
+        self.variablelocations = {}
         self.lines = []
         self.indent = 0
         self.gen_Graph()
@@ -30,16 +41,49 @@
 
     def gen_Graph(self):
         fun = self.functiongraph
-        inputargnames = [ var.pseudoname for var in fun.startblock.input_args ]
-        params = ", ".join(inputargnames)
-        self.putline("def %s(%s):" % (fun.functionname, params))
+        currentlines = self.lines
+        self.lines = []
         self.indent += 1 
         self.gen_BasicBlock(fun.startblock)
         self.indent -= 1
+        # emit the header after the body
+        functionbodylines = self.lines
+        self.lines = currentlines
+        inputargnames = [ self._declvar(var) for var in fun.startblock.input_args ]
+        params = ", ".join(inputargnames)
+        self.putline("def %s(%s):" % (fun.functionname, params))
+        self.indent += 1
+        #self.putline("# %r" % self.annotations)
+        for var in self.variablelocations:
+            if var not in fun.startblock.input_args:
+                self.putline("cdef %s" % self._declvar(var))
+        self.indent -= 1
+        self.lines.extend(functionbodylines)
+
+    def get_type(self, var):
+        block = self.variablelocations.get(var)
+        ann = self.annotations.get(block, [])
+        return get_type(var, ann)
+
+    def get_varname(self, var):
+        if self.get_type(var) == int:
+            prefix = "i_"
+        else:
+            prefix = ""
+        return prefix + var.pseudoname
+
+    def _declvar(self, var):
+        vartype = self.get_type(var)
+        if vartype == int:
+            ctype = "int "
+        else:
+            ctype = "object "
+        return ctype + self.get_varname(var)
 
-    def _str(self, obj):
+    def _str(self, obj, block):
         if isinstance(obj, Variable):
-            return obj.pseudoname
+            self.variablelocations[obj] = block
+            return self.get_varname(obj)
         elif isinstance(obj, Constant):
             return repr(obj.value)
         else:
@@ -59,40 +103,39 @@
             opsymbol = self.ops[op.opname] 
             arity = self.oparity[op.opname]
             assert(arity == len(op.args))
-            argnames = [self._str(arg) for arg in op.args]
+            argnames = [self._str(arg, block) for arg in op.args]
             if arity == 1 or arity == 3 or "a" <= opsymbol[0] <= "z":
-                
-                self.putline("%s = %s(%s)" % (op.result.pseudoname, opsymbol, ", ".join([argnames])))
+                self.putline("%s = %s(%s)" % (self._str(op.result, block), opsymbol, ", ".join(argnames)))
             else:
-                self.putline("%s = %s %s %s" % (op.result.pseudoname, argnames[0], opsymbol, argnames[1]))
+                self.putline("%s = %s %s %s" % (self._str(op.result, block), argnames[0], opsymbol, argnames[1]))
 
-        self.dispatchBranch(block.branch)
+        self.dispatchBranch(block, block.branch)
 
-    def dispatchBranch(self, branch):
+    def dispatchBranch(self, prevblock, branch):
         method = getattr(self, "gen_" + branch.__class__.__name__)
-        method(branch)
+        method(prevblock, branch)
 
-    def gen_Branch(self, branch):
+    def gen_Branch(self, prevblock, branch):
         _str = self._str
         block = branch.target
-        sourceargs = [_str(arg) for arg in branch.args]       
-        targetargs = [arg.pseudoname for arg in block.input_args]
+        sourceargs = [_str(arg, prevblock) for arg in branch.args]       
+        targetargs = [_str(arg, branch.target) for arg in block.input_args]
         assert(len(sourceargs) == len(targetargs))
         if sourceargs and sourceargs != targetargs: 
             self.putline("%s = %s" % (", ".join(targetargs), ", ".join(sourceargs)))
 
         self.gen_BasicBlock(block)    
 
-    def gen_EndBranch(self, branch):
-        self.putline("return %s" % self._str(branch.returnvalue))
+    def gen_EndBranch(self, prevblock, branch):
+        self.putline("return %s" % self._str(branch.returnvalue, prevblock))
  
-    def gen_ConditionalBranch(self, branch):
-        self.putline("if %s:" % self._str(branch.condition))
+    def gen_ConditionalBranch(self, prevblock, branch):
+        self.putline("if %s:" % self._str(branch.condition, prevblock))
         self.indent += 1
-        self.dispatchBranch(branch.ifbranch)
+        self.dispatchBranch(prevblock, branch.ifbranch)
         self.indent -= 1
         self.putline("else:")
         self.indent += 1
-        self.dispatchBranch(branch.elsebranch)
+        self.dispatchBranch(prevblock, branch.elsebranch)
         self.indent -= 1
 

Added: pypy/trunk/src/pypy/translator/test/test_typedpyrex.py
==============================================================================
--- (empty file)
+++ pypy/trunk/src/pypy/translator/test/test_typedpyrex.py	Wed Oct  1 19:25:17 2003
@@ -0,0 +1,100 @@
+
+import autopath
+from pypy.tool import test
+from pypy.tool.udir import udir
+from pypy.translator.genpyrex import GenPyrex
+from pypy.translator.flowmodel import *
+from pypy.translator.test.buildpyxmodule import make_module_from_pyxstring
+
+make_dot = 1
+
+if make_dot: 
+    from pypy.translator.test.make_dot import make_dot
+else:
+    def make_dot(*args): pass
+
+class TestCase(test.IntTestCase):
+    def setUp(self):
+        self.space = test.objspace('flow')
+
+    def make_cfunc(self, func, input_arg_types):
+        """ make a pyrex-generated cfunction from the given func """
+        import inspect
+        try:
+            func = func.im_func
+        except AttributeError:
+            pass
+        name = func.func_name
+        funcgraph = self.space.build_flow(func)
+        funcgraph.source = inspect.getsource(func)
+        genpyrex = GenPyrex(funcgraph)
+        genpyrex.annotate(input_arg_types)
+        result = genpyrex.emitcode()
+        make_dot(funcgraph, udir, 'ps')
+        mod = make_module_from_pyxstring(name, udir, result)
+        return getattr(mod, name)
+
+    #____________________________________________________
+    def simple_func(i):
+        return i+1
+
+    def test_simple_func(self):
+        cfunc = self.make_cfunc(self.simple_func, [int])
+        self.assertEquals(cfunc(1), 2)
+
+    #____________________________________________________
+    def while_func(i):
+        total = 0
+        while i > 0:
+            total = total + i
+            i = i - 1
+        return total
+
+    def test_while_func(self):
+        while_func = self.make_cfunc(self.while_func, [int])
+        self.assertEquals(while_func(10), 55)
+
+    #____________________________________________________
+    def yast(lst):
+        total = 0
+        for z in lst:
+            total += z
+        return total
+
+    def dont_test_yast(self):
+        yast = self.make_cfunc(self.yast, [list])
+        self.assertEquals(yast(range(11)), 66)
+
+    #____________________________________________________
+    def nested_whiles(i, j):
+        s = ''
+        z = 5
+        while z > 0:
+            z = z - 1
+            u = i
+            while u < j:
+                u = u + 1
+                s = s + '.'
+            s = s + '!'
+        return s
+
+    def test_nested_whiles(self):
+        nested_whiles = self.make_cfunc(self.nested_whiles, [int, int])
+        self.assertEquals(nested_whiles(111, 114),
+                          '...!...!...!...!...!')
+
+    #____________________________________________________
+    def poor_man_range(i):
+        lst = []
+        while i > 0:
+            i = i - 1
+            lst.append(i)
+        lst.reverse()
+        return lst
+
+    def dont_yet_test_poor_man_range(self):
+        poor_man_range = self.make_cfunc(self.poor_man_range, [int])
+        self.assertEquals(poor_man_range(10), range(10))
+
+if __name__ == '__main__':
+    test.main()


More information about the Pypy-commit mailing list