[Python-checkins] cpython: Issue #4331: Added functools.partialmethod

nick.coghlan python-checkins at python.org
Sun Nov 3 07:42:40 CET 2013


http://hg.python.org/cpython/rev/46d3c5539981
changeset:   86867:46d3c5539981
user:        Nick Coghlan <ncoghlan at gmail.com>
date:        Sun Nov 03 16:41:46 2013 +1000
summary:
  Issue #4331: Added functools.partialmethod

Initial patch by Alon Horev

files:
  Doc/library/functools.rst  |   43 +++++++++-
  Doc/whatsnew/3.4.rst       |   20 ++++-
  Lib/functools.py           |   78 ++++++++++++++++-
  Lib/test/test_functools.py |  116 +++++++++++++++++++++++++
  Misc/NEWS                  |    2 +
  5 files changed, 255 insertions(+), 4 deletions(-)


diff --git a/Doc/library/functools.rst b/Doc/library/functools.rst
--- a/Doc/library/functools.rst
+++ b/Doc/library/functools.rst
@@ -194,6 +194,48 @@
       18
 
 
+.. class:: partialmethod(func, *args, **keywords)
+
+   Return a new :class:`partialmethod` descriptor which behaves
+   like :class:`partial` except that it is designed to be used as a method
+   definition rather than being directly callable.
+
+   *func* must be a :term:`descriptor` or a callable (objects which are both,
+   like normal functions, are handled as descriptors).
+
+   When *func* is a descriptor (such as a normal Python function,
+   :func:`classmethod`, :func:`staticmethod`, :func:`abstractmethod` or
+   another instance of :class:`partialmethod`), calls to ``__get__`` are
+   delegated to the underlying descriptor, and an appropriate
+   :class:`partial` object returned as the result.
+
+   When *func* is a non-descriptor callable, an appropriate bound method is
+   created dynamically. This behaves like a normal Python function when
+   used as a method: the *self* argument will be inserted as the first
+   positional argument, even before the *args* and *keywords* supplied to
+   the :class:`partialmethod` constructor.
+
+   Example::
+
+      >>> class Cell(object):
+      ...     @property
+      ...     def alive(self):
+      ...         return self._alive
+      ...     def set_state(self, state):
+      ...         self._alive = bool(state)
+      ...     set_alive = partialmethod(set_alive, True)
+      ...     set_dead = partialmethod(set_alive, False)
+      ...
+      >>> c = Cell()
+      >>> c.alive
+      False
+      >>> c.set_alive()
+      >>> c.alive
+      True
+
+   .. versionadded:: 3.4
+
+
 .. function:: reduce(function, iterable[, initializer])
 
    Apply *function* of two arguments cumulatively to the items of *sequence*, from
@@ -431,4 +473,3 @@
 are not created automatically.  Also, :class:`partial` objects defined in
 classes behave like static methods and do not transform into bound methods
 during instance attribute look-up.
-
diff --git a/Doc/whatsnew/3.4.rst b/Doc/whatsnew/3.4.rst
--- a/Doc/whatsnew/3.4.rst
+++ b/Doc/whatsnew/3.4.rst
@@ -342,7 +342,25 @@
 functools
 ---------
 
-New :func:`functools.singledispatch` decorator: see the :pep:`443`.
+The new :func:`~functools.partialmethod` descriptor bring partial argument
+application to descriptors, just as :func:`~functools.partial` provides
+for normal callables. The new descriptor also makes it easier to get
+arbitrary callables (including :func:`~functools.partial` instances)
+to behave like normal instance methods when included in a class definition.
+
+(Contributed by Alon Horev and Nick Coghlan in :issue:`4331`)
+
+The new :func:`~functools.singledispatch` decorator brings support for
+single-dispatch generic functions to the Python standard library. Where
+object oriented programming focuses on grouping multiple operations on a
+common set of data into a class, a generic function focuses on grouping
+multiple implementations of an operation that allows it to work with
+*different* kinds of data.
+
+.. seealso::
+
+   :pep:`443` - Single-dispatch generic functions
+      PEP written and implemented by Łukasz Langa.
 
 
 hashlib
diff --git a/Lib/functools.py b/Lib/functools.py
--- a/Lib/functools.py
+++ b/Lib/functools.py
@@ -19,7 +19,7 @@
     pass
 from abc import get_cache_token
 from collections import namedtuple
-from types import MappingProxyType
+from types import MappingProxyType, MethodType
 from weakref import WeakKeyDictionary
 try:
     from _thread import RLock
@@ -223,8 +223,9 @@
 ### partial() argument application
 ################################################################################
 
+# Purely functional, no descriptor behaviour
 def partial(func, *args, **keywords):
