[Python-checkins] cpython (merge 3.3 -> default): Issue #16076: make _elementtree.Element pickle-able in a way that is compatible

eli.bendersky python-checkins at python.org
Thu Jan 10 15:07:36 CET 2013


http://hg.python.org/cpython/rev/4c268b7c86e6
changeset:   81360:4c268b7c86e6
parent:      81357:6478c4259ce3
parent:      81359:8d6dadfecf22
user:        Eli Bendersky <eliben at gmail.com>
date:        Thu Jan 10 06:06:01 2013 -0800
summary:
  Issue #16076: make _elementtree.Element pickle-able in a way that is compatible
with the Python version of the class.

Patch by Daniel Shahaf.

files:
  Lib/test/test_xml_etree.py |   81 ++++++++--
  Modules/_elementtree.c     |  180 ++++++++++++++++++++++++-
  2 files changed, 238 insertions(+), 23 deletions(-)


diff --git a/Lib/test/test_xml_etree.py b/Lib/test/test_xml_etree.py
--- a/Lib/test/test_xml_etree.py
+++ b/Lib/test/test_xml_etree.py
@@ -16,14 +16,20 @@
 
 import html
 import io
+import operator
 import pickle
 import sys
 import unittest
 import weakref
 
+from itertools import product
 from test import support
 from test.support import TESTFN, findfile, unlink, import_fresh_module, gc_collect
 
+# pyET is the pure-Python implementation.
+# 
+# ET is pyET in test_xml_etree and is the C accelerated version in
+# test_xml_etree_c.
 pyET = None
 ET = None
 
@@ -171,6 +177,38 @@
     for elem in element:
         check_element(elem)
 
+class ElementTestCase:
+    @classmethod
+    def setUpClass(cls):
+        cls.modules = {pyET, ET}
+
+    def pickleRoundTrip(self, obj, name, dumper, loader):
+        save_m = sys.modules[name]
+        try:
+            sys.modules[name] = dumper
+            temp = pickle.dumps(obj)
+            sys.modules[name] = loader
+            result = pickle.loads(temp)
+        except pickle.PicklingError as pe:
+            # pyET must be second, because pyET may be (equal to) ET.
+            human = dict([(ET, "cET"), (pyET, "pyET")])
+            raise support.TestFailed("Failed to round-trip %r from %r to %r"
+                                     % (obj,
+                                        human.get(dumper, dumper),
+                                        human.get(loader, loader))) from pe
+        finally:
+            sys.modules[name] = save_m
+        return result
+
+    def assertEqualElements(self, alice, bob):
+        self.assertIsInstance(alice, (ET.Element, pyET.Element))
+        self.assertIsInstance(bob, (ET.Element, pyET.Element))
+        self.assertEqual(len(list(alice)), len(list(bob)))
+        for x, y in zip(alice, bob):
+            self.assertEqualElements(x, y)
+        properties = operator.attrgetter('tag', 'tail', 'text', 'attrib')
+        self.assertEqual(properties(alice), properties(bob))
+
 # --------------------------------------------------------------------
 # element tree tests
 
@@ -1715,7 +1753,7 @@
 # --------------------------------------------------------------------
 
 
-class BasicElementTest(unittest.TestCase):
+class BasicElementTest(ElementTestCase, unittest.TestCase):
     def test_augmentation_type_errors(self):
         e = ET.Element('joe')
         self.assertRaises(TypeError, e.append, 'b')
@@ -1775,19 +1813,22 @@
         self.assertEqual(e1.get('w', default=7), 7)
 
     def test_pickle(self):
-        # For now this test only works for the Python version of ET,
-        # so set sys.modules accordingly because pickle uses __import__
-        # to load the __module__ of the class.
-        if pyET:
-            sys.modules['xml.etree.ElementTree'] = pyET
-        else:
-            raise unittest.SkipTest('only for the Python version')
-        e1 = ET.Element('foo', bar=42)
-        s = pickle.dumps(e1)
-        e2 = pickle.loads(s)
-        self.assertEqual(e2.tag, 'foo')
-        self.assertEqual(e2.attrib['bar'], 42)
-
+        # issue #16076: the C implementation wasn't pickleable.
+        for dumper, loader in product(self.modules, repeat=2):
+            e = dumper.Element('foo', bar=42)
+            e.text = "text goes here"
+            e.tail = "opposite of head"
+            dumper.SubElement(e, 'child').append(dumper.Element('grandchild'))
+            e.append(dumper.Element('child'))
+            e.findall('.//grandchild')[0].set('attr', 'other value')
+
+            e2 = self.pickleRoundTrip(e, 'xml.etree.ElementTree',
+                                      dumper, loader)
+
+            self.assertEqual(e2.tag, 'foo')
+            self.assertEqual(e2.attrib['bar'], 42)
+            self.assertEqual(len(e2), 2)
+            self.assertEqualElements(e, e2)
 
 class ElementTreeTest(unittest.TestCase):
     def test_istype(self):
