[Python-checkins] (no subject)

Batuhan Taşkaya webhook-mailer at python.org
Mon Mar 2 13:59:26 EST 2020




To: python-checkins at python.org
Subject: bpo-38870: Add docstring support to ast.unparse (GH-17760)
Content-Type: text/plain; charset="utf-8"
Content-Transfer-Encoding: quoted-printable
MIME-Version: 1.0

https://github.com/python/cpython/commit/89aa4694fc8c6d190325ef8ed6ce6a6b8efb=
3e50
commit: 89aa4694fc8c6d190325ef8ed6ce6a6b8efb3e50
branch: master
author: Batuhan Ta=C5=9Fkaya <47358913+isidentical at users.noreply.github.com>
committer: GitHub <noreply at github.com>
date: 2020-03-02T18:59:01Z
summary:

bpo-38870: Add docstring support to ast.unparse (GH-17760)

Allow ast.unparse to detect docstrings in functions, modules and classes and =
produce
nicely formatted unparsed output for said docstrings.

Co-Authored-By: Pablo Galindo <Pablogsal at gmail.com>

files:
M Lib/ast.py
M Lib/test/test_unparse.py

diff --git a/Lib/ast.py b/Lib/ast.py
index 4839201e2e234..93ffa1edc84d5 100644
--- a/Lib/ast.py
+++ b/Lib/ast.py
@@ -667,6 +667,22 @@ def set_precedence(self, precedence, *nodes):
         for node in nodes:
             self._precedences[node] =3D precedence
=20
+    def get_raw_docstring(self, node):
+        """If a docstring node is found in the body of the *node* parameter,
+        return that docstring node, None otherwise.
+
+        Logic mirrored from ``_PyAST_GetDocString``."""
+        if not isinstance(
+            node, (AsyncFunctionDef, FunctionDef, ClassDef, Module)
+        ) or len(node.body) < 1:
+            return None
+        node =3D node.body[0]
+        if not isinstance(node, Expr):
+            return None
+        node =3D node.value
+        if isinstance(node, Constant) and isinstance(node.value, str):
+            return node
+
     def traverse(self, node):
         if isinstance(node, list):
             for item in node:
@@ -681,9 +697,15 @@ def visit(self, node):
         self.traverse(node)
         return "".join(self._source)
=20
+    def _write_docstring_and_traverse_body(self, node):
+        if (docstring :=3D self.get_raw_docstring(node)):
+            self._write_docstring(docstring)
+            self.traverse(node.body[1:])
+        else:
+            self.traverse(node.body)
+
     def visit_Module(self, node):
-        for subnode in node.body:
-            self.traverse(subnode)
+        self._write_docstring_and_traverse_body(node)
=20
     def visit_Expr(self, node):
         self.fill()
@@ -850,15 +872,15 @@ def visit_ClassDef(self, node):
                 self.traverse(e)
=20
         with self.block():
-            self.traverse(node.body)
+            self._write_docstring_and_traverse_body(node)
=20
     def visit_FunctionDef(self, node):
-        self.__FunctionDef_helper(node, "def")
+        self._function_helper(node, "def")
=20
     def visit_AsyncFunctionDef(self, node):
-        self.__FunctionDef_helper(node, "async def")
+        self._function_helper(node, "async def")
=20
-    def __FunctionDef_helper(self, node, fill_suffix):
+    def _function_helper(self, node, fill_suffix):
         self.write("\n")
         for deco in node.decorator_list:
             self.fill("@")
@@ -871,15 +893,15 @@ def __FunctionDef_helper(self, node, fill_suffix):
             self.write(" -> ")
             self.traverse(node.returns)
         with self.block():
-            self.traverse(node.body)
+            self._write_docstring_and_traverse_body(node)
=20
     def visit_For(self, node):
-        self.__For_helper("for ", node)
+        self._for_helper("for ", node)
=20
     def visit_AsyncFor(self, node):
-        self.__For_helper("async for ", node)
+        self._for_helper("async for ", node)
=20
-    def __For_helper(self, fill, node):
+    def _for_helper(self, fill, node):
         self.fill(fill)
         self.traverse(node.target)
         self.write(" in ")
@@ -974,6 +996,19 @@ def _fstring_FormattedValue(self, node, write):
     def visit_Name(self, node):
         self.write(node.id)
=20
+    def _write_docstring(self, node):
+        self.fill()
+        if node.kind =3D=3D "u":
+            self.write("u")
+
+        # Preserve quotes in the docstring by escaping them
+        value =3D node.value.replace("\\", "\\\\")
+        value =3D value.replace('"""', '""\"')
+        if value[-1] =3D=3D '"':
+            value =3D value.replace('"', '\\"', -1)
+
+        self.write(f'"""{value}"""')
+
     def _write_constant(self, value):
         if isinstance(value, (float, complex)):
             # Substitute overflowing decimal literal for AST infinities.
