[pypy-svn] pypy default: remove some oopspecs in rdict to make the JIT trace the hash functions in

cfbolz commits-noreply at bitbucket.org
Fri Mar 25 12:47:49 CET 2011


Author: Carl Friedrich Bolz <cfbolz at gmx.de>
Branch: 
Changeset: r42933:ba5a9e3972e8
Date: 2011-03-25 12:47 +0100
http://bitbucket.org/pypy/pypy/changeset/ba5a9e3972e8/

Log:	remove some oopspecs in rdict to make the JIT trace the hash
	functions in dicts. this makes it necessary to hide some interior
	field manipulation in a helper function.

diff --git a/pypy/jit/codewriter/support.py b/pypy/jit/codewriter/support.py
--- a/pypy/jit/codewriter/support.py
+++ b/pypy/jit/codewriter/support.py
@@ -399,12 +399,7 @@
         return ll_rdict.ll_newdict(DICT)
     _ll_0_newdict.need_result_type = True
 
-    _ll_2_dict_getitem = ll_rdict.ll_dict_getitem
-    _ll_3_dict_setitem = ll_rdict.ll_dict_setitem
     _ll_2_dict_delitem = ll_rdict.ll_dict_delitem
-    _ll_3_dict_setdefault = ll_rdict.ll_setdefault
-    _ll_2_dict_contains = ll_rdict.ll_contains
-    _ll_3_dict_get = ll_rdict.ll_get
     _ll_1_dict_copy = ll_rdict.ll_copy
     _ll_1_dict_clear = ll_rdict.ll_clear
     _ll_2_dict_update = ll_rdict.ll_update

diff --git a/pypy/rpython/lltypesystem/rdict.py b/pypy/rpython/lltypesystem/rdict.py
--- a/pypy/rpython/lltypesystem/rdict.py
+++ b/pypy/rpython/lltypesystem/rdict.py
@@ -7,7 +7,7 @@
 from pypy.rlib.rarithmetic import r_uint, intmask, LONG_BIT
 from pypy.rlib.objectmodel import hlinvoke
 from pypy.rpython import robject
-from pypy.rlib import objectmodel
+from pypy.rlib import objectmodel, jit
 from pypy.rpython import rmodel
 
 HIGHEST_BIT = intmask(1 << (LONG_BIT - 1))
@@ -408,6 +408,10 @@
     ENTRIES = lltype.typeOf(entries).TO
     return ENTRIES.fasthashfn(entries[i].key)
 
+ at jit.dont_look_inside
+def ll_get_value(d, i):
+    return d.entries[i].value
+
 def ll_keyhash_custom(d, key):
     DICT = lltype.typeOf(d).TO
     return hlinvoke(DICT.r_rdict_hashfn, d.fnkeyhash, key)
@@ -426,17 +430,16 @@
 def ll_dict_getitem(d, key):
     i = ll_dict_lookup(d, key, d.keyhash(key))
     if not i & HIGHEST_BIT:
-        return d.entries[i].value
+        return ll_get_value(d, i)
     else:
         raise KeyError
-ll_dict_getitem.oopspec = 'dict.getitem(d, key)'
 
 def ll_dict_setitem(d, key, value):
     hash = d.keyhash(key)
     i = ll_dict_lookup(d, key, hash)
     return _ll_dict_setitem_lookup_done(d, key, value, hash, i)
-ll_dict_setitem.oopspec = 'dict.setitem(d, key, value)'
 
+ at jit.dont_look_inside
 def _ll_dict_setitem_lookup_done(d, key, value, hash, i):
     valid = (i & HIGHEST_BIT) == 0
     i = i & MASK
@@ -717,23 +720,19 @@
 
 def ll_get(dict, key, default):
     i = ll_dict_lookup(dict, key, dict.keyhash(key))
-    entries = dict.entries
     if not i & HIGHEST_BIT:
-        return entries[i].value
+        return ll_get_value(dict, i)
     else:
         return default
-ll_get.oopspec = 'dict.get(dict, key, default)'
 
 def ll_setdefault(dict, key, default):
     hash = dict.keyhash(key)
     i = ll_dict_lookup(dict, key, hash)
-    entries = dict.entries
     if not i & HIGHEST_BIT:
-        return entries[i].value
+        return ll_get_value(dict, i)
     else:
         _ll_dict_setitem_lookup_done(dict, key, default, hash, i)
         return default
-ll_setdefault.oopspec = 'dict.setdefault(dict, key, default)'
 
 def ll_copy(dict):
     DICT = lltype.typeOf(dict).TO
@@ -829,7 +828,6 @@
 def ll_contains(d, key):
     i = ll_dict_lookup(d, key, d.keyhash(key))
     return not i & HIGHEST_BIT
-ll_contains.oopspec = 'dict.contains(d, key)'
 
 POPITEMINDEX = lltype.Struct('PopItemIndex', ('nextindex', lltype.Signed))
 global_popitem_index = lltype.malloc(POPITEMINDEX, zero=True, immortal=True)

diff --git a/pypy/jit/metainterp/test/test_dict.py b/pypy/jit/metainterp/test/test_dict.py
--- a/pypy/jit/metainterp/test/test_dict.py
+++ b/pypy/jit/metainterp/test/test_dict.py
@@ -1,6 +1,7 @@
 import py
 from pypy.jit.metainterp.test.test_basic import LLJitMixin, OOJitMixin
 from pypy.rlib.jit import JitDriver
+from pypy.rlib import objectmodel
 
 class DictTests:
 
@@ -69,6 +70,66 @@
             res = self.meta_interp(f, [10], listops=True)
             assert res == expected
 
+    def test_dict_trace_hash(self):
+        myjitdriver = JitDriver(greens = [], reds = ['total', 'dct'])
+        def key(x):
+            return x % 2
+        def eq(x, y):
+            return (x % 2) == (y % 2)
+
+        def f(n):
+            dct = objectmodel.r_dict(eq, key)
+            total = n
+            while total:
+                myjitdriver.jit_merge_point(total=total, dct=dct)
+                if total not in dct:
+                    dct[total] = []
+                dct[total].append(total)
+                total -= 1
+            return len(dct[0])
+
+        res1 = f(100)
+        res2 = self.meta_interp(f, [100], listops=True)
+        assert res1 == res2
+        self.check_loops(int_mod=1) # the hash was traced
+
+    def test_dict_setdefault(self):
+        myjitdriver = JitDriver(greens = [], reds = ['total', 'dct'])
+        def f(n):
+            dct = {}
+            total = n
+            while total:
+                myjitdriver.jit_merge_point(total=total, dct=dct)
+                dct.setdefault(total % 2, []).append(total)
+                total -= 1
+            return len(dct[0])
+
+        assert f(100) == 50
+        res = self.meta_interp(f, [100], listops=True)
+        assert res == 50
+        self.check_loops(new=0, new_with_vtable=0)
+
+    def test_dict_as_counter(self):
+        myjitdriver = JitDriver(greens = [], reds = ['total', 'dct'])
+        def key(x):
+            return x % 2
+        def eq(x, y):
+            return (x % 2) == (y % 2)
+
+        def f(n):
+            dct = objectmodel.r_dict(eq, key)
+            total = n
+            while total:
+                myjitdriver.jit_merge_point(total=total, dct=dct)
+                dct[total] = dct.get(total, 0) + 1
+                total -= 1
+            return dct[0]
+
+        assert f(100) == 50
+        res = self.meta_interp(f, [100], listops=True)
+        assert res == 50
+        self.check_loops(int_mod=1)
+
 
 class TestOOtype(DictTests, OOJitMixin):
     pass


More information about the Pypy-commit mailing list