[pypy-commit] pypy fix-bytearray-complexity: Add some special cases for stringmethods to avoid buffer-overhead

waedt noreply at buildbot.pypy.org
Mon Jun 2 19:47:16 CEST 2014


Author: Tyler Wade <wayedt at gmail.com>
Branch: fix-bytearray-complexity
Changeset: r71882:7f2b803f1319
Date: 2014-06-02 11:41 -0500
http://bitbucket.org/pypy/pypy/changeset/7f2b803f1319/

Log:	Add some special cases for stringmethods to avoid buffer-overhead

diff --git a/pypy/objspace/std/bytearrayobject.py b/pypy/objspace/std/bytearrayobject.py
--- a/pypy/objspace/std/bytearrayobject.py
+++ b/pypy/objspace/std/bytearrayobject.py
@@ -12,6 +12,7 @@
 from pypy.objspace.std.sliceobject import W_SliceObject
 from pypy.objspace.std.stdtypedef import StdTypeDef
 from pypy.objspace.std.stringmethods import StringMethods, _get_buffer
+from pypy.objspace.std.bytesobject import W_BytesObject
 from pypy.objspace.std.util import get_positive_index
 
 NON_HEX_MSG = "non-hexadecimal number found in fromhex() arg at position %d"
@@ -43,8 +44,7 @@
         return W_BytearrayObject(value)
 
     def _new_from_buffer(self, buffer):
-        length = buffer.getlength()
-        return W_BytearrayObject([buffer.getitem(i) for i in range(length)])
+        return W_BytearrayObject([buffer[i] for i in range(len(buffer))])
 
     def _new_from_list(self, value):
         return W_BytearrayObject(value)
@@ -313,61 +313,52 @@
         min_length = min(len(value), buffer_len)
         return space.newbool(_memcmp(value, buffer, min_length) != 0)
 
+    def _comparison_helper(self, space, w_other):
+        value = self._val(space)
+
+        if isinstance(w_other, W_BytearrayObject):
+            other = w_other.data
+            other_len = len(other)
+            cmp = _memcmp(value, other, min(len(value), len(other)))
+        elif isinstance(w_other, W_BytesObject):
+            other = self._op_val(space, w_other)
+            other_len = len(other)
+            cmp = _memcmp(value, other, min(len(value), len(other)))
+        else:
+            try:
+                buffer = _get_buffer(space, w_other)
+            except OperationError as e:
+                if e.match(space, space.w_TypeError):
+                    return False, 0, 0
+                raise
+            other_len = len(buffer)
+            cmp = _memcmp(value, buffer, min(len(value), len(buffer)))
+
+        return True, cmp, other_len
+
     def descr_lt(self, space, w_other):
-        try:
-            buffer = _get_buffer(space, w_other)
-        except OperationError as e:
-            if e.match(space, space.w_TypeError):
-                return space.w_NotImplemented
-            raise
-
-        value = self._val(space)
-        buffer_len = buffer.getlength()
-
-        cmp = _memcmp(value, buffer, min(len(value), buffer_len))
-        return space.newbool(cmp < 0 or (cmp == 0 and len(value) < buffer_len))
+        success, cmp, other_len = self._comparison_helper(space, w_other)
+        if not success:
+            return space.w_NotImplemented
+        return space.newbool(cmp < 0 or (cmp == 0 and self._len() < other_len))
 
     def descr_le(self, space, w_other):
-        try:
-            buffer = _get_buffer(space, w_other)
-        except OperationError as e:
-            if e.match(space, space.w_TypeError):
-                return space.w_NotImplemented
-            raise
-
-        value = self._val(space)
-        buffer_len = buffer.getlength()
-
-        cmp = _memcmp(value, buffer, min(len(value), buffer_len))
-        return space.newbool(cmp < 0 or (cmp == 0 and len(value) <= buffer_len))
+        success, cmp, other_len = self._comparison_helper(space, w_other)
+        if not success:
+            return space.w_NotImplemented
+        return space.newbool(cmp < 0 or (cmp == 0 and self._len() <= other_len))
 
     def descr_gt(self, space, w_other):
