[Python-checkins] cpython: Issue #12428: Add a pure Python implementation of functools.partial().

antoine.pitrou python-checkins at python.org
Tue Nov 13 21:37:11 CET 2012


http://hg.python.org/cpython/rev/fcfaca024160
changeset:   80421:fcfaca024160
user:        Antoine Pitrou <solipsis at pitrou.net>
date:        Tue Nov 13 21:35:40 2012 +0100
summary:
  Issue #12428: Add a pure Python implementation of functools.partial().
Patch by Brian Thorne.

files:
  Lib/functools.py           |   28 +++-
  Lib/test/test_functools.py |  208 ++++++++++++++++--------
  Misc/ACKS                  |    1 +
  Misc/NEWS                  |    3 +
  4 files changed, 167 insertions(+), 73 deletions(-)


diff --git a/Lib/functools.py b/Lib/functools.py
--- a/Lib/functools.py
+++ b/Lib/functools.py
@@ -11,7 +11,10 @@
 __all__ = ['update_wrapper', 'wraps', 'WRAPPER_ASSIGNMENTS', 'WRAPPER_UPDATES',
            'total_ordering', 'cmp_to_key', 'lru_cache', 'reduce', 'partial']
 
-from _functools import partial, reduce
+try:
+    from _functools import reduce
+except ImportError:
+    pass
 from collections import namedtuple
 try:
     from _thread import allocate_lock as Lock
@@ -137,6 +140,29 @@
 
 
 ################################################################################
+### partial() argument application
+################################################################################
+
+def partial(func, *args, **keywords):
+    """new function with partial application of the given arguments
+    and keywords.
+    """
+    def newfunc(*fargs, **fkeywords):
+        newkeywords = keywords.copy()
+        newkeywords.update(fkeywords)
+        return func(*(args + fargs), **newkeywords)
+    newfunc.func = func
+    newfunc.args = args
+    newfunc.keywords = keywords
+    return newfunc
+
+try:
+    from _functools import partial
+except ImportError:
+    pass
+
+
+################################################################################
 ### LRU Cache function decorator
 ################################################################################
 
diff --git a/Lib/test/test_functools.py b/Lib/test/test_functools.py
--- a/Lib/test/test_functools.py
+++ b/Lib/test/test_functools.py
@@ -1,4 +1,3 @@
-import functools
 import collections
 import sys
 import unittest
@@ -7,17 +6,31 @@
 import pickle
 from random import choice
 
- at staticmethod
-def PythonPartial(func, *args, **keywords):
-    'Pure Python approximation of partial()'
-    def newfunc(*fargs, **fkeywords):
-        newkeywords = keywords.copy()
-        newkeywords.update(fkeywords)
-        return func(*(args + fargs), **newkeywords)
-    newfunc.func = func
-    newfunc.args = args
-    newfunc.keywords = keywords
-    return newfunc
+import functools
+
+original_functools = functools
+py_functools = support.import_fresh_module('functools', blocked=['_functools'])
+c_functools = support.import_fresh_module('functools', fresh=['_functools'])
+
+class BaseTest(unittest.TestCase):
+
+    """Base class required for testing C and Py implementations."""
+
+    def setUp(self):
+
+        # The module must be explicitly set so that the proper
+        # interaction between the c module and the python module
+        # can be controlled.
+        self.partial = self.module.partial
+        super(BaseTest, self).setUp()
+
+class BaseTestC(BaseTest):
+    module = c_functools
+
+class BaseTestPy(BaseTest):
+    module = py_functools
+
+PythonPartial = py_functools.partial
 
 def capture(*args, **kw):
     """capture all positional and keyword arguments"""
@@ -27,31 +40,32 @@
     """ return the signature of a partial object """
     return (part.func, part.args, part.keywords, part.__dict__)
 
-class TestPartial(unittest.TestCase):
+class TestPartial(object):
 
-    thetype = functools.partial
+    partial = functools.partial
 
     def test_basic_examples(self):