diff --git a/Lib/test/test_unparse.py b/Lib/test/test_unparse.py
index f7fcb2bffe891..d04db4d5f46e1 100644
--- a/Lib/test/test_unparse.py
+++ b/Lib/test/test_unparse.py
@@ -111,12 +111,18 @@ class Foo: pass
     suite1
 """
=20
+docstring_prefixes =3D [
+    "",
+    "class foo():\n    ",
+    "def foo():\n    ",
+    "async def foo():\n    ",
+]
=20
 class ASTTestCase(unittest.TestCase):
     def assertASTEqual(self, ast1, ast2):
         self.assertEqual(ast.dump(ast1), ast.dump(ast2))
=20
-    def check_roundtrip(self, code1):
+    def check_ast_roundtrip(self, code1):
         ast1 =3D ast.parse(code1)
         code2 =3D ast.unparse(ast1)
         ast2 =3D ast.parse(code2)
@@ -125,147 +131,154 @@ def check_roundtrip(self, code1):
     def check_invalid(self, node, raises=3DValueError):
         self.assertRaises(raises, ast.unparse, node)
=20
-    def check_src_roundtrip(self, code1, code2=3DNone, strip=3DTrue):
+    def get_source(self, code1, code2=3DNone, strip=3DTrue):
         code2 =3D code2 or code1
         code1 =3D ast.unparse(ast.parse(code1))
         if strip:
             code1 =3D code1.strip()
+        return code1, code2
+
+    def check_src_roundtrip(self, code1, code2=3DNone, strip=3DTrue):
+        code1, code2 =3D self.get_source(code1, code2, strip)
         self.assertEqual(code2, code1)
=20
+    def check_src_dont_roundtrip(self, code1, code2=3DNone, strip=3DTrue):
+        code1, code2 =3D self.get_source(code1, code2, strip)
+        self.assertNotEqual(code2, code1)
=20
 class UnparseTestCase(ASTTestCase):
     # Tests for specific bugs found in earlier versions of unparse
=20
     def test_fstrings(self):
         # See issue 25180
-        self.check_roundtrip(r"""f'{f"{0}"*3}'""")
-        self.check_roundtrip(r"""f'{f"{y}"*3}'""")
+        self.check_ast_roundtrip(r"""f'{f"{0}"*3}'""")
+        self.check_ast_roundtrip(r"""f'{f"{y}"*3}'""")
=20
     def test_strings(self):
-        self.check_roundtrip("u'foo'")
-        self.check_roundtrip("r'foo'")
-        self.check_roundtrip("b'foo'")
+        self.check_ast_roundtrip("u'foo'")
+        self.check_ast_roundtrip("r'foo'")
+        self.check_ast_roundtrip("b'foo'")
=20
     def test_del_statement(self):
-        self.check_roundtrip("del x, y, z")
+        self.check_ast_roundtrip("del x, y, z")
=20
     def test_shifts(self):
-        self.check_roundtrip("45 << 2")
-        self.check_roundtrip("13 >> 7")
+        self.check_ast_roundtrip("45 << 2")
+        self.check_ast_roundtrip("13 >> 7")
=20
     def test_for_else(self):
-        self.check_roundtrip(for_else)
+        self.check_ast_roundtrip(for_else)
=20
     def test_while_else(self):
-        self.check_roundtrip(while_else)
+        self.check_ast_roundtrip(while_else)
=20
     def test_unary_parens(self):
-        self.check_roundtrip("(-1)**7")
-        self.check_roundtrip("(-1.)**8")
-        self.check_roundtrip("(-1j)**6")
-        self.check_roundtrip("not True or False")
-        self.check_roundtrip("True or not False")
+        self.check_ast_roundtrip("(-1)**7")
+        self.check_ast_roundtrip("(-1.)**8")
+        self.check_ast_roundtrip("(-1j)**6")
+        self.check_ast_roundtrip("not True or False")
+        self.check_ast_roundtrip("True or not False")
=20
     def test_integer_parens(self):
-        self.check_roundtrip("3 .__abs__()")
+        self.check_ast_roundtrip("3 .__abs__()")
=20
     def test_huge_float(self):
-        self.check_roundtrip("1e1000")
-        self.check_roundtrip("-1e1000")
-        self.check_roundtrip("1e1000j")
-        self.check_roundtrip("-1e1000j")
+        self.check_ast_roundtrip("1e1000")
+        self.check_ast_roundtrip("-1e1000")
+        self.check_ast_roundtrip("1e1000j")
+        self.check_ast_roundtrip("-1e1000j")
=20
     def test_min_int(self):
-        self.check_roundtrip(str(-(2 ** 31)))
-        self.check_roundtrip(str(-(2 ** 63)))
+        self.check_ast_roundtrip(str(-(2 ** 31)))
+        self.check_ast_roundtrip(str(-(2 ** 63)))
=20
     def test_imaginary_literals(self):
-        self.check_roundtrip("7j")
-        self.check_roundtrip("-7j")
-        self.check_roundtrip("0j")
-        self.check_roundtrip("-0j")
+        self.check_ast_roundtrip("7j")
+        self.check_ast_roundtrip("-7j")
+        self.check_ast_roundtrip("0j")
+        self.check_ast_roundtrip("-0j")
=20
     def test_lambda_parentheses(self):
-        self.check_roundtrip("(lambda: int)()")
+        self.check_ast_roundtrip("(lambda: int)()")
=20
     def test_chained_comparisons(self):
-        self.check_roundtrip("1 < 4 <=3D 5")
-        self.check_roundtrip("a is b is c is not d")
+        self.check_ast_roundtrip("1 < 4 <=3D 5")
+        self.check_ast_roundtrip("a is b is c is not d")
=20
     def test_function_arguments(self):
-        self.check_roundtrip("def f(): pass")
-        self.check_roundtrip("def f(a): pass")
-        self.check_roundtrip("def f(b =3D 2): pass")
-        self.check_roundtrip("def f(a, b): pass")
-        self.check_roundtrip("def f(a, b =3D 2): pass")
-        self.check_roundtrip("def f(a =3D 5, b =3D 2): pass")
-        self.check_roundtrip("def f(*, a =3D 1, b =3D 2): pass")
-        self.check_roundtrip("def f(*, a =3D 1, b): pass")
-        self.check_roundtrip("def f(*, a, b =3D 2): pass")
-        self.check_roundtrip("def f(a, b =3D None, *, c, **kwds): pass")
-        self.check_roundtrip("def f(a=3D2, *args, c=3D5, d, **kwds): pass")
-        self.check_roundtrip("def f(*args, **kwargs): pass")
+        self.check_ast_roundtrip("def f(): pass")
+        self.check_ast_roundtrip("def f(a): pass")
+        self.check_ast_roundtrip("def f(b =3D 2): pass")
+        self.check_ast_roundtrip("def f(a, b): pass")
+        self.check_ast_roundtrip("def f(a, b =3D 2): pass")
+        self.check_ast_roundtrip("def f(a =3D 5, b =3D 2): pass")
+        self.check_ast_roundtrip("def f(*, a =3D 1, b =3D 2): pass")
+        self.check_ast_roundtrip("def f(*, a =3D 1, b): pass")
+        self.check_ast_roundtrip("def f(*, a, b =3D 2): pass")
+        self.check_ast_roundtrip("def f(a, b =3D None, *, c, **kwds): pass")
+        self.check_ast_roundtrip("def f(a=3D2, *args, c=3D5, d, **kwds): pas=
s")
+        self.check_ast_roundtrip("def f(*args, **kwargs): pass")
=20
     def test_relative_import(self):
-        self.check_roundtrip(relative_import)
+        self.check_ast_roundtrip(relative_import)
=20
     def test_nonlocal(self):
-        self.check_roundtrip(nonlocal_ex)
+        self.check_ast_roundtrip(nonlocal_ex)
=20
     def test_raise_from(self):
-        self.check_roundtrip(raise_from)
+        self.check_ast_roundtrip(raise_from)
=20
     def test_bytes(self):
-        self.check_roundtrip("b'123'")
+        self.check_ast_roundtrip("b'123'")
=20
     def test_annotations(self):
-        self.check_roundtrip("def f(a : int): pass")
-        self.check_roundtrip("def f(a: int =3D 5): pass")
-        self.check_roundtrip("def f(*args: [int]): pass")
-        self.check_roundtrip("def f(**kwargs: dict): pass")
-        self.check_roundtrip("def f() -> None: pass")
+        self.check_ast_roundtrip("def f(a : int): pass")
+        self.check_ast_roundtrip("def f(a: int =3D 5): pass")
+        self.check_ast_roundtrip("def f(*args: [int]): pass")
+        self.check_ast_roundtrip("def f(**kwargs: dict): pass")
+        self.check_ast_roundtrip("def f() -> None: pass")
=20
     def test_set_literal(self):
-        self.check_roundtrip("{'a', 'b', 'c'}")
+        self.check_ast_roundtrip("{'a', 'b', 'c'}")
=20
     def test_set_comprehension(self):
-        self.check_roundtrip("{x for x in range(5)}")
+        self.check_ast_roundtrip("{x for x in range(5)}")
=20
     def test_dict_comprehension(self):
-        self.check_roundtrip("{x: x*x for x in range(10)}")
+        self.check_ast_roundtrip("{x: x*x for x in range(10)}")
=20
     def test_class_decorators(self):
-        self.check_roundtrip(class_decorator)
+        self.check_ast_roundtrip(class_decorator)
=20
     def test_class_definition(self):
-        self.check_roundtrip("class A(metaclass=3Dtype, *[], **{}): pass")
+        self.check_ast_roundtrip("class A(metaclass=3Dtype, *[], **{}): pass=
")
=20
     def test_elifs(self):
-        self.check_roundtrip(elif1)
-        self.check_roundtrip(elif2)
+        self.check_ast_roundtrip(elif1)
+        self.check_ast_roundtrip(elif2)
=20
     def test_try_except_finally(self):
-        self.check_roundtrip(try_except_finally)
+        self.check_ast_roundtrip(try_except_finally)
=20
     def test_starred_assignment(self):
-        self.check_roundtrip("a, *b, c =3D seq")
-        self.check_roundtrip("a, (*b, c) =3D seq")
-        self.check_roundtrip("a, *b[0], c =3D seq")
-        self.check_roundtrip("a, *(b, c) =3D seq")
+        self.check_ast_roundtrip("a, *b, c =3D seq")
+        self.check_ast_roundtrip("a, (*b, c) =3D seq")
+        self.check_ast_roundtrip("a, *b[0], c =3D seq")
+        self.check_ast_roundtrip("a, *(b, c) =3D seq")
=20
     def test_with_simple(self):
-        self.check_roundtrip(with_simple)
+        self.check_ast_roundtrip(with_simple)
=20
     def test_with_as(self):
-        self.check_roundtrip(with_as)
+        self.check_ast_roundtrip(with_as)
=20
     def test_with_two_items(self):
-        self.check_roundtrip(with_two_items)
+        self.check_ast_roundtrip(with_two_items)
=20
     def test_dict_unpacking_in_dict(self):
         # See issue 26489
-        self.check_roundtrip(r"""{**{'y': 2}, 'x': 1}""")
-        self.check_roundtrip(r"""{**{'y': 2}, **{'x': 1}}""")
+        self.check_ast_roundtrip(r"""{**{'y': 2}, 'x': 1}""")
+        self.check_ast_roundtrip(r"""{**{'y': 2}, **{'x': 1}}""")
=20
     def test_invalid_raise(self):
         self.check_invalid(ast.Raise(exc=3DNone, cause=3Dast.Name(id=3D"X")))
@@ -288,6 +301,16 @@ def test_invalid_set(self):
     def test_invalid_yield_from(self):
         self.check_invalid(ast.YieldFrom(value=3DNone))
=20
+    def test_docstrings(self):
+        docstrings =3D (
+            'this ends with double quote"',
+            'this includes a """triple quote"""'
+        )
+        for docstring in docstrings:
+            # check as Module docstrings for easy testing
+            self.check_ast_roundtrip(f"'{docstring}'")
+
+
 class CosmeticTestCase(ASTTestCase):
     """Test if there are cosmetic issues caused by unnecesary additions"""
=20
@@ -321,6 +344,39 @@ def test_simple_expressions_parens(self):
         self.check_src_roundtrip("call((yield x))")
         self.check_src_roundtrip("return x + (yield x)")
=20
+    def test_docstrings(self):
+        docstrings =3D (
+            '"""simple doc string"""',
+            '''"""A more complex one
+            with some newlines"""''',
+            '''"""Foo bar baz
+
+            empty newline"""''',
+            '"""With some \t"""',
+            '"""Foo "bar" baz """',
+        )
+
+        for prefix in docstring_prefixes:
+            for docstring in docstrings:
+                self.check_src_roundtrip(f"{prefix}{docstring}")
+
+    def test_docstrings_negative_cases(self):
+        # Test some cases that involve strings in the children of the
+        # first node but aren't docstrings to make sure we don't have
+        # False positives.
+        docstrings_negative =3D (
+            'a =3D """false"""',
+            '"""false""" + """unless its optimized"""',
+            '1 + 1\n"""false"""',
+            'f"""no, top level but f-fstring"""'
+        )
+        for prefix in docstring_prefixes:
+            for negative in docstrings_negative:
+                # this cases should be result with single quote
+                # rather then triple quoted docstring
+                src =3D f"{prefix}{negative}"
+                self.check_ast_roundtrip(src)
+                self.check_src_dont_roundtrip(src)
=20
 class DirectoryTestCase(ASTTestCase):
     """Test roundtrip behaviour on all files in Lib and Lib/test."""
@@ -379,7 +435,7 @@ def test_files(self):
=20
             with self.subTest(filename=3Ditem):
                 source =3D read_pyfile(item)
-                self.check_roundtrip(source)
+                self.check_ast_roundtrip(source)
=20
=20
 if __name__ =3D=3D "__main__":



More information about the Python-checkins mailing list