[py-svn] py-virtualenv commit 31b6c51fab40: Add a simple (hopefully) cross-python marshaller

commits-noreply at bitbucket.org commits-noreply at bitbucket.org
Mon Sep 28 22:47:03 CEST 2009


# HG changeset patch -- Bitbucket.org
# Project py-virtualenv
# URL http://bitbucket.org/RonnyPfannschmidt/py-virtualenv/overview/
# User Benjamin Peterson <benjamin at python.org>
# Date 1253671720 18000
# Node ID 31b6c51fab40ba0798c2726d799d3a3f631a2ef0
# Parent 3369fee9b19e8f1ccf2d5daf60ce046280698bd0
Add a simple (hopefully) cross-python marshaller

Will rewrite the tests soon...

--- /dev/null
+++ b/testing/execnet/test_serializer.py
@@ -0,0 +1,127 @@
+# -*- coding: utf-8 -*-
+import shutil
+import py
+from py.__.execnet import serializer
+
+def setup_module(mod):
+    mod._save_python3 = serializer._INPY3
+
+def teardown_module(mod):
+    serializer._setup_version_dependent_constants()
+
+def _dump(obj):
+    stream = py.io.BytesIO()
+    saver = serializer.Serializer(stream)
+    saver.save(obj)
+    return stream.getvalue()
+
+def _load(serialized, str_coerion):
+    stream = py.io.BytesIO(serialized)
+    opts = serializer.UnserializationOptions(str_coerion)
+    unserializer = serializer.Unserializer(stream, opts)
+    return unserializer.load()
+
+def _run_in_version(is_py3, func, *args):
+    serializer._INPY3 = is_py3
+    serializer._setup_version_dependent_constants()
+    try:
+        return func(*args)
+    finally:
+        serializer._INPY3 = _save_python3
+
+def dump_py2(obj):
+    return _run_in_version(False, _dump, obj)
+
+def dump_py3(obj):
+    return _run_in_version(True, _dump, obj)
+
+def load_py2(serialized, str_coercion=False):
+    return _run_in_version(False, _load, serialized, str_coercion)
+
+def load_py3(serialized, str_coercion=False):
+    return _run_in_version(True, _load, serialized, str_coercion)
+
+try:
+    bytes
+except NameError:
+    bytes = str
+
+
+def pytest_funcarg__py2(request):
+    return _py2_wrapper
+
+def pytest_funcarg__py3(request):
+    return _py3_wrapper
+
+class TestSerializer:
+
+    def test_int(self):
+        for dump in dump_py2, dump_py3:
+            p = dump_py2(4)
+            for load in load_py2, load_py3:
+                i = load(p)
+                assert isinstance(i, int)
+                assert i == 4
+            py.test.raises(serializer.SerializationError, dump, 123456678900)
+
+    def test_bytes(self):
+        for dump in dump_py2, dump_py3:
+            p = dump(serializer._b('hi'))
+            for load in load_py2, load_py3:
+                s = load(p)
+                assert isinstance(s, serializer.bytes)
+                assert s == serializer._b('hi')
+
+    def check_sequence(self, seq):
+        for dump in dump_py2, dump_py3:
+            p = dump(seq)
+            for load in load_py2, load_py3:
+                l = load(p)
+                assert l == seq
+
+    def test_list(self):
+        self.check_sequence([1, 2, 3])
+
+    @py.test.mark.xfail
+    # I'm not sure if we need the complexity.
+    def test_recursive_list(self):
+        l = [1, 2, 3]
+        l.append(l)
+        self.check_sequence(l)
+
+    def test_tuple(self):
+        self.check_sequence((1, 2, 3))
+
+    def test_dict(self):
+        for dump in dump_py2, dump_py3:
+            p = dump({"hi" : 2, (1, 2, 3) : 32})
+            for load in load_py2, load_py3:
+                d = load(p, True)
+                assert d == {"hi" : 2, (1, 2, 3) : 32}
+
+    def test_string(self):
+        py.test.skip("will rewrite")
+        p = dump_py2("xyz")
+        s = load_py2(p)
+        assert isinstance(s, str)
+        assert s == "xyz"
+        s = load_py3(p)
+        assert isinstance(s, bytes)
+        assert s == serializer.b("xyz")
+        p = dump_py2("xyz")
+        s = load_py3(p, True)
+        assert isinstance(s, serializer._unicode)
+        assert s == serializer.unicode("xyz")
+        p = dump_py3("xyz")
+        s = load_py2(p, True)
+        assert isinstance(s, str)
+        assert s == "xyz"
+
+    def test_unicode(self):
+        py.test.skip("will rewrite")
+        for dump, uni in (dump_py2, serializer._unicode), (dump_py3, str):
+            p = dump(uni("xyz"))
+            for load in load_py2, load_py3:
+                s = load(p)
+                assert isinstance(s, serializer._unicode)
+                assert s == serializer._unicode("xyz")

