[Python-checkins] bpo-36819: Fix crashes in built-in encoders with weird error handlers (GH-28593)

miss-islington webhook-mailer at python.org
Mon May 2 05:59:03 EDT 2022


https://github.com/python/cpython/commit/d985c8e2e0d38e6905cc27124008eb926254247b
commit: d985c8e2e0d38e6905cc27124008eb926254247b
branch: 3.10
author: Miss Islington (bot) <31488909+miss-islington at users.noreply.github.com>
committer: miss-islington <31488909+miss-islington at users.noreply.github.com>
date: 2022-05-02T02:58:41-07:00
summary:

bpo-36819: Fix crashes in built-in encoders with weird error handlers (GH-28593)


If the error handler returns position less or equal than the starting
position of non-encodable characters, most of built-in encoders didn't
properly re-size the output buffer. This led to out-of-bounds writes,
and segfaults.
(cherry picked from commit 18b07d773e09a2719e69aeaa925d5abb7ba0c068)

Co-authored-by: Serhiy Storchaka <storchaka at gmail.com>

files:
A Misc/NEWS.d/next/Core and Builtins/2021-09-28-10-58-30.bpo-36819.cyV50C.rst
M Lib/test/test_codeccallbacks.py
M Objects/stringlib/codecs.h
M Objects/unicodeobject.c

diff --git a/Lib/test/test_codeccallbacks.py b/Lib/test/test_codeccallbacks.py
index 243f002c4ecad..4991330489d13 100644
--- a/Lib/test/test_codeccallbacks.py
+++ b/Lib/test/test_codeccallbacks.py
@@ -1,5 +1,6 @@
 import codecs
 import html.entities
+import itertools
 import sys
 import unicodedata
 import unittest
@@ -22,6 +23,18 @@ def handle(self, exc):
             self.pos = len(exc.object)
         return ("<?>", oldpos)
 
+class RepeatedPosReturn:
+    def __init__(self, repl="<?>"):
+        self.repl = repl
+        self.pos = 0
+        self.count = 0
+
+    def handle(self, exc):
+        if self.count > 0:
+            self.count -= 1
+            return (self.repl, self.pos)
+        return (self.repl, exc.end)
+
 # A UnicodeEncodeError object with a bad start attribute
 class BadStartUnicodeEncodeError(UnicodeEncodeError):
     def __init__(self):
@@ -783,20 +796,104 @@ def test_lookup(self):
             codecs.lookup_error("namereplace")
         )
 
-    def test_unencodablereplacement(self):
+    def test_encode_nonascii_replacement(self):
+        def handle(exc):
+            if isinstance(exc, UnicodeEncodeError):
+                return (repl, exc.end)
+            raise TypeError("don't know how to handle %r" % exc)
+        codecs.register_error("test.replacing", handle)
+
+        for enc, input, repl in (
+                ("ascii", "[¤]", "abc"),
+                ("iso-8859-1", "[€]", "½¾"),
+                ("iso-8859-15", "[¤]", "œŸ"),
+        ):
+            res = input.encode(enc, "test.replacing")
+            self.assertEqual(res, ("[" + repl + "]").encode(enc))
+
+        for enc, input, repl in (
+                ("utf-8", "[\udc80]", "\U0001f40d"),
+                ("utf-16", "[\udc80]", "\U0001f40d"),
+                ("utf-32", "[\udc80]", "\U0001f40d"),
+        ):
+            with self.subTest(encoding=enc):
+                with self.assertRaises(UnicodeEncodeError) as cm:
+                    input.encode(enc, "test.replacing")
+                exc = cm.exception
+                self.assertEqual(exc.start, 1)
+                self.assertEqual(exc.end, 2)
+                self.assertEqual(exc.object, input)
+
+    def test_encode_unencodable_replacement(self):
         def unencrepl(exc):
             if isinstance(exc, UnicodeEncodeError):
-                return ("\u4242", exc.end)
+                return (repl, exc.end)
             else:
                 raise TypeError("don't know how to handle %r" % exc)
         codecs.register_error("test.unencreplhandler", unencrepl)
