[pypy-svn] r53754 - in pypy/dist/pypy/rlib: . test

arigo at codespeak.net arigo at codespeak.net
Mon Apr 14 15:16:45 CEST 2008


Author: arigo
Date: Mon Apr 14 15:16:44 2008
New Revision: 53754

Added:
   pypy/dist/pypy/rlib/rStringIO.py   (contents, props changed)
   pypy/dist/pypy/rlib/test/test_rStringIO.py   (contents, props changed)
Log:
An RPython version of StringIO.  The "fast path" through this code is
for the case of a bunch of write() followed by getvalue().


Added: pypy/dist/pypy/rlib/rStringIO.py
==============================================================================
--- (empty file)
+++ pypy/dist/pypy/rlib/rStringIO.py	Mon Apr 14 15:16:44 2008
@@ -0,0 +1,164 @@
+
+PIECES = 80
+BIGPIECES = 32
+
+AT_END = -1
+
+
+class RStringIO(object):
+    """RPython-level StringIO object.
+    The fastest path through this code is for the case of a bunch of write()
+    followed by getvalue().  For at most PIECES write()s and one getvalue(),
+    there is one copy of the data done, as if ''.join() was used.
+    """
+    _mixin_ = True        # for interp_stringio.py
+
+    def __init__(self):
+        # The real content is the join of the following data:
+        #  * the list of characters self.bigbuffer;
+        #  * each of the strings in self.strings.
+        #
+        # Invariants:
+        #  * self.numbigstrings <= self.numstrings;
+        #  * all strings in self.strings[self.numstrings:PIECES] are empty.
+        #
+        self.strings = [''] * PIECES
+        self.numstrings = 0
+        self.numbigstrings = 0
+        self.bigbuffer = []
+        self.pos = AT_END
+
+    def getvalue(self):
+        """If self.strings contains more than 1 string, join all the
+        strings together.  Return the final single string."""
+        if len(self.bigbuffer) > 0:
+            self.copy_into_bigbuffer()
+            return ''.join(self.bigbuffer)
+        if self.numstrings > 1:
+            result = self.strings[0] = ''.join(self.strings)
+            self.numstrings = 1
+            self.numbigstrings = 1
+        else:
+            result = self.strings[0]
+        return result
+
+    def getsize(self):
+        result = len(self.bigbuffer)
+        for i in range(0, self.numstrings):
+            result += len(self.strings[i])
+        return result
+
+    def copy_into_bigbuffer(self):
+        """Copy all the data into the list of characters self.bigbuffer."""
+        for i in range(0, self.numstrings):
+            self.bigbuffer += self.strings[i]
+            self.strings[i] = ''
+        self.numstrings = 0
+        self.numbigstrings = 0
+
+    def reduce(self):
+        """Reduce the number of (non-empty) strings in self.strings."""
+        # When self.pos == AT_END, the calls to write(str) accumulate
+        # the strings in self.strings until all PIECES slots are filled.
+        # Then the reduce() method joins all the strings and put the
+        # result back into self.strings[0].  The next time all the slots
+        # are filled, we only join self.strings[1:] and put the result
+        # in self.strings[1]; and so on.  The purpose of this is that
+        # the string resulting from a join is expected to be big, so the
+        # next join operation should only join the newly added strings.
+        # When we have done this BIGPIECES times, the next join collects
+        # all strings again into self.strings[0] and we start from
+        # scratch.
+        limit = self.numbigstrings
+        self.strings[limit] = ''.join(self.strings[limit:])
+        for i in range(limit + 1, self.numstrings):
+            self.strings[i] = ''
+        self.numstrings = limit + 1
+        if limit < BIGPIECES:
+            self.numbigstrings = limit + 1
+        else:
+            self.numbigstrings = 0
+        assert self.numstrings <= BIGPIECES + 1
+        return self.numstrings
+
+    def write(self, buffer):
+        # Idea: for the common case of a sequence of write() followed
+        # by only getvalue(), self.bigbuffer remains empty.  It is only
+        # used to handle the more complicated cases.
+        p = self.pos
+        if p != AT_END:    # slow or semi-fast paths
+            endp = p + len(buffer)
+            if len(self.bigbuffer) >= endp:
+                # semi-fast path: the write is entirely inside self.bigbuffer
+                for i in range(len(buffer)):
+                    self.bigbuffer[p+i] = buffer[i]
+                self.pos = endp
+                return
+            else:
+                # slow path: collect all data into self.bigbuffer and
+                # handle the various cases
+                self.copy_into_bigbuffer()
+                fitting = len(self.bigbuffer) - p
+                if fitting > 0:
+                    # the write starts before the end of the data
+                    fitting = min(len(buffer), fitting)
+                    for i in range(fitting):
+                        self.bigbuffer[p+i] = buffer[i]
+                    if len(buffer) > fitting:
+                        # the write extends beyond the end of the data
+                        self.bigbuffer += buffer[fitting:]
+                        endp = AT_END
+                    self.pos = endp
+                    return
+                else:
+                    # the write starts at or beyond the end of the data
+                    self.bigbuffer += '\x00' * (-fitting)
+                    self.pos = AT_END      # fall-through to the fast path
+        # Fast path.
+        # See comments in reduce().
+        count = self.numstrings
+        if count == PIECES:
+            count = self.reduce()
+        self.strings[count] = buffer
+        self.numstrings = count + 1
+
+    def seek(self, position, mode=0):
+        if mode == 1:
+            if self.pos == AT_END:
+                self.pos = self.getsize()
+            position += self.pos
+        elif mode == 2:
+            if position == 0:
+                self.pos = AT_END
+                return
+            position += self.getsize()
+        if position < 0:
+            position = 0
+        self.pos = position
+
+    def tell(self):
+        if self.pos == AT_END:
+            return self.getsize()
+        else:
+            return self.pos
+
+    def read(self, n=-1):
+        p = self.pos
+        if p == 0 and n < 0:
+            self.pos = AT_END
+            return self.getvalue()     # reading everything
+        if p == AT_END:
+            return ''
+        self.copy_into_bigbuffer()
+        mysize = len(self.bigbuffer)
+        count = mysize - p
+        if n >= 0:
+            count = min(n, count)
+        if count <= 0:
+            return ''
+        if p == 0 and count == mysize:
+            self.pos = AT_END
+            return ''.join(self.bigbuffer)
+        else:
+            self.pos = p + count
+            return ''.join(self.bigbuffer[p:p+count])

