[pypy-svn] r15037 - in pypy/dist/pypy/rpython: . test

arigo at codespeak.net arigo at codespeak.net
Mon Jul 25 17:00:04 CEST 2005


Author: arigo
Date: Mon Jul 25 17:00:01 2005
New Revision: 15037

Modified:
   pypy/dist/pypy/rpython/rconstantdict.py
   pypy/dist/pypy/rpython/rdict.py
   pypy/dist/pypy/rpython/test/test_rconstantdict.py
   pypy/dist/pypy/rpython/test/test_rtuple.py
Log:
* Constant dicts with unicode chars as keys
* Fixed a bug in constant dicts (dicts of len 2 didn't get any unused entry)


Modified: pypy/dist/pypy/rpython/rconstantdict.py
==============================================================================
--- pypy/dist/pypy/rpython/rconstantdict.py	(original)
+++ pypy/dist/pypy/rpython/rconstantdict.py	Mon Jul 25 17:00:01 2005
@@ -54,7 +54,7 @@
         except KeyError:
             self.setup()
             dictlen = len(dictobj)
-            minentrylen = dictlen * 4 / 3
+            minentrylen = (dictlen * 4 + 2) / 3
             entrylen = 1
             while entrylen < minentrylen: 
                 entrylen *= 2
@@ -62,20 +62,21 @@
             self.dict_cache[key] = result
             r_key = self.key_repr
             r_value = self.value_repr
-            #hashcompute = self.get_key_hash_function()
+            hashcompute = self.get_key_hash_function()
             for dictkey, dictvalue in dictobj.items():
                 llkey = r_key.convert_const(dictkey)
                 llvalue = r_value.convert_const(dictvalue)
-                ll_constantdict_setnewitem(result, llkey, llvalue)#,hashcompute)
+                ll_constantdict_setnewitem(result, llkey, llvalue, hashcompute)
+            assert result.num_items < len(result.entries)
             return result
 
-##   def get_key_hash_function(self):
-##       if isinstance(self.key_repr, rmodel.IntegerRepr):
-##           return ll_hash_identity
-##       elif isinstance(self.key_repr, rmodel.CharRepr):
-##           return ll_hash_char
-##       else:
-##           raise TyperError("no easy hash function for %r" % (self.key_repr,))
+    def get_key_hash_function(self):
+        if isinstance(self.key_repr, rmodel.IntegerRepr):
+            return ll_hash_identity
+        elif isinstance(self.key_repr, rmodel.UniCharRepr):
+            return ll_hash_unichar
+        else:
+            raise TyperError("no easy hash function for %r" % (self.key_repr,))
 
     def rtype_len(self, hop):
         v_dict, = hop.inputargs(self)
@@ -87,23 +88,26 @@
     def rtype_method_get(self, hop):
         v_dict, v_key, v_default = hop.inputargs(self, self.key_repr,
                                                  self.value_repr)
-        return hop.gendirectcall(ll_constantdict_get, v_dict, v_key, v_default)
+        hashcompute = self.get_key_hash_function()
+        chashcompute = hop.inputconst(lltype.Void, hashcompute)
+        return hop.gendirectcall(ll_constantdict_get, v_dict, v_key, v_default,
+                                 chashcompute)
 
 class __extend__(pairtype(ConstantDictRepr, rmodel.Repr)): 
 
     def rtype_getitem((r_dict, r_key), hop):
         v_dict, v_key = hop.inputargs(r_dict, r_dict.key_repr) 
-        #hashcompute = r_dict.get_key_hash_function()
-        #chashcompute = hop.inputconst(lltype.Void, hashcompute)
-        return hop.gendirectcall(ll_constantdict_getitem, v_dict, v_key)
-                                 #chashcompute)
+        hashcompute = r_dict.get_key_hash_function()
+        chashcompute = hop.inputconst(lltype.Void, hashcompute)
+        return hop.gendirectcall(ll_constantdict_getitem, v_dict, v_key,
+                                 chashcompute)
 
     def rtype_contains((r_dict, r_key), hop):
         v_dict, v_key = hop.inputargs(r_dict, r_dict.key_repr)
-        #hashcompute = r_dict.get_key_hash_function()
-        #chashcompute = hop.inputconst(lltype.Void, hashcompute)
-        return hop.gendirectcall(ll_constantdict_contains, v_dict, v_key)
-                                 #chashcompute)
+        hashcompute = r_dict.get_key_hash_function()
+        chashcompute = hop.inputconst(lltype.Void, hashcompute)
+        return hop.gendirectcall(ll_constantdict_contains, v_dict, v_key,
+                                 chashcompute)
 
 # ____________________________________________________________
 #
@@ -114,26 +118,26 @@
 def ll_constantdict_len(d):
     return d.num_items 
 
-def ll_constantdict_getitem(d, key):#, hashcompute): 
-    entry = ll_constantdict_lookup(d, key)#, hashcompute)
+def ll_constantdict_getitem(d, key, hashcompute): 
+    entry = ll_constantdict_lookup(d, key, hashcompute)
     if entry.valid:
         return entry.value 
     else: 
         raise KeyError 
 
-def ll_constantdict_contains(d, key):#, hashcompute):
-    entry = ll_constantdict_lookup(d, key)#, hashcompute)
+def ll_constantdict_contains(d, key, hashcompute):
+    entry = ll_constantdict_lookup(d, key, hashcompute)
     return entry.valid
 
-def ll_constantdict_get(d, key, default):#, hashcompute):
-    entry = ll_constantdict_lookup(d, key)#, hashcompute)
+def ll_constantdict_get(d, key, default, hashcompute):
+    entry = ll_constantdict_lookup(d, key, hashcompute)
     if entry.valid:
         return entry.value
     else: 
         return default
 
-def ll_constantdict_setnewitem(d, key, value):#, hashcompute): 
-    entry = ll_constantdict_lookup(d, key)#, hashcompute)
+def ll_constantdict_setnewitem(d, key, value, hashcompute): 
+    entry = ll_constantdict_lookup(d, key, hashcompute)
     assert not entry.valid 
     entry.key = key
     entry.valid = True 
@@ -143,12 +147,12 @@
 # the below is a port of CPython's dictobject.c's lookdict implementation 
 PERTURB_SHIFT = 5
 
-def ll_constantdict_lookup(d, key):#, hashcompute): 
-    hash = key #hashcompute(key)
+def ll_constantdict_lookup(d, key, hashcompute): 
+    hash = r_uint(hashcompute(key))
     entries = d.entries
     mask = len(entries) - 1
-    perturb = r_uint(hash) 
-    i = r_uint(hash) 
+    perturb = hash
+    i = hash
     while 1: 
         entry = entries[i & mask]
         if not entry.valid: 
@@ -158,8 +162,8 @@
         perturb >>= PERTURB_SHIFT
         i = (i << 2) + i + perturb + 1
 
-##def ll_hash_identity(x): 
-##    return x
+def ll_hash_identity(x): 
+    return x
 
-##def ll_hash_char(x): 
-##    return ord(x) 
+def ll_hash_unichar(x): 
+    return ord(x) 

Modified: pypy/dist/pypy/rpython/rdict.py
==============================================================================
--- pypy/dist/pypy/rpython/rdict.py	(original)
+++ pypy/dist/pypy/rpython/rdict.py	Mon Jul 25 17:00:01 2005
@@ -35,7 +35,8 @@
             dictvalue = self.dictdef.dictvalue 
             return StrDictRepr(lambda: rtyper.getrepr(dictvalue.s_value), 
                                dictvalue)
-        elif isinstance(s_key, annmodel.SomeInteger):
+        elif isinstance(s_key, (annmodel.SomeInteger,
+                                annmodel.SomeUnicodeCodePoint)):
             dictkey = self.dictdef.dictkey
             dictvalue = self.dictdef.dictvalue 
             return rconstantdict.ConstantDictRepr(

Modified: pypy/dist/pypy/rpython/test/test_rconstantdict.py
==============================================================================
--- pypy/dist/pypy/rpython/test/test_rconstantdict.py	(original)
+++ pypy/dist/pypy/rpython/test/test_rconstantdict.py	Mon Jul 25 17:00:01 2005
@@ -25,3 +25,14 @@
     assert res == 62
     res = interpret(func, [4, 25])
     assert res == -44
+
+def test_unichar_dict():
+    d = {u'a': 5, u'b': 123, u'?': 321}
+    def func(i):
+        return d[unichr(i)]
+    res = interpret(func, [97])
+    assert res == 5
+    res = interpret(func, [98])
+    assert res == 123
+    res = interpret(func, [63])
+    assert res == 321

Modified: pypy/dist/pypy/rpython/test/test_rtuple.py
==============================================================================
--- pypy/dist/pypy/rpython/test/test_rtuple.py	(original)
+++ pypy/dist/pypy/rpython/test/test_rtuple.py	Mon Jul 25 17:00:01 2005
@@ -78,6 +78,14 @@
     res = interpret(f, [0])
     assert res is False 
 
+def test_constant_unichar_tuple_contains():
+    def f(i):
+        return unichr(i) in (u'1', u'9')
+    res = interpret(f, [49])
+    assert res is True 
+    res = interpret(f, [50])
+    assert res is False 
+
 def test_tuple_iterator_length1():
     def f(i):
         total = 0



More information about the Pypy-commit mailing list