-        for enc in ("ascii", "iso-8859-1", "iso-8859-15"):
-            self.assertRaises(
-                UnicodeEncodeError,
-                "\u4242".encode,
-                enc,
-                "test.unencreplhandler"
-            )
+
+        for enc, input, repl in (
+                ("ascii", "[¤]", "½"),
+                ("iso-8859-1", "[€]", "œ"),
+                ("iso-8859-15", "[¤]", "½"),
+                ("utf-8", "[\udc80]", "\udcff"),
+                ("utf-16", "[\udc80]", "\udcff"),
+                ("utf-32", "[\udc80]", "\udcff"),
+        ):
+            with self.subTest(encoding=enc):
+                with self.assertRaises(UnicodeEncodeError) as cm:
+                    input.encode(enc, "test.unencreplhandler")
+                exc = cm.exception
+                self.assertEqual(exc.start, 1)
+                self.assertEqual(exc.end, 2)
+                self.assertEqual(exc.object, input)
+
+    def test_encode_bytes_replacement(self):
+        def handle(exc):
+            if isinstance(exc, UnicodeEncodeError):
+                return (repl, exc.end)
+            raise TypeError("don't know how to handle %r" % exc)
+        codecs.register_error("test.replacing", handle)
+
+        # It works even if the bytes sequence is not decodable.
+        for enc, input, repl in (
+                ("ascii", "[¤]", b"\xbd\xbe"),
+                ("iso-8859-1", "[€]", b"\xbd\xbe"),
+                ("iso-8859-15", "[¤]", b"\xbd\xbe"),
+                ("utf-8", "[\udc80]", b"\xbd\xbe"),
+                ("utf-16le", "[\udc80]", b"\xbd\xbe"),
+                ("utf-16be", "[\udc80]", b"\xbd\xbe"),
+                ("utf-32le", "[\udc80]", b"\xbc\xbd\xbe\xbf"),
+                ("utf-32be", "[\udc80]", b"\xbc\xbd\xbe\xbf"),
+        ):
+            with self.subTest(encoding=enc):
+                res = input.encode(enc, "test.replacing")
+                self.assertEqual(res, "[".encode(enc) + repl + "]".encode(enc))
+
+    def test_encode_odd_bytes_replacement(self):
+        def handle(exc):
+            if isinstance(exc, UnicodeEncodeError):
+                return (repl, exc.end)
+            raise TypeError("don't know how to handle %r" % exc)
+        codecs.register_error("test.replacing", handle)
+
+        input = "[\udc80]"
+        # Tests in which the replacement bytestring contains not whole number
+        # of code units.
+        for enc, repl in (
+            *itertools.product(("utf-16le", "utf-16be"),
+                               [b"a", b"abc"]),
+            *itertools.product(("utf-32le", "utf-32be"),
+                               [b"a", b"ab", b"abc", b"abcde"]),
+        ):
+            with self.subTest(encoding=enc, repl=repl):
+                with self.assertRaises(UnicodeEncodeError) as cm:
+                    input.encode(enc, "test.replacing")
+                exc = cm.exception
+                self.assertEqual(exc.start, 1)
+                self.assertEqual(exc.end, 2)
+                self.assertEqual(exc.object, input)
+                self.assertEqual(exc.reason, "surrogates not allowed")
 
     def test_badregistercall(self):
         # enhance coverage of:
@@ -940,6 +1037,68 @@ def __getitem__(self, key):
             self.assertRaises(ValueError, codecs.charmap_encode, "\xff", err, D())
             self.assertRaises(TypeError, codecs.charmap_encode, "\xff", err, {0xff: 300})
 
