[pypy-svn] pypy default: remove some oopspecs in rdict to make the JIT trace the hash functions in
cfbolz
commits-noreply at bitbucket.org
Fri Mar 25 12:47:49 CET 2011
Author: Carl Friedrich Bolz <cfbolz at gmx.de>
Branch:
Changeset: r42933:ba5a9e3972e8
Date: 2011-03-25 12:47 +0100
http://bitbucket.org/pypy/pypy/changeset/ba5a9e3972e8/
Log: remove some oopspecs in rdict to make the JIT trace the hash
functions in dicts. this makes it necessary to hide some interior
field manipulation in a helper function.
diff --git a/pypy/jit/codewriter/support.py b/pypy/jit/codewriter/support.py
--- a/pypy/jit/codewriter/support.py
+++ b/pypy/jit/codewriter/support.py
@@ -399,12 +399,7 @@
return ll_rdict.ll_newdict(DICT)
_ll_0_newdict.need_result_type = True
- _ll_2_dict_getitem = ll_rdict.ll_dict_getitem
- _ll_3_dict_setitem = ll_rdict.ll_dict_setitem
_ll_2_dict_delitem = ll_rdict.ll_dict_delitem
- _ll_3_dict_setdefault = ll_rdict.ll_setdefault
- _ll_2_dict_contains = ll_rdict.ll_contains
- _ll_3_dict_get = ll_rdict.ll_get
_ll_1_dict_copy = ll_rdict.ll_copy
_ll_1_dict_clear = ll_rdict.ll_clear
_ll_2_dict_update = ll_rdict.ll_update
diff --git a/pypy/rpython/lltypesystem/rdict.py b/pypy/rpython/lltypesystem/rdict.py
--- a/pypy/rpython/lltypesystem/rdict.py
+++ b/pypy/rpython/lltypesystem/rdict.py
@@ -7,7 +7,7 @@
from pypy.rlib.rarithmetic import r_uint, intmask, LONG_BIT
from pypy.rlib.objectmodel import hlinvoke
from pypy.rpython import robject
-from pypy.rlib import objectmodel
+from pypy.rlib import objectmodel, jit
from pypy.rpython import rmodel
HIGHEST_BIT = intmask(1 << (LONG_BIT - 1))
@@ -408,6 +408,10 @@
ENTRIES = lltype.typeOf(entries).TO
return ENTRIES.fasthashfn(entries[i].key)
+ at jit.dont_look_inside
+def ll_get_value(d, i):
+ return d.entries[i].value
+
def ll_keyhash_custom(d, key):
DICT = lltype.typeOf(d).TO
return hlinvoke(DICT.r_rdict_hashfn, d.fnkeyhash, key)
@@ -426,17 +430,16 @@
def ll_dict_getitem(d, key):
i = ll_dict_lookup(d, key, d.keyhash(key))
if not i & HIGHEST_BIT:
- return d.entries[i].value
+ return ll_get_value(d, i)
else:
raise KeyError
-ll_dict_getitem.oopspec = 'dict.getitem(d, key)'
def ll_dict_setitem(d, key, value):
hash = d.keyhash(key)
i = ll_dict_lookup(d, key, hash)
return _ll_dict_setitem_lookup_done(d, key, value, hash, i)
-ll_dict_setitem.oopspec = 'dict.setitem(d, key, value)'
+ at jit.dont_look_inside
def _ll_dict_setitem_lookup_done(d, key, value, hash, i):
valid = (i & HIGHEST_BIT) == 0
i = i & MASK
@@ -717,23 +720,19 @@
def ll_get(dict, key, default):
i = ll_dict_lookup(dict, key, dict.keyhash(key))
- entries = dict.entries
if not i & HIGHEST_BIT:
- return entries[i].value
+ return ll_get_value(dict, i)
else:
return default
-ll_get.oopspec = 'dict.get(dict, key, default)'
def ll_setdefault(dict, key, default):
hash = dict.keyhash(key)
i = ll_dict_lookup(dict, key, hash)
- entries = dict.entries
if not i & HIGHEST_BIT:
- return entries[i].value
+ return ll_get_value(dict, i)
else:
_ll_dict_setitem_lookup_done(dict, key, default, hash, i)
return default
-ll_setdefault.oopspec = 'dict.setdefault(dict, key, default)'
def ll_copy(dict):
DICT = lltype.typeOf(dict).TO
@@ -829,7 +828,6 @@
def ll_contains(d, key):
i = ll_dict_lookup(d, key, d.keyhash(key))
return not i & HIGHEST_BIT
-ll_contains.oopspec = 'dict.contains(d, key)'
POPITEMINDEX = lltype.Struct('PopItemIndex', ('nextindex', lltype.Signed))
global_popitem_index = lltype.malloc(POPITEMINDEX, zero=True, immortal=True)
diff --git a/pypy/jit/metainterp/test/test_dict.py b/pypy/jit/metainterp/test/test_dict.py
--- a/pypy/jit/metainterp/test/test_dict.py
+++ b/pypy/jit/metainterp/test/test_dict.py
@@ -1,6 +1,7 @@
import py
from pypy.jit.metainterp.test.test_basic import LLJitMixin, OOJitMixin
from pypy.rlib.jit import JitDriver
+from pypy.rlib import objectmodel
class DictTests:
@@ -69,6 +70,66 @@
res = self.meta_interp(f, [10], listops=True)
assert res == expected
+ def test_dict_trace_hash(self):
+ myjitdriver = JitDriver(greens = [], reds = ['total', 'dct'])
+ def key(x):
+ return x % 2
+ def eq(x, y):
+ return (x % 2) == (y % 2)
+
+ def f(n):
+ dct = objectmodel.r_dict(eq, key)
+ total = n
+ while total:
+ myjitdriver.jit_merge_point(total=total, dct=dct)
+ if total not in dct:
+ dct[total] = []
+ dct[total].append(total)
+ total -= 1
+ return len(dct[0])
+
+ res1 = f(100)
+ res2 = self.meta_interp(f, [100], listops=True)
+ assert res1 == res2
+ self.check_loops(int_mod=1) # the hash was traced
+
+ def test_dict_setdefault(self):
+ myjitdriver = JitDriver(greens = [], reds = ['total', 'dct'])
+ def f(n):
+ dct = {}
+ total = n
+ while total:
+ myjitdriver.jit_merge_point(total=total, dct=dct)
+ dct.setdefault(total % 2, []).append(total)
+ total -= 1
+ return len(dct[0])
+
+ assert f(100) == 50
+ res = self.meta_interp(f, [100], listops=True)
+ assert res == 50
+ self.check_loops(new=0, new_with_vtable=0)
+
+ def test_dict_as_counter(self):
+ myjitdriver = JitDriver(greens = [], reds = ['total', 'dct'])
+ def key(x):
+ return x % 2
+ def eq(x, y):
+ return (x % 2) == (y % 2)
+
+ def f(n):
+ dct = objectmodel.r_dict(eq, key)
+ total = n
+ while total:
+ myjitdriver.jit_merge_point(total=total, dct=dct)
+ dct[total] = dct.get(total, 0) + 1
+ total -= 1
+ return dct[0]
+
+ assert f(100) == 50
+ res = self.meta_interp(f, [100], listops=True)
+ assert res == 50
+ self.check_loops(int_mod=1)
+
class TestOOtype(DictTests, OOJitMixin):
pass
More information about the Pypy-commit
mailing list