Added: pypy/dist/pypy/rlib/test/test_rStringIO.py
==============================================================================
--- (empty file)
+++ pypy/dist/pypy/rlib/test/test_rStringIO.py	Mon Apr 14 15:16:44 2008
@@ -0,0 +1,115 @@
+from pypy.rlib.rStringIO import RStringIO
+
+
+def test_simple():
+    f = RStringIO()
+    f.write('hello')
+    f.write(' world')
+    assert f.getvalue() == 'hello world'
+
+def test_write_many():
+    f = RStringIO()
+    for j in range(10):
+        for i in range(253):
+            f.write(chr(i))
+    expected = ''.join([chr(i) for j in range(10) for i in range(253)])
+    assert f.getvalue() == expected
+
+def test_seek():
+    f = RStringIO()
+    f.write('0123')
+    f.write('456')
+    f.write('789')
+    f.seek(4)
+    f.write('AB')
+    assert f.getvalue() == '0123AB6789'
+    f.seek(-2, 2)
+    f.write('CDE')
+    assert f.getvalue() == '0123AB67CDE'
+    f.seek(2, 0)
+    f.seek(5, 1)
+    f.write('F')
+    assert f.getvalue() == '0123AB6FCDE'
+
+def test_write_beyond_end():
+    f = RStringIO()
+    f.seek(20, 1)
+    assert f.tell() == 20
+    f.write('X')
+    assert f.getvalue() == '\x00' * 20 + 'X'
+
+def test_tell():
+    f = RStringIO()
+    f.write('0123')
+    f.write('456')
+    assert f.tell() == 7
+    f.seek(2)
+    for i in range(3, 20):
+        f.write('X')
+        assert f.tell() == i
+    assert f.getvalue() == '01XXXXXXXXXXXXXXXXX'
+
+def test_read():
+    f = RStringIO()
+    assert f.read() == ''
+    f.write('0123')
+    f.write('456')
+    assert f.read() == ''
+    assert f.read(5) == ''
+    assert f.tell() == 7
+    f.seek(1)
+    assert f.read() == '123456'
+    assert f.tell() == 7
+    f.seek(1)
+    assert f.read(12) == '123456'
+    assert f.tell() == 7
+    f.seek(1)
+    assert f.read(2) == '12'
+    assert f.read(1) == '3'
+    assert f.tell() == 4
+    f.seek(0)
+    assert f.read() == '0123456'
+    assert f.tell() == 7
+    f.seek(0)
+    assert f.read(7) == '0123456'
+    assert f.tell() == 7
+    f.seek(15)
+    assert f.read(2) == ''
+    assert f.tell() == 15
+
+def test_stress():
+    import cStringIO, random
+    f = RStringIO()
+    expected = cStringIO.StringIO()
+    for i in range(2000):
+        r = random.random()
+        if r < 0.15:
+            p = random.randrange(-5000, 10000)
+            if r < 0.05:
+                mode = 0
+            elif r < 0.1:
+                mode = 1
+            else:
+                mode = 2
+            print 'seek', p, mode
+            f.seek(p, mode)
+            expected.seek(p, mode)
+        elif r < 0.6:
+            buf = str(random.random())
+            print 'write %d bytes' % len(buf)
+            f.write(buf)
+            expected.write(buf)
+        elif r < 0.92:
+            n = random.randrange(0, 100)
+            print 'read %d bytes' % n
+            data1 = f.read(n)
+            data2 = expected.read(n)
+            assert data1 == data2
+        elif r < 0.97:
+            print 'check tell()'
+            assert f.tell() == expected.tell()
+        else:
+            print 'check getvalue()'
+            assert f.getvalue() == expected.getvalue()
+    assert f.getvalue() == expected.getvalue()
+    assert f.tell() == expected.tell()



More information about the Pypy-commit mailing list