[pypy-svn] r61737 - in pypy/branch/pyjitpl5/pypy/jit/metainterp: . test

fijal at codespeak.net fijal at codespeak.net
Wed Feb 11 16:40:34 CET 2009


Author: fijal
Date: Wed Feb 11 16:40:30 2009
New Revision: 61737

Modified:
   pypy/branch/pyjitpl5/pypy/jit/metainterp/optimize.py
   pypy/branch/pyjitpl5/pypy/jit/metainterp/test/test_virtual.py
Log:
(arigo, fijal)
Fix a couple of bugs to pass most of test_virtual


Modified: pypy/branch/pyjitpl5/pypy/jit/metainterp/optimize.py
==============================================================================
--- pypy/branch/pyjitpl5/pypy/jit/metainterp/optimize.py	(original)
+++ pypy/branch/pyjitpl5/pypy/jit/metainterp/optimize.py	Wed Feb 11 16:40:30 2009
@@ -38,8 +38,12 @@
             for key, value in self.fields.items():
                 if key not in other.fields:
                     return False
-                if not value.equals(other.fields[key]):
-                    return False
+                if value is None:
+                    if other.fields[key] is not None:
+                        return False
+                else:
+                    if not value.equals(other.fields[key]):
+                        return False
             return True
 
     def matches(self, instnode):
@@ -48,7 +52,7 @@
         for key, value in self.fields.items():
             if key not in instnode.curfields:
                 return False
-            if not value.matches(instnode.curfields[key]):
+            if value is not None and not value.matches(instnode.curfields[key]):
                 return False
         return True
 
@@ -151,10 +155,20 @@
                 return None
             return FixedClassSpecNode(known_class)
         fields = {}
-        for ofs, node in self.origfields.items():
-            if ofs in other.curfields:
-                specnode = node.intersect(other.curfields[ofs])
+        for ofs, node in other.curfields.items():
+            if ofs in self.origfields:
+                specnode = self.origfields[ofs].intersect(node)
                 fields[ofs] = specnode
+            else:
+                fields[ofs] = None
+                self.origfields[ofs] = InstanceNode(node.source.clonebox())
+        
+#         for ofs, node in self.origfields.items():
+#             if ofs in other.curfields:
+#                 specnode = node.intersect(other.curfields[ofs])
+#                 fields[ofs] = specnode
+#             else:
+#                fields[ofs] = None
         return VirtualInstanceSpecNode(known_class, fields)
 
     def adapt_to(self, specnode):
@@ -291,7 +305,8 @@
             if isinstance(specnode, VirtualInstanceSpecNode):
                 curfields = {}
                 for ofs, subspecnode in specnode.fields.items():
-                    subinstnode = instnode.origfields[ofs]   # should be there
+                    subinstnode = instnode.origfields[ofs]
+                    # should really be there
                     self.mutate_nodes(subinstnode, subspecnode)
                     curfields[ofs] = subinstnode
                 instnode.curfields = curfields

Modified: pypy/branch/pyjitpl5/pypy/jit/metainterp/test/test_virtual.py
==============================================================================
--- pypy/branch/pyjitpl5/pypy/jit/metainterp/test/test_virtual.py	(original)
+++ pypy/branch/pyjitpl5/pypy/jit/metainterp/test/test_virtual.py	Wed Feb 11 16:40:30 2009
@@ -1,45 +1,9 @@
 import py
-py.test.skip("look later")
 from pypy.rlib.jit import JitDriver, hint
-from pypy.jit.hintannotator.policy import StopAtXPolicy
-from pyjitpl import oo_meta_interp, get_stats
-from test.test_basic import LLJitMixin, OOJitMixin
+from pypy.jit.metainterp.policy import StopAtXPolicy
+from pypy.jit.metainterp.test.test_basic import LLJitMixin, OOJitMixin
 from pypy.rpython.lltypesystem import lltype, rclass