-        try:
-            buffer = _get_buffer(space, w_other)
-        except OperationError as e:
-            if e.match(space, space.w_TypeError):
-                return space.w_NotImplemented
-            raise
-
-        value = self._val(space)
-        buffer_len = buffer.getlength()
-
-        cmp = _memcmp(value, buffer, min(len(value), buffer_len))
-        return space.newbool(cmp > 0 or (cmp == 0 and len(value) > buffer_len))
+        success, cmp, other_len = self._comparison_helper(space, w_other)
+        if not success:
+            return space.w_NotImplemented
+        return space.newbool(cmp > 0 or (cmp == 0 and self._len() > other_len))
 
     def descr_ge(self, space, w_other):
-        try:
-            buffer = _get_buffer(space, w_other)
-        except OperationError as e:
-            if e.match(space, space.w_TypeError):
-                return space.w_NotImplemented
-            raise
-
-        value = self._val(space)
-        buffer_len = buffer.getlength()
-
-        cmp = _memcmp(value, buffer, min(len(value), buffer_len))
-        return space.newbool(cmp > 0 or (cmp == 0 and len(value) >= buffer_len))
+        success, cmp, other_len = self._comparison_helper(space, w_other)
+        if not success:
+            return space.w_NotImplemented
+        return space.newbool(cmp > 0 or (cmp == 0 and self._len() >= other_len))
 
     def descr_iter(self, space):
         return space.newseqiter(self)
@@ -377,11 +368,17 @@
             self.data += w_other.data
             return self
 
-        buffer = _get_buffer(space, w_other)
-        for i in range(buffer.getlength()):
-            self.data.append(buffer.getitem(i))
+        if isinstance(w_other, W_BytesObject):
+            self._inplace_add(self._op_val(space, w_other))
+        else:
+            self._inplace_add(_get_buffer(space, w_other))
         return self
 
+    @specialize.argtype(1)
+    def _inplace_add(self, other):
+        for i in range(len(other)):
+            self.data.append(other[i])
+
     def descr_inplace_mul(self, space, w_times):
         try:
             times = space.getindex_w(w_times, space.w_OverflowError)
@@ -469,18 +466,20 @@
         if isinstance(w_other, W_BytearrayObject):
             return self._new(self.data + w_other.data)
 
+        if isinstance(w_other, W_BytesObject):
+            return self._add(self._op_val(space, w_other))
+
         try:
             buffer = _get_buffer(space, w_other)
         except OperationError as e:
             if e.match(space, space.w_TypeError):
                 return space.w_NotImplemented
             raise
+        return self._add(buffer)
 
-        buffer_len = buffer.getlength()
-        data = list(self.data + ['\0'] * buffer_len)
-        for i in range(buffer_len):
-            data[len(self.data) + i] = buffer.getitem(i)
-        return self._new(data)
+    @specialize.argtype(1)
+    def _add(self, other):
+        return self._new(self.data + [other[i] for i in range(len(other))])
 
     def descr_reverse(self, space):
         self.data.reverse()
@@ -1232,11 +1231,11 @@
         self.data[index] = char
 
 
- at specialize.argtype(0)
+ at specialize.argtype(1)
 def _memcmp(selfvalue, buffer, length):
     for i in range(length):
-        if selfvalue[i] < buffer.getitem(i):
+        if selfvalue[i] < buffer[i]:
             return -1
-        if selfvalue[i] > buffer.getitem(i):
+        if selfvalue[i] > buffer[i]:
             return 1
     return 0
diff --git a/pypy/objspace/std/stringmethods.py b/pypy/objspace/std/stringmethods.py
--- a/pypy/objspace/std/stringmethods.py
+++ b/pypy/objspace/std/stringmethods.py
@@ -4,8 +4,7 @@
 from rpython.rlib.objectmodel import specialize, newlist_hint
 from rpython.rlib.rarithmetic import ovfcheck
 from rpython.rlib.rstring import (
-    search, SEARCH_FIND, SEARCH_RFIND, SEARCH_COUNT, endswith, replace, rsplit,
-    split, startswith)
+    find, rfind, count, endswith, replace, rsplit, split, startswith)
 from rpython.rlib.buffer import Buffer
 
 from pypy.interpreter.error import OperationError, oefmt
@@ -46,8 +45,14 @@
             other = self._op_val(space, w_sub)
             return space.newbool(value.find(other) >= 0)
 