@@ -2433,7 +2474,7 @@
 class NoAcceleratorTest(unittest.TestCase):
     def setUp(self):
         if not pyET:
-            raise SkipTest('only for the Python version')
+            raise unittest.SkipTest('only for the Python version')
 
     # Test that the C accelerator was not imported for pyET
     def test_correct_import_pyET(self):
@@ -2486,10 +2527,10 @@
 def test_main(module=None):
     # When invoked without a module, runs the Python ET tests by loading pyET.
     # Otherwise, uses the given module as the ET.
+    global pyET
+    pyET = import_fresh_module('xml.etree.ElementTree',
+                               blocked=['_elementtree'])
     if module is None:
-        global pyET
-        pyET = import_fresh_module('xml.etree.ElementTree',
-                                   blocked=['_elementtree'])
         module = pyET
 
     global ET
@@ -2509,7 +2550,7 @@
     # These tests will only run for the pure-Python version that doesn't import
     # _elementtree. We can't use skipUnless here, because pyET is filled in only
     # after the module is loaded.
-    if pyET:
+    if pyET is not ET:
         test_classes.extend([
             NoAcceleratorTest,
             ])
@@ -2518,7 +2559,7 @@
         support.run_unittest(*test_classes)
 
         # XXX the C module should give the same warnings as the Python module
-        with CleanContext(quiet=(module is not pyET)):
+        with CleanContext(quiet=(pyET is not ET)):
             support.run_doctest(sys.modules[__name__], verbosity=True)
     finally:
         # don't interfere with subsequent tests
diff --git a/Modules/_elementtree.c b/Modules/_elementtree.c
--- a/Modules/_elementtree.c
+++ b/Modules/_elementtree.c
@@ -814,6 +814,176 @@
     return PyLong_FromSsize_t(result);
 }
 
+/* dict keys for getstate/setstate. */
+#define PICKLED_TAG "tag"
+#define PICKLED_CHILDREN "_children"
+#define PICKLED_ATTRIB "attrib"
+#define PICKLED_TAIL "tail"
+#define PICKLED_TEXT "text"
+
+/* __getstate__ returns a fabricated instance dict as in the pure-Python
+ * Element implementation, for interoperability/interchangeability.  This
+ * makes the pure-Python implementation details an API, but (a) there aren't
+ * any unnecessary structures there; and (b) it buys compatibility with 3.2
+ * pickles.  See issue #16076.
+ */
+static PyObject *
+element_getstate(ElementObject *self)
+{
+    int i, noattrib;
+    PyObject *instancedict = NULL, *children;
+
+    /* Build a list of children. */
+    children = PyList_New(self->extra ? self->extra->length : 0);
+    if (!children)
+        return NULL;
+    for (i = 0; i < PyList_GET_SIZE(children); i++) {
+        PyObject *child = self->extra->children[i];
+        Py_INCREF(child);
+        PyList_SET_ITEM(children, i, child);
+    }
+
+    /* Construct the state object. */
+    noattrib = (self->extra == NULL || self->extra->attrib == Py_None);
+    if (noattrib)
+        instancedict = Py_BuildValue("{sOsOs{}sOsO}",
+                                     PICKLED_TAG, self->tag,
+                                     PICKLED_CHILDREN, children,
+                                     PICKLED_ATTRIB,
+                                     PICKLED_TEXT, self->text,
+                                     PICKLED_TAIL, self->tail);
+    else
+        instancedict = Py_BuildValue("{sOsOsOsOsO}",
+                                     PICKLED_TAG, self->tag,
+                                     PICKLED_CHILDREN, children,
+                                     PICKLED_ATTRIB, self->extra->attrib,
+                                     PICKLED_TEXT, self->text,
+                                     PICKLED_TAIL, self->tail);
+    if (instancedict)
+        return instancedict;
+    else {
+        for (i = 0; i < PyList_GET_SIZE(children); i++)
+            Py_DECREF(PyList_GET_ITEM(children, i));
+        Py_DECREF(children);
+
+        return NULL;
+    }
+}
+
+static PyObject *
+element_setstate_from_attributes(ElementObject *self,
+                                 PyObject *tag,
+                                 PyObject *attrib,
+                                 PyObject *text,
+                                 PyObject *tail,
+                                 PyObject *children)
+{
+    Py_ssize_t i, nchildren;
+
+    if (!tag) {
+        PyErr_SetString(PyExc_TypeError, "tag may not be NULL");
+        return NULL;
+    }
+    if (!text) {
+        Py_INCREF(Py_None);
+        text = Py_None;
+    }
+    if (!tail) {
+        Py_INCREF(Py_None);
+        tail = Py_None;
+    }
+
+    Py_CLEAR(self->tag);
+    self->tag = tag;
+    Py_INCREF(self->tag);
+
+    Py_CLEAR(self->text);
+    self->text = text;
+    Py_INCREF(self->text);
+
+    Py_CLEAR(self->tail);
+    self->tail = tail;
+    Py_INCREF(self->tail);
+
+    /* Handle ATTRIB and CHILDREN. */
+    if (!children && !attrib)
+        Py_RETURN_NONE;
+
+    /* Compute 'nchildren'. */
+    if (children) {
+        if (!PyList_Check(children)) {
+            PyErr_SetString(PyExc_TypeError, "'_children' is not a list");
+            return NULL;
+        }
+        nchildren = PyList_Size(children);
+    }
+    else {
+        nchildren = 0;
+    }
+
+    /* Allocate 'extra'. */
+    if (element_resize(self, nchildren)) {
+        return NULL;
+    }
+    assert(self->extra && self->extra->allocated >= nchildren);
+
+    /* Copy children */
+    for (i = 0; i < nchildren; i++) {
+        self->extra->children[i] = PyList_GET_ITEM(children, i);
+        Py_INCREF(self->extra->children[i]);
+    }
+
+    self->extra->length = nchildren;
+    self->extra->allocated = nchildren;
+
+    /* Stash attrib. */
+    if (attrib) {
+        Py_CLEAR(self->extra->attrib);
+        self->extra->attrib = attrib;
+        Py_INCREF(attrib);
+    }
+
+    Py_RETURN_NONE;
+}
+
+/* __setstate__ for Element instance from the Python implementation.
+ * 'state' should be the instance dict.
+ */
+static PyObject *
+element_setstate_from_Python(ElementObject *self, PyObject *state)
+{
+    static char *kwlist[] = {PICKLED_TAG, PICKLED_ATTRIB, PICKLED_TEXT,
+                             PICKLED_TAIL, PICKLED_CHILDREN, 0};
+    PyObject *args;
+    PyObject *tag, *attrib, *text, *tail, *children;
+    int error;
+
+    /* More instance dict members than we know to handle? */
+    tag = attrib = text = tail = children = NULL;
+    args = PyTuple_New(0);
+    error = ! PyArg_ParseTupleAndKeywords(args, state, "|$OOOOO", kwlist, &tag,
+                                          &attrib, &text, &tail, &children);
+    Py_DECREF(args);
+    if (error)
+        return NULL;
+    else
+        return element_setstate_from_attributes(self, tag, attrib, text,
+                                                tail, children);
+}
+
+static PyObject *
+element_setstate(ElementObject *self, PyObject *state)
+{
+    if (!PyDict_CheckExact(state)) {
+        PyErr_Format(PyExc_TypeError,
+                     "Don't know how to unpickle \"%.200R\" as an Element",
+                     state);
+        return NULL;
+    }
+    else
+        return element_setstate_from_Python(self, state);
+}
+
 LOCAL(int)
 checkpath(PyObject* tag)
 {
@@ -1587,6 +1757,8 @@
     {"__copy__", (PyCFunction) element_copy, METH_VARARGS},
     {"__deepcopy__", (PyCFunction) element_deepcopy, METH_VARARGS},
     {"__sizeof__", element_sizeof, METH_NOARGS},
+    {"__getstate__", (PyCFunction)element_getstate, METH_NOARGS},
+    {"__setstate__", (PyCFunction)element_setstate, METH_O},
 
     {NULL, NULL}
 };