-        p = self.thetype(capture, 1, 2, a=10, b=20)
+        p = self.partial(capture, 1, 2, a=10, b=20)
+        self.assertTrue(callable(p))
         self.assertEqual(p(3, 4, b=30, c=40),
                          ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
-        p = self.thetype(map, lambda x: x*10)
+        p = self.partial(map, lambda x: x*10)
         self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
 
     def test_attributes(self):
-        p = self.thetype(capture, 1, 2, a=10, b=20)
+        p = self.partial(capture, 1, 2, a=10, b=20)
         # attributes should be readable
         self.assertEqual(p.func, capture)
         self.assertEqual(p.args, (1, 2))
         self.assertEqual(p.keywords, dict(a=10, b=20))
         # attributes should not be writable
-        if not isinstance(self.thetype, type):
+        if not isinstance(self.partial, type):
             return
         self.assertRaises(AttributeError, setattr, p, 'func', map)
         self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
         self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
 
-        p = self.thetype(hex)
+        p = self.partial(hex)
         try:
             del p.__dict__
         except TypeError:
@@ -60,9 +74,9 @@
             self.fail('partial object allowed __dict__ to be deleted')
 
     def test_argument_checking(self):
-        self.assertRaises(TypeError, self.thetype)     # need at least a func arg
+        self.assertRaises(TypeError, self.partial)     # need at least a func arg
         try:
-            self.thetype(2)()
+            self.partial(2)()
         except TypeError:
             pass
         else:
@@ -73,7 +87,7 @@
         def func(a=10, b=20):
             return a
         d = {'a':3}
-        p = self.thetype(func, a=5)
+        p = self.partial(func, a=5)
         self.assertEqual(p(**d), 3)
         self.assertEqual(d, {'a':3})
         p(b=7)
@@ -82,20 +96,20 @@
     def test_arg_combinations(self):
         # exercise special code paths for zero args in either partial
         # object or the caller
-        p = self.thetype(capture)
+        p = self.partial(capture)
         self.assertEqual(p(), ((), {}))
         self.assertEqual(p(1,2), ((1,2), {}))
-        p = self.thetype(capture, 1, 2)
+        p = self.partial(capture, 1, 2)
         self.assertEqual(p(), ((1,2), {}))
         self.assertEqual(p(3,4), ((1,2,3,4), {}))
 
     def test_kw_combinations(self):
         # exercise special code paths for no keyword args in
         # either the partial object or the caller
-        p = self.thetype(capture)
+        p = self.partial(capture)
         self.assertEqual(p(), ((), {}))
         self.assertEqual(p(a=1), ((), {'a':1}))
-        p = self.thetype(capture, a=1)
+        p = self.partial(capture, a=1)
         self.assertEqual(p(), ((), {'a':1}))
         self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
         # keyword args in the call override those in the partial object
@@ -104,7 +118,7 @@
     def test_positional(self):
         # make sure positional arguments are captured correctly
         for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
-            p = self.thetype(capture, *args)
+            p = self.partial(capture, *args)
             expected = args + ('x',)
             got, empty = p('x')
             self.assertTrue(expected == got and empty == {})
@@ -112,14 +126,14 @@
     def test_keyword(self):
         # make sure keyword arguments are captured correctly
         for a in ['a', 0, None, 3.5]:
-            p = self.thetype(capture, a=a)
+            p = self.partial(capture, a=a)
             expected = {'a':a,'x':None}
             empty, got = p(x=None)
             self.assertTrue(expected == got and empty == ())
 
     def test_no_side_effects(self):
         # make sure there are no side effects that affect subsequent calls
-        p = self.thetype(capture, 0, a=1)
+        p = self.partial(capture, 0, a=1)
         args1, kw1 = p(1, b=2)
         self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
         args2, kw2 = p()
@@ -128,13 +142,13 @@
     def test_error_propagation(self):
         def f(x, y):
             x / y
-        self.assertRaises(ZeroDivisionError, self.thetype(f, 1, 0))
-        self.assertRaises(ZeroDivisionError, self.thetype(f, 1), 0)
-        self.assertRaises(ZeroDivisionError, self.thetype(f), 1, 0)
-        self.assertRaises(ZeroDivisionError, self.thetype(f, y=0), 1)
+        self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0))
+        self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0)
+        self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0)
+        self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1)
 
     def test_weakref(self):
