[py-svn] py-virtualenv commit 31b6c51fab40: Add a simple (hopefully) cross-python marshaller
commits-noreply at bitbucket.org
commits-noreply at bitbucket.org
Mon Sep 28 22:47:03 CEST 2009
# HG changeset patch -- Bitbucket.org
# Project py-virtualenv
# URL http://bitbucket.org/RonnyPfannschmidt/py-virtualenv/overview/
# User Benjamin Peterson <benjamin at python.org>
# Date 1253671720 18000
# Node ID 31b6c51fab40ba0798c2726d799d3a3f631a2ef0
# Parent 3369fee9b19e8f1ccf2d5daf60ce046280698bd0
Add a simple (hopefully) cross-python marshaller
Will rewrite the tests soon...
--- /dev/null
+++ b/testing/execnet/test_serializer.py
@@ -0,0 +1,127 @@
+# -*- coding: utf-8 -*-
+import shutil
+import py
+from py.__.execnet import serializer
+
+def setup_module(mod):
+ mod._save_python3 = serializer._INPY3
+
+def teardown_module(mod):
+ serializer._setup_version_dependent_constants()
+
+def _dump(obj):
+ stream = py.io.BytesIO()
+ saver = serializer.Serializer(stream)
+ saver.save(obj)
+ return stream.getvalue()
+
+def _load(serialized, str_coerion):
+ stream = py.io.BytesIO(serialized)
+ opts = serializer.UnserializationOptions(str_coerion)
+ unserializer = serializer.Unserializer(stream, opts)
+ return unserializer.load()
+
+def _run_in_version(is_py3, func, *args):
+ serializer._INPY3 = is_py3
+ serializer._setup_version_dependent_constants()
+ try:
+ return func(*args)
+ finally:
+ serializer._INPY3 = _save_python3
+
+def dump_py2(obj):
+ return _run_in_version(False, _dump, obj)
+
+def dump_py3(obj):
+ return _run_in_version(True, _dump, obj)
+
+def load_py2(serialized, str_coercion=False):
+ return _run_in_version(False, _load, serialized, str_coercion)
+
+def load_py3(serialized, str_coercion=False):
+ return _run_in_version(True, _load, serialized, str_coercion)
+
+try:
+ bytes
+except NameError:
+ bytes = str
+
+
+def pytest_funcarg__py2(request):
+ return _py2_wrapper
+
+def pytest_funcarg__py3(request):
+ return _py3_wrapper
+
+class TestSerializer:
+
+ def test_int(self):
+ for dump in dump_py2, dump_py3:
+ p = dump_py2(4)
+ for load in load_py2, load_py3:
+ i = load(p)
+ assert isinstance(i, int)
+ assert i == 4
+ py.test.raises(serializer.SerializationError, dump, 123456678900)
+
+ def test_bytes(self):
+ for dump in dump_py2, dump_py3:
+ p = dump(serializer._b('hi'))
+ for load in load_py2, load_py3:
+ s = load(p)
+ assert isinstance(s, serializer.bytes)
+ assert s == serializer._b('hi')
+
+ def check_sequence(self, seq):
+ for dump in dump_py2, dump_py3:
+ p = dump(seq)
+ for load in load_py2, load_py3:
+ l = load(p)
+ assert l == seq
+
+ def test_list(self):
+ self.check_sequence([1, 2, 3])
+
+ @py.test.mark.xfail
+ # I'm not sure if we need the complexity.
+ def test_recursive_list(self):
+ l = [1, 2, 3]
+ l.append(l)
+ self.check_sequence(l)
+
+ def test_tuple(self):
+ self.check_sequence((1, 2, 3))
+
+ def test_dict(self):
+ for dump in dump_py2, dump_py3:
+ p = dump({"hi" : 2, (1, 2, 3) : 32})
+ for load in load_py2, load_py3:
+ d = load(p, True)
+ assert d == {"hi" : 2, (1, 2, 3) : 32}
+
+ def test_string(self):
+ py.test.skip("will rewrite")
+ p = dump_py2("xyz")
+ s = load_py2(p)
+ assert isinstance(s, str)
+ assert s == "xyz"
+ s = load_py3(p)
+ assert isinstance(s, bytes)
+ assert s == serializer.b("xyz")
+ p = dump_py2("xyz")
+ s = load_py3(p, True)
+ assert isinstance(s, serializer._unicode)
+ assert s == serializer.unicode("xyz")
+ p = dump_py3("xyz")
+ s = load_py2(p, True)
+ assert isinstance(s, str)
+ assert s == "xyz"
+
+ def test_unicode(self):
+ py.test.skip("will rewrite")
+ for dump, uni in (dump_py2, serializer._unicode), (dump_py3, str):
+ p = dump(uni("xyz"))
+ for load in load_py2, load_py3:
+ s = load(p)
+ assert isinstance(s, serializer._unicode)
+ assert s == serializer._unicode("xyz")
--- /dev/null
+++ b/py/execnet/serializer.py
@@ -0,0 +1,291 @@
+"""
+Simple marshal format (based on pickle) designed to work across Python versions.
+"""
+
+import sys
+import struct
+
+import py
+
+_INPY3 = _REALLY_PY3 = sys.version_info > (3, 0)
+
+class SerializeError(Exception):
+ pass
+
+class SerializationError(SerializeError):
+ """Error while serializing an object."""
+
+class UnserializableType(SerializationError):
+ """Can't serialize a type."""
+
+class UnserializationError(SerializeError):
+ """Error while unserializing an object."""
+
+class VersionMismatch(UnserializationError):
+ """Data from a previous or later format."""
+
+class Corruption(UnserializationError):
+ """The pickle format appears to have been corrupted."""
+
+if _INPY3:
+ def b(s):
+ return s.encode("ascii")
+ _b = b
+ class _unicode(str):
+ pass
+ bytes = bytes
+else:
+ class bytes(str):
+ pass
+ b = str
+ _b = bytes
+ _unicode = unicode
+
+FOUR_BYTE_INT_MAX = 2147483647
+
+_int4_format = struct.Struct("!i")
+
+# Protocol constants
+VERSION_NUMBER = 1
+VERSION = b(chr(VERSION_NUMBER))
+PY2STRING = b('s')
+PY3STRING = b('t')
+UNICODE = b('u')
+BYTES = b('b')
+NEWLIST = b('l')
+BUILDTUPLE = b('T')
+SETITEM = b('m')
+NEWDICT = b('d')
+INT = b('i')
+STOP = b('S')
+
+class CrossVersionOptions(object):
+ pass
+
+class Serializer(object):
+
+ def __init__(self, stream):
+ self.stream = stream
+
+ def save(self, obj):
+ self.stream.write(VERSION)
+ self._save(obj)
+ self.stream.write(STOP)
+
+ def _save(self, obj):
+ tp = type(obj)
+ try:
+ dispatch = self.dispatch[tp]
+ except KeyError:
+ raise UnserializableType("can't serialize %s" % (tp,))
+ dispatch(self, obj)
+
+ def save_bytes(self, bytes_):
+ self.stream.write(BYTES)
+ self._write_byte_sequence(bytes_)
+
+ def save_unicode(self, s):
+ self.stream.write(UNICODE)
+ self._write_unicode_string(s)
+
+ def save_string(self, s):
+ if _INPY3:
+ self.stream.write(PY3STRING)
+ self._write_unicode_string(s)
+ else:
+ # Case for tests
+ if _REALLY_PY3 and isinstance(s, str):
+ s = s.encode("latin-1")
+ self.stream.write(PY2STRING)
+ self._write_byte_sequence(s)
+
+ def _write_unicode_string(self, s):
+ try:
+ as_bytes = s.encode("utf-8")
+ except UnicodeEncodeError:
+ raise SerializationError("strings must be utf-8 encodable")
+ self._write_byte_sequence(as_bytes)
+
+ def _write_byte_sequence(self, bytes_):
+ self._write_int4(len(bytes_), "string is too long")
+ self.stream.write(bytes_)
+
+ def save_int(self, i):
+ self.stream.write(INT)
+ self._write_int4(i)
+
+ def _write_int4(self, i, error="int must be less than %i" %
+ (FOUR_BYTE_INT_MAX,)):
+ if i > FOUR_BYTE_INT_MAX:
+ raise SerializationError(error)
+ self.stream.write(_int4_format.pack(i))
+
+ def save_list(self, L):
+ self.stream.write(NEWLIST)
+ self._write_int4(len(L), "list is too long")
+ for i, item in enumerate(L):
+ self._write_setitem(i, item)
+
+ def _write_setitem(self, key, value):
+ self._save(key)
+ self._save(value)
+ self.stream.write(SETITEM)
+
+ def save_dict(self, d):
+ self.stream.write(NEWDICT)
+ for key, value in d.items():
+ self._write_setitem(key, value)
+
+ def save_tuple(self, tup):
+ for item in tup:
+ self._save(item)
+ self.stream.write(BUILDTUPLE)
+ self._write_int4(len(tup), "tuple is too long")
+
+
+class _UnserializationOptions(object):
+ pass
+
+class _Py2UnserializationOptions(_UnserializationOptions):
+
+ def __init__(self, py3_strings_as_str=False):
+ self.py3_strings_as_str = py3_strings_as_str
+
+class _Py3UnserializationOptions(_UnserializationOptions):
+
+ def __init__(self, py2_strings_as_str=False):
+ self.py2_strings_as_str = py2_strings_as_str
+
+
+_unchanging_dispatch = {}
+for tp in (dict, list, tuple, int):
+ name = "save_%s" % (tp.__name__,)
+ _unchanging_dispatch[tp] = getattr(Serializer, name)
+del tp, name
+
+def _setup_dispatch():
+ dispatch = _unchanging_dispatch.copy()
+ # This is sutble. bytes is aliased to str in 2.6, so
+ # dispatch[bytes] is overwritten. Additionally, we alias unicode
+ # to str in 3.x, so dispatch[unicode] is overwritten with
+ # save_string.
+ dispatch[bytes] = Serializer.save_bytes
+ dispatch[unicode] = Serializer.save_unicode
+ dispatch[str] = Serializer.save_string
+ Serializer.dispatch = dispatch
+
+def _setup_version_dependent_constants(leave_unicode_alone=False):
+ global unicode, UnserializationOptions
+ if _INPY3:
+ unicode = str
+ UnserializationOptions = _Py3UnserializationOptions
+ else:
+ UnserializationOptions = _Py2UnserializationOptions
+ unicode = _unicode
+ _setup_dispatch()
+_setup_version_dependent_constants()
+
+
+class _Stop(Exception):
+ pass
+
+class Unserializer(object):
+
+ def __init__(self, stream, options=None):
+ self.stream = stream
+ if options is None:
+ options = UnserializationOptions()
+ self.options = options
+
+ def load(self):
+ self.stack = []
+ version = ord(self.stream.read(1))
+ if version != VERSION_NUMBER:
+ raise VersionMismatch("%i != %i" % (version, VERSION_NUMBER))
+ try:
+ while True:
+ opcode = self.stream.read(1)
+ if not opcode:
+ raise EOFError
+ try:
+ loader = self.opcodes[opcode]
+ except KeyError:
+ raise Corruption("unkown opcode %s" % (opcode,))
+ loader(self)
+ except _Stop:
+ if len(self.stack) != 1:
+ raise UnserializationError("internal unserialization error")
+ return self.stack[0]
+ else:
+ raise Corruption("didn't get STOP")
+
+ opcodes = {}
+
+ def load_int(self):
+ i = self._read_int4()
+ self.stack.append(i)
+ opcodes[INT] = load_int
+
+ def _read_int4(self):
+ return _int4_format.unpack(self.stream.read(4))[0]
+
+ def _read_byte_string(self):
+ length = self._read_int4()
+ as_bytes = self.stream.read(length)
+ return as_bytes
+
+ def load_py3string(self):
+ as_bytes = self._read_byte_string()
+ if (not _INPY3 and self.options.py3_strings_as_str) and not _REALLY_PY3:
+ # XXX Should we try to decode into latin-1?
+ self.stack.append(as_bytes)
+ else:
+ self.stack.append(as_bytes.decode("utf-8"))
+ opcodes[PY3STRING] = load_py3string
+
+ def load_py2string(self):
+ as_bytes = self._read_byte_string()
+ if (_INPY3 and self.options.py2_strings_as_str) or \
+ (_REALLY_PY3 and not _INPY3):
+ s = as_bytes.decode("latin-1")
+ else:
+ s = as_bytes
+ self.stack.append(s)
+ opcodes[PY2STRING] = load_py2string
+
+ def load_bytes(self):
+ s = bytes(self._read_byte_string())
+ self.stack.append(s)
+ opcodes[BYTES] = load_bytes
+
+ def load_unicode(self):
+ self.stack.append(self._read_byte_string().decode("utf-8"))
+ opcodes[UNICODE] = load_unicode
+
+ def load_newlist(self):
+ length = self._read_int4()
+ self.stack.append([None] * length)
+ opcodes[NEWLIST] = load_newlist
+
+ def load_setitem(self):
+ if len(self.stack) < 3:
+ raise Corruption("not enough items for setitem")
+ value = self.stack.pop()
+ key = self.stack.pop()
+ self.stack[-1][key] = value
+ opcodes[SETITEM] = load_setitem
+
+ def load_newdict(self):
+ self.stack.append({})
+ opcodes[NEWDICT] = load_newdict
+
+ def load_buildtuple(self):
+ length = self._read_int4()
+ tup = tuple(self.stack[-length:])
+ del self.stack[-length:]
+ self.stack.append(tup)
+ opcodes[BUILDTUPLE] = load_buildtuple
+
+ def load_stop(self):
+ raise _Stop
+ opcodes[STOP] = load_stop
More information about the pytest-commit
mailing list