[Python-checkins] cpython (merge 3.2 -> 3.3): Issue #1470548: XMLGenerator now works with binary output streams.

serhiy.storchaka python-checkins at python.org
Sun Feb 10 13:38:04 CET 2013


http://hg.python.org/cpython/rev/03b878d636cf
changeset:   82124:03b878d636cf
branch:      3.3
parent:      82120:bfe9526606e2
parent:      82123:66f92f76b2ce
user:        Serhiy Storchaka <storchaka at gmail.com>
date:        Sun Feb 10 14:31:07 2013 +0200
summary:
  Issue #1470548: XMLGenerator now works with binary output streams.

files:
  Lib/test/test_sax.py    |  215 ++++++++++++++++++---------
  Lib/xml/sax/saxutils.py |   67 +++++--
  Misc/NEWS               |    2 +
  3 files changed, 192 insertions(+), 92 deletions(-)


diff --git a/Lib/test/test_sax.py b/Lib/test/test_sax.py
--- a/Lib/test/test_sax.py
+++ b/Lib/test/test_sax.py
@@ -13,7 +13,7 @@
 from xml.sax.expatreader import create_parser
 from xml.sax.handler import feature_namespaces
 from xml.sax.xmlreader import InputSource, AttributesImpl, AttributesNSImpl
-from io import StringIO
+from io import BytesIO, StringIO
 import os.path
 import shutil
 from test import support
@@ -173,31 +173,29 @@
 
 # ===== XMLGenerator
 
-start = '<?xml version="1.0" encoding="iso-8859-1"?>\n'
-
-class XmlgenTest(unittest.TestCase):
+class XmlgenTest:
     def test_xmlgen_basic(self):
-        result = StringIO()
+        result = self.ioclass()
         gen = XMLGenerator(result)
         gen.startDocument()
         gen.startElement("doc", {})
         gen.endElement("doc")
         gen.endDocument()
 
-        self.assertEqual(result.getvalue(), start + "<doc></doc>")
+        self.assertEqual(result.getvalue(), self.xml("<doc></doc>"))
 
     def test_xmlgen_basic_empty(self):
-        result = StringIO()
+        result = self.ioclass()
         gen = XMLGenerator(result, short_empty_elements=True)
         gen.startDocument()
         gen.startElement("doc", {})
         gen.endElement("doc")
         gen.endDocument()
 
-        self.assertEqual(result.getvalue(), start + "<doc/>")
+        self.assertEqual(result.getvalue(), self.xml("<doc/>"))
 
     def test_xmlgen_content(self):
-        result = StringIO()
+        result = self.ioclass()
         gen = XMLGenerator(result)
 
         gen.startDocument()
@@ -206,10 +204,10 @@
         gen.endElement("doc")
         gen.endDocument()
 
-        self.assertEqual(result.getvalue(), start + "<doc>huhei</doc>")
+        self.assertEqual(result.getvalue(), self.xml("<doc>huhei</doc>"))
 
     def test_xmlgen_content_empty(self):
-        result = StringIO()
+        result = self.ioclass()
         gen = XMLGenerator(result, short_empty_elements=True)
 
         gen.startDocument()
@@ -218,10 +216,10 @@
         gen.endElement("doc")
         gen.endDocument()
 
-        self.assertEqual(result.getvalue(), start + "<doc>huhei</doc>")
+        self.assertEqual(result.getvalue(), self.xml("<doc>huhei</doc>"))
 
     def test_xmlgen_pi(self):
-        result = StringIO()
+        result = self.ioclass()
         gen = XMLGenerator(result)
 
         gen.startDocument()
@@ -230,10 +228,11 @@
         gen.endElement("doc")
         gen.endDocument()
 
-        self.assertEqual(result.getvalue(), start + "<?test data?><doc></doc>")
+        self.assertEqual(result.getvalue(),
+            self.xml("<?test data?><doc></doc>"))
 
     def test_xmlgen_content_escape(self):
-        result = StringIO()
+        result = self.ioclass()
         gen = XMLGenerator(result)
 
         gen.startDocument()