-        f = self.thetype(int, base=16)
+        f = self.partial(int, base=16)
         p = proxy(f)
         self.assertEqual(f.func, p.func)
         f = None
@@ -142,9 +156,9 @@
 
     def test_with_bound_and_unbound_methods(self):
         data = list(map(str, range(10)))
-        join = self.thetype(str.join, '')
+        join = self.partial(str.join, '')
         self.assertEqual(join(data), '0123456789')
-        join = self.thetype(''.join)
+        join = self.partial(''.join)
         self.assertEqual(join(data), '0123456789')
 
     def test_repr(self):
@@ -152,49 +166,57 @@
         args_repr = ', '.join(repr(a) for a in args)
         kwargs = {'a': object(), 'b': object()}
         kwargs_repr = ', '.join("%s=%r" % (k, v) for k, v in kwargs.items())
-        if self.thetype is functools.partial:
+        if self.partial is functools.partial:
             name = 'functools.partial'
         else:
-            name = self.thetype.__name__
+            name = self.partial.__name__
 
-        f = self.thetype(capture)
+        f = self.partial(capture)
         self.assertEqual('{}({!r})'.format(name, capture),
                          repr(f))
 
-        f = self.thetype(capture, *args)
+        f = self.partial(capture, *args)
         self.assertEqual('{}({!r}, {})'.format(name, capture, args_repr),
                          repr(f))
 
-        f = self.thetype(capture, **kwargs)
+        f = self.partial(capture, **kwargs)
         self.assertEqual('{}({!r}, {})'.format(name, capture, kwargs_repr),
                          repr(f))
 
-        f = self.thetype(capture, *args, **kwargs)
+        f = self.partial(capture, *args, **kwargs)
         self.assertEqual('{}({!r}, {}, {})'.format(name, capture, args_repr, kwargs_repr),
                          repr(f))
 
     def test_pickle(self):
-        f = self.thetype(signature, 'asdf', bar=True)
+        f = self.partial(signature, 'asdf', bar=True)
         f.add_something_to__dict__ = True
         f_copy = pickle.loads(pickle.dumps(f))
         self.assertEqual(signature(f), signature(f_copy))
 
-class PartialSubclass(functools.partial):
+class TestPartialC(BaseTestC, TestPartial):
     pass
 
-class TestPartialSubclass(TestPartial):
+class TestPartialPy(BaseTestPy, TestPartial):
 
-    thetype = PartialSubclass
+    def test_pickle(self):
+        raise unittest.SkipTest("Python implementation of partial isn't picklable")
+    
+    def test_repr(self):
+        raise unittest.SkipTest("Python implementation of partial uses own repr")
 
-class TestPythonPartial(TestPartial):
+class TestPartialCSubclass(BaseTestC, TestPartial):
 
-    thetype = PythonPartial
+    class PartialSubclass(c_functools.partial):
+        pass
 
-    # the python version hasn't a nice repr
-    def test_repr(self): pass
+    partial = staticmethod(PartialSubclass)
 
-    # the python version isn't picklable
-    def test_pickle(self): pass
+class TestPartialPySubclass(TestPartialPy):
+
+    class PartialSubclass(c_functools.partial):
+        pass
+
+    partial = staticmethod(PartialSubclass)
 
 class TestUpdateWrapper(unittest.TestCase):
 
@@ -320,7 +342,7 @@
         self.assertEqual(wrapper.__qualname__, f.__qualname__)
         self.assertEqual(wrapper.attr, 'This is also a test')
 
