[Python-checkins] cpython (merge 3.5 -> default): Issue #25945: Fixed bugs in functools.partial.

serhiy.storchaka python-checkins at python.org
Tue Feb 2 11:46:31 EST 2016


https://hg.python.org/cpython/rev/33109176538d
changeset:   100148:33109176538d
parent:      100146:03708c680eca
parent:      100147:542b5744ddc3
user:        Serhiy Storchaka <storchaka at gmail.com>
date:        Tue Feb 02 18:45:47 2016 +0200
summary:
  Issue #25945: Fixed bugs in functools.partial.
Fixed a crash when unpickle the functools.partial object with wrong state.
Fixed a leak in failed functools.partial constructor.
"args" and "keywords" attributes of functools.partial have now always types
tuple and dict correspondingly.

files:
  Lib/test/test_functools.py |   93 ++++++++++++++++-
  Misc/NEWS                  |    5 +
  Modules/_functoolsmodule.c |  128 +++++++++++++-----------
  3 files changed, 160 insertions(+), 66 deletions(-)


diff --git a/Lib/test/test_functools.py b/Lib/test/test_functools.py
--- a/Lib/test/test_functools.py
+++ b/Lib/test/test_functools.py
@@ -30,6 +30,16 @@
     """ return the signature of a partial object """
     return (part.func, part.args, part.keywords, part.__dict__)
 
+class MyTuple(tuple):
+    pass
+
+class BadTuple(tuple):
+    def __add__(self, other):
+        return list(self) + list(other)
+
+class MyDict(dict):
+    pass
+
 
 class TestPartial:
 
@@ -208,11 +218,84 @@
                        for kwargs_repr in kwargs_reprs])
 
     def test_pickle(self):
-        f = self.partial(signature, 'asdf', bar=True)
-        f.add_something_to__dict__ = True
+        f = self.partial(signature, ['asdf'], bar=[True])
+        f.attr = []
         for proto in range(pickle.HIGHEST_PROTOCOL + 1):
             f_copy = pickle.loads(pickle.dumps(f, proto))
-            self.assertEqual(signature(f), signature(f_copy))
+            self.assertEqual(signature(f_copy), signature(f))
+
+    def test_copy(self):
+        f = self.partial(signature, ['asdf'], bar=[True])
+        f.attr = []
+        f_copy = copy.copy(f)
+        self.assertEqual(signature(f_copy), signature(f))
+        self.assertIs(f_copy.attr, f.attr)
+        self.assertIs(f_copy.args, f.args)
+        self.assertIs(f_copy.keywords, f.keywords)
+
+    def test_deepcopy(self):
+        f = self.partial(signature, ['asdf'], bar=[True])
+        f.attr = []
+        f_copy = copy.deepcopy(f)
+        self.assertEqual(signature(f_copy), signature(f))
+        self.assertIsNot(f_copy.attr, f.attr)
+        self.assertIsNot(f_copy.args, f.args)
+        self.assertIsNot(f_copy.args[0], f.args[0])
+        self.assertIsNot(f_copy.keywords, f.keywords)
+        self.assertIsNot(f_copy.keywords['bar'], f.keywords['bar'])
+
+    def test_setstate(self):
+        f = self.partial(signature)
+        f.__setstate__((capture, (1,), dict(a=10), dict(attr=[])))
+        self.assertEqual(signature(f),
+                         (capture, (1,), dict(a=10), dict(attr=[])))
+        self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
+
+        f.__setstate__((capture, (1,), dict(a=10), None))
+        self.assertEqual(signature(f), (capture, (1,), dict(a=10), {}))
+        self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
+
+        f.__setstate__((capture, (1,), None, None))
+        #self.assertEqual(signature(f), (capture, (1,), {}, {}))
+        self.assertEqual(f(2, b=20), ((1, 2), {'b': 20}))
+        self.assertEqual(f(2), ((1, 2), {}))
+        self.assertEqual(f(), ((1,), {}))
+
+        f.__setstate__((capture, (), {}, None))
+        self.assertEqual(signature(f), (capture, (), {}, {}))
+        self.assertEqual(f(2, b=20), ((2,), {'b': 20}))
+        self.assertEqual(f(2), ((2,), {}))
+        self.assertEqual(f(), ((), {}))
+
+    def test_setstate_errors(self):
+        f = self.partial(signature)
+        self.assertRaises(TypeError, f.__setstate__, (capture, (), {}))
+        self.assertRaises(TypeError, f.__setstate__, (capture, (), {}, {}, None))
+        self.assertRaises(TypeError, f.__setstate__, [capture, (), {}, None])
+        self.assertRaises(TypeError, f.__setstate__, (None, (), {}, None))
+        self.assertRaises(TypeError, f.__setstate__, (capture, None, {}, None))
+        self.assertRaises(TypeError, f.__setstate__, (capture, [], {}, None))
+        self.assertRaises(TypeError, f.__setstate__, (capture, (), [], None))
+
+    def test_setstate_subclasses(self):
+        f = self.partial(signature)
+        f.__setstate__((capture, MyTuple((1,)), MyDict(a=10), None))
+        s = signature(f)
+        self.assertEqual(s, (capture, (1,), dict(a=10), {}))
+        self.assertIs(type(s[1]), tuple)
+        self.assertIs(type(s[2]), dict)
+        r = f()
+        self.assertEqual(r, ((1,), {'a': 10}))
+        self.assertIs(type(r[0]), tuple)
+        self.assertIs(type(r[1]), dict)
+
+        f.__setstate__((capture, BadTuple((1,)), {}, None))
+        s = signature(f)
+        self.assertEqual(s, (capture, (1,), {}, {}))
+        self.assertIs(type(s[1]), tuple)
+        r = f(2)
+        self.assertEqual(r, ((1, 2), {}))
+        self.assertIs(type(r[0]), tuple)
 
     # Issue 6083: Reference counting bug
     def test_setstate_refcount(self):