@@ -243,10 +242,10 @@
         gen.endDocument()
 
         self.assertEqual(result.getvalue(),
-            start + "<doc><huhei&</doc>")
+            self.xml("<doc><huhei&</doc>"))
 
     def test_xmlgen_attr_escape(self):
-        result = StringIO()
+        result = self.ioclass()
         gen = XMLGenerator(result)
 
         gen.startDocument()
@@ -260,13 +259,43 @@
         gen.endElement("doc")
         gen.endDocument()
 
-        self.assertEqual(result.getvalue(), start +
-            ("<doc a='\"'><e a=\"'\"></e>"
-             "<e a=\"'"\"></e>"
-             "<e a=\"

	\"></e></doc>"))
+        self.assertEqual(result.getvalue(), self.xml(
+            "<doc a='\"'><e a=\"'\"></e>"
+            "<e a=\"'"\"></e>"
+            "<e a=\"

	\"></e></doc>"))
+
+    def test_xmlgen_encoding(self):
+        encodings = ('iso-8859-15', 'utf-8', 'utf-8-sig',
+                     'utf-16', 'utf-16be', 'utf-16le',
+                     'utf-32', 'utf-32be', 'utf-32le')
+        for encoding in encodings:
+            result = self.ioclass()
+            gen = XMLGenerator(result, encoding=encoding)
+
+            gen.startDocument()
+            gen.startElement("doc", {"a": '\u20ac'})
+            gen.characters("\u20ac")
+            gen.endElement("doc")
+            gen.endDocument()
+
+            self.assertEqual(result.getvalue(),
+                self.xml('<doc a="\u20ac">\u20ac</doc>', encoding=encoding))
+
+    def test_xmlgen_unencodable(self):
+        result = self.ioclass()
+        gen = XMLGenerator(result, encoding='ascii')
+
+        gen.startDocument()
+        gen.startElement("doc", {"a": '\u20ac'})
+        gen.characters("\u20ac")
+        gen.endElement("doc")
+        gen.endDocument()
+
+        self.assertEqual(result.getvalue(),
+            self.xml('<doc a="€">€</doc>', encoding='ascii'))
 
     def test_xmlgen_ignorable(self):
-        result = StringIO()
+        result = self.ioclass()
         gen = XMLGenerator(result)
 
         gen.startDocument()
@@ -275,10 +304,10 @@
         gen.endElement("doc")
         gen.endDocument()
 
-        self.assertEqual(result.getvalue(), start + "<doc> </doc>")
+        self.assertEqual(result.getvalue(), self.xml("<doc> </doc>"))
 
     def test_xmlgen_ignorable_empty(self):
-        result = StringIO()
+        result = self.ioclass()
         gen = XMLGenerator(result, short_empty_elements=True)
 
         gen.startDocument()
@@ -287,10 +316,10 @@
         gen.endElement("doc")
         gen.endDocument()
 
-        self.assertEqual(result.getvalue(), start + "<doc> </doc>")
+        self.assertEqual(result.getvalue(), self.xml("<doc> </doc>"))
 
     def test_xmlgen_ns(self):
-        result = StringIO()
+        result = self.ioclass()
         gen = XMLGenerator(result)
 
         gen.startDocument()
@@ -303,12 +332,12 @@
         gen.endPrefixMapping("ns1")
         gen.endDocument()
 