+    def test_decodehelper_bug36819(self):
+        handler = RepeatedPosReturn("x")
+        codecs.register_error("test.bug36819", handler.handle)
+
+        testcases = [
+            ("ascii", b"\xff"),
+            ("utf-8", b"\xff"),
+            ("utf-16be", b'\xdc\x80'),
+            ("utf-32be", b'\x00\x00\xdc\x80'),
+            ("iso-8859-6", b"\xff"),
+        ]
+        for enc, bad in testcases:
+            input = "abcd".encode(enc) + bad
+            with self.subTest(encoding=enc):
+                handler.count = 50
+                decoded = input.decode(enc, "test.bug36819")
+                self.assertEqual(decoded, 'abcdx' * 51)
+
+    def test_encodehelper_bug36819(self):
+        handler = RepeatedPosReturn()
+        codecs.register_error("test.bug36819", handler.handle)
+
+        input = "abcd\udc80"
+        encodings = ["ascii", "latin1", "utf-8", "utf-16", "utf-32"]  # built-in
+        encodings += ["iso-8859-15"]  # charmap codec
+        if sys.platform == 'win32':
+            encodings = ["mbcs", "oem"]  # code page codecs
+
+        handler.repl = "\udcff"
+        for enc in encodings:
+            with self.subTest(encoding=enc):
+                handler.count = 50
+                with self.assertRaises(UnicodeEncodeError) as cm:
+                    input.encode(enc, "test.bug36819")
+                exc = cm.exception
+                self.assertEqual(exc.start, 4)
+                self.assertEqual(exc.end, 5)
+                self.assertEqual(exc.object, input)
+        if sys.platform == "win32":
+            handler.count = 50
+            with self.assertRaises(UnicodeEncodeError) as cm:
+                codecs.code_page_encode(437, input, "test.bug36819")
+            exc = cm.exception
+            self.assertEqual(exc.start, 4)
+            self.assertEqual(exc.end, 5)
+            self.assertEqual(exc.object, input)
+
+        handler.repl = "x"
+        for enc in encodings:
+            with self.subTest(encoding=enc):
+                # The interpreter should segfault after a handful of attempts.
+                # 50 was chosen to try to ensure a segfault without a fix,
+                # but not OOM a machine with one.
+                handler.count = 50
+                encoded = input.encode(enc, "test.bug36819")
+                self.assertEqual(encoded.decode(enc), "abcdx" * 51)
+        if sys.platform == "win32":
+            handler.count = 50
+            encoded = codecs.code_page_encode(437, input, "test.bug36819")
+            self.assertEqual(encoded[0].decode(), "abcdx" * 51)
+            self.assertEqual(encoded[1], len(input))
+
     def test_translatehelper(self):
         # enhance coverage of:
         # Objects/unicodeobject.c::unicode_encode_call_errorhandler()
