[Python-checkins] bpo-42809: Improve pickle tests for recursive data. (GH-24060)

serhiy-storchaka webhook-mailer at python.org
Sat Jan 2 12:32:57 EST 2021


https://github.com/python/cpython/commit/a25011be8c6f62cb3333903befe6295d57f0bd30
commit: a25011be8c6f62cb3333903befe6295d57f0bd30
branch: master
author: Serhiy Storchaka <storchaka at gmail.com>
committer: serhiy-storchaka <storchaka at gmail.com>
date: 2021-01-02T19:32:47+02:00
summary:

bpo-42809: Improve pickle tests for recursive data. (GH-24060)

files:
M Lib/test/pickletester.py

diff --git a/Lib/test/pickletester.py b/Lib/test/pickletester.py
index ae288f5d01250..fd05e7af94a1a 100644
--- a/Lib/test/pickletester.py
+++ b/Lib/test/pickletester.py
@@ -69,6 +69,10 @@ def count_opcode(code, pickle):
     return n
 
 
+def identity(x):
+    return x
+
+
 class UnseekableIO(io.BytesIO):
     def peek(self, *args):
         raise NotImplementedError
@@ -138,11 +142,12 @@ class E(C):
     def __getinitargs__(self):
         return ()
 
-class H(object):
+# Simple mutable object.
+class Object:
     pass
 
-# Hashable mutable key
-class K(object):
+# Hashable immutable key object containing unheshable mutable data.
+class K:
     def __init__(self, value):
         self.value = value
 
@@ -157,10 +162,6 @@ def __reduce__(self):
 D.__module__ = "__main__"
 __main__.E = E
 E.__module__ = "__main__"
-__main__.H = H
-H.__module__ = "__main__"
-__main__.K = K
-K.__module__ = "__main__"
 
 class myint(int):
     def __init__(self, x):
@@ -1496,54 +1497,182 @@ def dont_test_disassembly(self):
             got = filelike.getvalue()
             self.assertEqual(expected, got)
 
-    def test_recursive_list(self):
-        l = []
+    def _test_recursive_list(self, cls, aslist=identity, minprotocol=0):
+        # List containing itself.
+        l = cls()
         l.append(l)
-        for proto in protocols:
+        for proto in range(minprotocol, pickle.HIGHEST_PROTOCOL + 1):
             s = self.dumps(l, proto)
             x = self.loads(s)
-            self.assertIsInstance(x, list)
-            self.assertEqual(len(x), 1)
-            self.assertIs(x[0], x)
+            self.assertIsInstance(x, cls)
+            y = aslist(x)
+            self.assertEqual(len(y), 1)
+            self.assertIs(y[0], x)
 
-    def test_recursive_tuple_and_list(self):
-        t = ([],)
+    def test_recursive_list(self):
+        self._test_recursive_list(list)
+
+    def test_recursive_list_subclass(self):
+        self._test_recursive_list(MyList, minprotocol=2)
+
+    def test_recursive_list_like(self):
+        self._test_recursive_list(REX_six, aslist=lambda x: x.items)
+
+    def _test_recursive_tuple_and_list(self, cls, aslist=identity, minprotocol=0):
+        # Tuple containing a list containing the original tuple.
+        t = (cls(),)
         t[0].append(t)
-        for proto in protocols:
+        for proto in range(minprotocol, pickle.HIGHEST_PROTOCOL + 1):
             s = self.dumps(t, proto)
             x = self.loads(s)
             self.assertIsInstance(x, tuple)
             self.assertEqual(len(x), 1)
-            self.assertIsInstance(x[0], list)
-            self.assertEqual(len(x[0]), 1)
-            self.assertIs(x[0][0], x)
+            self.assertIsInstance(x[0], cls)
+            y = aslist(x[0])
+            self.assertEqual(len(y), 1)
+            self.assertIs(y[0], x)
+
+        # List containing a tuple containing the original list.
+        t, = t
+        for proto in range(minprotocol, pickle.HIGHEST_PROTOCOL + 1):
+            s = self.dumps(t, proto)
+            x = self.loads(s)
+            self.assertIsInstance(x, cls)
+            y = aslist(x)
+            self.assertEqual(len(y), 1)
+            self.assertIsInstance(y[0], tuple)
+            self.assertEqual(len(y[0]), 1)
+            self.assertIs(y[0][0], x)
 