-    """new function with partial application of the given arguments
+    """New function with partial application of the given arguments
     and keywords.
     """
     def newfunc(*fargs, **fkeywords):
@@ -241,6 +242,79 @@
 except ImportError:
     pass
 
+# Descriptor version
+class partialmethod(object):
+    """Method descriptor with partial application of the given arguments
+    and keywords.
+
+    Supports wrapping existing descriptors and handles non-descriptor
+    callables as instance methods.
+    """
+
+    def __init__(self, func, *args, **keywords):
+        if not callable(func) and not hasattr(func, "__get__"):
+            raise TypeError("{!r} is not callable or a descriptor"
+                                 .format(func))
+
+        # func could be a descriptor like classmethod which isn't callable,
+        # so we can't inherit from partial (it verifies func is callable)
+        if isinstance(func, partialmethod):
+            # flattening is mandatory in order to place cls/self before all
+            # other arguments
+            # it's also more efficient since only one function will be called
+            self.func = func.func
+            self.args = func.args + args
+            self.keywords = func.keywords.copy()
+            self.keywords.update(keywords)
+        else:
+            self.func = func
+            self.args = args
+            self.keywords = keywords
+
+    def __repr__(self):
+        args = ", ".join(map(repr, self.args))
+        keywords = ", ".join("{}={!r}".format(k, v)
+                                 for k, v in self.keywords.items())
+        format_string = "{module}.{cls}({func}, {args}, {keywords})"
+        return format_string.format(module=self.__class__.__module__,
+                                    cls=self.__class__.__name__,
+                                    func=self.func,
+                                    args=args,
+                                    keywords=keywords)
+
+    def _make_unbound_method(self):
+        def _method(*args, **keywords):
+            call_keywords = self.keywords.copy()
+            call_keywords.update(keywords)
+            cls_or_self, *rest = args
+            call_args = (cls_or_self,) + self.args + tuple(rest)
+            return self.func(*call_args, **call_keywords)
+        _method.__isabstractmethod__ = self.__isabstractmethod__
+        return _method
+
+    def __get__(self, obj, cls):
+        get = getattr(self.func, "__get__", None)
+        result = None
+        if get is not None:
+            new_func = get(obj, cls)
+            if new_func is not self.func:
+                # Assume __get__ returning something new indicates the
+                # creation of an appropriate callable
+                result = partial(new_func, *self.args, **self.keywords)
+                try:
+                    result.__self__ = new_func.__self__
+                except AttributeError:
+                    pass
+        if result is None:
+            # If the underlying descriptor didn't do anything, treat this
+            # like an instance method
+            result = self._make_unbound_method().__get__(obj, cls)
+        return result
+
+    @property
+    def __isabstractmethod__(self):
+        return getattr(self.func, "__isabstractmethod__", False)
+
 
 ################################################################################
 ### LRU Cache function decorator
diff --git a/Lib/test/test_functools.py b/Lib/test/test_functools.py
--- a/Lib/test/test_functools.py
+++ b/Lib/test/test_functools.py
@@ -1,3 +1,4 @@
+import abc
 import collections
 from itertools import permutations
 import pickle
@@ -217,6 +218,120 @@
         partial = PartialSubclass
 
 
+class TestPartialMethod(unittest.TestCase):
+
+    class A(object):
+        nothing = functools.partialmethod(capture)
+        positional = functools.partialmethod(capture, 1)
+        keywords = functools.partialmethod(capture, a=2)
+        both = functools.partialmethod(capture, 3, b=4)
+
+        nested = functools.partialmethod(positional, 5)
+
+        over_partial = functools.partialmethod(functools.partial(capture, c=6), 7)
+
+        static = functools.partialmethod(staticmethod(capture), 8)
+        cls = functools.partialmethod(classmethod(capture), d=9)
+
+    a = A()
+
+    def test_arg_combinations(self):
+        self.assertEqual(self.a.nothing(), ((self.a,), {}))
+        self.assertEqual(self.a.nothing(5), ((self.a, 5), {}))
+        self.assertEqual(self.a.nothing(c=6), ((self.a,), {'c': 6}))
+        self.assertEqual(self.a.nothing(5, c=6), ((self.a, 5), {'c': 6}))
+
+        self.assertEqual(self.a.positional(), ((self.a, 1), {}))
+        self.assertEqual(self.a.positional(5), ((self.a, 1, 5), {}))
+        self.assertEqual(self.a.positional(c=6), ((self.a, 1), {'c': 6}))
+        self.assertEqual(self.a.positional(5, c=6), ((self.a, 1, 5), {'c': 6}))
+
+        self.assertEqual(self.a.keywords(), ((self.a,), {'a': 2}))
+        self.assertEqual(self.a.keywords(5), ((self.a, 5), {'a': 2}))
+        self.assertEqual(self.a.keywords(c=6), ((self.a,), {'a': 2, 'c': 6}))
+        self.assertEqual(self.a.keywords(5, c=6), ((self.a, 5), {'a': 2, 'c': 6}))
+
+        self.assertEqual(self.a.both(), ((self.a, 3), {'b': 4}))
+        self.assertEqual(self.a.both(5), ((self.a, 3, 5), {'b': 4}))
+        self.assertEqual(self.a.both(c=6), ((self.a, 3), {'b': 4, 'c': 6}))
+        self.assertEqual(self.a.both(5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
+
+        self.assertEqual(self.A.both(self.a, 5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
+
+    def test_nested(self):
+        self.assertEqual(self.a.nested(), ((self.a, 1, 5), {}))
+        self.assertEqual(self.a.nested(6), ((self.a, 1, 5, 6), {}))
+        self.assertEqual(self.a.nested(d=7), ((self.a, 1, 5), {'d': 7}))
+        self.assertEqual(self.a.nested(6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
+
+        self.assertEqual(self.A.nested(self.a, 6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
+
+    def test_over_partial(self):
+        self.assertEqual(self.a.over_partial(), ((self.a, 7), {'c': 6}))
+        self.assertEqual(self.a.over_partial(5), ((self.a, 7, 5), {'c': 6}))
+        self.assertEqual(self.a.over_partial(d=8), ((self.a, 7), {'c': 6, 'd': 8}))
+        self.assertEqual(self.a.over_partial(5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
+
+        self.assertEqual(self.A.over_partial(self.a, 5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
+
+    def test_bound_method_introspection(self):
+        obj = self.a
+        self.assertIs(obj.both.__self__, obj)
+        self.assertIs(obj.nested.__self__, obj)
+        self.assertIs(obj.over_partial.__self__, obj)
+        self.assertIs(obj.cls.__self__, self.A)
+        self.assertIs(self.A.cls.__self__, self.A)
+
+    def test_unbound_method_retrieval(self):
+        obj = self.A
+        self.assertFalse(hasattr(obj.both, "__self__"))
+        self.assertFalse(hasattr(obj.nested, "__self__"))
+        self.assertFalse(hasattr(obj.over_partial, "__self__"))
+        self.assertFalse(hasattr(obj.static, "__self__"))
+        self.assertFalse(hasattr(self.a.static, "__self__"))
+
+    def test_descriptors(self):
+        for obj in [self.A, self.a]:
+            with self.subTest(obj=obj):
+                self.assertEqual(obj.static(), ((8,), {}))
+                self.assertEqual(obj.static(5), ((8, 5), {}))
+                self.assertEqual(obj.static(d=8), ((8,), {'d': 8}))
+                self.assertEqual(obj.static(5, d=8), ((8, 5), {'d': 8}))
+
+                self.assertEqual(obj.cls(), ((self.A,), {'d': 9}))
+                self.assertEqual(obj.cls(5), ((self.A, 5), {'d': 9}))
+                self.assertEqual(obj.cls(c=8), ((self.A,), {'c': 8, 'd': 9}))
+                self.assertEqual(obj.cls(5, c=8), ((self.A, 5), {'c': 8, 'd': 9}))
+
+    def test_overriding_keywords(self):
+        self.assertEqual(self.a.keywords(a=3), ((self.a,), {'a': 3}))
+        self.assertEqual(self.A.keywords(self.a, a=3), ((self.a,), {'a': 3}))
+
+    def test_invalid_args(self):
+        with self.assertRaises(TypeError):
+            class B(object):
+                method = functools.partialmethod(None, 1)
+
+    def test_repr(self):
+        self.assertEqual(repr(vars(self.A)['both']),
+                         'functools.partialmethod({}, 3, b=4)'.format(capture))
+
+    def test_abstract(self):
+        class Abstract(abc.ABCMeta):
+
+            @abc.abstractmethod
+            def add(self, x, y):
+                pass
+
+            add5 = functools.partialmethod(add, 5)
+
+        self.assertTrue(Abstract.add.__isabstractmethod__)
+        self.assertTrue(Abstract.add5.__isabstractmethod__)
+
+        for func in [self.A.static, self.A.cls, self.A.over_partial, self.A.nested, self.A.both]:
+            self.assertFalse(getattr(func, '__isabstractmethod__', False))
+
+
 class TestUpdateWrapper(unittest.TestCase):
 
     def check_wrapper(self, wrapper, wrapped,
@@ -1433,6 +1548,7 @@
         TestPartialC,
         TestPartialPy,
         TestPartialCSubclass,
+        TestPartialMethod,
         TestUpdateWrapper,
         TestTotalOrdering,
         TestCmpToKeyC,
diff --git a/Misc/NEWS b/Misc/NEWS
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -1192,6 +1192,8 @@
 Library
 -------
 
+- Issue #4331: Added functools.partialmethod (Initial patch by Alon Horev)
+
 - Issue #13461: Fix a crash in the TextIOWrapper.tell method on 64-bit
   platforms.  Patch by Yogesh Chaudhari.
 

-- 
Repository URL: http://hg.python.org/cpython


More information about the Python-checkins mailing list