[Python-checkins] r67657 - in sandbox/trunk/2to3/lib2to3: fixer_util.py fixes/fix_intern.py fixes/fix_reduce.py tests/test_fixers.py tests/test_util.py

armin.ronacher python-checkins at python.org
Mon Dec 8 01:29:36 CET 2008


Author: armin.ronacher
Date: Mon Dec  8 01:29:35 2008
New Revision: 67657

Log:
2to3: intern and reduce fixes now add the imports if missing.  Because that is a common task the fixer_util module now has a function "touch_import" that adds imports if missing.



Added:
   sandbox/trunk/2to3/lib2to3/fixes/fix_reduce.py
Modified:
   sandbox/trunk/2to3/lib2to3/fixer_util.py
   sandbox/trunk/2to3/lib2to3/fixes/fix_intern.py
   sandbox/trunk/2to3/lib2to3/tests/test_fixers.py
   sandbox/trunk/2to3/lib2to3/tests/test_util.py

Modified: sandbox/trunk/2to3/lib2to3/fixer_util.py
==============================================================================
--- sandbox/trunk/2to3/lib2to3/fixer_util.py	(original)
+++ sandbox/trunk/2to3/lib2to3/fixer_util.py	Mon Dec  8 01:29:35 2008
@@ -232,20 +232,78 @@
     suite.parent = parent
     return suite
 
-def does_tree_import(package, name, node):
-    """ Returns true if name is imported from package at the
-        top level of the tree which node belongs to.
-        To cover the case of an import like 'import foo', use
-        Null for the package and 'foo' for the name. """
+def find_root(node):
+    """Find the top level namespace."""
     # Scamper up to the top level namespace
     while node.type != syms.file_input:
         assert node.parent, "Tree is insane! root found before "\
                            "file_input node was found."
         node = node.parent
+    return node
 
-    binding = find_binding(name, node, package)
+def does_tree_import(package, name, node):
+    """ Returns true if name is imported from package at the
+        top level of the tree which node belongs to.
+        To cover the case of an import like 'import foo', use
+        None for the package and 'foo' for the name. """
+    binding = find_binding(name, find_root(node), package)
     return bool(binding)
 