@@ -1691,7 +1863,7 @@
 
 static PyTypeObject Element_Type = {
     PyVarObject_HEAD_INIT(NULL, 0)
-    "Element", sizeof(ElementObject), 0,
+    "xml.etree.ElementTree.Element", sizeof(ElementObject), 0,
     /* methods */
     (destructor)element_dealloc,                    /* tp_dealloc */
     0,                                              /* tp_print */
@@ -1913,6 +2085,8 @@
 
 static PyTypeObject ElementIter_Type = {
     PyVarObject_HEAD_INIT(NULL, 0)
+    /* Using the module's name since the pure-Python implementation does not
+       have such a type. */
     "_elementtree._element_iterator",           /* tp_name */
     sizeof(ElementIterObject),                  /* tp_basicsize */
     0,                                          /* tp_itemsize */
@@ -2458,7 +2632,7 @@
 
 static PyTypeObject TreeBuilder_Type = {
     PyVarObject_HEAD_INIT(NULL, 0)
-    "TreeBuilder", sizeof(TreeBuilderObject), 0,
+    "xml.etree.ElementTree.TreeBuilder", sizeof(TreeBuilderObject), 0,
     /* methods */
     (destructor)treebuilder_dealloc,                /* tp_dealloc */
     0,                                              /* tp_print */
@@ -3420,7 +3594,7 @@
 
 static PyTypeObject XMLParser_Type = {
     PyVarObject_HEAD_INIT(NULL, 0)
-    "XMLParser", sizeof(XMLParserObject), 0,
+    "xml.etree.ElementTree.XMLParser", sizeof(XMLParserObject), 0,
     /* methods */
     (destructor)xmlparser_dealloc,                  /* tp_dealloc */
     0,                                              /* tp_print */

-- 
Repository URL: http://hg.python.org/cpython


More information about the Python-checkins mailing list