[pypy-svn] r16170 - in pypy/dist/pypy: objspace/std objspace/std/test rpython rpython/test

arigo at codespeak.net arigo at codespeak.net
Fri Aug 19 18:59:54 CEST 2005


Author: arigo
Date: Fri Aug 19 18:59:48 2005
New Revision: 16170

Modified:
   pypy/dist/pypy/objspace/std/stringobject.py
   pypy/dist/pypy/objspace/std/test/test_stringobject.py
   pypy/dist/pypy/rpython/rarithmetic.py
   pypy/dist/pypy/rpython/rstr.py
   pypy/dist/pypy/rpython/test/test_rstr.py
Log:
Using the implementation provided by rpython for str.find(), str.rfind() and
hash(str).

* rarithmetic._hash_string() computes the hash of a string; this is used from
    stringobject.py instead of directly hash(s), beause otherwise we would get
    different hash values in some cases (0/-1 special cases).  This would be a
    problem for strings that are used as dictionary keys when the dictionary
    is frozen by translation.

* rtyper support for the full str.find(substr, start=0, end=None) as well as
    str.rfind.  This allows us to clean up stringobject.py quite a bit, and
    probably gives an important speed-up.

* fix in ll_find() for the case of an empty substring; wrote direct tests
    for ll_find() and ll_rfind().

