[pypy-svn] r53754 - in pypy/dist/pypy/rlib: . test
arigo at codespeak.net
arigo at codespeak.net
Mon Apr 14 15:16:45 CEST 2008
Author: arigo
Date: Mon Apr 14 15:16:44 2008
New Revision: 53754
Added:
pypy/dist/pypy/rlib/rStringIO.py (contents, props changed)
pypy/dist/pypy/rlib/test/test_rStringIO.py (contents, props changed)
Log:
An RPython version of StringIO. The "fast path" through this code is
for the case of a bunch of write() followed by getvalue().
Added: pypy/dist/pypy/rlib/rStringIO.py
==============================================================================
--- (empty file)
+++ pypy/dist/pypy/rlib/rStringIO.py Mon Apr 14 15:16:44 2008
@@ -0,0 +1,164 @@
+
+PIECES = 80
+BIGPIECES = 32
+
+AT_END = -1
+
+
+class RStringIO(object):
+ """RPython-level StringIO object.
+ The fastest path through this code is for the case of a bunch of write()
+ followed by getvalue(). For at most PIECES write()s and one getvalue(),
+ there is one copy of the data done, as if ''.join() was used.
+ """
+ _mixin_ = True # for interp_stringio.py
+
+ def __init__(self):
+ # The real content is the join of the following data:
+ # * the list of characters self.bigbuffer;
+ # * each of the strings in self.strings.
+ #
+ # Invariants:
+ # * self.numbigstrings <= self.numstrings;
+ # * all strings in self.strings[self.numstrings:PIECES] are empty.
+ #
+ self.strings = [''] * PIECES
+ self.numstrings = 0
+ self.numbigstrings = 0
+ self.bigbuffer = []
+ self.pos = AT_END
+
+ def getvalue(self):
+ """If self.strings contains more than 1 string, join all the
+ strings together. Return the final single string."""
+ if len(self.bigbuffer) > 0:
+ self.copy_into_bigbuffer()
+ return ''.join(self.bigbuffer)
+ if self.numstrings > 1:
+ result = self.strings[0] = ''.join(self.strings)
+ self.numstrings = 1
+ self.numbigstrings = 1
+ else:
+ result = self.strings[0]
+ return result
+
+ def getsize(self):
+ result = len(self.bigbuffer)
+ for i in range(0, self.numstrings):
+ result += len(self.strings[i])
+ return result
+
+ def copy_into_bigbuffer(self):
+ """Copy all the data into the list of characters self.bigbuffer."""
+ for i in range(0, self.numstrings):
+ self.bigbuffer += self.strings[i]
+ self.strings[i] = ''
+ self.numstrings = 0
+ self.numbigstrings = 0
+
+ def reduce(self):
+ """Reduce the number of (non-empty) strings in self.strings."""
+ # When self.pos == AT_END, the calls to write(str) accumulate
+ # the strings in self.strings until all PIECES slots are filled.
+ # Then the reduce() method joins all the strings and put the
+ # result back into self.strings[0]. The next time all the slots
+ # are filled, we only join self.strings[1:] and put the result
+ # in self.strings[1]; and so on. The purpose of this is that
+ # the string resulting from a join is expected to be big, so the
+ # next join operation should only join the newly added strings.
+ # When we have done this BIGPIECES times, the next join collects
+ # all strings again into self.strings[0] and we start from
+ # scratch.
+ limit = self.numbigstrings
+ self.strings[limit] = ''.join(self.strings[limit:])
+ for i in range(limit + 1, self.numstrings):
+ self.strings[i] = ''
+ self.numstrings = limit + 1
+ if limit < BIGPIECES:
+ self.numbigstrings = limit + 1
+ else:
+ self.numbigstrings = 0
+ assert self.numstrings <= BIGPIECES + 1
+ return self.numstrings
+
+ def write(self, buffer):
+ # Idea: for the common case of a sequence of write() followed
+ # by only getvalue(), self.bigbuffer remains empty. It is only
+ # used to handle the more complicated cases.
+ p = self.pos
+ if p != AT_END: # slow or semi-fast paths
+ endp = p + len(buffer)
+ if len(self.bigbuffer) >= endp:
+ # semi-fast path: the write is entirely inside self.bigbuffer
+ for i in range(len(buffer)):
+ self.bigbuffer[p+i] = buffer[i]
+ self.pos = endp
+ return
+ else:
+ # slow path: collect all data into self.bigbuffer and
+ # handle the various cases
+ self.copy_into_bigbuffer()
+ fitting = len(self.bigbuffer) - p
+ if fitting > 0:
+ # the write starts before the end of the data
+ fitting = min(len(buffer), fitting)
+ for i in range(fitting):
+ self.bigbuffer[p+i] = buffer[i]
+ if len(buffer) > fitting:
+ # the write extends beyond the end of the data
+ self.bigbuffer += buffer[fitting:]
+ endp = AT_END
+ self.pos = endp
+ return
+ else:
+ # the write starts at or beyond the end of the data
+ self.bigbuffer += '\x00' * (-fitting)
+ self.pos = AT_END # fall-through to the fast path
+ # Fast path.
+ # See comments in reduce().
+ count = self.numstrings
+ if count == PIECES:
+ count = self.reduce()
+ self.strings[count] = buffer
+ self.numstrings = count + 1
+
+ def seek(self, position, mode=0):
+ if mode == 1:
+ if self.pos == AT_END:
+ self.pos = self.getsize()
+ position += self.pos
+ elif mode == 2:
+ if position == 0:
+ self.pos = AT_END
+ return
+ position += self.getsize()
+ if position < 0:
+ position = 0
+ self.pos = position
+
+ def tell(self):
+ if self.pos == AT_END:
+ return self.getsize()
+ else:
+ return self.pos
+
+ def read(self, n=-1):
+ p = self.pos
+ if p == 0 and n < 0:
+ self.pos = AT_END
+ return self.getvalue() # reading everything
+ if p == AT_END:
+ return ''
+ self.copy_into_bigbuffer()
+ mysize = len(self.bigbuffer)
+ count = mysize - p
+ if n >= 0:
+ count = min(n, count)
+ if count <= 0:
+ return ''
+ if p == 0 and count == mysize:
+ self.pos = AT_END
+ return ''.join(self.bigbuffer)
+ else:
+ self.pos = p + count
+ return ''.join(self.bigbuffer[p:p+count])
Added: pypy/dist/pypy/rlib/test/test_rStringIO.py
==============================================================================
--- (empty file)
+++ pypy/dist/pypy/rlib/test/test_rStringIO.py Mon Apr 14 15:16:44 2008
@@ -0,0 +1,115 @@
+from pypy.rlib.rStringIO import RStringIO
+
+
+def test_simple():
+ f = RStringIO()
+ f.write('hello')
+ f.write(' world')
+ assert f.getvalue() == 'hello world'
+
+def test_write_many():
+ f = RStringIO()
+ for j in range(10):
+ for i in range(253):
+ f.write(chr(i))
+ expected = ''.join([chr(i) for j in range(10) for i in range(253)])
+ assert f.getvalue() == expected
+
+def test_seek():
+ f = RStringIO()
+ f.write('0123')
+ f.write('456')
+ f.write('789')
+ f.seek(4)
+ f.write('AB')
+ assert f.getvalue() == '0123AB6789'
+ f.seek(-2, 2)
+ f.write('CDE')
+ assert f.getvalue() == '0123AB67CDE'
+ f.seek(2, 0)
+ f.seek(5, 1)
+ f.write('F')
+ assert f.getvalue() == '0123AB6FCDE'
+
+def test_write_beyond_end():
+ f = RStringIO()
+ f.seek(20, 1)
+ assert f.tell() == 20
+ f.write('X')
+ assert f.getvalue() == '\x00' * 20 + 'X'
+
+def test_tell():
+ f = RStringIO()
+ f.write('0123')
+ f.write('456')
+ assert f.tell() == 7
+ f.seek(2)
+ for i in range(3, 20):
+ f.write('X')
+ assert f.tell() == i
+ assert f.getvalue() == '01XXXXXXXXXXXXXXXXX'
+
+def test_read():
+ f = RStringIO()
+ assert f.read() == ''
+ f.write('0123')
+ f.write('456')
+ assert f.read() == ''
+ assert f.read(5) == ''
+ assert f.tell() == 7
+ f.seek(1)
+ assert f.read() == '123456'
+ assert f.tell() == 7
+ f.seek(1)
+ assert f.read(12) == '123456'
+ assert f.tell() == 7
+ f.seek(1)
+ assert f.read(2) == '12'
+ assert f.read(1) == '3'
+ assert f.tell() == 4
+ f.seek(0)
+ assert f.read() == '0123456'
+ assert f.tell() == 7
+ f.seek(0)
+ assert f.read(7) == '0123456'
+ assert f.tell() == 7
+ f.seek(15)
+ assert f.read(2) == ''
+ assert f.tell() == 15
+
+def test_stress():
+ import cStringIO, random
+ f = RStringIO()
+ expected = cStringIO.StringIO()
+ for i in range(2000):
+ r = random.random()
+ if r < 0.15:
+ p = random.randrange(-5000, 10000)
+ if r < 0.05:
+ mode = 0
+ elif r < 0.1:
+ mode = 1
+ else:
+ mode = 2
+ print 'seek', p, mode
+ f.seek(p, mode)
+ expected.seek(p, mode)
+ elif r < 0.6:
+ buf = str(random.random())
+ print 'write %d bytes' % len(buf)
+ f.write(buf)
+ expected.write(buf)
+ elif r < 0.92:
+ n = random.randrange(0, 100)
+ print 'read %d bytes' % n
+ data1 = f.read(n)
+ data2 = expected.read(n)
+ assert data1 == data2
+ elif r < 0.97:
+ print 'check tell()'
+ assert f.tell() == expected.tell()
+ else:
+ print 'check getvalue()'
+ assert f.getvalue() == expected.getvalue()
+ assert f.getvalue() == expected.getvalue()
+ assert f.tell() == expected.tell()
More information about the Pypy-commit
mailing list