[pypy-commit] pypy fix-bytearray-complexity: Add some special cases for stringmethods to avoid buffer-overhead
waedt
noreply at buildbot.pypy.org
Mon Jun 2 19:47:16 CEST 2014
Author: Tyler Wade <wayedt at gmail.com>
Branch: fix-bytearray-complexity
Changeset: r71882:7f2b803f1319
Date: 2014-06-02 11:41 -0500
http://bitbucket.org/pypy/pypy/changeset/7f2b803f1319/
Log: Add some special cases for stringmethods to avoid buffer-overhead
diff --git a/pypy/objspace/std/bytearrayobject.py b/pypy/objspace/std/bytearrayobject.py
--- a/pypy/objspace/std/bytearrayobject.py
+++ b/pypy/objspace/std/bytearrayobject.py
@@ -12,6 +12,7 @@
from pypy.objspace.std.sliceobject import W_SliceObject
from pypy.objspace.std.stdtypedef import StdTypeDef
from pypy.objspace.std.stringmethods import StringMethods, _get_buffer
+from pypy.objspace.std.bytesobject import W_BytesObject
from pypy.objspace.std.util import get_positive_index
NON_HEX_MSG = "non-hexadecimal number found in fromhex() arg at position %d"
@@ -43,8 +44,7 @@
return W_BytearrayObject(value)
def _new_from_buffer(self, buffer):
- length = buffer.getlength()
- return W_BytearrayObject([buffer.getitem(i) for i in range(length)])
+ return W_BytearrayObject([buffer[i] for i in range(len(buffer))])
def _new_from_list(self, value):
return W_BytearrayObject(value)
@@ -313,61 +313,52 @@
min_length = min(len(value), buffer_len)
return space.newbool(_memcmp(value, buffer, min_length) != 0)
+ def _comparison_helper(self, space, w_other):
+ value = self._val(space)
+
+ if isinstance(w_other, W_BytearrayObject):
+ other = w_other.data
+ other_len = len(other)
+ cmp = _memcmp(value, other, min(len(value), len(other)))
+ elif isinstance(w_other, W_BytesObject):
+ other = self._op_val(space, w_other)
+ other_len = len(other)
+ cmp = _memcmp(value, other, min(len(value), len(other)))
+ else:
+ try:
+ buffer = _get_buffer(space, w_other)
+ except OperationError as e:
+ if e.match(space, space.w_TypeError):
+ return False, 0, 0
+ raise
+ other_len = len(buffer)
+ cmp = _memcmp(value, buffer, min(len(value), len(buffer)))
+
+ return True, cmp, other_len
+
def descr_lt(self, space, w_other):
- try:
- buffer = _get_buffer(space, w_other)
- except OperationError as e:
- if e.match(space, space.w_TypeError):
- return space.w_NotImplemented
- raise
-
- value = self._val(space)
- buffer_len = buffer.getlength()
-
- cmp = _memcmp(value, buffer, min(len(value), buffer_len))
- return space.newbool(cmp < 0 or (cmp == 0 and len(value) < buffer_len))
+ success, cmp, other_len = self._comparison_helper(space, w_other)
+ if not success:
+ return space.w_NotImplemented
+ return space.newbool(cmp < 0 or (cmp == 0 and self._len() < other_len))
def descr_le(self, space, w_other):
- try:
- buffer = _get_buffer(space, w_other)
- except OperationError as e:
- if e.match(space, space.w_TypeError):
- return space.w_NotImplemented
- raise
-
- value = self._val(space)
- buffer_len = buffer.getlength()
-
- cmp = _memcmp(value, buffer, min(len(value), buffer_len))
- return space.newbool(cmp < 0 or (cmp == 0 and len(value) <= buffer_len))
+ success, cmp, other_len = self._comparison_helper(space, w_other)
+ if not success:
+ return space.w_NotImplemented
+ return space.newbool(cmp < 0 or (cmp == 0 and self._len() <= other_len))
def descr_gt(self, space, w_other):
- try:
- buffer = _get_buffer(space, w_other)
- except OperationError as e:
- if e.match(space, space.w_TypeError):
- return space.w_NotImplemented
- raise
-
- value = self._val(space)
- buffer_len = buffer.getlength()
-
- cmp = _memcmp(value, buffer, min(len(value), buffer_len))
- return space.newbool(cmp > 0 or (cmp == 0 and len(value) > buffer_len))
+ success, cmp, other_len = self._comparison_helper(space, w_other)
+ if not success:
+ return space.w_NotImplemented
+ return space.newbool(cmp > 0 or (cmp == 0 and self._len() > other_len))
def descr_ge(self, space, w_other):
- try:
- buffer = _get_buffer(space, w_other)
- except OperationError as e:
- if e.match(space, space.w_TypeError):
- return space.w_NotImplemented
- raise
-
- value = self._val(space)
- buffer_len = buffer.getlength()
-
- cmp = _memcmp(value, buffer, min(len(value), buffer_len))
- return space.newbool(cmp > 0 or (cmp == 0 and len(value) >= buffer_len))
+ success, cmp, other_len = self._comparison_helper(space, w_other)
+ if not success:
+ return space.w_NotImplemented
+ return space.newbool(cmp > 0 or (cmp == 0 and self._len() >= other_len))
def descr_iter(self, space):
return space.newseqiter(self)
@@ -377,11 +368,17 @@
self.data += w_other.data
return self
- buffer = _get_buffer(space, w_other)
- for i in range(buffer.getlength()):
- self.data.append(buffer.getitem(i))
+ if isinstance(w_other, W_BytesObject):
+ self._inplace_add(self._op_val(space, w_other))
+ else:
+ self._inplace_add(_get_buffer(space, w_other))
return self
+ @specialize.argtype(1)
+ def _inplace_add(self, other):
+ for i in range(len(other)):
+ self.data.append(other[i])
+
def descr_inplace_mul(self, space, w_times):
try:
times = space.getindex_w(w_times, space.w_OverflowError)
@@ -469,18 +466,20 @@
if isinstance(w_other, W_BytearrayObject):
return self._new(self.data + w_other.data)
+ if isinstance(w_other, W_BytesObject):
+ return self._add(self._op_val(space, w_other))
+
try:
buffer = _get_buffer(space, w_other)
except OperationError as e:
if e.match(space, space.w_TypeError):
return space.w_NotImplemented
raise
+ return self._add(buffer)
- buffer_len = buffer.getlength()
- data = list(self.data + ['\0'] * buffer_len)
- for i in range(buffer_len):
- data[len(self.data) + i] = buffer.getitem(i)
- return self._new(data)
+ @specialize.argtype(1)
+ def _add(self, other):
+ return self._new(self.data + [other[i] for i in range(len(other))])
def descr_reverse(self, space):
self.data.reverse()
@@ -1232,11 +1231,11 @@
self.data[index] = char
- at specialize.argtype(0)
+ at specialize.argtype(1)
def _memcmp(selfvalue, buffer, length):
for i in range(length):
- if selfvalue[i] < buffer.getitem(i):
+ if selfvalue[i] < buffer[i]:
return -1
- if selfvalue[i] > buffer.getitem(i):
+ if selfvalue[i] > buffer[i]:
return 1
return 0
diff --git a/pypy/objspace/std/stringmethods.py b/pypy/objspace/std/stringmethods.py
--- a/pypy/objspace/std/stringmethods.py
+++ b/pypy/objspace/std/stringmethods.py
@@ -4,8 +4,7 @@
from rpython.rlib.objectmodel import specialize, newlist_hint
from rpython.rlib.rarithmetic import ovfcheck
from rpython.rlib.rstring import (
- search, SEARCH_FIND, SEARCH_RFIND, SEARCH_COUNT, endswith, replace, rsplit,
- split, startswith)
+ find, rfind, count, endswith, replace, rsplit, split, startswith)
from rpython.rlib.buffer import Buffer
from pypy.interpreter.error import OperationError, oefmt
@@ -46,8 +45,14 @@
other = self._op_val(space, w_sub)
return space.newbool(value.find(other) >= 0)
- buffer = _get_buffer(space, w_sub)
- res = search(value, buffer, 0, len(value), SEARCH_FIND)
+ from pypy.objspace.std.bytesobject import W_BytesObject
+ if isinstance(w_sub, W_BytesObject):
+ other = self._op_val(space, w_sub)
+ res = find(value, other, 0, len(value))
+ else:
+ buffer = _get_buffer(space, w_sub)
+ res = find(value, buffer, 0, len(value))
+
return space.newbool(res >= 0)
def descr_add(self, space, w_other):
@@ -149,8 +154,16 @@
return space.newint(value.count(self._op_val(space, w_sub), start,
end))
- buffer = _get_buffer(space, w_sub)
- res = search(value, buffer, start, end, SEARCH_COUNT)
+ from pypy.objspace.std.bytearrayobject import W_BytearrayObject
+ from pypy.objspace.std.bytesobject import W_BytesObject
+ if isinstance(w_sub, W_BytearrayObject):
+ res = count(value, w_sub.data, start, end)
+ elif isinstance(w_sub, W_BytesObject):
+ res = count(value, w_sub._value, start, end)
+ else:
+ buffer = _get_buffer(space, w_sub)
+ res = count(value, buffer, start, end)
+
return space.wrap(max(res, 0))
def descr_decode(self, space, w_encoding=None, w_errors=None):
@@ -226,8 +239,16 @@
res = value.find(self._op_val(space, w_sub), start, end)
return space.wrap(res)
- buffer = _get_buffer(space, w_sub)
- res = search(value, buffer, start, end, SEARCH_FIND)
+ from pypy.objspace.std.bytearrayobject import W_BytearrayObject
+ from pypy.objspace.std.bytesobject import W_BytesObject
+ if isinstance(w_sub, W_BytearrayObject):
+ res = find(value, w_sub.data, start, end)
+ elif isinstance(w_sub, W_BytesObject):
+ res = find(value, w_sub._value, start, end)
+ else:
+ buffer = _get_buffer(space, w_sub)
+ res = find(value, buffer, start, end)
+
return space.wrap(res)
def descr_rfind(self, space, w_sub, w_start=None, w_end=None):
@@ -237,18 +258,32 @@
res = value.rfind(self._op_val(space, w_sub), start, end)
return space.wrap(res)
- buffer = _get_buffer(space, w_sub)
- res = search(value, buffer, start, end, SEARCH_RFIND)
+ from pypy.objspace.std.bytearrayobject import W_BytearrayObject
+ from pypy.objspace.std.bytesobject import W_BytesObject
+ if isinstance(w_sub, W_BytearrayObject):
+ res = rfind(value, w_sub.data, start, end)
+ elif isinstance(w_sub, W_BytesObject):
+ res = rfind(value, w_sub._value, start, end)
+ else:
+ buffer = _get_buffer(space, w_sub)
+ res = rfind(value, buffer, start, end)
+
return space.wrap(res)
def descr_index(self, space, w_sub, w_start=None, w_end=None):
(value, start, end) = self._convert_idx_params(space, w_start, w_end)
+ from pypy.objspace.std.bytearrayobject import W_BytearrayObject
+ from pypy.objspace.std.bytesobject import W_BytesObject
if self._use_rstr_ops(space, w_sub):
res = value.find(self._op_val(space, w_sub), start, end)
+ elif isinstance(w_sub, W_BytearrayObject):
+ res = find(value, w_sub.data, start, end)
+ elif isinstance(w_sub, W_BytesObject):
+ res = find(value, w_sub._value, start, end)
else:
buffer = _get_buffer(space, w_sub)
- res = search(value, buffer, start, end, SEARCH_FIND)
+ res = find(value, buffer, start, end)
if res < 0:
raise oefmt(space.w_ValueError,
@@ -258,11 +293,17 @@
def descr_rindex(self, space, w_sub, w_start=None, w_end=None):
(value, start, end) = self._convert_idx_params(space, w_start, w_end)
+ from pypy.objspace.std.bytearrayobject import W_BytearrayObject
+ from pypy.objspace.std.bytesobject import W_BytesObject
if self._use_rstr_ops(space, w_sub):
res = value.rfind(self._op_val(space, w_sub), start, end)
+ elif isinstance(w_sub, W_BytearrayObject):
+ res = rfind(value, w_sub.data, start, end)
+ elif isinstance(w_sub, W_BytesObject):
+ res = rfind(value, w_sub._value, start, end)
else:
buffer = _get_buffer(space, w_sub)
- res = search(value, buffer, start, end, SEARCH_RFIND)
+ res = rfind(value, buffer, start, end)
if res < 0:
raise oefmt(space.w_ValueError,
@@ -456,7 +497,7 @@
if sublen == 0:
raise oefmt(space.w_ValueError, "empty separator")
- pos = search(value, sub, 0, len(value), SEARCH_FIND)
+ pos = find(value, sub, 0, len(value))
if pos != -1 and isinstance(self, W_BytearrayObject):
w_sub = self._new_from_buffer(sub)
@@ -486,7 +527,7 @@
if sublen == 0:
raise oefmt(space.w_ValueError, "empty separator")
- pos = search(value, sub, 0, len(value), SEARCH_RFIND)
+ pos = rfind(value, sub, 0, len(value))
if pos != -1 and isinstance(self, W_BytearrayObject):
w_sub = self._new_from_buffer(sub)
@@ -502,12 +543,14 @@
@unwrap_spec(count=int)
def descr_replace(self, space, w_old, w_new, count=-1):
input = self._val(space)
+
sub = self._op_val(space, w_old)
by = self._op_val(space, w_new)
try:
res = replace(input, sub, by, count)
except OverflowError:
raise oefmt(space.w_OverflowError, "replace string is too long")
+
return self._new(res)
@unwrap_spec(maxsplit=int)
@@ -518,11 +561,17 @@
res = split(value, maxsplit=maxsplit)
return self._newlist_unwrapped(space, res)
- by = self._op_val(space, w_sep)
- bylen = len(by)
- if bylen == 0:
- raise oefmt(space.w_ValueError, "empty separator")
- res = split(value, by, maxsplit)
+ if self._use_rstr_ops(space, w_sep):
+ by = self._op_val(space, w_sep)
+ if len(by) == 0:
+ raise oefmt(space.w_ValueError, "empty separator")
+ res = split(value, by, maxsplit)
+ else:
+ by = _get_buffer(space, w_sep)
+ if len(by) == 0:
+ raise oefmt(space.w_ValueError, "empty separator")
+ res = split(value, by, maxsplit)
+
return self._newlist_unwrapped(space, res)
@unwrap_spec(maxsplit=int)
@@ -533,11 +582,17 @@
res = rsplit(value, maxsplit=maxsplit)
return self._newlist_unwrapped(space, res)
- by = self._op_val(space, w_sep)
- bylen = len(by)
- if bylen == 0:
- raise oefmt(space.w_ValueError, "empty separator")
- res = rsplit(value, by, maxsplit)
+ if self._use_rstr_ops(space, w_sep):
+ by = self._op_val(space, w_sep)
+ if len(by) == 0:
+ raise oefmt(space.w_ValueError, "empty separator")
+ res = rsplit(value, by, maxsplit)
+ else:
+ by = _get_buffer(space, w_sep)
+ if len(by) == 0:
+ raise oefmt(space.w_ValueError, "empty separator")
+ res = rsplit(value, by, maxsplit)
+
return self._newlist_unwrapped(space, res)
@unwrap_spec(keepends=bool)
@@ -574,7 +629,10 @@
end))
def _startswith(self, space, value, w_prefix, start, end):
- return startswith(value, self._op_val(space, w_prefix), start, end)
+ if self._use_rstr_ops(space, w_prefix):
+ return startswith(value, self._op_val(space, w_prefix), start, end)
+ else:
+ return startswith(value, _get_buffer(space, w_prefix), start, end)
def descr_endswith(self, space, w_suffix, w_start=None, w_end=None):
(value, start, end) = self._convert_idx_params(space, w_start, w_end,
@@ -588,7 +646,10 @@
end))
def _endswith(self, space, value, w_prefix, start, end):
- return endswith(value, self._op_val(space, w_prefix), start, end)
+ if self._use_rstr_ops(space, w_prefix):
+ return endswith(value, self._op_val(space, w_prefix), start, end)
+ else:
+ return endswith(value, _get_buffer(space, w_prefix), start, end)
def _strip(self, space, w_chars, left, right):
"internal function called by str_xstrip methods"
More information about the pypy-commit
mailing list