[Python-checkins] cpython: Closes #16551. Cleanup pickle.py.

serhiy.storchaka python-checkins at python.org
Sun Apr 14 12:38:07 CEST 2013


http://hg.python.org/cpython/rev/3dff836cedef
changeset:   83359:3dff836cedef
user:        Serhiy Storchaka <storchaka at gmail.com>
date:        Sun Apr 14 13:37:02 2013 +0300
summary:
  Closes #16551. Cleanup pickle.py.

files:
  Lib/pickle.py |  225 ++++++++++++++-----------------------
  1 files changed, 86 insertions(+), 139 deletions(-)


diff --git a/Lib/pickle.py b/Lib/pickle.py
--- a/Lib/pickle.py
+++ b/Lib/pickle.py
@@ -26,9 +26,10 @@
 from types import FunctionType, BuiltinFunctionType
 from copyreg import dispatch_table
 from copyreg import _extension_registry, _inverted_registry, _extension_cache
-import marshal
+from itertools import islice
 import sys
-import struct
+from sys import maxsize
+from struct import pack, unpack
 import re
 import io
 import codecs
@@ -58,11 +59,6 @@
 # there are too many issues with that.
 DEFAULT_PROTOCOL = 3
 
-# Why use struct.pack() for pickling but marshal.loads() for
-# unpickling?  struct.pack() is 40% faster than marshal.dumps(), but
-# marshal.loads() is twice as fast as struct.unpack()!
-mloads = marshal.loads
-
 class PickleError(Exception):
     """A common base class for the other pickling exceptions."""
     pass
@@ -231,7 +227,7 @@
             raise PicklingError("Pickler.__init__() was not called by "
                                 "%s.__init__()" % (self.__class__.__name__,))
         if self.proto >= 2:
-            self.write(PROTO + bytes([self.proto]))
+            self.write(PROTO + pack("<B", self.proto))
         self.save(obj)
         self.write(STOP)
 
@@ -258,20 +254,20 @@
         self.memo[id(obj)] = memo_len, obj
 
     # Return a PUT (BINPUT, LONG_BINPUT) opcode string, with argument i.
-    def put(self, i, pack=struct.pack):
+    def put(self, i):
         if self.bin:
             if i < 256:
-                return BINPUT + bytes([i])
+                return BINPUT + pack("<B", i)
             else:
                 return LONG_BINPUT + pack("<I", i)
 
         return PUT + repr(i).encode("ascii") + b'\n'
 
     # Return a GET (BINGET, LONG_BINGET) opcode string, with argument i.
-    def get(self, i, pack=struct.pack):
+    def get(self, i):
         if self.bin:
             if i < 256:
-                return BINGET + bytes([i])
+                return BINGET + pack("<B", i)
             else:
                 return LONG_BINGET + pack("<I", i)
 
@@ -286,20 +282,20 @@
 
         # Check the memo
         x = self.memo.get(id(obj))
-        if x:
+        if x is not None:
             self.write(self.get(x[0]))
             return
 
         # Check the type dispatch table
         t = type(obj)
         f = self.dispatch.get(t)
-        if f:
+        if f is not None:
             f(self, obj) # Call unbound method with explicit self
             return
 
         # Check private dispatch table if any, or else copyreg.dispatch_table
         reduce = getattr(self, 'dispatch_table', dispatch_table).get(t)
-        if reduce:
+        if reduce is not None:
             rv = reduce(obj)
         else:
             # Check for a class with a custom metaclass; treat as regular class
@@ -313,11 +309,11 @@
 
             # Check for a __reduce_ex__ method, fall back to __reduce__
             reduce = getattr(obj, "__reduce_ex__", None)
-            if reduce:
+            if reduce is not None:
                 rv = reduce(self.proto)
             else:
                 reduce = getattr(obj, "__reduce__", None)