-    def test_recursive_dict(self):
-        d = {}
+    def test_recursive_tuple_and_list(self):
+        self._test_recursive_tuple_and_list(list)
+
+    def test_recursive_tuple_and_list_subclass(self):
+        self._test_recursive_tuple_and_list(MyList, minprotocol=2)
+
+    def test_recursive_tuple_and_list_like(self):
+        self._test_recursive_tuple_and_list(REX_six, aslist=lambda x: x.items)
+
+    def _test_recursive_dict(self, cls, asdict=identity, minprotocol=0):
+        # Dict containing itself.
+        d = cls()
         d[1] = d
-        for proto in protocols:
+        for proto in range(minprotocol, pickle.HIGHEST_PROTOCOL + 1):
             s = self.dumps(d, proto)
             x = self.loads(s)
-            self.assertIsInstance(x, dict)
-            self.assertEqual(list(x.keys()), [1])
-            self.assertIs(x[1], x)
+            self.assertIsInstance(x, cls)
+            y = asdict(x)
+            self.assertEqual(list(y.keys()), [1])
+            self.assertIs(y[1], x)
 
-    def test_recursive_dict_key(self):
-        d = {}
-        k = K(d)
-        d[k] = 1
-        for proto in protocols:
+    def test_recursive_dict(self):
+        self._test_recursive_dict(dict)
+
+    def test_recursive_dict_subclass(self):
+        self._test_recursive_dict(MyDict, minprotocol=2)
+
+    def test_recursive_dict_like(self):
+        self._test_recursive_dict(REX_seven, asdict=lambda x: x.table)
+
+    def _test_recursive_tuple_and_dict(self, cls, asdict=identity, minprotocol=0):
+        # Tuple containing a dict containing the original tuple.
+        t = (cls(),)
+        t[0][1] = t
+        for proto in range(minprotocol, pickle.HIGHEST_PROTOCOL + 1):
+            s = self.dumps(t, proto)
+            x = self.loads(s)
+            self.assertIsInstance(x, tuple)
+            self.assertEqual(len(x), 1)
+            self.assertIsInstance(x[0], cls)
+            y = asdict(x[0])
+            self.assertEqual(list(y), [1])
+            self.assertIs(y[1], x)
+
+        # Dict containing a tuple containing the original dict.
+        t, = t
+        for proto in range(minprotocol, pickle.HIGHEST_PROTOCOL + 1):
+            s = self.dumps(t, proto)
+            x = self.loads(s)
+            self.assertIsInstance(x, cls)
+            y = asdict(x)
+            self.assertEqual(list(y), [1])
+            self.assertIsInstance(y[1], tuple)
+            self.assertEqual(len(y[1]), 1)
+            self.assertIs(y[1][0], x)
+
+    def test_recursive_tuple_and_dict(self):
+        self._test_recursive_tuple_and_dict(dict)
+
+    def test_recursive_tuple_and_dict_subclass(self):
+        self._test_recursive_tuple_and_dict(MyDict, minprotocol=2)
+
+    def test_recursive_tuple_and_dict_like(self):
+        self._test_recursive_tuple_and_dict(REX_seven, asdict=lambda x: x.table)
+
+    def _test_recursive_dict_key(self, cls, asdict=identity, minprotocol=0):
+        # Dict containing an immutable object (as key) containing the original
+        # dict.
+        d = cls()
+        d[K(d)] = 1
+        for proto in range(minprotocol, pickle.HIGHEST_PROTOCOL + 1):
             s = self.dumps(d, proto)
             x = self.loads(s)
-            self.assertIsInstance(x, dict)
-            self.assertEqual(len(x.keys()), 1)
-            self.assertIsInstance(list(x.keys())[0], K)
-            self.assertIs(list(x.keys())[0].value, x)
+            self.assertIsInstance(x, cls)
+            y = asdict(x)
+            self.assertEqual(len(y.keys()), 1)
+            self.assertIsInstance(list(y.keys())[0], K)
+            self.assertIs(list(y.keys())[0].value, x)
+
+    def test_recursive_dict_key(self):
+        self._test_recursive_dict_key(dict)
+
+    def test_recursive_dict_subclass_key(self):
+        self._test_recursive_dict_key(MyDict, minprotocol=2)
+
+    def test_recursive_dict_like_key(self):
+        self._test_recursive_dict_key(REX_seven, asdict=lambda x: x.table)
+
+    def _test_recursive_tuple_and_dict_key(self, cls, asdict=identity, minprotocol=0):
+        # Tuple containing a dict containing an immutable object (as key)
+        # containing the original tuple.
+        t = (cls(),)
+        t[0][K(t)] = 1
+        for proto in range(minprotocol, pickle.HIGHEST_PROTOCOL + 1):
+            s = self.dumps(t, proto)
+            x = self.loads(s)
+            self.assertIsInstance(x, tuple)
+            self.assertEqual(len(x), 1)
+            self.assertIsInstance(x[0], cls)
+            y = asdict(x[0])
+            self.assertEqual(len(y), 1)
+            self.assertIsInstance(list(y.keys())[0], K)
+            self.assertIs(list(y.keys())[0].value, x)
+
+        # Dict containing an immutable object (as key) containing a tuple
+        # containing the original dict.
+        t, = t
+        for proto in range(minprotocol, pickle.HIGHEST_PROTOCOL + 1):
+            s = self.dumps(t, proto)
+            x = self.loads(s)
+            self.assertIsInstance(x, cls)
+            y = asdict(x)
+            self.assertEqual(len(y), 1)
+            self.assertIsInstance(list(y.keys())[0], K)
+            self.assertIs(list(y.keys())[0].value[0], x)
+
+    def test_recursive_tuple_and_dict_key(self):
+        self._test_recursive_tuple_and_dict_key(dict)
+
+    def test_recursive_tuple_and_dict_subclass_key(self):
+        self._test_recursive_tuple_and_dict_key(MyDict, minprotocol=2)
+
+    def test_recursive_tuple_and_dict_like_key(self):
+        self._test_recursive_tuple_and_dict_key(REX_seven, asdict=lambda x: x.table)
 
     def test_recursive_set(self):