diff --git a/Misc/NEWS.d/next/Core and Builtins/2021-09-28-10-58-30.bpo-36819.cyV50C.rst b/Misc/NEWS.d/next/Core and Builtins/2021-09-28-10-58-30.bpo-36819.cyV50C.rst
new file mode 100644
index 0000000000000..32bb55a90e6c4
--- /dev/null
+++ b/Misc/NEWS.d/next/Core and Builtins/2021-09-28-10-58-30.bpo-36819.cyV50C.rst	
@@ -0,0 +1,2 @@
+Fix crashes in built-in encoders with error handlers that return position
+less or equal than the starting position of non-encodable characters.
diff --git a/Objects/stringlib/codecs.h b/Objects/stringlib/codecs.h
index b17cda18f54b3..958cc86147815 100644
--- a/Objects/stringlib/codecs.h
+++ b/Objects/stringlib/codecs.h
@@ -387,8 +387,19 @@ STRINGLIB(utf8_encoder)(_PyBytesWriter *writer,
                 if (!rep)
                     goto error;
 
-                /* subtract preallocated bytes */
-                writer->min_size -= max_char_size * (newpos - startpos);
+                if (newpos < startpos) {
+                    writer->overallocate = 1;
+                    p = _PyBytesWriter_Prepare(writer, p,
+                                               max_char_size * (startpos - newpos));
+                    if (p == NULL)
+                        goto error;
+                }
+                else {
+                    /* subtract preallocated bytes */
+                    writer->min_size -= max_char_size * (newpos - startpos);
+                    /* Only overallocate the buffer if it's not the last write */
+                    writer->overallocate = (newpos < size);
+                }
 
                 if (PyBytes_Check(rep)) {
                     p = _PyBytesWriter_WriteBytes(writer, p,
diff --git a/Objects/unicodeobject.c b/Objects/unicodeobject.c
index 5ddde5054428a..b7ec1f28d7e86 100644
--- a/Objects/unicodeobject.c
+++ b/Objects/unicodeobject.c
@@ -5959,7 +5959,7 @@ _PyUnicode_EncodeUTF32(PyObject *str,
 
     pos = 0;
     while (pos < len) {
-        Py_ssize_t repsize, moreunits;
+        Py_ssize_t newpos, repsize, moreunits;
 
         if (kind == PyUnicode_2BYTE_KIND) {
             pos += ucs2lib_utf32_encode((const Py_UCS2 *)data + pos, len - pos,
@@ -5976,7 +5976,7 @@ _PyUnicode_EncodeUTF32(PyObject *str,
         rep = unicode_encode_call_errorhandler(
                 errors, &errorHandler,
                 encoding, "surrogates not allowed",
-                str, &exc, pos, pos + 1, &pos);
+                str, &exc, pos, pos + 1, &newpos);
         if (!rep)
             goto error;
 
@@ -5984,7 +5984,7 @@ _PyUnicode_EncodeUTF32(PyObject *str,
             repsize = PyBytes_GET_SIZE(rep);
             if (repsize & 3) {
                 raise_encode_exception(&exc, encoding,
-                                       str, pos - 1, pos,
+                                       str, pos, pos + 1,
                                        "surrogates not allowed");
                 goto error;
             }
@@ -5997,28 +5997,30 @@ _PyUnicode_EncodeUTF32(PyObject *str,
             moreunits = repsize = PyUnicode_GET_LENGTH(rep);
             if (!PyUnicode_IS_ASCII(rep)) {
                 raise_encode_exception(&exc, encoding,
-                                       str, pos - 1, pos,
+                                       str, pos, pos + 1,
                                        "surrogates not allowed");
                 goto error;
             }
         }
+        moreunits += pos - newpos;
+        pos = newpos;
 
         /* four bytes are reserved for each surrogate */
-        if (moreunits > 1) {
+        if (moreunits > 0) {
             Py_ssize_t outpos = out - (uint32_t*) PyBytes_AS_STRING(v);
             if (moreunits >= (PY_SSIZE_T_MAX - PyBytes_GET_SIZE(v)) / 4) {
                 /* integer overflow */
                 PyErr_NoMemory();
                 goto error;
             }
-            if (_PyBytes_Resize(&v, PyBytes_GET_SIZE(v) + 4 * (moreunits - 1)) < 0)
+            if (_PyBytes_Resize(&v, PyBytes_GET_SIZE(v) + 4 * moreunits) < 0)
                 goto error;
             out = (uint32_t*) PyBytes_AS_STRING(v) + outpos;
         }
 
         if (PyBytes_Check(rep)) {
             memcpy(out, PyBytes_AS_STRING(rep), repsize);
-            out += moreunits;
+            out += repsize / 4;
         } else /* rep is unicode */ {
             assert(PyUnicode_KIND(rep) == PyUnicode_1BYTE_KIND);
             ucs1lib_utf32_encode(PyUnicode_1BYTE_DATA(rep), repsize,
@@ -6311,7 +6313,7 @@ _PyUnicode_EncodeUTF16(PyObject *str,
 
     pos = 0;
     while (pos < len) {
-        Py_ssize_t repsize, moreunits;
+        Py_ssize_t newpos, repsize, moreunits;
 
         if (kind == PyUnicode_2BYTE_KIND) {
             pos += ucs2lib_utf16_encode((const Py_UCS2 *)data + pos, len - pos,
@@ -6328,7 +6330,7 @@ _PyUnicode_EncodeUTF16(PyObject *str,
         rep = unicode_encode_call_errorhandler(
                 errors, &errorHandler,
                 encoding, "surrogates not allowed",
-                str, &exc, pos, pos + 1, &pos);
+                str, &exc, pos, pos + 1, &newpos);
         if (!rep)
             goto error;
 
@@ -6336,7 +6338,7 @@ _PyUnicode_EncodeUTF16(PyObject *str,
             repsize = PyBytes_GET_SIZE(rep);
             if (repsize & 1) {
                 raise_encode_exception(&exc, encoding,
-                                       str, pos - 1, pos,
+                                       str, pos, pos + 1,
                                        "surrogates not allowed");
                 goto error;
             }
@@ -6349,28 +6351,30 @@ _PyUnicode_EncodeUTF16(PyObject *str,
             moreunits = repsize = PyUnicode_GET_LENGTH(rep);
             if (!PyUnicode_IS_ASCII(rep)) {
                 raise_encode_exception(&exc, encoding,
-                                       str, pos - 1, pos,
+                                       str, pos, pos + 1,
                                        "surrogates not allowed");
                 goto error;
             }
         }
+        moreunits += pos - newpos;
+        pos = newpos;
 
         /* two bytes are reserved for each surrogate */
-        if (moreunits > 1) {
+        if (moreunits > 0) {
             Py_ssize_t outpos = out - (unsigned short*) PyBytes_AS_STRING(v);
             if (moreunits >= (PY_SSIZE_T_MAX - PyBytes_GET_SIZE(v)) / 2) {
                 /* integer overflow */
                 PyErr_NoMemory();
                 goto error;
             }
-            if (_PyBytes_Resize(&v, PyBytes_GET_SIZE(v) + 2 * (moreunits - 1)) < 0)
+            if (_PyBytes_Resize(&v, PyBytes_GET_SIZE(v) + 2 * moreunits) < 0)
                 goto error;
             out = (unsigned short*) PyBytes_AS_STRING(v) + outpos;
         }
 
         if (PyBytes_Check(rep)) {
             memcpy(out, PyBytes_AS_STRING(rep), repsize);
-            out += moreunits;
+            out += repsize / 2;
         } else /* rep is unicode */ {
             assert(PyUnicode_KIND(rep) == PyUnicode_1BYTE_KIND);
             ucs1lib_utf16_encode(PyUnicode_1BYTE_DATA(rep), repsize,
@@ -7297,8 +7301,19 @@ unicode_encode_ucs1(PyObject *unicode,
                 if (rep == NULL)
                     goto onError;
 
-                /* subtract preallocated bytes */
-                writer.min_size -= newpos - collstart;
+                if (newpos < collstart) {
+                    writer.overallocate = 1;
+                    str = _PyBytesWriter_Prepare(&writer, str,
+                                                 collstart - newpos);
+                    if (str == NULL)
+                        goto onError;
+                }
+                else {
+                    /* subtract preallocated bytes */
+                    writer.min_size -= newpos - collstart;
+                    /* Only overallocate the buffer if it's not the last write */
+                    writer.overallocate = (newpos < size);
+                }
 
                 if (PyBytes_Check(rep)) {
                     /* Directly copy bytes result to output. */
@@ -8104,13 +8119,14 @@ encode_code_page_errors(UINT code_page, PyObject **outbytes,
                   pos, pos + 1, &newpos);
         if (rep == NULL)
             goto error;
-        pos = newpos;
 
+        Py_ssize_t morebytes = pos - newpos;
         if (PyBytes_Check(rep)) {
             outsize = PyBytes_GET_SIZE(rep);
-            if (outsize != 1) {
+            morebytes += outsize;
+            if (morebytes > 0) {
                 Py_ssize_t offset = out - PyBytes_AS_STRING(*outbytes);
-                newoutsize = PyBytes_GET_SIZE(*outbytes) + (outsize - 1);
+                newoutsize = PyBytes_GET_SIZE(*outbytes) + morebytes;
                 if (_PyBytes_Resize(outbytes, newoutsize) < 0) {
                     Py_DECREF(rep);
                     goto error;
@@ -8131,9 +8147,10 @@ encode_code_page_errors(UINT code_page, PyObject **outbytes,
             }
 
             outsize = PyUnicode_GET_LENGTH(rep);
-            if (outsize != 1) {
+            morebytes += outsize;
+            if (morebytes > 0) {
                 Py_ssize_t offset = out - PyBytes_AS_STRING(*outbytes);
-                newoutsize = PyBytes_GET_SIZE(*outbytes) + (outsize - 1);
+                newoutsize = PyBytes_GET_SIZE(*outbytes) + morebytes;
                 if (_PyBytes_Resize(outbytes, newoutsize) < 0) {
                     Py_DECREF(rep);
                     goto error;
@@ -8156,6 +8173,7 @@ encode_code_page_errors(UINT code_page, PyObject **outbytes,
                 out++;
             }
         }
+        pos = newpos;
         Py_DECREF(rep);
     }
     /* write a NUL byte */



More information about the Python-checkins mailing list