--- /dev/null
+++ b/py/execnet/serializer.py
@@ -0,0 +1,291 @@
+"""
+Simple marshal format (based on pickle) designed to work across Python versions.
+"""
+
+import sys
+import struct
+
+import py
+
+_INPY3 = _REALLY_PY3 = sys.version_info > (3, 0)
+
+class SerializeError(Exception):
+    pass
+
+class SerializationError(SerializeError):
+    """Error while serializing an object."""
+
+class UnserializableType(SerializationError):
+    """Can't serialize a type."""
+
+class UnserializationError(SerializeError):
+    """Error while unserializing an object."""
+
+class VersionMismatch(UnserializationError):
+    """Data from a previous or later format."""
+
+class Corruption(UnserializationError):
+    """The pickle format appears to have been corrupted."""
+
+if _INPY3:
+    def b(s):
+        return s.encode("ascii")
+    _b = b
+    class _unicode(str):
+        pass
+    bytes = bytes
+else:
+    class bytes(str):
+        pass
+    b = str
+    _b = bytes
+    _unicode = unicode
+
+FOUR_BYTE_INT_MAX = 2147483647
+
+_int4_format = struct.Struct("!i")
+
+# Protocol constants
+VERSION_NUMBER = 1
+VERSION = b(chr(VERSION_NUMBER))
+PY2STRING = b('s')
+PY3STRING = b('t')
+UNICODE = b('u')
+BYTES = b('b')
+NEWLIST = b('l')
+BUILDTUPLE = b('T')
+SETITEM = b('m')
+NEWDICT = b('d')
+INT = b('i')
+STOP = b('S')
+
+class CrossVersionOptions(object):
+    pass
+
+class Serializer(object):
+
+    def __init__(self, stream):
+        self.stream = stream
+
+    def save(self, obj):
+        self.stream.write(VERSION)
+        self._save(obj)
+        self.stream.write(STOP)
+
+    def _save(self, obj):
+        tp = type(obj)
+        try:
+            dispatch = self.dispatch[tp]
+        except KeyError:
+            raise UnserializableType("can't serialize %s" % (tp,))
+        dispatch(self, obj)
+
+    def save_bytes(self, bytes_):
+        self.stream.write(BYTES)
+        self._write_byte_sequence(bytes_)
+
+    def save_unicode(self, s):
+        self.stream.write(UNICODE)
+        self._write_unicode_string(s)
+
+    def save_string(self, s):
+        if _INPY3:
+            self.stream.write(PY3STRING)
+            self._write_unicode_string(s)
+        else:
+            # Case for tests
+            if _REALLY_PY3 and isinstance(s, str):
+                s = s.encode("latin-1")
+            self.stream.write(PY2STRING)
+            self._write_byte_sequence(s)
+
+    def _write_unicode_string(self, s):
+        try:
+            as_bytes = s.encode("utf-8")
+        except UnicodeEncodeError:
+            raise SerializationError("strings must be utf-8 encodable")
+        self._write_byte_sequence(as_bytes)
+
+    def _write_byte_sequence(self, bytes_):
+        self._write_int4(len(bytes_), "string is too long")
+        self.stream.write(bytes_)
+
+    def save_int(self, i):
+        self.stream.write(INT)
+        self._write_int4(i)
+
+    def _write_int4(self, i, error="int must be less than %i" %
+                    (FOUR_BYTE_INT_MAX,)):
+        if i > FOUR_BYTE_INT_MAX:
+            raise SerializationError(error)
+        self.stream.write(_int4_format.pack(i))
+
+    def save_list(self, L):
+        self.stream.write(NEWLIST)
+        self._write_int4(len(L), "list is too long")
+        for i, item in enumerate(L):
+            self._write_setitem(i, item)
+
+    def _write_setitem(self, key, value):
+        self._save(key)
+        self._save(value)
+        self.stream.write(SETITEM)
+
+    def save_dict(self, d):
+        self.stream.write(NEWDICT)
+        for key, value in d.items():
+            self._write_setitem(key, value)
+
+    def save_tuple(self, tup):
+        for item in tup:
+            self._save(item)
+        self.stream.write(BUILDTUPLE)
+        self._write_int4(len(tup), "tuple is too long")
+
+
+class _UnserializationOptions(object):
+    pass
+
+class _Py2UnserializationOptions(_UnserializationOptions):
+
+    def __init__(self, py3_strings_as_str=False):
+        self.py3_strings_as_str = py3_strings_as_str
+
+class _Py3UnserializationOptions(_UnserializationOptions):
+
+    def __init__(self, py2_strings_as_str=False):
+        self.py2_strings_as_str = py2_strings_as_str
+
+
+_unchanging_dispatch = {}
+for tp in (dict, list, tuple, int):
+    name = "save_%s" % (tp.__name__,)
+    _unchanging_dispatch[tp] = getattr(Serializer, name)
+del tp, name
+
+def _setup_dispatch():
+    dispatch = _unchanging_dispatch.copy()
+    # This is sutble.  bytes is aliased to str in 2.6, so
+    # dispatch[bytes] is overwritten.  Additionally, we alias unicode
+    # to str in 3.x, so dispatch[unicode] is overwritten with
+    # save_string.
+    dispatch[bytes] = Serializer.save_bytes
+    dispatch[unicode] = Serializer.save_unicode
+    dispatch[str] = Serializer.save_string
+    Serializer.dispatch = dispatch
+
+def _setup_version_dependent_constants(leave_unicode_alone=False):
+    global unicode, UnserializationOptions
+    if _INPY3:
+        unicode = str
+        UnserializationOptions = _Py3UnserializationOptions
+    else:
+        UnserializationOptions = _Py2UnserializationOptions
+        unicode = _unicode
+    _setup_dispatch()
+_setup_version_dependent_constants()
+
+
+class _Stop(Exception):
+    pass
+
+class Unserializer(object):
+
+    def __init__(self, stream, options=None):
+        self.stream = stream
+        if options is None:
+            options = UnserializationOptions()
+        self.options = options
+
+    def load(self):
+        self.stack = []
+        version = ord(self.stream.read(1))
+        if version != VERSION_NUMBER:
+            raise VersionMismatch("%i != %i" % (version, VERSION_NUMBER))
+        try:
+            while True:
+                opcode = self.stream.read(1)
+                if not opcode:
+                    raise EOFError
+                try:
+                    loader = self.opcodes[opcode]
+                except KeyError:
+                    raise Corruption("unkown opcode %s" % (opcode,))
+                loader(self)
+        except _Stop:
+            if len(self.stack) != 1:
+                raise UnserializationError("internal unserialization error")
+            return self.stack[0]
+        else:
+            raise Corruption("didn't get STOP")
+
+    opcodes = {}
+
+    def load_int(self):
+        i = self._read_int4()
+        self.stack.append(i)
+    opcodes[INT] = load_int
+
+    def _read_int4(self):
+        return _int4_format.unpack(self.stream.read(4))[0]
+
+    def _read_byte_string(self):
+        length = self._read_int4()
+        as_bytes = self.stream.read(length)
+        return as_bytes
+
+    def load_py3string(self):
+        as_bytes = self._read_byte_string()
+        if (not _INPY3 and self.options.py3_strings_as_str) and not _REALLY_PY3:
+            # XXX Should we try to decode into latin-1?
+            self.stack.append(as_bytes)
+        else:
+            self.stack.append(as_bytes.decode("utf-8"))
+    opcodes[PY3STRING] = load_py3string
+
+    def load_py2string(self):
+        as_bytes = self._read_byte_string()
+        if (_INPY3 and self.options.py2_strings_as_str) or \
+               (_REALLY_PY3 and not _INPY3):
+            s = as_bytes.decode("latin-1")
+        else:
+            s = as_bytes
+        self.stack.append(s)
+    opcodes[PY2STRING] = load_py2string
+
+    def load_bytes(self):
+        s = bytes(self._read_byte_string())
+        self.stack.append(s)
+    opcodes[BYTES] = load_bytes
+
+    def load_unicode(self):
+        self.stack.append(self._read_byte_string().decode("utf-8"))
+    opcodes[UNICODE] = load_unicode
+
+    def load_newlist(self):
+        length = self._read_int4()
+        self.stack.append([None] * length)
+    opcodes[NEWLIST] = load_newlist
+
+    def load_setitem(self):
+        if len(self.stack) < 3:
+            raise Corruption("not enough items for setitem")
+        value = self.stack.pop()
+        key = self.stack.pop()
+        self.stack[-1][key] = value
+    opcodes[SETITEM] = load_setitem
+
+    def load_newdict(self):
+        self.stack.append({})
+    opcodes[NEWDICT] = load_newdict
+
+    def load_buildtuple(self):
+        length = self._read_int4()
+        tup = tuple(self.stack[-length:])
+        del self.stack[-length:]
+        self.stack.append(tup)
+    opcodes[BUILDTUPLE] = load_buildtuple
+
+    def load_stop(self):
+        raise _Stop
+    opcodes[STOP] = load_stop



More information about the pytest-commit mailing list