+def is_import(node):
+    """Returns true if the node is an import statement."""
+    return node.type in (syms.import_name, syms.import_from)
+
+def touch_import(package, name, node):
+    """ Works like `does_tree_import` but adds an import statement
+        if it was not imported. """
+    def is_import_stmt(node):
+        return node.type == syms.simple_stmt and node.children and \
+               is_import(node.children[0])
+
+    root = find_root(node)
+
+    if does_tree_import(package, name, root):
+        return
+
+    add_newline_before = False
+
+    # figure out where to insert the new import.  First try to find
+    # the first import and then skip to the last one.
+    insert_pos = offset = 0
+    for idx, node in enumerate(root.children):
+        if not is_import_stmt(node):
+            continue
+        for offset, node2 in enumerate(root.children[idx:]):
+            if not is_import_stmt(node2):
+                break
+        insert_pos = idx + offset
+        break
+
+    # if there are no imports where we can insert, find the docstring.
+    # if that also fails, we stick to the beginning of the file
+    if insert_pos == 0:
+        for idx, node in enumerate(root.children):
+            if node.type == syms.simple_stmt and node.children and \
+               node.children[0].type == token.STRING:
+                insert_pos = idx + 1
+                add_newline_before
+                break
+
+    if package is None:
+        import_ = Node(syms.import_name, [
+            Leaf(token.NAME, 'import'),
+            Leaf(token.NAME, name, prefix=' ')
+        ])
+    else:
+        import_ = FromImport(package, [Leaf(token.NAME, name, prefix=' ')])
+
+    children = [import_, Newline()]
+    if add_newline_before:
+        children.insert(0, Newline())
+    root.changed()
+    root.children.insert(insert_pos, Node(syms.simple_stmt, children))
+
+
 _def_syms = set([syms.classdef, syms.funcdef])
 def find_binding(name, node, package=None):
     """ Returns the node which binds variable name, otherwise None.
@@ -285,7 +343,7 @@
         if ret:
             if not package:
                 return ret
-            if ret.type in (syms.import_name, syms.import_from):
+            if is_import(ret):
                 return ret
     return None
 

Modified: sandbox/trunk/2to3/lib2to3/fixes/fix_intern.py
==============================================================================
--- sandbox/trunk/2to3/lib2to3/fixes/fix_intern.py	(original)
+++ sandbox/trunk/2to3/lib2to3/fixes/fix_intern.py	Mon Dec  8 01:29:35 2008
@@ -8,7 +8,7 @@
 # Local imports
 from .. import pytree
 from .. import fixer_base
-from ..fixer_util import Name, Attr
+from ..fixer_util import Name, Attr, touch_import
 
 
 class FixIntern(fixer_base.BaseFix):
@@ -40,4 +40,5 @@
                                         newarglist,
                                         results["rpar"].clone()])] + after)
         new.set_prefix(node.get_prefix())
+        touch_import(None, 'sys', node)
         return new

Added: sandbox/trunk/2to3/lib2to3/fixes/fix_reduce.py
==============================================================================
--- (empty file)
+++ sandbox/trunk/2to3/lib2to3/fixes/fix_reduce.py	Mon Dec  8 01:29:35 2008
@@ -0,0 +1,33 @@
+# Copyright 2008 Armin Ronacher.
+# Licensed to PSF under a Contributor Agreement.
+
+"""Fixer for reduce().
+
+Makes sure reduce() is imported from the functools module if reduce is
+used in that module.
+"""
+
+from .. import pytree
+from .. import fixer_base
+from ..fixer_util import Name, Attr, touch_import
+
+
+
+class FixReduce(fixer_base.BaseFix):
+
+    PATTERN = """
+    power< 'reduce'
+        trailer< '('
+            arglist< (
+                (not(argument<any '=' any>) any ','
+                 not(argument<any '=' any>) any) |
+                (not(argument<any '=' any>) any ','
+                 not(argument<any '=' any>) any ','
+                 not(argument<any '=' any>) any)
+            ) >
+        ')' >
+    >
+    """
+
+    def transform(self, node, results):
+        touch_import('functools', 'reduce', node)

Modified: sandbox/trunk/2to3/lib2to3/tests/test_fixers.py
==============================================================================
--- sandbox/trunk/2to3/lib2to3/tests/test_fixers.py	(original)
+++ sandbox/trunk/2to3/lib2to3/tests/test_fixers.py	Mon Dec  8 01:29:35 2008
@@ -293,30 +293,30 @@
 
     def test_prefix_preservation(self):
         b = """x =   intern(  a  )"""
-        a = """x =   sys.intern(  a  )"""
+        a = """import sys\nx =   sys.intern(  a  )"""
         self.check(b, a)
 
         b = """y = intern("b" # test
               )"""
-        a = """y = sys.intern("b" # test
+        a = """import sys\ny = sys.intern("b" # test
               )"""
         self.check(b, a)
 
         b = """z = intern(a+b+c.d,   )"""
-        a = """z = sys.intern(a+b+c.d,   )"""
+        a = """import sys\nz = sys.intern(a+b+c.d,   )"""
         self.check(b, a)
 
     def test(self):
         b = """x = intern(a)"""
-        a = """x = sys.intern(a)"""
+        a = """import sys\nx = sys.intern(a)"""
         self.check(b, a)
 
         b = """z = intern(a+b+c.d,)"""
-        a = """z = sys.intern(a+b+c.d,)"""
+        a = """import sys\nz = sys.intern(a+b+c.d,)"""
         self.check(b, a)
 
         b = """intern("y%s" % 5).replace("y", "")"""
-        a = """sys.intern("y%s" % 5).replace("y", "")"""
+        a = """import sys\nsys.intern("y%s" % 5).replace("y", "")"""
         self.check(b, a)
 
     # These should not be refactored
@@ -337,6 +337,35 @@
         s = """intern()"""
         self.unchanged(s)
 
+class Test_reduce(FixerTestCase):
+    fixer = "reduce"
+
+    def test_simple_call(self):
+        b = "reduce(a, b, c)"
+        a = "from functools import reduce\nreduce(a, b, c)"
+        self.check(b, a)
+
+    def test_call_with_lambda(self):
+        b = "reduce(lambda x, y: x + y, seq)"
+        a = "from functools import reduce\nreduce(lambda x, y: x + y, seq)"
+        self.check(b, a)
+
+    def test_unchanged(self):
+        s = "reduce(a)"
+        self.unchanged(s)
+
+        s = "reduce(a, b=42)"
+        self.unchanged(s)
+
+        s = "reduce(a, b, c, d)"
+        self.unchanged(s)
+
+        s = "reduce(**c)"
+        self.unchanged(s)
+
+        s = "reduce()"
+        self.unchanged(s)
+
 class Test_print(FixerTestCase):
     fixer = "print"
 

Modified: sandbox/trunk/2to3/lib2to3/tests/test_util.py
==============================================================================
--- sandbox/trunk/2to3/lib2to3/tests/test_util.py	(original)
+++ sandbox/trunk/2to3/lib2to3/tests/test_util.py	Mon Dec  8 01:29:35 2008
@@ -526,6 +526,33 @@
                     b = 7"""
         self.failIf(self.find_binding("a", s))
 
+class Test_touch_import(support.TestCase):
+
+    def test_after_docstring(self):
+        node = parse('"""foo"""\nbar()')
+        fixer_util.touch_import(None, "foo", node)
+        self.assertEqual(str(node), '"""foo"""\nimport foo\nbar()\n\n')
+
+    def test_after_imports(self):
+        node = parse('"""foo"""\nimport bar\nbar()')
+        fixer_util.touch_import(None, "foo", node)
+        self.assertEqual(str(node), '"""foo"""\nimport bar\nimport foo\nbar()\n\n')
+
+    def test_beginning(self):
+        node = parse('bar()')
+        fixer_util.touch_import(None, "foo", node)
+        self.assertEqual(str(node), 'import foo\nbar()\n\n')
+
+    def test_from_import(self):
+        node = parse('bar()')
+        fixer_util.touch_import("cgi", "escape", node)
+        self.assertEqual(str(node), 'from cgi import escape\nbar()\n\n')
+
+    def test_name_import(self):
+        node = parse('bar()')
+        fixer_util.touch_import(None, "cgi", node)
+        self.assertEqual(str(node), 'import cgi\nbar()\n\n')
+
 
 if __name__ == "__main__":
     import __main__


More information about the Python-checkins mailing list