-    @unittest.skipIf(not sys.flags.optimize <= 1,
+    @unittest.skipIf(sys.flags.optimize >= 2,
                      "Docstrings are omitted with -O2 and above")
     def test_default_update_doc(self):
         wrapper, _ = self._default_update()
@@ -441,24 +463,28 @@
         d = {"one": 1, "two": 2, "three": 3}
         self.assertEqual(self.func(add, d), "".join(d.keys()))
 
-class TestCmpToKey(unittest.TestCase):
+class TestCmpToKey(object):
 
     def test_cmp_to_key(self):
         def cmp1(x, y):
             return (x > y) - (x < y)
-        key = functools.cmp_to_key(cmp1)
+        key = self.cmp_to_key(cmp1)
         self.assertEqual(key(3), key(3))
         self.assertGreater(key(3), key(1))
+        self.assertGreaterEqual(key(3), key(3))
+
         def cmp2(x, y):
             return int(x) - int(y)
-        key = functools.cmp_to_key(cmp2)
+        key = self.cmp_to_key(cmp2)
         self.assertEqual(key(4.0), key('4'))
         self.assertLess(key(2), key('35'))
+        self.assertLessEqual(key(2), key('35'))
+        self.assertNotEqual(key(2), key('35'))
 
     def test_cmp_to_key_arguments(self):
         def cmp1(x, y):
             return (x > y) - (x < y)
-        key = functools.cmp_to_key(mycmp=cmp1)
+        key = self.cmp_to_key(mycmp=cmp1)
         self.assertEqual(key(obj=3), key(obj=3))
         self.assertGreater(key(obj=3), key(obj=1))
         with self.assertRaises((TypeError, AttributeError)):
@@ -466,10 +492,10 @@
         with self.assertRaises((TypeError, AttributeError)):
             1 < key(3)    # lhs is not a K object
         with self.assertRaises(TypeError):
-            key = functools.cmp_to_key()             # too few args
+            key = self.cmp_to_key()             # too few args
         with self.assertRaises(TypeError):
-            key = functools.cmp_to_key(cmp1, None)   # too many args
-        key = functools.cmp_to_key(cmp1)
+            key = self.module.cmp_to_key(cmp1, None)   # too many args
+        key = self.cmp_to_key(cmp1)
         with self.assertRaises(TypeError):
             key()                                    # too few args
         with self.assertRaises(TypeError):
@@ -478,7 +504,7 @@
     def test_bad_cmp(self):
         def cmp1(x, y):
             raise ZeroDivisionError
-        key = functools.cmp_to_key(cmp1)
+        key = self.cmp_to_key(cmp1)
         with self.assertRaises(ZeroDivisionError):
             key(3) > key(1)
 
@@ -493,13 +519,13 @@
     def test_obj_field(self):
         def cmp1(x, y):
             return (x > y) - (x < y)
-        key = functools.cmp_to_key(mycmp=cmp1)
+        key = self.cmp_to_key(mycmp=cmp1)
         self.assertEqual(key(50).obj, 50)
 
     def test_sort_int(self):
         def mycmp(x, y):
             return y - x
-        self.assertEqual(sorted(range(5), key=functools.cmp_to_key(mycmp)),
+        self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)),
                          [4, 3, 2, 1, 0])
 
     def test_sort_int_str(self):
@@ -507,18 +533,24 @@
             x, y = int(x), int(y)
             return (x > y) - (x < y)
         values = [5, '3', 7, 2, '0', '1', 4, '10', 1]
-        values = sorted(values, key=functools.cmp_to_key(mycmp))
+        values = sorted(values, key=self.cmp_to_key(mycmp))
         self.assertEqual([int(value) for value in values],
                          [0, 1, 1, 2, 3, 4, 5, 7, 10])
 
     def test_hash(self):
         def mycmp(x, y):
             return y - x
