[Python-checkins] bpo-38870: Refactor delimiting with context managers in ast.unparse (GH-17612)

Pablo Galindo webhook-mailer at python.org
Mon Dec 23 11:11:08 EST 2019


https://github.com/python/cpython/commit/4b3b1226e86df6cd45e921c8f2ad23c3639c43b2
commit: 4b3b1226e86df6cd45e921c8f2ad23c3639c43b2
branch: master
author: Batuhan Taşkaya <47358913+isidentical at users.noreply.github.com>
committer: Pablo Galindo <Pablogsal at gmail.com>
date: 2019-12-23T16:11:00Z
summary:

bpo-38870: Refactor delimiting with context managers in ast.unparse (GH-17612)

Co-Authored-By: Victor Stinner <vstinner at python.org>
Co-authored-by: Pablo Galindo <pablogsal at gmail.com>

files:
M Lib/ast.py

diff --git a/Lib/ast.py b/Lib/ast.py
index ee3f74358ee12..76e0cac838b92 100644
--- a/Lib/ast.py
+++ b/Lib/ast.py
@@ -26,6 +26,7 @@
 """
 import sys
 from _ast import *
+from contextlib import contextmanager, nullcontext
 
 
 def parse(source, filename='<unknown>', mode='exec', *,
@@ -613,6 +614,21 @@ def __exit__(self, exc_type, exc_value, traceback):
     def block(self):
         return self._Block(self)
 
+    @contextmanager
+    def delimit(self, start, end):
+        """A context manager for preparing the source for expressions. It adds
+        *start* to the buffer and enters, after exit it adds *end*."""
+
+        self.write(start)
+        yield
+        self.write(end)
+
+    def delimit_if(self, start, end, condition):
+        if condition:
+            return self.delimit(start, end)
+        else:
+            return nullcontext()
+
     def traverse(self, node):
         if isinstance(node, list):
             for item in node:
@@ -636,11 +652,10 @@ def visit_Expr(self, node):
         self.traverse(node.value)
 
     def visit_NamedExpr(self, node):
-        self.write("(")
-        self.traverse(node.target)
-        self.write(" := ")
-        self.traverse(node.value)
-        self.write(")")
+        with self.delimit("(", ")"):
+            self.traverse(node.target)
+            self.write(" := ")
+            self.traverse(node.value)
 
     def visit_Import(self, node):
         self.fill("import ")
@@ -669,11 +684,8 @@ def visit_AugAssign(self, node):
 
     def visit_AnnAssign(self, node):
         self.fill()
-        if not node.simple and isinstance(node.target, Name):
-            self.write("(")
-        self.traverse(node.target)
-        if not node.simple and isinstance(node.target, Name):
-            self.write(")")
+        with self.delimit_if("(", ")", not node.simple and isinstance(node.target, Name)):
+            self.traverse(node.target)
         self.write(": ")
         self.traverse(node.annotation)
         if node.value:
@@ -715,28 +727,25 @@ def visit_Nonlocal(self, node):
         self.interleave(lambda: self.write(", "), self.write, node.names)
 
     def visit_Await(self, node):
-        self.write("(")
-        self.write("await")
-        if node.value:
-            self.write(" ")
-            self.traverse(node.value)
-        self.write(")")
+        with self.delimit("(", ")"):
+            self.write("await")
+            if node.value:
+                self.write(" ")
+                self.traverse(node.value)
 
     def visit_Yield(self, node):
-        self.write("(")
-        self.write("yield")
-        if node.value:
-            self.write(" ")
-            self.traverse(node.value)
-        self.write(")")
+        with self.delimit("(", ")"):
+            self.write("yield")
+            if node.value:
+                self.write(" ")
+                self.traverse(node.value)
 
     def visit_YieldFrom(self, node):
-        self.write("(")
-        self.write("yield from")
-        if node.value:
-            self.write(" ")
-            self.traverse(node.value)
-        self.write(")")
+        with self.delimit("(", ")"):
+            self.write("yield from")
+            if node.value:
+                self.write(" ")
+                self.traverse(node.value)
 
     def visit_Raise(self, node):
         self.fill("raise")
@@ -782,21 +791,20 @@ def visit_ClassDef(self, node):
             self.fill("@")
             self.traverse(deco)
         self.fill("class " + node.name)
-        self.write("(")
-        comma = False
-        for e in node.bases:
-            if comma:
-                self.write(", ")
-            else:
-                comma = True
-            self.traverse(e)
-        for e in node.keywords:
-            if comma:
-                self.write(", ")
-            else:
-                comma = True
-            self.traverse(e)
-        self.write(")")
+        with self.delimit("(", ")"):
+            comma = False
+            for e in node.bases:
+                if comma:
+                    self.write(", ")
+                else:
+                    comma = True
+                self.traverse(e)
+            for e in node.keywords:
+                if comma:
+                    self.write(", ")
+                else:
+                    comma = True
+                self.traverse(e)
 
         with self.block():
             self.traverse(node.body)
@@ -812,10 +820,10 @@ def __FunctionDef_helper(self, node, fill_suffix):
         for deco in node.decorator_list:
             self.fill("@")
             self.traverse(deco)
-        def_str = fill_suffix + " " + node.name + "("
+        def_str = fill_suffix + " " + node.name
         self.fill(def_str)
-        self.traverse(node.args)
-        self.write(")")
+        with self.delimit("(", ")"):
+            self.traverse(node.args)
         if node.returns:
             self.write(" -> ")
             self.traverse(node.returns)
@@ -931,13 +939,12 @@ def _write_constant(self, value):
     def visit_Constant(self, node):
         value = node.value
         if isinstance(value, tuple):
-            self.write("(")
-            if len(value) == 1:
-                self._write_constant(value[0])
-                self.write(",")
-            else:
-                self.interleave(lambda: self.write(", "), self._write_constant, value)
-            self.write(")")
+            with self.delimit("(", ")"):
+                if len(value) == 1:
+                    self._write_constant(value[0])
+                    self.write(",")
+                else:
+                    self.interleave(lambda: self.write(", "), self._write_constant, value)
         elif value is ...:
             self.write("...")
         else:
@@ -946,39 +953,34 @@ def visit_Constant(self, node):
             self._write_constant(node.value)
 
     def visit_List(self, node):
-        self.write("[")
-        self.interleave(lambda: self.write(", "), self.traverse, node.elts)
-        self.write("]")
+        with self.delimit("[", "]"):
+            self.interleave(lambda: self.write(", "), self.traverse, node.elts)
 
     def visit_ListComp(self, node):
-        self.write("[")
-        self.traverse(node.elt)
-        for gen in node.generators:
-            self.traverse(gen)
-        self.write("]")
+        with self.delimit("[", "]"):
+            self.traverse(node.elt)
+            for gen in node.generators:
+                self.traverse(gen)
 
     def visit_GeneratorExp(self, node):
-        self.write("(")
-        self.traverse(node.elt)
-        for gen in node.generators:
-            self.traverse(gen)
-        self.write(")")
+        with self.delimit("(", ")"):
+            self.traverse(node.elt)
+            for gen in node.generators:
+                self.traverse(gen)
 
     def visit_SetComp(self, node):
-        self.write("{")
-        self.traverse(node.elt)
-        for gen in node.generators:
-            self.traverse(gen)
-        self.write("}")
+        with self.delimit("{", "}"):
+            self.traverse(node.elt)
+            for gen in node.generators:
+                self.traverse(gen)
 
     def visit_DictComp(self, node):
-        self.write("{")
-        self.traverse(node.key)
-        self.write(": ")
-        self.traverse(node.value)
-        for gen in node.generators:
-            self.traverse(gen)
-        self.write("}")
+        with self.delimit("{", "}"):
+            self.traverse(node.key)
+            self.write(": ")
+            self.traverse(node.value)
+            for gen in node.generators:
+                self.traverse(gen)
 
     def visit_comprehension(self, node):
         if node.is_async:
@@ -993,24 +995,20 @@ def visit_comprehension(self, node):
             self.traverse(if_clause)
 
     def visit_IfExp(self, node):
-        self.write("(")
-        self.traverse(node.body)
-        self.write(" if ")
-        self.traverse(node.test)
-        self.write(" else ")
-        self.traverse(node.orelse)
-        self.write(")")
+        with self.delimit("(", ")"):
+            self.traverse(node.body)
+            self.write(" if ")
+            self.traverse(node.test)
+            self.write(" else ")
+            self.traverse(node.orelse)
 
     def visit_Set(self, node):
         if not node.elts:
             raise ValueError("Set node should has at least one item")
-        self.write("{")
-        self.interleave(lambda: self.write(", "), self.traverse, node.elts)
-        self.write("}")
+        with self.delimit("{", "}"):
+            self.interleave(lambda: self.write(", "), self.traverse, node.elts)
 
     def visit_Dict(self, node):
-        self.write("{")
-
         def write_key_value_pair(k, v):
             self.traverse(k)
             self.write(": ")
@@ -1026,29 +1024,27 @@ def write_item(item):
             else:
                 write_key_value_pair(k, v)
 
-        self.interleave(
-            lambda: self.write(", "), write_item, zip(node.keys, node.values)
-        )
-        self.write("}")
+        with self.delimit("{", "}"):
+            self.interleave(
+                lambda: self.write(", "), write_item, zip(node.keys, node.values)
+            )
 
     def visit_Tuple(self, node):
-        self.write("(")
-        if len(node.elts) == 1:
-            elt = node.elts[0]
-            self.traverse(elt)
-            self.write(",")
-        else:
-            self.interleave(lambda: self.write(", "), self.traverse, node.elts)
-        self.write(")")
+        with self.delimit("(", ")"):
+            if len(node.elts) == 1:
+                elt = node.elts[0]
+                self.traverse(elt)
+                self.write(",")
+            else:
+                self.interleave(lambda: self.write(", "), self.traverse, node.elts)
 
     unop = {"Invert": "~", "Not": "not", "UAdd": "+", "USub": "-"}
 
     def visit_UnaryOp(self, node):
-        self.write("(")
-        self.write(self.unop[node.op.__class__.__name__])
-        self.write(" ")
-        self.traverse(node.operand)
-        self.write(")")
+        with self.delimit("(", ")"):
+            self.write(self.unop[node.op.__class__.__name__])
+            self.write(" ")
+            self.traverse(node.operand)
 
     binop = {
         "Add": "+",
@@ -1067,11 +1063,10 @@ def visit_UnaryOp(self, node):
     }
 
     def visit_BinOp(self, node):
-        self.write("(")
-        self.traverse(node.left)
-        self.write(" " + self.binop[node.op.__class__.__name__] + " ")
-        self.traverse(node.right)
-        self.write(")")
+        with self.delimit("(", ")"):
+            self.traverse(node.left)
+            self.write(" " + self.binop[node.op.__class__.__name__] + " ")
+            self.traverse(node.right)
 
     cmpops = {
         "Eq": "==",
@@ -1087,20 +1082,18 @@ def visit_BinOp(self, node):
     }
 
     def visit_Compare(self, node):
-        self.write("(")
-        self.traverse(node.left)
-        for o, e in zip(node.ops, node.comparators):
-            self.write(" " + self.cmpops[o.__class__.__name__] + " ")
-            self.traverse(e)
-        self.write(")")
+        with self.delimit("(", ")"):
+            self.traverse(node.left)
+            for o, e in zip(node.ops, node.comparators):
+                self.write(" " + self.cmpops[o.__class__.__name__] + " ")
+                self.traverse(e)
 
-    boolops = {And: "and", Or: "or"}
+    boolops = {"And": "and", "Or": "or"}
 
     def visit_BoolOp(self, node):
-        self.write("(")
-        s = " %s " % self.boolops[node.op.__class__]
-        self.interleave(lambda: self.write(s), self.traverse, node.values)
-        self.write(")")
+        with self.delimit("(", ")"):
+            s = " %s " % self.boolops[node.op.__class__.__name__]
+            self.interleave(lambda: self.write(s), self.traverse, node.values)
 
     def visit_Attribute(self, node):
         self.traverse(node.value)
@@ -1114,27 +1107,25 @@ def visit_Attribute(self, node):
 
     def visit_Call(self, node):
         self.traverse(node.func)
-        self.write("(")
-        comma = False
-        for e in node.args:
-            if comma:
-                self.write(", ")
-            else:
-                comma = True
-            self.traverse(e)
-        for e in node.keywords:
-            if comma:
-                self.write(", ")
-            else:
-                comma = True
-            self.traverse(e)
-        self.write(")")
+        with self.delimit("(", ")"):
+            comma = False
+            for e in node.args:
+                if comma:
+                    self.write(", ")
+                else:
+                    comma = True
+                self.traverse(e)
+            for e in node.keywords:
+                if comma:
+                    self.write(", ")
+                else:
+                    comma = True
+                self.traverse(e)
 
     def visit_Subscript(self, node):
         self.traverse(node.value)
-        self.write("[")
-        self.traverse(node.slice)
-        self.write("]")
+        with self.delimit("[", "]"):
+            self.traverse(node.slice)
 
     def visit_Starred(self, node):
         self.write("*")
@@ -1225,12 +1216,11 @@ def visit_keyword(self, node):
         self.traverse(node.value)
 
     def visit_Lambda(self, node):
-        self.write("(")
-        self.write("lambda ")
-        self.traverse(node.args)
-        self.write(": ")
-        self.traverse(node.body)
-        self.write(")")
+        with self.delimit("(", ")"):
+            self.write("lambda ")
+            self.traverse(node.args)
+            self.write(": ")
+            self.traverse(node.body)
 
     def visit_alias(self, node):
         self.write(node.name)



More information about the Python-checkins mailing list