-        self.assertEqual(result.getvalue(), start + \
-           ('<ns1:doc xmlns:ns1="%s"><udoc></udoc></ns1:doc>' %
+        self.assertEqual(result.getvalue(), self.xml(
+           '<ns1:doc xmlns:ns1="%s"><udoc></udoc></ns1:doc>' %
                                          ns_uri))
 
     def test_xmlgen_ns_empty(self):
-        result = StringIO()
+        result = self.ioclass()
         gen = XMLGenerator(result, short_empty_elements=True)
 
         gen.startDocument()
@@ -321,12 +350,12 @@
         gen.endPrefixMapping("ns1")
         gen.endDocument()
 
-        self.assertEqual(result.getvalue(), start + \
-           ('<ns1:doc xmlns:ns1="%s"><udoc/></ns1:doc>' %
+        self.assertEqual(result.getvalue(), self.xml(
+           '<ns1:doc xmlns:ns1="%s"><udoc/></ns1:doc>' %
                                          ns_uri))
 
     def test_1463026_1(self):
-        result = StringIO()
+        result = self.ioclass()
         gen = XMLGenerator(result)
 
         gen.startDocument()
@@ -334,10 +363,10 @@
         gen.endElementNS((None, 'a'), 'a')
         gen.endDocument()
 
-        self.assertEqual(result.getvalue(), start+'<a b="c"></a>')
+        self.assertEqual(result.getvalue(), self.xml('<a b="c"></a>'))
 
     def test_1463026_1_empty(self):
-        result = StringIO()
+        result = self.ioclass()
         gen = XMLGenerator(result, short_empty_elements=True)
 
         gen.startDocument()
@@ -345,10 +374,10 @@
         gen.endElementNS((None, 'a'), 'a')
         gen.endDocument()
 
-        self.assertEqual(result.getvalue(), start+'<a b="c"/>')
+        self.assertEqual(result.getvalue(), self.xml('<a b="c"/>'))
 
     def test_1463026_2(self):
-        result = StringIO()
+        result = self.ioclass()
         gen = XMLGenerator(result)
 
         gen.startDocument()
@@ -358,10 +387,10 @@
         gen.endPrefixMapping(None)
         gen.endDocument()
 
-        self.assertEqual(result.getvalue(), start+'<a xmlns="qux"></a>')
+        self.assertEqual(result.getvalue(), self.xml('<a xmlns="qux"></a>'))
 
     def test_1463026_2_empty(self):
-        result = StringIO()
+        result = self.ioclass()
         gen = XMLGenerator(result, short_empty_elements=True)
 
         gen.startDocument()
@@ -371,10 +400,10 @@
         gen.endPrefixMapping(None)
         gen.endDocument()
 
-        self.assertEqual(result.getvalue(), start+'<a xmlns="qux"/>')
+        self.assertEqual(result.getvalue(), self.xml('<a xmlns="qux"/>'))
 
     def test_1463026_3(self):
-        result = StringIO()
+        result = self.ioclass()
         gen = XMLGenerator(result)
 
         gen.startDocument()
@@ -385,10 +414,10 @@
         gen.endDocument()
 
         self.assertEqual(result.getvalue(),
-            start+'<my:a xmlns:my="qux" b="c"></my:a>')
+            self.xml('<my:a xmlns:my="qux" b="c"></my:a>'))
 
     def test_1463026_3_empty(self):
-        result = StringIO()
+        result = self.ioclass()
         gen = XMLGenerator(result, short_empty_elements=True)
 
         gen.startDocument()
@@ -399,7 +428,7 @@
         gen.endDocument()
 
         self.assertEqual(result.getvalue(),
-            start+'<my:a xmlns:my="qux" b="c"/>')
+            self.xml('<my:a xmlns:my="qux" b="c"/>'))
 
     def test_5027_1(self):
         # The xml prefix (as in xml:lang below) is reserved and bound by
@@ -416,13 +445,13 @@
 
         parser = make_parser()
         parser.setFeature(feature_namespaces, True)
-        result = StringIO()
+        result = self.ioclass()
         gen = XMLGenerator(result)
         parser.setContentHandler(gen)
         parser.parse(test_xml)
 
         self.assertEqual(result.getvalue(),
-                         start + (
+                         self.xml(
                          '<a:g1 xmlns:a="http://example.com/ns">'
                           '<a:g2 xml:lang="en">Hello</a:g2>'
                          '</a:g1>'))
@@ -435,7 +464,7 @@
         #
         # This test demonstrates the bug by direct manipulation of the
         # XMLGenerator.
-        result = StringIO()
+        result = self.ioclass()
         gen = XMLGenerator(result)
 
         gen.startDocument()
@@ -450,15 +479,57 @@
         gen.endDocument()
 
         self.assertEqual(result.getvalue(),
-                         start + (
+                         self.xml(
                          '<a:g1 xmlns:a="http://example.com/ns">'
                           '<a:g2 xml:lang="en">Hello</a:g2>'
                          '</a:g1>'))
 
+    def test_no_close_file(self):
+        result = self.ioclass()
+        def func(out):
+            gen = XMLGenerator(out)
+            gen.startDocument()
+            gen.startElement("doc", {})
+        func(result)
+        self.assertFalse(result.closed)
+
+class StringXmlgenTest(XmlgenTest, unittest.TestCase):
+    ioclass = StringIO
+
+    def xml(self, doc, encoding='iso-8859-1'):
+        return '<?xml version="1.0" encoding="%s"?>\n%s' % (encoding, doc)
+
+    test_xmlgen_unencodable = None
+
+class BytesXmlgenTest(XmlgenTest, unittest.TestCase):
+    ioclass = BytesIO
+
+    def xml(self, doc, encoding='iso-8859-1'):
+        return ('<?xml version="1.0" encoding="%s"?>\n%s' %
+                (encoding, doc)).encode(encoding, 'xmlcharrefreplace')
+
+class WriterXmlgenTest(BytesXmlgenTest):
+    class ioclass(list):
+        write = list.append
+        closed = False
+
+        def seekable(self):
+            return True
+
+        def tell(self):
+            # return 0 at start and not 0 after start
+            return len(self)
+
+        def getvalue(self):
+            return b''.join(self)
+
+
+start = b'<?xml version="1.0" encoding="iso-8859-1"?>\n'
+
 
 class XMLFilterBaseTest(unittest.TestCase):
     def test_filter_basic(self):
-        result = StringIO()
+        result = BytesIO()
         gen = XMLGenerator(result)
         filter = XMLFilterBase()
         filter.setContentHandler(gen)
@@ -470,7 +541,7 @@
         filter.endElement("doc")
         filter.endDocument()
 
-        self.assertEqual(result.getvalue(), start + "<doc>content </doc>")
+        self.assertEqual(result.getvalue(), start + b"<doc>content </doc>")
 
 # ===========================================================================
 #
@@ -478,7 +549,7 @@
 #
 # ===========================================================================
 
-with open(TEST_XMLFILE_OUT) as f:
+with open(TEST_XMLFILE_OUT, 'rb') as f:
     xml_test_out = f.read()
 
 class ExpatReaderTest(XmlTestBase):
@@ -487,11 +558,11 @@
 
     def test_expat_file(self):
         parser = create_parser()
-        result = StringIO()
+        result = BytesIO()
         xmlgen = XMLGenerator(result)
 
         parser.setContentHandler(xmlgen)
-        with open(TEST_XMLFILE) as f:
+        with open(TEST_XMLFILE, 'rb') as f:
             parser.parse(f)
 
         self.assertEqual(result.getvalue(), xml_test_out)
@@ -503,7 +574,7 @@
         self.addCleanup(support.unlink, fname)
 
         parser = create_parser()
-        result = StringIO()
+        result = BytesIO()
         xmlgen = XMLGenerator(result)
 
         parser.setContentHandler(xmlgen)
@@ -547,13 +618,13 @@
 
         def resolveEntity(self, publicId, systemId):
             inpsrc = InputSource()
-            inpsrc.setByteStream(StringIO("<entity/>"))
+            inpsrc.setByteStream(BytesIO(b"<entity/>"))
             return inpsrc
 
     def test_expat_entityresolver(self):
         parser = create_parser()
         parser.setEntityResolver(self.TestEntityResolver())
-        result = StringIO()
+        result = BytesIO()
         parser.setContentHandler(XMLGenerator(result))
 
         parser.feed('<!DOCTYPE doc [\n')
@@ -563,7 +634,7 @@
         parser.close()
 
         self.assertEqual(result.getvalue(), start +
-                         "<doc><entity></entity></doc>")
+                         b"<doc><entity></entity></doc>")
 
     # ===== Attributes support
 
@@ -632,7 +703,7 @@
 
     def test_expat_inpsource_filename(self):
         parser = create_parser()
-        result = StringIO()
+        result = BytesIO()
         xmlgen = XMLGenerator(result)
 
         parser.setContentHandler(xmlgen)
@@ -642,7 +713,7 @@
 
     def test_expat_inpsource_sysid(self):
         parser = create_parser()
-        result = StringIO()
+        result = BytesIO()
         xmlgen = XMLGenerator(result)
 
         parser.setContentHandler(xmlgen)
@@ -657,7 +728,7 @@
         self.addCleanup(support.unlink, fname)
 
         parser = create_parser()
-        result = StringIO()
+        result = BytesIO()
         xmlgen = XMLGenerator(result)
 
         parser.setContentHandler(xmlgen)
@@ -667,12 +738,12 @@
 
     def test_expat_inpsource_stream(self):
         parser = create_parser()
-        result = StringIO()
+        result = BytesIO()
         xmlgen = XMLGenerator(result)
 
         parser.setContentHandler(xmlgen)
         inpsrc = InputSource()
-        with open(TEST_XMLFILE) as f:
+        with open(TEST_XMLFILE, 'rb') as f:
             inpsrc.setByteStream(f)
             parser.parse(inpsrc)
 
@@ -681,7 +752,7 @@
     # ===== IncrementalParser support
 
     def test_expat_incremental(self):
-        result = StringIO()
+        result = BytesIO()
         xmlgen = XMLGenerator(result)
         parser = create_parser()
         parser.setContentHandler(xmlgen)
@@ -690,10 +761,10 @@
         parser.feed("</doc>")
         parser.close()
 
-        self.assertEqual(result.getvalue(), start + "<doc></doc>")
+        self.assertEqual(result.getvalue(), start + b"<doc></doc>")
 
     def test_expat_incremental_reset(self):
-        result = StringIO()
+        result = BytesIO()
         xmlgen = XMLGenerator(result)
         parser = create_parser()
         parser.setContentHandler(xmlgen)
@@ -701,7 +772,7 @@
         parser.feed("<doc>")
         parser.feed("text")
 
-        result = StringIO()
+        result = BytesIO()
         xmlgen = XMLGenerator(result)
         parser.setContentHandler(xmlgen)
         parser.reset()
@@ -711,12 +782,12 @@
         parser.feed("</doc>")
         parser.close()
 
-        self.assertEqual(result.getvalue(), start + "<doc>text</doc>")
+        self.assertEqual(result.getvalue(), start + b"<doc>text</doc>")
 
     # ===== Locator support
 
     def test_expat_locator_noinfo(self):
-        result = StringIO()
+        result = BytesIO()
         xmlgen = XMLGenerator(result)
         parser = create_parser()
         parser.setContentHandler(xmlgen)
@@ -730,7 +801,7 @@
         self.assertEqual(parser.getLineNumber(), 1)
 
     def test_expat_locator_withinfo(self):
-        result = StringIO()
+        result = BytesIO()
         xmlgen = XMLGenerator(result)
         parser = create_parser()
         parser.setContentHandler(xmlgen)
@@ -745,7 +816,7 @@
         shutil.copyfile(TEST_XMLFILE, fname)
         self.addCleanup(support.unlink, fname)
 
-        result = StringIO()
+        result = BytesIO()
         xmlgen = XMLGenerator(result)
         parser = create_parser()
         parser.setContentHandler(xmlgen)
@@ -766,7 +837,7 @@
         parser = create_parser()
         parser.setContentHandler(ContentHandler()) # do nothing
         source = InputSource()
-        source.setByteStream(StringIO("<foo bar foobar>"))   #ill-formed
+        source.setByteStream(BytesIO(b"<foo bar foobar>"))   #ill-formed
         name = "a file name"
         source.setSystemId(name)
         try:
@@ -857,7 +928,9 @@
 def test_main():
     run_unittest(MakeParserTest,
                  SaxutilsTest,
-                 XmlgenTest,
+                 StringXmlgenTest,
+                 BytesXmlgenTest,
+                 WriterXmlgenTest,
                  ExpatReaderTest,
                  ErrorReportingTest,
                  XmlReaderTest)
diff --git a/Lib/xml/sax/saxutils.py b/Lib/xml/sax/saxutils.py
--- a/Lib/xml/sax/saxutils.py
+++ b/Lib/xml/sax/saxutils.py
@@ -4,18 +4,10 @@
 """
 
 import os, urllib.parse, urllib.request
+import io
 from . import handler
 from . import xmlreader
 
-# See whether the xmlcharrefreplace error handler is
-# supported
-try:
-    from codecs import xmlcharrefreplace_errors
-    _error_handling = "xmlcharrefreplace"
-    del xmlcharrefreplace_errors
-except ImportError:
-    _error_handling = "strict"
-
 def __dict_replace(s, d):
     """Replace substrings of a string using a dictionary."""
     for key, value in d.items():
@@ -76,14 +68,50 @@
     return data
 
 
+def _gettextwriter(out, encoding):
+    if out is None:
+        import sys
+        return sys.stdout
+
+    if isinstance(out, io.TextIOBase):
+        # use a text writer as is
+        return out
+
+    # wrap a binary writer with TextIOWrapper
+    if isinstance(out, io.RawIOBase):
+        # Keep the original file open when the TextIOWrapper is
+        # destroyed
+        class _wrapper:
+            __class__ = out.__class__
+            def __getattr__(self, name):
+                return getattr(out, name)
+        buffer = _wrapper()
+        buffer.close = lambda: None
+    else:
+        # This is to handle passed objects that aren't in the
+        # IOBase hierarchy, but just have a write method
+        buffer = io.BufferedIOBase()
+        buffer.writable = lambda: True
+        buffer.write = out.write
+        try:
+            # TextIOWrapper uses this methods to determine
+            # if BOM (for UTF-16, etc) should be added
+            buffer.seekable = out.seekable
+            buffer.tell = out.tell
+        except AttributeError:
+            pass
+    return io.TextIOWrapper(buffer, encoding=encoding,
+                            errors='xmlcharrefreplace',
+                            newline='\n',
+                            write_through=True)
+
 class XMLGenerator(handler.ContentHandler):
 
     def __init__(self, out=None, encoding="iso-8859-1", short_empty_elements=False):
-        if out is None:
-            import sys
-            out = sys.stdout
         handler.ContentHandler.__init__(self)
-        self._out = out
+        out = _gettextwriter(out, encoding)
+        self._write = out.write
+        self._flush = out.flush
         self._ns_contexts = [{}] # contains uri -> prefix dicts
         self._current_context = self._ns_contexts[-1]
         self._undeclared_ns_maps = []
@@ -91,12 +119,6 @@
         self._short_empty_elements = short_empty_elements
         self._pending_start_element = False
 
-    def _write(self, text):
-        if isinstance(text, str):
-            self._out.write(text)
-        else:
-            self._out.write(text.encode(self._encoding, _error_handling))
-
     def _qname(self, name):
         """Builds a qualified name from a (ns_url, localname) pair"""
         if name[0]:
@@ -125,6 +147,9 @@
         self._write('<?xml version="1.0" encoding="%s"?>\n' %
                         self._encoding)
 
+    def endDocument(self):
+        self._flush()
+
     def startPrefixMapping(self, prefix, uri):
         self._ns_contexts.append(self._current_context.copy())
         self._current_context[uri] = prefix
@@ -157,9 +182,9 @@
 
         for prefix, uri in self._undeclared_ns_maps:
             if prefix:
-                self._out.write(' xmlns:%s="%s"' % (prefix, uri))
+                self._write(' xmlns:%s="%s"' % (prefix, uri))
             else:
-                self._out.write(' xmlns="%s"' % uri)
+                self._write(' xmlns="%s"' % uri)
         self._undeclared_ns_maps = []
 
         for (name, value) in attrs.items():
diff --git a/Misc/NEWS b/Misc/NEWS
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -172,6 +172,8 @@
 Library
 -------
 
+- Issue #1470548: XMLGenerator now works with binary output streams.
+
 - Issue #6975: os.path.realpath() now correctly resolves multiple nested
   symlinks on POSIX platforms.
 

-- 
Repository URL: http://hg.python.org/cpython


More information about the Python-checkins mailing list