@@ -229,9 +312,7 @@
                 raise IndexError
 
         f = self.partial(object)
-        self.assertRaisesRegex(SystemError,
-                "new style getargs format but argument is not a tuple",
-                f.__setstate__, BadSequence())
+        self.assertRaises(TypeError, f.__setstate__, BadSequence())
 
 
 class TestPartialPy(TestPartial, unittest.TestCase):
diff --git a/Misc/NEWS b/Misc/NEWS
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -166,6 +166,11 @@
 Library
 -------
 
+- Issue #25945: Fixed a crash when unpickle the functools.partial object with
+  wrong state.  Fixed a leak in failed functools.partial constructor.
+  "args" and "keywords" attributes of functools.partial have now always types
+  tuple and dict correspondingly.
+
 - Issue #26202: copy.deepcopy() now correctly copies range() objects with
   non-atomic attributes.
 
diff --git a/Modules/_functoolsmodule.c b/Modules/_functoolsmodule.c
--- a/Modules/_functoolsmodule.c
+++ b/Modules/_functoolsmodule.c
@@ -34,7 +34,7 @@
         return NULL;
     }
 
-    pargs = pkw = Py_None;
+    pargs = pkw = NULL;
     func = PyTuple_GET_ITEM(args, 0);
     if (Py_TYPE(func) == &partial_type && type == &partial_type) {
         partialobject *part = (partialobject *)func;
@@ -42,6 +42,8 @@
             pargs = part->args;
             pkw = part->kw;
             func = part->fn;
+            assert(PyTuple_Check(pargs));
+            assert(PyDict_Check(pkw));
         }
     }
     if (!PyCallable_Check(func)) {
@@ -60,12 +62,10 @@
 
     nargs = PyTuple_GetSlice(args, 1, PY_SSIZE_T_MAX);
     if (nargs == NULL) {
-        pto->args = NULL;
-        pto->kw = NULL;
         Py_DECREF(pto);
         return NULL;
     }
-    if (pargs == Py_None || PyTuple_GET_SIZE(pargs) == 0) {
+    if (pargs == NULL || PyTuple_GET_SIZE(pargs) == 0) {
         pto->args = nargs;
         Py_INCREF(nargs);
     }
@@ -76,47 +76,36 @@
     else {
         pto->args = PySequence_Concat(pargs, nargs);
         if (pto->args == NULL) {
-            pto->kw = NULL;
+            Py_DECREF(nargs);
             Py_DECREF(pto);
             return NULL;
         }
+        assert(PyTuple_Check(pto->args));
     }
     Py_DECREF(nargs);
 
-    if (kw != NULL) {
-        if (pkw == Py_None) {
-            pto->kw = PyDict_Copy(kw);
+    if (pkw == NULL || PyDict_Size(pkw) == 0) {
+        if (kw == NULL) {
+            pto->kw = PyDict_New();
         }
         else {
-            pto->kw = PyDict_Copy(pkw);
-            if (pto->kw != NULL) {
-                if (PyDict_Merge(pto->kw, kw, 1) != 0) {
-                    Py_DECREF(pto);
-                    return NULL;
-                }
-            }
-        }
-        if (pto->kw == NULL) {
-            Py_DECREF(pto);
-            return NULL;
+            Py_INCREF(kw);
+            pto->kw = kw;
         }
     }
     else {
-        if (pkw == Py_None) {
-            pto->kw = PyDict_New();
-            if (pto->kw == NULL) {
+        pto->kw = PyDict_Copy(pkw);
+        if (kw != NULL && pto->kw != NULL) {
+            if (PyDict_Merge(pto->kw, kw, 1) != 0) {
                 Py_DECREF(pto);
                 return NULL;
             }
         }
-        else {
-            pto->kw = pkw;
-            Py_INCREF(pkw);
-        }
     }
-
-    pto->weakreflist = NULL;
-    pto->dict = NULL;
+    if (pto->kw == NULL) {
+        Py_DECREF(pto);
+        return NULL;
+    }
 
     return (PyObject *)pto;
 }
@@ -138,11 +127,11 @@
 partial_call(partialobject *pto, PyObject *args, PyObject *kw)
 {
     PyObject *ret;
-    PyObject *argappl = NULL, *kwappl = NULL;
+    PyObject *argappl, *kwappl;
 
     assert (PyCallable_Check(pto->fn));
     assert (PyTuple_Check(pto->args));
-    assert (pto->kw == Py_None  ||  PyDict_Check(pto->kw));
+    assert (PyDict_Check(pto->kw));
 
     if (PyTuple_GET_SIZE(pto->args) == 0) {
         argappl = args;
@@ -154,11 +143,12 @@
         argappl = PySequence_Concat(pto->args, args);
         if (argappl == NULL)
             return NULL;
+        assert(PyTuple_Check(argappl));
     }
 
-    if (pto->kw == Py_None) {
+    if (PyDict_Size(pto->kw) == 0) {
         kwappl = kw;
-        Py_XINCREF(kw);
+        Py_XINCREF(kwappl);
     } else {
         kwappl = PyDict_Copy(pto->kw);
         if (kwappl == NULL) {
@@ -217,6 +207,7 @@
     PyObject *arglist;
     PyObject *tmp;
     Py_ssize_t i, n;
+    PyObject *key, *value;
 
     arglist = PyUnicode_FromString("");
     if (arglist == NULL) {
@@ -234,17 +225,14 @@
         arglist = tmp;
     }
     /* Pack keyword arguments */
-    assert (pto->kw == Py_None  ||  PyDict_Check(pto->kw));
-    if (pto->kw != Py_None) {
-        PyObject *key, *value;
-        for (i = 0; PyDict_Next(pto->kw, &i, &key, &value);) {
-            tmp = PyUnicode_FromFormat("%U, %U=%R", arglist,
-                                       key, value);
-            Py_DECREF(arglist);
-            if (tmp == NULL)
-                return NULL;
-            arglist = tmp;
-        }
+    assert (PyDict_Check(pto->kw));
+    for (i = 0; PyDict_Next(pto->kw, &i, &key, &value);) {
+        tmp = PyUnicode_FromFormat("%U, %U=%R", arglist,
+                                    key, value);
+        Py_DECREF(arglist);
+        if (tmp == NULL)
+            return NULL;
+        arglist = tmp;
     }
     result = PyUnicode_FromFormat("%s(%R%U)", Py_TYPE(pto)->tp_name,
                                   pto->fn, arglist);
@@ -271,25 +259,45 @@
 partial_setstate(partialobject *pto, PyObject *state)
 {
     PyObject *fn, *fnargs, *kw, *dict;
-    if (!PyArg_ParseTuple(state, "OOOO",
-                          &fn, &fnargs, &kw, &dict))
+
+    if (!PyTuple_Check(state) ||
+        !PyArg_ParseTuple(state, "OOOO", &fn, &fnargs, &kw, &dict) ||
+        !PyCallable_Check(fn) ||
+        !PyTuple_Check(fnargs) ||
+        (kw != Py_None && !PyDict_Check(kw)))
+    {
+        PyErr_SetString(PyExc_TypeError, "invalid partial state");
         return NULL;
-    Py_XDECREF(pto->fn);
-    Py_XDECREF(pto->args);
-    Py_XDECREF(pto->kw);
-    Py_XDECREF(pto->dict);
-    pto->fn = fn;
-    pto->args = fnargs;
-    pto->kw = kw;
-    if (dict != Py_None) {
-      pto->dict = dict;
-      Py_INCREF(dict);
-    } else {
-      pto->dict = NULL;
     }
+
+    if(!PyTuple_CheckExact(fnargs))
+        fnargs = PySequence_Tuple(fnargs);
+    else
+        Py_INCREF(fnargs);
+    if (fnargs == NULL)
+        return NULL;
+
+    if (kw == Py_None)
+        kw = PyDict_New();
+    else if(!PyDict_CheckExact(kw))
+        kw = PyDict_Copy(kw);
+    else
+        Py_INCREF(kw);
+    if (kw == NULL) {
+        Py_DECREF(fnargs);
+        return NULL;
+    }
+
     Py_INCREF(fn);
-    Py_INCREF(fnargs);
-    Py_INCREF(kw);
+    if (dict == Py_None)
+        dict = NULL;
+    else
+        Py_INCREF(dict);
+
+    Py_SETREF(pto->fn, fn);
+    Py_SETREF(pto->args, fnargs);
+    Py_SETREF(pto->kw, kw);
+    Py_SETREF(pto->dict, dict);
     Py_RETURN_NONE;
 }
 

-- 
Repository URL: https://hg.python.org/cpython


More information about the Python-checkins mailing list