-        buffer = _get_buffer(space, w_sub)
-        res = search(value, buffer, 0, len(value), SEARCH_FIND)
+        from pypy.objspace.std.bytesobject import W_BytesObject
+        if isinstance(w_sub, W_BytesObject):
+            other = self._op_val(space, w_sub)
+            res = find(value, other, 0, len(value))
+        else:
+            buffer = _get_buffer(space, w_sub)
+            res = find(value, buffer, 0, len(value))
+
         return space.newbool(res >= 0)
 
     def descr_add(self, space, w_other):
@@ -149,8 +154,16 @@
             return space.newint(value.count(self._op_val(space, w_sub), start,
                                             end))
 
-        buffer = _get_buffer(space, w_sub)
-        res = search(value, buffer, start, end, SEARCH_COUNT)
+        from pypy.objspace.std.bytearrayobject import W_BytearrayObject
+        from pypy.objspace.std.bytesobject import W_BytesObject
+        if isinstance(w_sub, W_BytearrayObject):
+            res = count(value, w_sub.data, start, end)
+        elif isinstance(w_sub, W_BytesObject):
+            res = count(value, w_sub._value, start, end)
+        else:
+            buffer = _get_buffer(space, w_sub)
+            res = count(value, buffer, start, end)
+
         return space.wrap(max(res, 0))
 
     def descr_decode(self, space, w_encoding=None, w_errors=None):
@@ -226,8 +239,16 @@
             res = value.find(self._op_val(space, w_sub), start, end)
             return space.wrap(res)
 
-        buffer = _get_buffer(space, w_sub)
-        res = search(value, buffer, start, end, SEARCH_FIND)
+        from pypy.objspace.std.bytearrayobject import W_BytearrayObject
+        from pypy.objspace.std.bytesobject import W_BytesObject
+        if isinstance(w_sub, W_BytearrayObject):
+            res = find(value, w_sub.data, start, end)
+        elif isinstance(w_sub, W_BytesObject):
+            res = find(value, w_sub._value, start, end)
+        else:
+            buffer = _get_buffer(space, w_sub)
+            res = find(value, buffer, start, end)
+
         return space.wrap(res)
 
     def descr_rfind(self, space, w_sub, w_start=None, w_end=None):
@@ -237,18 +258,32 @@
             res = value.rfind(self._op_val(space, w_sub), start, end)
             return space.wrap(res)
 
-        buffer = _get_buffer(space, w_sub)
-        res = search(value, buffer, start, end, SEARCH_RFIND)
+        from pypy.objspace.std.bytearrayobject import W_BytearrayObject
+        from pypy.objspace.std.bytesobject import W_BytesObject
+        if isinstance(w_sub, W_BytearrayObject):
+            res = rfind(value, w_sub.data, start, end)
+        elif isinstance(w_sub, W_BytesObject):
+            res = rfind(value, w_sub._value, start, end)
+        else:
+            buffer = _get_buffer(space, w_sub)
+            res = rfind(value, buffer, start, end)
+
         return space.wrap(res)
 
     def descr_index(self, space, w_sub, w_start=None, w_end=None):
         (value, start, end) = self._convert_idx_params(space, w_start, w_end)
 
+        from pypy.objspace.std.bytearrayobject import W_BytearrayObject
+        from pypy.objspace.std.bytesobject import W_BytesObject
         if self._use_rstr_ops(space, w_sub):
             res = value.find(self._op_val(space, w_sub), start, end)
+        elif isinstance(w_sub, W_BytearrayObject):
+            res = find(value, w_sub.data, start, end)
+        elif isinstance(w_sub, W_BytesObject):
+            res = find(value, w_sub._value, start, end)
         else:
             buffer = _get_buffer(space, w_sub)
