[pypy-svn] r20566 - in pypy/dist/pypy/rpython: . test

arigo at codespeak.net arigo at codespeak.net
Fri Dec 2 11:43:45 CET 2005


Author: arigo
Date: Fri Dec  2 11:43:44 2005
New Revision: 20566

Modified:
   pypy/dist/pypy/rpython/rpbc.py
   pypy/dist/pypy/rpython/test/test_rpbc.py
Log:
(mwh, arre-and-ericvrp-overlooking, arigo)

Implemented multiple-specializations, multiple-functions PBCs.



Modified: pypy/dist/pypy/rpython/rpbc.py
==============================================================================
--- pypy/dist/pypy/rpython/rpbc.py	(original)
+++ pypy/dist/pypy/rpython/rpbc.py	Fri Dec  2 11:43:44 2005
@@ -5,7 +5,7 @@
 from pypy.annotation import description
 from pypy.objspace.flow.model import Constant
 from pypy.rpython.lltypesystem.lltype import \
-     typeOf, Void, Bool, nullptr, frozendict
+     typeOf, Void, Bool, nullptr, frozendict, Ptr, Struct, malloc
 from pypy.rpython.error import TyperError
 from pypy.rpython.rmodel import Repr, inputconst, HalfConcreteWrapper
 from pypy.rpython import rclass
@@ -91,6 +91,8 @@
         # a 'matching' row is one that has the same llfn, expect
         # that it may have more or less 'holes'
         for existingindex, existingrow in enumerate(uniquerows):
+            if row.fntype != existingrow.fntype:
+                continue   # not the same pointer type, cannot match
             for funcdesc, llfn in row.items():
                 if funcdesc in existingrow:
                     if llfn != existingrow[funcdesc]:
@@ -99,6 +101,7 @@
                 # potential match, unless the two rows have no common funcdesc
                 merged = ConcreteCallTableRow(row)
                 merged.update(existingrow)
+                merged.fntype = row.fntype
                 if len(merged) == len(row) + len(existingrow):
                     pass   # no common funcdesc, not a match
                 else:
@@ -125,14 +128,19 @@
             for funcdesc, graph in row.items():
                 llfn = rtyper.getcallable(graph)
                 concreterow[funcdesc] = llfn
+            assert len(concreterow) > 0
+            concreterow.fntype = typeOf(llfn)   # 'llfn' from the loop above
+                                         # (they should all have the same type)
             concreterows[shape, index] = concreterow
 
     for row in concreterows.values():
         addrow(row)
 
     for (shape, index), row in concreterows.items():
-        _, biggerrow = lookuprow(row)
-        concretetable[shape, index] = biggerrow
+        existingindex, biggerrow = lookuprow(row)
+        row = uniquerows[existingindex]
+        assert biggerrow == row   # otherwise, addrow() is broken
+        concretetable[shape, index] = row
 
     for finalindex, row in enumerate(uniquerows):
         row.attrname = 'variant%d' % finalindex
@@ -174,10 +182,16 @@
             self.uniquerows = uniquerows
             if len(uniquerows) == 1:
                 row = uniquerows[0]
-                examplellfn = row.itervalues().next()
-                self.lowleveltype = typeOf(examplellfn)
+                self.lowleveltype = row.fntype
             else:
-                XXX_later
+                # several functions, each with several specialized variants.
+                # each function becomes a pointer to a Struct containing
+                # pointers to its variants.
+                fields = []
+                for row in uniquerows:
+                    fields.append((row.attrname, row.fntype))
+                self.lowleveltype = Ptr(Struct('specfunc', *fields))
+        self.funccache = {}
 
     def get_s_callable(self):
         return self.s_pbc
@@ -201,25 +215,36 @@
 
     def convert_desc(self, funcdesc):
         # get the whole "column" of the call table corresponding to this desc
+        try:
+            return self.funccache[funcdesc]
+        except KeyError:
+            pass
         if self.lowleveltype is Void:
-            return HalfConcreteWrapper(self.get_unique_llfn)
-        llfns = {}
-        found_anything = False
-        for row in self.uniquerows:
-            if funcdesc in row:
-                llfn = row[funcdesc]
-                found_anything = True
-            else:
-                null = self.rtyper.type_system.null_callable(self.lowleveltype)
-                llfn = null
-            llfns[row.attrname] = llfn
-        if not found_anything:
-            raise TyperError("%r not in %r" % (funcdesc,
-                                               self.s_pbc.descriptions))
-        if len(self.uniquerows) == 1:
-            return llfn   # from the loop above
+            result = HalfConcreteWrapper(self.get_unique_llfn)
         else:
-            XXX_later
+            llfns = {}
+            found_anything = False
+            for row in self.uniquerows:
+                if funcdesc in row:
+                    llfn = row[funcdesc]
+                    found_anything = True
+                else:
+                    # missing entry -- need a 'null' of the type that matches
+                    # this row
+                    llfn = self.rtyper.type_system.null_callable(row.fntype)
+                llfns[row.attrname] = llfn
+            if not found_anything:
+                raise TyperError("%r not in %r" % (funcdesc,
+                                                   self.s_pbc.descriptions))
+            if len(self.uniquerows) == 1:
+                result = llfn   # from the loop above
+            else:
+                # build a Struct with all the values collected in 'llfns'
+                result = malloc(self.lowleveltype.TO, immortal=True)
+                for attrname, llfn in llfns.items():
+                    setattr(result, attrname, llfn)
+        self.funccache[funcdesc] = result
+        return result
 
     def convert_const(self, value):
         if isinstance(value, types.MethodType) and value.im_self is None:
@@ -237,6 +262,7 @@
         low-level function.  In case the call table contains multiple rows,
         'index' and 'shape' tells which of its items we are interested in.
         """
+        assert v.concretetype == self.lowleveltype
         if self.lowleveltype is Void:
             assert len(self.s_pbc.descriptions) == 1
                                       # lowleveltype wouldn't be Void otherwise
@@ -248,7 +274,10 @@
         elif len(self.uniquerows) == 1:
             return v
         else:
-            XXX_later
+            # 'v' is a Struct pointer, read the corresponding field
+            row = self.concretetable[shape, index]
+            cname = inputconst(Void, row.attrname)
+            return llop.genop('getfield', [v, cname], resulttype = row.fntype)
 
     def get_unique_llfn(self):
         # try to build a unique low-level function.  Avoid to use

Modified: pypy/dist/pypy/rpython/test/test_rpbc.py
==============================================================================
--- pypy/dist/pypy/rpython/test/test_rpbc.py	(original)
+++ pypy/dist/pypy/rpython/test/test_rpbc.py	Fri Dec  2 11:43:44 2005
@@ -1169,3 +1169,52 @@
         assert type(res.item1) is float
         assert res.item0 == f(i)[0]
         assert res.item1 == f(i)[1]
+
+def test_multiple_specialized_functions():
+    def myadder(x, y):   # int,int->int or str,str->str
+        return x+y
+    def myfirst(x, y):   # int,int->int or str,str->str
+        return x
+    def mysecond(x, y):  # int,int->int or str,str->str
+        return y
+    myadder._annspecialcase_ = 'specialize:argtype0'
+    myfirst._annspecialcase_ = 'specialize:argtype0'
+    mysecond._annspecialcase_ = 'specialize:argtype0'
+    def f(i):
+        if i == 0:
+            g = myfirst
+        elif i == 1:
+            g = mysecond
+        else:
+            g = myadder
+        s = g("hel", "lo")
+        n = g(40, 2)
+        return len(s) * n
+    for i in range(3):
+        res = interpret(f, [i])
+        assert res == f(i)
+
+def test_specialized_method_of_frozen():
+    class space:
+        def __init__(self, tag):
+            self.tag = tag
+        def wrap(self, x):
+            if isinstance(x, int):
+                return self.tag + '< %d >' % x
+            else:
+                return self.tag + x
+        wrap._annspecialcase_ = 'specialize:argtype1'
+    space1 = space("tag1:")
+    space2 = space("tag2:")
+    def f(i):
+        if i == 1:
+            sp = space1
+        else:
+            sp = space2
+        w1 = sp.wrap('hello')
+        w2 = sp.wrap(42)
+        return w1 + w2
+    res = interpret(f, [1])
+    assert ''.join(res.chars) == 'tag1:hellotag1:< 42 >'
+    res = interpret(f, [0])
+    assert ''.join(res.chars) == 'tag2:hellotag2:< 42 >'



More information about the Pypy-commit mailing list