-        key = functools.cmp_to_key(mycmp)
+        key = self.cmp_to_key(mycmp)
         k = key(10)
         self.assertRaises(TypeError, hash, k)
         self.assertNotIsInstance(k, collections.Hashable)
 
+class TestCmpToKeyC(BaseTestC, TestCmpToKey):
+    cmp_to_key = c_functools.cmp_to_key
+
+class TestCmpToKeyPy(BaseTestPy, TestCmpToKey):
+    cmp_to_key = staticmethod(py_functools.cmp_to_key)
+
 class TestTotalOrdering(unittest.TestCase):
 
     def test_total_ordering_lt(self):
@@ -623,7 +655,7 @@
 
     def test_lru(self):
         def orig(x, y):
-            return 3*x+y
+            return 3 * x + y
         f = functools.lru_cache(maxsize=20)(orig)
         hits, misses, maxsize, currsize = f.cache_info()
         self.assertEqual(maxsize, 20)
@@ -728,7 +760,7 @@
         # Verify that user_function exceptions get passed through without
         # creating a hard-to-read chained exception.
         # http://bugs.python.org/issue13177
-        for maxsize in (None, 100):
+        for maxsize in (None, 128):
             @functools.lru_cache(maxsize)
             def func(i):
                 return 'abc'[i]
@@ -741,7 +773,7 @@
                 func(15)
 
     def test_lru_with_types(self):
-        for maxsize in (None, 100):
+        for maxsize in (None, 128):
             @functools.lru_cache(maxsize=maxsize, typed=True)
             def square(x):
                 return x * x
@@ -756,14 +788,46 @@
             self.assertEqual(square.cache_info().hits, 4)
             self.assertEqual(square.cache_info().misses, 4)
 
+    def test_lru_with_keyword_args(self):
+        @functools.lru_cache()
+        def fib(n):
+            if n < 2:
+                return n
+            return fib(n=n-1) + fib(n=n-2)
+        self.assertEqual(
+            [fib(n=number) for number in range(16)],
+            [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]
+        )
+        self.assertEqual(fib.cache_info(),
+            functools._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16))
+        fib.cache_clear()
+        self.assertEqual(fib.cache_info(),
+            functools._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0))
+
+    def test_lru_with_keyword_args_maxsize_none(self):
+        @functools.lru_cache(maxsize=None)
+        def fib(n):
+            if n < 2:
+                return n
+            return fib(n=n-1) + fib(n=n-2)
+        self.assertEqual([fib(n=number) for number in range(16)],
+            [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
+        self.assertEqual(fib.cache_info(),
+            functools._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
+        fib.cache_clear()
+        self.assertEqual(fib.cache_info(),
+            functools._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
+
 def test_main(verbose=None):
     test_classes = (
-        TestPartial,
-        TestPartialSubclass,
-        TestPythonPartial,
+        TestPartialC,
+        TestPartialPy,
+        TestPartialCSubclass,
+        TestPartialPySubclass,
         TestUpdateWrapper,
         TestTotalOrdering,
-        TestCmpToKey,
+        TestCmpToKeyC,
+        TestCmpToKeyPy,
         TestWraps,
         TestReduce,
         TestLRU,
diff --git a/Misc/ACKS b/Misc/ACKS
--- a/Misc/ACKS
+++ b/Misc/ACKS
@@ -1166,6 +1166,7 @@
 Nicolas M. Thiéry
 James Thomas
 Robin Thomas
+Brian Thorne
 Stephen Thorne
 Jeremy Thurgood
 Eric Tiedemann
diff --git a/Misc/NEWS b/Misc/NEWS
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -124,6 +124,9 @@
 Library
 -------
 
+- Issue #12428: Add a pure Python implementation of functools.partial().
+  Patch by Brian Thorne.
+
 - Issue #16140: The subprocess module no longer double closes its child
   subprocess.PIPE parent file descriptors on child error prior to exec().
 

-- 
Repository URL: http://hg.python.org/cpython


More information about the Python-checkins mailing list