-                if reduce:
+                if reduce is not None:
                     rv = reduce()
                 else:
                     raise PicklingError("Can't pickle %r object: %r" %
@@ -448,12 +444,12 @@
 
     def save_bool(self, obj):
         if self.proto >= 2:
-            self.write(obj and NEWTRUE or NEWFALSE)
+            self.write(NEWTRUE if obj else NEWFALSE)
         else:
-            self.write(obj and TRUE or FALSE)
+            self.write(TRUE if obj else FALSE)
     dispatch[bool] = save_bool
 
-    def save_long(self, obj, pack=struct.pack):
+    def save_long(self, obj):
         if self.bin:
             # If the int is small enough to fit in a signed 4-byte 2's-comp
             # format, we can store it more efficiently than the general
@@ -461,39 +457,36 @@
             # First one- and two-byte unsigned ints:
             if obj >= 0:
                 if obj <= 0xff:
-                    self.write(BININT1 + bytes([obj]))
+                    self.write(BININT1 + pack("<B", obj))
                     return
                 if obj <= 0xffff:
-                    self.write(BININT2 + bytes([obj&0xff, obj>>8]))
+                    self.write(BININT2 + pack("<H", obj))
                     return
             # Next check for 4-byte signed ints:
-            high_bits = obj >> 31  # note that Python shift sign-extends
-            if high_bits == 0 or high_bits == -1:
-                # All high bits are copies of bit 2**31, so the value
-                # fits in a 4-byte signed int.
+            if -0x80000000 <= obj <= 0x7fffffff:
                 self.write(BININT + pack("<i", obj))
                 return
         if self.proto >= 2:
             encoded = encode_long(obj)
             n = len(encoded)
             if n < 256:
-                self.write(LONG1 + bytes([n]) + encoded)
+                self.write(LONG1 + pack("<B", n) + encoded)
             else:
                 self.write(LONG4 + pack("<i", n) + encoded)
             return
         self.write(LONG + repr(obj).encode("ascii") + b'L\n')
     dispatch[int] = save_long
 
-    def save_float(self, obj, pack=struct.pack):
+    def save_float(self, obj):
         if self.bin:
             self.write(BINFLOAT + pack('>d', obj))
         else:
             self.write(FLOAT + repr(obj).encode("ascii") + b'\n')
     dispatch[float] = save_float
 
-    def save_bytes(self, obj, pack=struct.pack):
+    def save_bytes(self, obj):
         if self.proto < 3:
-            if len(obj) == 0:
+            if not obj: # bytes object is empty
                 self.save_reduce(bytes, (), obj=obj)
             else:
                 self.save_reduce(codecs.encode,
@@ -501,13 +494,13 @@
             return
         n = len(obj)
         if n < 256:
-            self.write(SHORT_BINBYTES + bytes([n]) + bytes(obj))
+            self.write(SHORT_BINBYTES + pack("<B", n) + obj)
         else:
-            self.write(BINBYTES + pack("<I", n) + bytes(obj))
+            self.write(BINBYTES + pack("<I", n) + obj)
         self.memoize(obj)
     dispatch[bytes] = save_bytes
 
-    def save_str(self, obj, pack=struct.pack):
+    def save_str(self, obj):
         if self.bin:
             encoded = obj.encode('utf-8', 'surrogatepass')
             n = len(encoded)
@@ -515,39 +508,36 @@
         else:
             obj = obj.replace("\\", "\\u005c")
             obj = obj.replace("\n", "\\u000a")
-            self.write(UNICODE + bytes(obj.encode('raw-unicode-escape')) +
-                       b'\n')
+            self.write(UNICODE + obj.encode('raw-unicode-escape') + b'\n')
         self.memoize(obj)
     dispatch[str] = save_str
 
     def save_tuple(self, obj):
-        write = self.write
-        proto = self.proto
+        if not obj: # tuple is empty
+            if self.bin:
+                self.write(EMPTY_TUPLE)
+            else:
+                self.write(MARK + TUPLE)
+            return
 
         n = len(obj)
-        if n == 0:
-            if proto:
-                write(EMPTY_TUPLE)
-            else:
-                write(MARK + TUPLE)
-            return
-
         save = self.save
         memo = self.memo
-        if n <= 3 and proto >= 2:
+        if n <= 3 and self.proto >= 2:
             for element in obj:
                 save(element)
             # Subtle.  Same as in the big comment below.
             if id(obj) in memo:
                 get = self.get(memo[id(obj)][0])
-                write(POP * n + get)
+                self.write(POP * n + get)
             else:
-                write(_tuplesize2code[n])
+                self.write(_tuplesize2code[n])
                 self.memoize(obj)
             return
 
         # proto 0 or proto 1 and tuple isn't empty, or proto > 1 and tuple
         # has more than 3 elements.
+        write = self.write
         write(MARK)
         for element in obj:
             save(element)
@@ -561,25 +551,23 @@
             # could have been done in the "for element" loop instead, but
             # recursive tuples are a rare thing.
             get = self.get(memo[id(obj)][0])
-            if proto:
+            if self.bin:
                 write(POP_MARK + get)
             else:   # proto 0 -- POP_MARK not available
                 write(POP * (n+1) + get)
             return
 
         # No recursion.
-        self.write(TUPLE)
+        write(TUPLE)
         self.memoize(obj)
 
     dispatch[tuple] = save_tuple
 
     def save_list(self, obj):
-        write = self.write
-
         if self.bin:
-            write(EMPTY_LIST)
+            self.write(EMPTY_LIST)
         else:   # proto 0 -- can't use EMPTY_LIST
-            write(MARK + LIST)
+            self.write(MARK + LIST)
 
         self.memoize(obj)
         self._batch_appends(obj)
@@ -599,17 +587,9 @@
                 write(APPEND)
             return
 
-        items = iter(items)
-        r = range(self._BATCHSIZE)
-        while items is not None:
-            tmp = []
-            for i in r:
-                try:
-                    x = next(items)
-                    tmp.append(x)
-                except StopIteration:
-                    items = None
-                    break
+        it = iter(items)
+        while True:
+            tmp = list(islice(it, self._BATCHSIZE))
             n = len(tmp)
             if n > 1:
                 write(MARK)
@@ -620,14 +600,14 @@
                 save(tmp[0])
                 write(APPEND)
             # else tmp is empty, and we're done
+            if n < self._BATCHSIZE:
+                return
 
     def save_dict(self, obj):
-        write = self.write
-
         if self.bin:
-            write(EMPTY_DICT)
+            self.write(EMPTY_DICT)
         else:   # proto 0 -- can't use EMPTY_DICT
-            write(MARK + DICT)
+            self.write(MARK + DICT)
 
         self.memoize(obj)
         self._batch_setitems(obj.items())
@@ -648,16 +628,9 @@
                 write(SETITEM)
             return
 
-        items = iter(items)
-        r = range(self._BATCHSIZE)
-        while items is not None:
-            tmp = []
-            for i in r:
-                try:
-                    tmp.append(next(items))
-                except StopIteration:
-                    items = None
-                    break
+        it = iter(items)
+        while True:
+            tmp = list(islice(it, self._BATCHSIZE))
             n = len(tmp)
             if n > 1:
                 write(MARK)
@@ -671,8 +644,10 @@
                 save(v)
                 write(SETITEM)
             # else tmp is empty, and we're done
+            if n < self._BATCHSIZE:
+                return
 
-    def save_global(self, obj, name=None, pack=struct.pack):
+    def save_global(self, obj, name=None):
         write = self.write
         memo = self.memo
 
@@ -702,9 +677,9 @@
             if code:
                 assert code > 0
                 if code <= 0xff:
-                    write(EXT1 + bytes([code]))
+                    write(EXT1 + pack("<B", code))
                 elif code <= 0xffff:
-                    write(EXT2 + bytes([code&0xff, code>>8]))
+                    write(EXT2 + pack("<H", code))
                 else:
                     write(EXT4 + pack("<i", code))
                 return
@@ -732,25 +707,6 @@
     dispatch[BuiltinFunctionType] = save_global
     dispatch[type] = save_global
 
-# Pickling helpers
-
-def _keep_alive(x, memo):
-    """Keeps a reference to the object x in the memo.
-
-    Because we remember objects by their id, we have
-    to assure that possibly temporary objects are kept
-    alive by referencing them.
-    We store a reference at the id of the memo, which should
-    normally not be used unless someone tries to deepcopy
-    the memo itself...
-    """
-    try:
-        memo[id(memo)].append(x)
-    except KeyError:
-        # aha, this is the first one :-)
-        memo[id(memo)]=[x]
-
-
 # A cache for whichmodule(), mapping a function object to the name of
 # the module in which the function was found.
 
@@ -832,7 +788,7 @@
         read = self.read
         dispatch = self.dispatch
         try:
-            while 1:
+            while True:
                 key = read(1)
                 if not key:
                     raise EOFError
@@ -862,7 +818,7 @@
     dispatch = {}
 
     def load_proto(self):
-        proto = ord(self.read(1))
+        proto = self.read(1)[0]
         if not 0 <= proto <= HIGHEST_PROTOCOL:
             raise ValueError("unsupported pickle protocol: %d" % proto)
         self.proto = proto
@@ -897,40 +853,37 @@
         elif data == TRUE[1:]:
             val = True
         else:
-            try:
-                val = int(data, 0)
-            except ValueError:
-                val = int(data, 0)
+            val = int(data, 0)
         self.append(val)
     dispatch[INT[0]] = load_int
 
     def load_binint(self):
-        self.append(mloads(b'i' + self.read(4)))
+        self.append(unpack('<i', self.read(4))[0])
     dispatch[BININT[0]] = load_binint
 
     def load_binint1(self):
-        self.append(ord(self.read(1)))
+        self.append(self.read(1)[0])
     dispatch[BININT1[0]] = load_binint1
 
     def load_binint2(self):
-        self.append(mloads(b'i' + self.read(2) + b'\000\000'))
+        self.append(unpack('<H', self.read(2))[0])
     dispatch[BININT2[0]] = load_binint2
 
     def load_long(self):
-        val = self.readline()[:-1].decode("ascii")
-        if val and val[-1] == 'L':
+        val = self.readline()[:-1]
+        if val and val[-1] == b'L'[0]:
             val = val[:-1]
         self.append(int(val, 0))
     dispatch[LONG[0]] = load_long
 
     def load_long1(self):
-        n = ord(self.read(1))
+        n = self.read(1)[0]
         data = self.read(n)
         self.append(decode_long(data))
     dispatch[LONG1[0]] = load_long1
 
     def load_long4(self):
-        n = mloads(b'i' + self.read(4))
+        n, = unpack('<i', self.read(4))
         if n < 0:
             # Corrupt or hostile pickle -- we never write one like this
             raise UnpicklingError("LONG pickle has negative byte count")
@@ -942,28 +895,25 @@
         self.append(float(self.readline()[:-1]))
     dispatch[FLOAT[0]] = load_float
 
-    def load_binfloat(self, unpack=struct.unpack):
+    def load_binfloat(self):
         self.append(unpack('>d', self.read(8))[0])
     dispatch[BINFLOAT[0]] = load_binfloat
 
     def load_string(self):
         orig = self.readline()
         rep = orig[:-1]
-        for q in (b'"', b"'"): # double or single quote
-            if rep.startswith(q):
-                if not rep.endswith(q):
-                    raise ValueError("insecure string pickle")
-                rep = rep[len(q):-len(q)]
-                break
+        # Strip outermost quotes
+        if rep[0] == rep[-1] and rep[0] in b'"\'':
+            rep = rep[1:-1]
         else:
-            raise ValueError("insecure string pickle: %r" % orig)
+            raise ValueError("insecure string pickle")
         self.append(codecs.escape_decode(rep)[0]
                     .decode(self.encoding, self.errors))
     dispatch[STRING[0]] = load_string
 
     def load_binstring(self):
         # Deprecated BINSTRING uses signed 32-bit length
-        len = mloads(b'i' + self.read(4))
+        len, = unpack('<i', self.read(4))
         if len < 0:
             raise UnpicklingError("BINSTRING pickle has negative byte count")
         data = self.read(len)
@@ -971,7 +921,7 @@
         self.append(value)
     dispatch[BINSTRING[0]] = load_binstring
 
-    def load_binbytes(self, unpack=struct.unpack, maxsize=sys.maxsize):
+    def load_binbytes(self):
         len, = unpack('<I', self.read(4))
         if len > maxsize:
             raise UnpicklingError("BINBYTES exceeds system's maximum size "
@@ -983,7 +933,7 @@
         self.append(str(self.readline()[:-1], 'raw-unicode-escape'))
     dispatch[UNICODE[0]] = load_unicode
 
-    def load_binunicode(self, unpack=struct.unpack, maxsize=sys.maxsize):
+    def load_binunicode(self):
         len, = unpack('<I', self.read(4))
         if len > maxsize:
             raise UnpicklingError("BINUNICODE exceeds system's maximum size "
@@ -992,15 +942,15 @@
     dispatch[BINUNICODE[0]] = load_binunicode
 
     def load_short_binstring(self):
-        len = ord(self.read(1))
-        data = bytes(self.read(len))
+        len = self.read(1)[0]
+        data = self.read(len)
         value = str(data, self.encoding, self.errors)
         self.append(value)
     dispatch[SHORT_BINSTRING[0]] = load_short_binstring
 
     def load_short_binbytes(self):
-        len = ord(self.read(1))
-        self.append(bytes(self.read(len)))
+        len = self.read(1)[0]
+        self.append(self.read(len))
     dispatch[SHORT_BINBYTES[0]] = load_short_binbytes
 
     def load_tuple(self):
@@ -1039,12 +989,9 @@
 
     def load_dict(self):
         k = self.marker()
-        d = {}
         items = self.stack[k+1:]
-        for i in range(0, len(items), 2):
-            key = items[i]
-            value = items[i+1]
-            d[key] = value
+        d = {items[i]: items[i+1]
+             for i in range(0, len(items), 2)}
         self.stack[k:] = [d]
     dispatch[DICT[0]] = load_dict
 
@@ -1096,17 +1043,17 @@
     dispatch[GLOBAL[0]] = load_global
 
     def load_ext1(self):
-        code = ord(self.read(1))
+        code = self.read(1)[0]
         self.get_extension(code)
     dispatch[EXT1[0]] = load_ext1
 
     def load_ext2(self):
-        code = mloads(b'i' + self.read(2) + b'\000\000')
+        code, = unpack('<H', self.read(2))
         self.get_extension(code)
     dispatch[EXT2[0]] = load_ext2
 
     def load_ext4(self):
-        code = mloads(b'i' + self.read(4))
+        code, = unpack('<i', self.read(4))
         self.get_extension(code)
     dispatch[EXT4[0]] = load_ext4
 
@@ -1174,7 +1121,7 @@
         self.append(self.memo[i])
     dispatch[BINGET[0]] = load_binget
 
-    def load_long_binget(self, unpack=struct.unpack):
+    def load_long_binget(self):
         i, = unpack('<I', self.read(4))
         self.append(self.memo[i])
     dispatch[LONG_BINGET[0]] = load_long_binget
@@ -1193,7 +1140,7 @@
         self.memo[i] = self.stack[-1]
     dispatch[BINPUT[0]] = load_binput
 
-    def load_long_binput(self, unpack=struct.unpack, maxsize=sys.maxsize):
+    def load_long_binput(self):
         i, = unpack('<I', self.read(4))
         if i > maxsize:
             raise ValueError("negative LONG_BINPUT argument")
@@ -1238,7 +1185,7 @@
         state = stack.pop()
         inst = stack[-1]
         setstate = getattr(inst, "__setstate__", None)
-        if setstate:
+        if setstate is not None:
             setstate(state)
             return
         slotstate = None

-- 
Repository URL: http://hg.python.org/cpython


More information about the Python-checkins mailing list