+        # Set containing an immutable object containing the original set.
         y = set()
-        k = K(y)
-        y.add(k)
+        y.add(K(y))
         for proto in range(4, pickle.HIGHEST_PROTOCOL + 1):
             s = self.dumps(y, proto)
             x = self.loads(s)
@@ -1552,52 +1681,31 @@ def test_recursive_set(self):
             self.assertIsInstance(list(x)[0], K)
             self.assertIs(list(x)[0].value, x)
 
-    def test_recursive_list_subclass(self):
-        y = MyList()
-        y.append(y)
-        for proto in range(2, pickle.HIGHEST_PROTOCOL + 1):
+        # Immutable object containing a set containing the original object.
+        y, = y
+        for proto in range(4, pickle.HIGHEST_PROTOCOL + 1):
             s = self.dumps(y, proto)
             x = self.loads(s)
-            self.assertIsInstance(x, MyList)
-            self.assertEqual(len(x), 1)
-            self.assertIs(x[0], x)
-
-    def test_recursive_dict_subclass(self):
-        d = MyDict()
-        d[1] = d
-        for proto in range(2, pickle.HIGHEST_PROTOCOL + 1):
-            s = self.dumps(d, proto)
-            x = self.loads(s)
-            self.assertIsInstance(x, MyDict)
-            self.assertEqual(list(x.keys()), [1])
-            self.assertIs(x[1], x)
-
-    def test_recursive_dict_subclass_key(self):
-        d = MyDict()
-        k = K(d)
-        d[k] = 1
-        for proto in range(2, pickle.HIGHEST_PROTOCOL + 1):
-            s = self.dumps(d, proto)
-            x = self.loads(s)
-            self.assertIsInstance(x, MyDict)
-            self.assertEqual(len(list(x.keys())), 1)
-            self.assertIsInstance(list(x.keys())[0], K)
-            self.assertIs(list(x.keys())[0].value, x)
+            self.assertIsInstance(x, K)
+            self.assertIsInstance(x.value, set)
+            self.assertEqual(len(x.value), 1)
+            self.assertIs(list(x.value)[0], x)
 
     def test_recursive_inst(self):
-        i = C()
+        # Mutable object containing itself.
+        i = Object()
         i.attr = i
         for proto in protocols:
             s = self.dumps(i, proto)
             x = self.loads(s)
-            self.assertIsInstance(x, C)
+            self.assertIsInstance(x, Object)
             self.assertEqual(dir(x), dir(i))
             self.assertIs(x.attr, x)
 
     def test_recursive_multi(self):
         l = []
         d = {1:l}
-        i = C()
+        i = Object()
         i.attr = d
         l.append(i)
         for proto in protocols:
@@ -1607,49 +1715,94 @@ def test_recursive_multi(self):
             self.assertEqual(len(x), 1)
             self.assertEqual(dir(x[0]), dir(i))
             self.assertEqual(list(x[0].attr.keys()), [1])
-            self.assertTrue(x[0].attr[1] is x)
-
-    def check_recursive_collection_and_inst(self, factory):
-        h = H()
-        y = factory([h])
-        h.attr = y
+            self.assertIs(x[0].attr[1], x)
+
+    def _test_recursive_collection_and_inst(self, factory):
+        # Mutable object containing a collection containing the original
+        # object.
+        o = Object()
+        o.attr = factory([o])
+        t = type(o.attr)
         for proto in protocols:
-            s = self.dumps(y, proto)
+            s = self.dumps(o, proto)
             x = self.loads(s)
-            self.assertIsInstance(x, type(y))
+            self.assertIsInstance(x.attr, t)
+            self.assertEqual(len(x.attr), 1)
+            self.assertIsInstance(list(x.attr)[0], Object)
+            self.assertIs(list(x.attr)[0], x)
+
+        # Collection containing a mutable object containing the original
+        # collection.
+        o = o.attr
+        for proto in protocols:
+            s = self.dumps(o, proto)
+            x = self.loads(s)
+            self.assertIsInstance(x, t)
             self.assertEqual(len(x), 1)
-            self.assertIsInstance(list(x)[0], H)
+            self.assertIsInstance(list(x)[0], Object)
             self.assertIs(list(x)[0].attr, x)
 
     def test_recursive_list_and_inst(self):
-        self.check_recursive_collection_and_inst(list)
+        self._test_recursive_collection_and_inst(list)
 
     def test_recursive_tuple_and_inst(self):
-        self.check_recursive_collection_and_inst(tuple)
+        self._test_recursive_collection_and_inst(tuple)
 
     def test_recursive_dict_and_inst(self):
-        self.check_recursive_collection_and_inst(dict.fromkeys)
+        self._test_recursive_collection_and_inst(dict.fromkeys)
 
     def test_recursive_set_and_inst(self):
-        self.check_recursive_collection_and_inst(set)
+        self._test_recursive_collection_and_inst(set)
 
     def test_recursive_frozenset_and_inst(self):
-        self.check_recursive_collection_and_inst(frozenset)
+        self._test_recursive_collection_and_inst(frozenset)
 
     def test_recursive_list_subclass_and_inst(self):
-        self.check_recursive_collection_and_inst(MyList)
+        self._test_recursive_collection_and_inst(MyList)
 
     def test_recursive_tuple_subclass_and_inst(self):
-        self.check_recursive_collection_and_inst(MyTuple)
+        self._test_recursive_collection_and_inst(MyTuple)
 
     def test_recursive_dict_subclass_and_inst(self):
-        self.check_recursive_collection_and_inst(MyDict.fromkeys)
+        self._test_recursive_collection_and_inst(MyDict.fromkeys)
 
     def test_recursive_set_subclass_and_inst(self):
-        self.check_recursive_collection_and_inst(MySet)
+        self._test_recursive_collection_and_inst(MySet)
 
     def test_recursive_frozenset_subclass_and_inst(self):
-        self.check_recursive_collection_and_inst(MyFrozenSet)
+        self._test_recursive_collection_and_inst(MyFrozenSet)
+
+    def test_recursive_inst_state(self):
+        # Mutable object containing itself.
+        y = REX_state()
+        y.state = y
+        for proto in protocols:
+            s = self.dumps(y, proto)
+            x = self.loads(s)
+            self.assertIsInstance(x, REX_state)
+            self.assertIs(x.state, x)
+
+    def test_recursive_tuple_and_inst_state(self):
+        # Tuple containing a mutable object containing the original tuple.
+        t = (REX_state(),)
+        t[0].state = t
+        for proto in protocols:
+            s = self.dumps(t, proto)
+            x = self.loads(s)
+            self.assertIsInstance(x, tuple)
+            self.assertEqual(len(x), 1)
+            self.assertIsInstance(x[0], REX_state)
+            self.assertIs(x[0].state, x)
+
+        # Mutable object containing a tuple containing the object.
+        t, = t
+        for proto in protocols:
+            s = self.dumps(t, proto)
+            x = self.loads(s)
+            self.assertIsInstance(x, REX_state)
+            self.assertIsInstance(x.state, tuple)
+            self.assertEqual(len(x.state), 1)
+            self.assertIs(x.state[0], x)
 
     def test_unicode(self):
         endcases = ['', '<\\u>', '<\\\u1234>', '<\n>',
@@ -3062,6 +3215,19 @@ def __setitem__(self, key, value):
     def __reduce__(self):
         return type(self), (), None, None, iter(self.table.items())
 
+class REX_state(object):
+    """This class is used to check the 3th argument (state) of
+    the reduce protocol.
+    """
+    def __init__(self, state=None):
+        self.state = state
+    def __eq__(self, other):
+        return type(self) is type(other) and self.state == other.state
+    def __setstate__(self, state):
+        self.state = state
+    def __reduce__(self):
+        return type(self), (), self.state
+
 
 # Test classes for newobj
 



More information about the Python-checkins mailing list