-            res = search(value, buffer, start, end, SEARCH_FIND)
+            res = find(value, buffer, start, end)
 
         if res < 0:
             raise oefmt(space.w_ValueError,
@@ -258,11 +293,17 @@
     def descr_rindex(self, space, w_sub, w_start=None, w_end=None):
         (value, start, end) = self._convert_idx_params(space, w_start, w_end)
 
+        from pypy.objspace.std.bytearrayobject import W_BytearrayObject
+        from pypy.objspace.std.bytesobject import W_BytesObject
         if self._use_rstr_ops(space, w_sub):
             res = value.rfind(self._op_val(space, w_sub), start, end)
+        elif isinstance(w_sub, W_BytearrayObject):
+            res = rfind(value, w_sub.data, start, end)
+        elif isinstance(w_sub, W_BytesObject):
+            res = rfind(value, w_sub._value, start, end)
         else:
             buffer = _get_buffer(space, w_sub)
-            res = search(value, buffer, start, end, SEARCH_RFIND)
+            res = rfind(value, buffer, start, end)
 
         if res < 0:
             raise oefmt(space.w_ValueError,
@@ -456,7 +497,7 @@
             if sublen == 0:
                 raise oefmt(space.w_ValueError, "empty separator")
 
-            pos = search(value, sub, 0, len(value), SEARCH_FIND)
+            pos = find(value, sub, 0, len(value))
             if pos != -1 and isinstance(self, W_BytearrayObject):
                 w_sub = self._new_from_buffer(sub)
 
@@ -486,7 +527,7 @@
             if sublen == 0:
                 raise oefmt(space.w_ValueError, "empty separator")
 
-            pos = search(value, sub, 0, len(value), SEARCH_RFIND)
+            pos = rfind(value, sub, 0, len(value))
             if pos != -1 and isinstance(self, W_BytearrayObject):
                 w_sub = self._new_from_buffer(sub)
 
@@ -502,12 +543,14 @@
     @unwrap_spec(count=int)
     def descr_replace(self, space, w_old, w_new, count=-1):
         input = self._val(space)
+
         sub = self._op_val(space, w_old)
         by = self._op_val(space, w_new)
         try:
             res = replace(input, sub, by, count)
         except OverflowError:
             raise oefmt(space.w_OverflowError, "replace string is too long")
+
         return self._new(res)
 
     @unwrap_spec(maxsplit=int)
@@ -518,11 +561,17 @@
             res = split(value, maxsplit=maxsplit)
             return self._newlist_unwrapped(space, res)
 
-        by = self._op_val(space, w_sep)
-        bylen = len(by)
-        if bylen == 0:
-            raise oefmt(space.w_ValueError, "empty separator")
-        res = split(value, by, maxsplit)
+        if self._use_rstr_ops(space, w_sep):
+            by = self._op_val(space, w_sep)
+            if len(by) == 0:
+                raise oefmt(space.w_ValueError, "empty separator")
+            res = split(value, by, maxsplit)
+        else:
+            by = _get_buffer(space, w_sep)
+            if len(by) == 0:
+                raise oefmt(space.w_ValueError, "empty separator")
+            res = split(value, by, maxsplit)
+
         return self._newlist_unwrapped(space, res)
 
     @unwrap_spec(maxsplit=int)
@@ -533,11 +582,17 @@
             res = rsplit(value, maxsplit=maxsplit)
             return self._newlist_unwrapped(space, res)
 
-        by = self._op_val(space, w_sep)
-        bylen = len(by)
-        if bylen == 0:
-            raise oefmt(space.w_ValueError, "empty separator")
-        res = rsplit(value, by, maxsplit)
+        if self._use_rstr_ops(space, w_sep):
+            by = self._op_val(space, w_sep)
+            if len(by) == 0:
+                raise oefmt(space.w_ValueError, "empty separator")
+            res = rsplit(value, by, maxsplit)
+        else:
+            by = _get_buffer(space, w_sep)
+            if len(by) == 0:
+                raise oefmt(space.w_ValueError, "empty separator")
+            res = rsplit(value, by, maxsplit)
+
         return self._newlist_unwrapped(space, res)
 
     @unwrap_spec(keepends=bool)
@@ -574,7 +629,10 @@
                                               end))
 
     def _startswith(self, space, value, w_prefix, start, end):
-        return startswith(value, self._op_val(space, w_prefix), start, end)
+        if self._use_rstr_ops(space, w_prefix):
+            return startswith(value, self._op_val(space, w_prefix), start, end)
+        else:
+            return startswith(value, _get_buffer(space, w_prefix), start, end)
 
     def descr_endswith(self, space, w_suffix, w_start=None, w_end=None):
         (value, start, end) = self._convert_idx_params(space, w_start, w_end,
@@ -588,7 +646,10 @@
                                             end))
 
     def _endswith(self, space, value, w_prefix, start, end):
-        return endswith(value, self._op_val(space, w_prefix), start, end)
+        if self._use_rstr_ops(space, w_prefix):
+            return endswith(value, self._op_val(space, w_prefix), start, end)
+        else:
+            return endswith(value, _get_buffer(space, w_prefix), start, end)
 
     def _strip(self, space, w_chars, left, right):
         "internal function called by str_xstrip methods"


More information about the pypy-commit mailing list