* eventually removed the W_StringObject.w_hash cache, as rpython provides its
    own hash caching.  (Right now this cache doesn't work with PyPy over
    CPython -- don't know if it causes a performance problem)

* fix bug in str.count() (see added test)



Modified: pypy/dist/pypy/objspace/std/stringobject.py
==============================================================================
--- pypy/dist/pypy/objspace/std/stringobject.py	(original)
+++ pypy/dist/pypy/objspace/std/stringobject.py	Fri Aug 19 18:59:48 2005
@@ -2,7 +2,8 @@
 
 from pypy.objspace.std.objspace import *
 from pypy.interpreter import gateway
-from pypy.rpython.rarithmetic import intmask, ovfcheck
+from pypy.rpython.rarithmetic import intmask, ovfcheck, _hash_string
+from pypy.rpython.objectmodel import we_are_translated
 from pypy.objspace.std.intobject   import W_IntObject
 from pypy.objspace.std.sliceobject import W_SliceObject
 from pypy.objspace.std import slicetype
@@ -20,7 +21,6 @@
     def __init__(w_self, space, str):
         W_Object.__init__(w_self, space)
         w_self._value = str
-        w_self.w_hash = None
 
     def __repr__(w_self):
         """ representation for debugging purposes """
@@ -275,9 +275,7 @@
         splitcount = maxsplit
 
     while splitcount:             
-        next = _find(value, by, start, len(value), 1)
-        #next = value.find(by, start)    #of course we cannot use 
-                                         #the find method, 
+        next = value.find(by, start)
         if next < 0:
             break
         res_w.append(W_StringObject(space, value[start:next]))
@@ -339,9 +337,7 @@
         splitcount = maxsplit
 
     while splitcount:
-        next = _find(value, by, 0, end, -1)
-        #next = value.rfind(by, end)    #of course we cannot use 
-                                        #the find method, 
+        next = value.rfind(by, 0, end)
         if next < 0:
             break
         res_w.append(W_StringObject(space, value[next+bylen:end]))
@@ -445,32 +441,33 @@
 
     start = space.int_w(w_start)
     end = space.int_w(w_end)
+    assert start >= 0
+    assert end >= 0
 
     return (self, sub, start, end)
 
 def contains__String_String(space, w_self, w_sub):
     self = w_self._value
     sub = w_sub._value
-    return space.newbool(_find(self, sub, 0, len(self), 1) >= 0)
+    return space.newbool(self.find(sub) >= 0)
 
 def str_find__String_String_ANY_ANY(space, w_self, w_sub, w_start, w_end):
 
     (self, sub, start, end) =  _convert_idx_params(space, w_self, w_sub, w_start, w_end)
-    res = _find(self, sub, start, end, 1)
+    res = self.find(sub, start, end)
     return space.wrap(res)
 
 def str_rfind__String_String_ANY_ANY(space, w_self, w_sub, w_start, w_end):
 
     (self, sub, start, end) =  _convert_idx_params(space, w_self, w_sub, w_start, w_end)
-    res = _find(self, sub, start, end, -1)
+    res = self.rfind(sub, start, end)
     return space.wrap(res)
 
 def str_index__String_String_ANY_ANY(space, w_self, w_sub, w_start, w_end):
 
     (self, sub, start, end) =  _convert_idx_params(space, w_self, w_sub, w_start, w_end)
-    res = _find(self, sub, start, end, 1)
-
-    if res == -1:
+    res = self.find(sub, start, end)
+    if res < 0:
         raise OperationError(space.w_ValueError,
                              space.wrap("substring not found in string.index"))
 
@@ -480,8 +477,8 @@
 def str_rindex__String_String_ANY_ANY(space, w_self, w_sub, w_start, w_end):
 
     (self, sub, start, end) =  _convert_idx_params(space, w_self, w_sub, w_start, w_end)
-    res = _find(self, sub, start, end, -1)
-    if res == -1:
+    res = self.rfind(sub, start, end)
+    if res < 0:
         raise OperationError(space.w_ValueError,
                              space.wrap("substring not found in string.rindex"))
 
@@ -499,19 +496,17 @@
 
     #what do we have to replace?
     startidx = 0
-    endidx = len(input)
     indices = []
-    foundidx = _find(input, sub, startidx, endidx, 1)
-    while foundidx > -1 and (maxsplit == -1 or maxsplit > 0):
+    foundidx = input.find(sub, startidx)
+    while foundidx >= 0 and maxsplit != 0:
         indices.append(foundidx)
         if len(sub) == 0:
             #so that we go forward, even if sub is empty
             startidx = foundidx + 1
         else: 
             startidx = foundidx + len(sub)        
-        foundidx = _find(input, sub, startidx, endidx, 1)
-        if maxsplit != -1:
-            maxsplit = maxsplit - 1
+        foundidx = input.find(sub, startidx)
+        maxsplit = maxsplit - 1
     indiceslen = len(indices)
     buf = [' '] * (len(input) - indiceslen * len(sub) + indiceslen * len(by))
     startidx = 0
@@ -534,52 +529,52 @@
         bufpos = bufpos + 1 
     return space.wrap("".join(buf))
 
-def _find(self, sub, start, end, dir):
-
-    length = len(self)
+##def _find(self, sub, start, end, dir):
 
-    #adjust_indicies
-    if (end > length):
-        end = length
-    elif (end < 0):
-        end += length
-    if (end < 0):
-        end = 0
-    if (start < 0):
-        start += length
-    if (start < 0):
-        start = 0
+##    length = len(self)
 
-    if dir > 0:
-        if len(sub) == 0 and start < end:
-            return start
-
-        end = end - len(sub) + 1
-
-        for i in range(start, end):
-            match = 1
-            for idx in range(len(sub)):
-                if sub[idx] != self[idx+i]:
-                    match = 0
-                    break
-            if match: 
-                return i
-        return -1
-    else:
-        if len(sub) == 0 and start < end:
-            return end
+##    #adjust_indicies
+##    if (end > length):
+##        end = length
+##    elif (end < 0):
+##        end += length
+##    if (end < 0):
+##        end = 0
+##    if (start < 0):
+##        start += length
+##    if (start < 0):
+##        start = 0
+
+##    if dir > 0:
+##        if len(sub) == 0 and start < end:
+##            return start
+
+##        end = end - len(sub) + 1
+
+##        for i in range(start, end):
+##            match = 1
+##            for idx in range(len(sub)):
+##                if sub[idx] != self[idx+i]:
+##                    match = 0
+##                    break
+##            if match: 
+##                return i
+##        return -1
+##    else:
+##        if len(sub) == 0 and start < end:
+##            return end
 
-        end = end - len(sub)
+##        end = end - len(sub)
 
-        for j in range(end, start-1, -1):
-            match = 1
-            for idx in range(len(sub)):
-                if sub[idx] != self[idx+j]:
-                    match = 0
-                    break
-            if match:
-                return j
-        return -1        
+##        for j in range(end, start-1, -1):
+##            match = 1
+##            for idx in range(len(sub)):
+##                if sub[idx] != self[idx+j]:
+##                    match = 0
+##                    break
+##            if match:
+##                return j
+##        return -1        
 
 
 def _strip(space, w_self, w_chars, left, right):
@@ -668,15 +663,19 @@
     w_end = slicetype.adapt_bound(space, w_end, space.wrap(len(u_self)))
     u_start = space.int_w(w_start)
     u_end = space.int_w(w_end)
-    
-    count = 0  
+    assert u_start >= 0
+    assert u_end >= 0
 
-    pos = u_start - 1 
-    while 1: 
-       pos = _find(u_self, u_arg, pos+1, u_end, 1)
-       if pos == -1:
-          break
-       count += 1
+    if len(u_arg) == 0:
+        count = len(u_self) + 1
+    else:
+        count = 0  
+        while 1: 
+            pos = u_self.find(u_arg, u_start, u_end)
+            if pos < 0:
+                break
+            count += 1
+            u_start = pos + len(u_arg)
        
     return W_IntObject(space, count)
 
@@ -804,21 +803,14 @@
     return w_str._value
 
 def hash__String(space, w_str):
-    w_hash = w_str.w_hash
-    if w_hash is None:
-        s = w_str._value
-        try:
-            x = ord(s[0]) << 7
-        except IndexError:
-            x = 0
-        else:
-            for c in s:
-                x = (1000003*x) ^ ord(c)
-            x ^= len(s)
-        # unlike CPython, there is no reason to avoid to return -1
-        w_hash = W_IntObject(space, intmask(x))
-        w_str.w_hash = w_hash
-    return w_hash
+    s = w_str._value
+    if we_are_translated():
+        x = hash(s)            # to use the hash cache in rpython strings
+    else:
+        x = _hash_string(s)    # to make sure we get the same hash as rpython
+        # (otherwise translation will freeze W_DictObjects where we can't find
+        #  the keys any more!)
+    return W_IntObject(space, x)
 
 
 ##EQ = 1

Modified: pypy/dist/pypy/objspace/std/test/test_stringobject.py
==============================================================================
--- pypy/dist/pypy/objspace/std/test/test_stringobject.py	(original)
+++ pypy/dist/pypy/objspace/std/test/test_stringobject.py	Fri Aug 19 18:59:48 2005
@@ -313,6 +313,7 @@
         assert 'aaa'.count('a', -10) == 3
         assert 'aaa'.count('a', 0, -1) == 2
         assert 'aaa'.count('a', 0, -10) == 0
+        assert 'ababa'.count('aba') == 1
      
     def test_startswith(self):
         assert 'ab'.startswith('ab') is True
@@ -592,7 +593,7 @@
     def test_hash(self):
         # check that we have the same hash as CPython for at least 31 bits
         # (but don't go checking CPython's special case -1)
-        assert hash('') == 0
+        # disabled: assert hash('') == 0 --- different special case
         assert hash('hello') & 0x7fffffff == 0x347697fd
         assert hash('hello world!') & 0x7fffffff == 0x2f0bb411
 

Modified: pypy/dist/pypy/rpython/rarithmetic.py
==============================================================================
--- pypy/dist/pypy/rpython/rarithmetic.py	(original)
+++ pypy/dist/pypy/rpython/rarithmetic.py	Fri Aug 19 18:59:48 2005
@@ -383,3 +383,20 @@
 
 def formatd(fmt, x):
     return fmt % (x,)
+
+# a common string hash function
+
+def _hash_string(s):
+    length = len(s)
+    if length == 0:
+        x = -1
+    else:
+        x = ord(s[0]) << 7
+        i = 0
+        while i < length:
+            x = (1000003*x) ^ ord(s[i])
+            i += 1
+        x ^= length
+        if x == 0:
+            x = -1
+    return intmask(x)

Modified: pypy/dist/pypy/rpython/rstr.py
==============================================================================
--- pypy/dist/pypy/rpython/rstr.py	(original)
+++ pypy/dist/pypy/rpython/rstr.py	Fri Aug 19 18:59:48 2005
@@ -3,7 +3,7 @@
 from pypy.annotation import model as annmodel
 from pypy.rpython.rmodel import Repr, TyperError, IntegerRepr
 from pypy.rpython.rmodel import StringRepr, CharRepr, inputconst, UniCharRepr
-from pypy.rpython.rarithmetic import intmask
+from pypy.rpython.rarithmetic import intmask, _hash_string
 from pypy.rpython.robject import PyObjRepr, pyobj_repr
 from pypy.rpython.rtuple import TupleRepr
 from pypy.rpython import rint
@@ -107,9 +107,29 @@
         v_str, v_value = hop.inputargs(string_repr, string_repr)
         return hop.gendirectcall(ll_endswith, v_str, v_value)
 
-    def rtype_method_find(_, hop):
-        v_str, v_value = hop.inputargs(string_repr, string_repr)
-        return hop.gendirectcall(ll_find, v_str, v_value)
+    def rtype_method_find(_, hop, reverse=False):
+        v_str = hop.inputarg(string_repr, arg=0)
+        v_value = hop.inputarg(string_repr, arg=1)
+        if hop.nb_args > 2:
+            v_start = hop.inputarg(Signed, arg=2)
+            if not hop.args_s[2].nonneg:
+                raise TyperError("str.find() start must be proven non-negative")
+        else:
+            v_start = hop.inputconst(Signed, 0)
+        if hop.nb_args > 3:
+            v_end = hop.inputarg(Signed, arg=3)
+            if not hop.args_s[2].nonneg:
+                raise TyperError("str.find() end must be proven non-negative")
+        else:
+            v_end = hop.gendirectcall(ll_strlen, v_str)
+        if reverse:
+            llfn = ll_rfind
+        else:
+            llfn = ll_find
+        return hop.gendirectcall(llfn, v_str, v_value, v_start, v_end)
+
+    def rtype_method_rfind(self, hop):
+        return self.rtype_method_find(hop, reverse=True)
 
     def rtype_method_upper(_, hop):
         v_str, = hop.inputargs(string_repr)
@@ -541,19 +561,8 @@
     # special non-computed-yet value.
     x = s.hash
     if x == 0:
-        length = len(s.chars)
-        if length == 0:
-            x = -1
-        else:
-            x = ord(s.chars[0]) << 7
-            i = 0
-            while i < length:
-                x = (1000003*x) ^ ord(s.chars[i])
-                i += 1
-            x ^= length
-            if x == 0:
-                x = -1
-        s.hash = intmask(x)
+        x = _hash_string(s.chars)
+        s.hash = x
     return x
 
 def ll_strconcat(s1, s2):
@@ -643,18 +652,20 @@
 
     return True
 
-def ll_find(s1, s2):
+def ll_find(s1, s2, start, end):
     """Knuth Morris Prath algorithm for substring match"""
-    len1 = len(s1.chars)
     len2 = len(s2.chars)
+    if len2 == 0:
+        return start
     # Construct the array of possible restarting positions
     # T = Array_of_ints [-1..len2]
     # T[-1] = -1 s2.chars[-1] is supposed to be unequal to everything else
     T = malloc( SIGNED_ARRAY, len2 )
-    i = 0
-    j = -1
+    T[0] = 0
+    i = 1
+    j = 0
     while i<len2:
-        if j>=0 and s2.chars[i] == s2.chars[j]:
+        if s2.chars[i] == s2.chars[j]:
             j += 1
             T[i] = j
             i += 1
@@ -667,8 +678,8 @@
 
     # Now the find algorithm
     i = 0
-    m = 0
-    while m+i<len1:
+    m = start
+    while m+i<end:
         if s1.chars[m+i]==s2.chars[i]:
             i += 1
             if i==len2:
@@ -676,14 +687,53 @@
         else:
             # mismatch, go back to the last possible starting pos
             if i==0:
-                e = -1
+                m += 1
             else:
                 e = T[i-1]
-            m = m + i - e
-            if i>0:
+                m = m + i - e
                 i = e
     return -1
-    
+
+def ll_rfind(s1, s2, start, end):
+    """Reversed version of ll_find()"""
+    len2 = len(s2.chars)
+    if len2 == 0:
+        return end
+    # Construct the array of possible restarting positions
+    T = malloc( SIGNED_ARRAY, len2 )
+    T[0] = 1
+    i = 1
+    j = 1
+    while i<len2:
+        if s2.chars[len2-i-1] == s2.chars[len2-j]:
+            j += 1
+            T[i] = j
+            i += 1
+        elif j>1:
+            j = T[j-2]
+        else:
+            T[i] = 1
+            i += 1
+            j = 1
+
+    # Now the find algorithm
+    i = 1
+    m = end
+    while m-i>=start:
+        if s1.chars[m-i]==s2.chars[len2-i]:
+            if i==len2:
+                return m-i
+            i += 1
+        else:
+            # mismatch, go back to the last possible starting pos
+            if i==1:
+                m -= 1
+            else:
+                e = T[i-2]
+                m = m - i + e
+                i = e
+    return -1
+
 emptystr = string_repr.convert_const("")
 
 def ll_upper(s):

Modified: pypy/dist/pypy/rpython/test/test_rstr.py
==============================================================================
--- pypy/dist/pypy/rpython/test/test_rstr.py	(original)
+++ pypy/dist/pypy/rpython/test/test_rstr.py	Fri Aug 19 18:59:48 2005
@@ -1,10 +1,28 @@
+import random
 from pypy.translator.translator import Translator
 from pypy.rpython.lltype import *
-from pypy.rpython.rstr import parse_fmt_string
+from pypy.rpython.rstr import parse_fmt_string, ll_find, ll_rfind, STR
 from pypy.rpython.rtyper import RPythonTyper, TyperError
 from pypy.rpython.test.test_llinterp import interpret, interpret_raises
 from pypy.rpython.llinterp import LLException
 
+def llstr(s):
+    p = malloc(STR, len(s))
+    for i in range(len(s)):
+        p.chars[i] = s[i]
+    return p
+
+def test_ll_find_rfind():
+    for i in range(50):
+        n1 = random.randint(0, 10)
+        s1 = ''.join([random.choice("ab") for i in range(n1)])
+        n2 = random.randint(0, 5)
+        s2 = ''.join([random.choice("ab") for i in range(n2)])
+        res = ll_find(llstr(s1), llstr(s2), 0, n1)
+        assert res == s1.find(s2)
+        res = ll_rfind(llstr(s1), llstr(s2), 0, n1)
+        assert res == s1.rfind(s2)
+
 def test_simple():
     def fn(i):
         s = 'hello'
@@ -200,13 +218,36 @@
 def test_find():
     def fn(i, j):
         s1 = ['one two three', 'abc abcdab abcdabcdabde']
-        s2 = ['one', 'two', 'abcdab', 'one tou', 'abcdefgh', 'fortytwo']
+        s2 = ['one', 'two', 'abcdab', 'one tou', 'abcdefgh', 'fortytwo', '']
         return s1[i].find(s2[j])
     for i in range(2):
-        for j in range(6):
+        for j in range(7):
             res = interpret(fn, [i,j])
             assert res == fn(i, j)
 
+def test_find_with_start():
+    def fn(i):
+        assert i >= 0
+        return 'ababcabc'.find('abc', i)
+    for i in range(9):
+        res = interpret(fn, [i])
+        assert res == fn(i)
+
+def test_find_with_start_end():
+    def fn(i, j):
+        assert i >= 0
+        assert j >= 0
+        return 'ababcabc'.find('abc', i, j)
+    for (i, j) in [(1,7), (2,6), (3,7), (3,8)]:
+        res = interpret(fn, [i, j])
+        assert res == fn(i, j)
+
+def test_rfind():
+    def fn():
+        return 'aaa'.rfind('a') + 'aaa'.rfind('a', 1) + 'aaa'.rfind('a', 1, 2)
+    res = interpret(fn, [])
+    assert res == 2 + 2 + 1
+
 def test_upper():
     def fn(i):
         strings = ['', ' ', 'upper', 'UpPeR', ',uppEr,']



More information about the Pypy-commit mailing list