-
-from vinst import find_sorted_list
-import heaptracker
-
-
-##def test_heaptracker():
-##    def f():
-##        n = lltype.malloc(NODE)
-##        n.value = 42
-##        return n
-##    res = oo_meta_interp(f, [])
-##    assert lltype.typeOf(res) is NODE
-##    assert res.value == 42
-##    assert get_stats().heaptracker.known_unescaped(res)
-
-def test_find_sorted_list():
-    assert find_sorted_list(5, []) == 0
-    assert find_sorted_list(4, [(5,)]) == 0
-    assert find_sorted_list(5, [(5,)]) == -1
-    assert find_sorted_list(6, [(5,)]) == 1
-    assert find_sorted_list(2,   [(5,), (6,), (7,)]) == 0
-    assert find_sorted_list(5.2, [(5,), (6,), (7,)]) == 1
-    assert find_sorted_list(6.5, [(5,), (6,), (7,)]) == 2
-    assert find_sorted_list(11,  [(5,), (6,), (7,)]) == 3
-    assert find_sorted_list(5, [(5,), (6,), (7,)]) == -3
-    assert find_sorted_list(6, [(5,), (6,), (7,)]) == -2
-    assert find_sorted_list(7, [(5,), (6,), (7,)]) == -1
-    lst = [(j,) for j in range(1, 50, 2)]
-    for i in range(50):
-        res = find_sorted_list(i, lst)
-        if i % 2 == 0:
-            assert res == i // 2
-        else:
-            assert res == (i // 2) - (50 // 2)
-
+from pypy.jit.metainterp import heaptracker
 
 class VirtualTests:
     def _freeze_(self):
@@ -52,6 +16,7 @@
             node.value = 0
             node.extra = 0
             while n > 0:
+                myjitdriver.can_enter_jit(n=n, node=node)
                 myjitdriver.jit_merge_point(n=n, node=node)
                 next = self._new()
                 next.value = node.value + n
@@ -60,12 +25,12 @@
                 n -= 1
             return node.value * node.extra
         assert f(10) == 55 * 10
-        res = self.meta_interp(f, [10], exceptions=False)
+        res = self.meta_interp(f, [10])
         assert res == 55 * 10
-        assert len(get_stats().loops) == 1
-        get_stats().check_loops(new=0, new_with_vtable=0,
-                                getfield_int=0, getfield_ptr=0,
-                                setfield_int=0, setfield_ptr=0)
+        self.check_loop_count(1)
+        self.check_loops(new=0, new_with_vtable=0,
+                                getfield_gc__4=0, getfield_gc_ptr=0,
+                                setfield_gc__4=0, setfield_gc_ptr=0)
 
     def test_virtualized_2(self):
         myjitdriver = JitDriver(greens = [], reds = ['n', 'node'])
@@ -74,6 +39,7 @@
             node.value = 0
             node.extra = 0
             while n > 0:
+                myjitdriver.can_enter_jit(n=n, node=node)
                 myjitdriver.jit_merge_point(n=n, node=node)
                 next = self._new()
                 next.value = node.value
@@ -85,12 +51,12 @@
                 node = next
                 n -= 1
             return node.value * node.extra
-        res = self.meta_interp(f, [10], exceptions=False)
+        res = self.meta_interp(f, [10])
         assert res == 55 * 30
-        assert len(get_stats().loops) == 1
-        get_stats().check_loops(new=0, new_with_vtable=0,
-                                getfield_int=0, getfield_ptr=0,
-                                setfield_int=0, setfield_ptr=0)
+        self.check_loop_count(1)
+        self.check_loops(new=0, new_with_vtable=0,
+                                getfield_gc__4=0, getfield_gc_ptr=0,
+                                setfield_gc__4=0, setfield_gc_ptr=0)
 
     def test_nonvirtual_obj_delays_loop(self):
         myjitdriver = JitDriver(greens = [], reds = ['n', 'node'])
@@ -99,6 +65,7 @@
         def f(n):
             node = node0
             while True:
+                myjitdriver.can_enter_jit(n=n, node=node)
                 myjitdriver.jit_merge_point(n=n, node=node)
                 i = node.value
                 if i >= n:
@@ -106,17 +73,12 @@
                 node = self._new()
                 node.value = i * 2
             return node.value
-        res = self.meta_interp(f, [500], exceptions=False)
+        res = self.meta_interp(f, [500])
         assert res == 640
-        # The only way to make an efficient loop (in which the node is
-        # virtual) is to keep the first iteration out of the residual loop's
-        # body.  Indeed, the initial value 'node0' cannot be passed inside
-        # the loop as a virtual.  It's hard to test that this is what occurred,
-        # though.
-        assert len(get_stats().loops) == 1
-        get_stats().check_loops(new=0, new_with_vtable=0,
-                                getfield_int=0, getfield_ptr=0,
-                                setfield_int=0, setfield_ptr=0)
+        self.check_loop_count(1)
+        self.check_loops(new=0, new_with_vtable=0,
+                                getfield_gc__4=0, getfield_gc_ptr=0,
+                                setfield_gc__4=0, setfield_gc_ptr=0)
 
     def test_two_loops_with_virtual(self):
         myjitdriver = JitDriver(greens = [], reds = ['n', 'node'])
@@ -125,6 +87,7 @@
             node.value = 0
             node.extra = 0
             while n > 0:
+                myjitdriver.can_enter_jit(n=n, node=node)
                 myjitdriver.jit_merge_point(n=n, node=node)
                 next = self._new()
                 next.value = node.value + n
@@ -135,12 +98,12 @@
                 node = next
                 n -= 1
             return node.value
-        res = self.meta_interp(f, [10], exceptions=False)
+        res = self.meta_interp(f, [10])
         assert res == 255
-        assert len(get_stats().loops) == 2
-        get_stats().check_loops(new=0, new_with_vtable=0,
-                                getfield_int=0, getfield_ptr=0,
-                                setfield_int=0, setfield_ptr=0)
+        self.check_loop_count(2)
+        self.check_loops(new=0, new_with_vtable=0,
+                                getfield_gc__4=0, getfield_gc_ptr=0,
+                                setfield_gc__4=0, setfield_gc_ptr=0)
 
     def test_two_loops_with_escaping_virtual(self):
         myjitdriver = JitDriver(greens = [], reds = ['n', 'node'])
@@ -151,6 +114,7 @@
             node.value = 0
             node.extra = 0
             while n > 0:
+                myjitdriver.can_enter_jit(n=n, node=node)
                 myjitdriver.jit_merge_point(n=n, node=node)
                 next = self._new()
                 next.value = node.value + n
@@ -161,32 +125,14 @@
                 node = next
                 n -= 1
             return node.value
-        res = self.meta_interp(f, [10], policy=StopAtXPolicy(externfn),
-                                        exceptions=False)
+        res = self.meta_interp(f, [10], policy=StopAtXPolicy(externfn))
         assert res == f(10)
-        assert len(get_stats().loops) == 2
-        get_stats().check_loops(**{self._new_op: 1})
-        get_stats().check_loops(int_mul=0, call__4=1)
-
-    def test_virtual_if_unescaped_so_far(self):
-        class Foo(object):
-            def __init__(self, x, y):
-                self.x = x
-                self.y = y
-
-        def f(n):
-            foo = Foo(n, 0)
-            while foo.x > 0:
-                foo.y += foo.x
-                foo.x -= 1
-            return foo.y
-
-        res = self.meta_interp(f, [10], exceptions=False)
-        assert res == 55
-        py.test.skip("unsure yet if we want to be clever about this")
-        get_stats().check_loops(getfield_int=0, setfield_int=0)
+        self.check_loop_count(2)
+        self.check_loops(**{self._new_op: 1})
+        self.check_loops(int_mul=0, call__4=1)
 
     def test_two_virtuals(self):
+        myjitdriver = JitDriver(greens = [], reds = ['n', 'prev'])
         class Foo(object):
             def __init__(self, x, y):
                 self.x = x
@@ -196,20 +142,26 @@
             prev = Foo(n, 0)
             n -= 1
             while n >= 0:
+                myjitdriver.can_enter_jit(n=n, prev=prev)
+                myjitdriver.jit_merge_point(n=n, prev=prev)
                 foo = Foo(n, 0)
                 foo.x += prev.x
                 prev = foo
                 n -= 1
             return prev.x
 
-        res = self.meta_interp(f, [12], exceptions=False)
+        res = self.meta_interp(f, [12])
         assert res == 78
+        self.check_loops(new_with_vtable=0, new=0)
 
     def test_both_virtual_and_field_variable(self):
+        myjitdriver = JitDriver(greens = [], reds = ['n'])
         class Foo(object):
             pass
         def f(n):
             while n >= 0:
+                myjitdriver.can_enter_jit(n=n)
+                myjitdriver.jit_merge_point(n=n)
                 foo = Foo()
                 foo.n = n
                 if n < 10:
@@ -217,7 +169,7 @@
                 n = foo.n - 1
             return n
 
-        res = self.meta_interp(f, [20], exceptions=False)
+        res = self.meta_interp(f, [20])
         assert res == 9
 
 



More information about the Pypy-commit mailing list