[pypy-commit] pypy singledispatch: Copy singledispatch 3.4.0.2 to rpython/tool/singledispatch/

rlamy noreply at buildbot.pypy.org
Sun Feb 16 04:24:14 CET 2014


Author: Ronan Lamy <ronan.lamy at gmail.com>
Branch: singledispatch
Changeset: r69168:b2ff76f3966b
Date: 2014-02-16 01:59 +0000
http://bitbucket.org/pypy/pypy/changeset/b2ff76f3966b/

Log:	Copy singledispatch 3.4.0.2 to rpython/tool/singledispatch/

diff --git a/rpython/tool/singledispatch/singledispatch.py b/rpython/tool/singledispatch/singledispatch.py
new file mode 100644
--- /dev/null
+++ b/rpython/tool/singledispatch/singledispatch.py
@@ -0,0 +1,219 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+__all__ = ['singledispatch']
+
+from functools import update_wrapper
+from weakref import WeakKeyDictionary
+from singledispatch_helpers import MappingProxyType, get_cache_token
+
+################################################################################
+### singledispatch() - single-dispatch generic function decorator
+################################################################################
+
+def _c3_merge(sequences):
+    """Merges MROs in *sequences* to a single MRO using the C3 algorithm.
+
+    Adapted from http://www.python.org/download/releases/2.3/mro/.
+
+    """
+    result = []
+    while True:
+        sequences = [s for s in sequences if s]   # purge empty sequences
+        if not sequences:
+            return result
+        for s1 in sequences:   # find merge candidates among seq heads
+            candidate = s1[0]
+            for s2 in sequences:
+                if candidate in s2[1:]:
+                    candidate = None
+                    break      # reject the current head, it appears later
+            else:
+                break
+        if not candidate:
+            raise RuntimeError("Inconsistent hierarchy")
+        result.append(candidate)
+        # remove the chosen candidate
+        for seq in sequences:
+            if seq[0] == candidate:
+                del seq[0]
+
+def _c3_mro(cls, abcs=None):
+    """Computes the method resolution order using extended C3 linearization.
+
+    If no *abcs* are given, the algorithm works exactly like the built-in C3
+    linearization used for method resolution.
+
+    If given, *abcs* is a list of abstract base classes that should be inserted
+    into the resulting MRO. Unrelated ABCs are ignored and don't end up in the
+    result. The algorithm inserts ABCs where their functionality is introduced,
+    i.e. issubclass(cls, abc) returns True for the class itself but returns
+    False for all its direct base classes. Implicit ABCs for a given class
+    (either registered or inferred from the presence of a special method like
+    __len__) are inserted directly after the last ABC explicitly listed in the
+    MRO of said class. If two implicit ABCs end up next to each other in the
+    resulting MRO, their ordering depends on the order of types in *abcs*.
+
+    """
+    for i, base in enumerate(reversed(cls.__bases__)):
+        if hasattr(base, '__abstractmethods__'):
+            boundary = len(cls.__bases__) - i
+            break   # Bases up to the last explicit ABC are considered first.
+    else:
+        boundary = 0
+    abcs = list(abcs) if abcs else []
+    explicit_bases = list(cls.__bases__[:boundary])
+    abstract_bases = []
+    other_bases = list(cls.__bases__[boundary:])
+    for base in abcs:
+        if issubclass(cls, base) and not any(
+                issubclass(b, base) for b in cls.__bases__
+            ):
+            # If *cls* is the class that introduces behaviour described by
+            # an ABC *base*, insert said ABC to its MRO.
+            abstract_bases.append(base)
+    for base in abstract_bases:
+        abcs.remove(base)
+    explicit_c3_mros = [_c3_mro(base, abcs=abcs) for base in explicit_bases]
+    abstract_c3_mros = [_c3_mro(base, abcs=abcs) for base in abstract_bases]
+    other_c3_mros = [_c3_mro(base, abcs=abcs) for base in other_bases]
+    return _c3_merge(
+        [[cls]] +
+        explicit_c3_mros + abstract_c3_mros + other_c3_mros +
+        [explicit_bases] + [abstract_bases] + [other_bases]
+    )
+
+def _compose_mro(cls, types):
+    """Calculates the method resolution order for a given class *cls*.
+
+    Includes relevant abstract base classes (with their respective bases) from
+    the *types* iterable. Uses a modified C3 linearization algorithm.
+
+    """
+    bases = set(cls.__mro__)
+    # Remove entries which are already present in the __mro__ or unrelated.
+    def is_related(typ):
+        return (typ not in bases and hasattr(typ, '__mro__')
+                                 and issubclass(cls, typ))
+    types = [n for n in types if is_related(n)]
+    # Remove entries which are strict bases of other entries (they will end up
+    # in the MRO anyway.
+    def is_strict_base(typ):
+        for other in types:
+            if typ != other and typ in other.__mro__:
+                return True
+        return False
+    types = [n for n in types if not is_strict_base(n)]
+    # Subclasses of the ABCs in *types* which are also implemented by
+    # *cls* can be used to stabilize ABC ordering.
+    type_set = set(types)
+    mro = []
+    for typ in types:
+        found = []
+        for sub in typ.__subclasses__():
+            if sub not in bases and issubclass(cls, sub):
+                found.append([s for s in sub.__mro__ if s in type_set])
+        if not found:
+            mro.append(typ)
+            continue
+        # Favor subclasses with the biggest number of useful bases
+        found.sort(key=len, reverse=True)
+        for sub in found:
+            for subcls in sub:
+                if subcls not in mro:
+                    mro.append(subcls)
+    return _c3_mro(cls, abcs=mro)
+
+def _find_impl(cls, registry):
+    """Returns the best matching implementation from *registry* for type *cls*.
+
+    Where there is no registered implementation for a specific type, its method
+    resolution order is used to find a more generic implementation.
+
+    Note: if *registry* does not contain an implementation for the base
+    *object* type, this function may return None.
+
+    """
+    mro = _compose_mro(cls, registry.keys())
+    match = None
+    for t in mro:
+        if match is not None:
+            # If *match* is an implicit ABC but there is another unrelated,
+            # equally matching implicit ABC, refuse the temptation to guess.
+            if (t in registry and t not in cls.__mro__
+                              and match not in cls.__mro__
+                              and not issubclass(match, t)):
+                raise RuntimeError("Ambiguous dispatch: {0} or {1}".format(
+                    match, t))
+            break
+        if t in registry:
+            match = t
+    return registry.get(match)
+
+def singledispatch(func):
+    """Single-dispatch generic function decorator.
+
+    Transforms a function into a generic function, which can have different
+    behaviours depending upon the type of its first argument. The decorated
+    function acts as the default implementation, and additional
+    implementations can be registered using the register() attribute of the
+    generic function.
+
+    """
+    registry = {}
+    dispatch_cache = WeakKeyDictionary()
+    def ns(): pass
+    ns.cache_token = None
+
+    def dispatch(cls):
+        """generic_func.dispatch(cls) -> <function implementation>
+
+        Runs the dispatch algorithm to return the best available implementation
+        for the given *cls* registered on *generic_func*.
+
+        """
+        if ns.cache_token is not None:
+            current_token = get_cache_token()
+            if ns.cache_token != current_token:
+                dispatch_cache.clear()
+                ns.cache_token = current_token
+        try:
+            impl = dispatch_cache[cls]
+        except KeyError:
+            try:
+                impl = registry[cls]
+            except KeyError:
+                impl = _find_impl(cls, registry)
+            dispatch_cache[cls] = impl
+        return impl
+
+    def register(cls, func=None):
+        """generic_func.register(cls, func) -> func
+
+        Registers a new implementation for the given *cls* on a *generic_func*.
+
+        """
+        if func is None:
+            return lambda f: register(cls, f)
+        registry[cls] = func
+        if ns.cache_token is None and hasattr(cls, '__abstractmethods__'):
+            ns.cache_token = get_cache_token()
+        dispatch_cache.clear()
+        return func
+
+    def wrapper(*args, **kw):
+        return dispatch(args[0].__class__)(*args, **kw)
+
+    registry[object] = func
+    wrapper.register = register
+    wrapper.dispatch = dispatch
+    wrapper.registry = MappingProxyType(registry)
+    wrapper._clear_cache = dispatch_cache.clear
+    update_wrapper(wrapper, func)
+    return wrapper
+
diff --git a/rpython/tool/singledispatch/singledispatch_helpers.py b/rpython/tool/singledispatch/singledispatch_helpers.py
new file mode 100644
--- /dev/null
+++ b/rpython/tool/singledispatch/singledispatch_helpers.py
@@ -0,0 +1,170 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+from abc import ABCMeta
+from collections import MutableMapping
+import sys
+try:
+    from collections import UserDict
+except ImportError:
+    from UserDict import UserDict
+try:
+    from collections import OrderedDict
+except ImportError:
+    from ordereddict import OrderedDict
+try:
+    from thread import get_ident
+except ImportError:
+    try:
+        from _thread import get_ident
+    except ImportError:
+        from _dummy_thread import get_ident
+
+
+def recursive_repr(fillvalue='...'):
+    'Decorator to make a repr function return fillvalue for a recursive call'
+
+    def decorating_function(user_function):
+        repr_running = set()
+
+        def wrapper(self):
+            key = id(self), get_ident()
+            if key in repr_running:
+                return fillvalue
+            repr_running.add(key)
+            try:
+                result = user_function(self)
+            finally:
+                repr_running.discard(key)
+            return result
+
+        # Can't use functools.wraps() here because of bootstrap issues
+        wrapper.__module__ = getattr(user_function, '__module__')
+        wrapper.__doc__ = getattr(user_function, '__doc__')
+        wrapper.__name__ = getattr(user_function, '__name__')
+        wrapper.__annotations__ = getattr(user_function, '__annotations__', {})
+        return wrapper
+
+    return decorating_function
+
+
+class ChainMap(MutableMapping):
+    ''' A ChainMap groups multiple dicts (or other mappings) together
+    to create a single, updateable view.
+
+    The underlying mappings are stored in a list.  That list is public and can
+    accessed or updated using the *maps* attribute.  There is no other state.
+
+    Lookups search the underlying mappings successively until a key is found.
+    In contrast, writes, updates, and deletions only operate on the first
+    mapping.
+
+    '''
+
+    def __init__(self, *maps):
+        '''Initialize a ChainMap by setting *maps* to the given mappings.
+        If no mappings are provided, a single empty dictionary is used.
+
+        '''
+        self.maps = list(maps) or [{}]          # always at least one map
+
+    def __missing__(self, key):
+        raise KeyError(key)
+
+    def __getitem__(self, key):
+        for mapping in self.maps:
+            try:
+                return mapping[key]             # can't use 'key in mapping' with defaultdict
+            except KeyError:
+                pass
+        return self.__missing__(key)            # support subclasses that define __missing__
+
+    def get(self, key, default=None):
+        return self[key] if key in self else default
+
+    def __len__(self):
+        return len(set().union(*self.maps))     # reuses stored hash values if possible
+
+    def __iter__(self):
+        return iter(set().union(*self.maps))
+
+    def __contains__(self, key):
+        return any(key in m for m in self.maps)
+
+    @recursive_repr()
+    def __repr__(self):
+        return '{0.__class__.__name__}({1})'.format(
+            self, ', '.join(map(repr, self.maps)))
+
+    @classmethod
+    def fromkeys(cls, iterable, *args):
+        'Create a ChainMap with a single dict created from the iterable.'
+        return cls(dict.fromkeys(iterable, *args))
+
+    def copy(self):
+        'New ChainMap or subclass with a new copy of maps[0] and refs to maps[1:]'
+        return self.__class__(self.maps[0].copy(), *self.maps[1:])
+
+    __copy__ = copy
+
+    def new_child(self):                        # like Django's Context.push()
+        'New ChainMap with a new dict followed by all previous maps.'
+        return self.__class__({}, *self.maps)
+
+    @property
+    def parents(self):                          # like Django's Context.pop()
+        'New ChainMap from maps[1:].'
+        return self.__class__(*self.maps[1:])
+
+    def __setitem__(self, key, value):
+        self.maps[0][key] = value
+
+    def __delitem__(self, key):
+        try:
+            del self.maps[0][key]
+        except KeyError:
+            raise KeyError('Key not found in the first mapping: {!r}'.format(key))
+
+    def popitem(self):
+        'Remove and return an item pair from maps[0]. Raise KeyError is maps[0] is empty.'
+        try:
+            return self.maps[0].popitem()
+        except KeyError:
+            raise KeyError('No keys found in the first mapping.')
+
+    def pop(self, key, *args):
+        'Remove *key* from maps[0] and return its value. Raise KeyError if *key* not in maps[0].'
+        try:
+            return self.maps[0].pop(key, *args)
+        except KeyError:
+            raise KeyError('Key not found in the first mapping: {!r}'.format(key))
+
+    def clear(self):
+        'Clear maps[0], leaving maps[1:] intact.'
+        self.maps[0].clear()
+
+
+class MappingProxyType(UserDict):
+    def __init__(self, data):
+        UserDict.__init__(self)
+        self.data = data
+
+
+def get_cache_token():
+    return ABCMeta._abc_invalidation_counter
+
+
+
+class Support(object):
+    def dummy(self):
+        pass
+
+    def cpython_only(self, func):
+        if 'PyPy' in sys.version:
+            return self.dummy
+        return func
diff --git a/rpython/tool/singledispatch/test_singledispatch.py b/rpython/tool/singledispatch/test_singledispatch.py
new file mode 100644
--- /dev/null
+++ b/rpython/tool/singledispatch/test_singledispatch.py
@@ -0,0 +1,519 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+import collections
+import decimal
+from itertools import permutations
+import singledispatch as functools
+from singledispatch_helpers import Support
+try:
+    from collections import ChainMap
+except ImportError:
+    from singledispatch_helpers import ChainMap
+    collections.ChainMap = ChainMap
+try:
+    from collections import OrderedDict
+except ImportError:
+    from singledispatch_helpers import OrderedDict
+    collections.OrderedDict = OrderedDict
+try:
+    import unittest2 as unittest
+except ImportError:
+    import unittest
+
+
+support = Support()
+for _prefix in ('collections.abc', '_abcoll'):
+    if _prefix in repr(collections.Container):
+        abcoll_prefix = _prefix
+        break
+else:
+    abcoll_prefix = '?'
+del _prefix
+
+
+class TestSingleDispatch(unittest.TestCase):
+    def test_simple_overloads(self):
+        @functools.singledispatch
+        def g(obj):
+            return "base"
+        def g_int(i):
+            return "integer"
+        g.register(int, g_int)
+        self.assertEqual(g("str"), "base")
+        self.assertEqual(g(1), "integer")
+        self.assertEqual(g([1,2,3]), "base")
+
+    def test_mro(self):
+        @functools.singledispatch
+        def g(obj):
+            return "base"
+        class A(object):
+            pass
+        class C(A):
+            pass
+        class B(A):
+            pass
+        class D(C, B):
+            pass
+        def g_A(a):
+            return "A"
+        def g_B(b):
+            return "B"
+        g.register(A, g_A)
+        g.register(B, g_B)
+        self.assertEqual(g(A()), "A")
+        self.assertEqual(g(B()), "B")
+        self.assertEqual(g(C()), "A")
+        self.assertEqual(g(D()), "B")
+
+    def test_register_decorator(self):
+        @functools.singledispatch
+        def g(obj):
+            return "base"
+        @g.register(int)
+        def g_int(i):
+            return "int %s" % (i,)
+        self.assertEqual(g(""), "base")
+        self.assertEqual(g(12), "int 12")
+        self.assertIs(g.dispatch(int), g_int)
+        self.assertIs(g.dispatch(object), g.dispatch(str))
+        # Note: in the assert above this is not g.
+        # @singledispatch returns the wrapper.
+
+    def test_wrapping_attributes(self):
+        @functools.singledispatch
+        def g(obj):
+            "Simple test"
+            return "Test"
+        self.assertEqual(g.__name__, "g")
+        self.assertEqual(g.__doc__, "Simple test")
+
+    @unittest.skipUnless(decimal, 'requires _decimal')
+    @support.cpython_only
+    def test_c_classes(self):
+        @functools.singledispatch
+        def g(obj):
+            return "base"
+        @g.register(decimal.DecimalException)
+        def _(obj):
+            return obj.args
+        subn = decimal.Subnormal("Exponent < Emin")
+        rnd = decimal.Rounded("Number got rounded")
+        self.assertEqual(g(subn), ("Exponent < Emin",))
+        self.assertEqual(g(rnd), ("Number got rounded",))
+        @g.register(decimal.Subnormal)
+        def _(obj):
+            return "Too small to care."
+        self.assertEqual(g(subn), "Too small to care.")
+        self.assertEqual(g(rnd), ("Number got rounded",))
+
+    def test_compose_mro(self):
+        # None of the examples in this test depend on haystack ordering.
+        c = collections
+        mro = functools._compose_mro
+        bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set]
+        for haystack in permutations(bases):
+            m = mro(dict, haystack)
+            self.assertEqual(m, [dict, c.MutableMapping, c.Mapping, c.Sized,
+                                 c.Iterable, c.Container, object])
+        bases = [c.Container, c.Mapping, c.MutableMapping, c.OrderedDict]
+        for haystack in permutations(bases):
+            m = mro(c.ChainMap, haystack)
+            self.assertEqual(m, [c.ChainMap, c.MutableMapping, c.Mapping,
+                                 c.Sized, c.Iterable, c.Container, object])
+
+        # If there's a generic function with implementations registered for
+        # both Sized and Container, passing a defaultdict to it results in an
+        # ambiguous dispatch which will cause a RuntimeError (see
+        # test_mro_conflicts).
+        bases = [c.Container, c.Sized, str]
+        for haystack in permutations(bases):
+            m = mro(c.defaultdict, [c.Sized, c.Container, str])
+            self.assertEqual(m, [c.defaultdict, dict, c.Sized, c.Container,
+                                 object])
+
+        # MutableSequence below is registered directly on D. In other words, it
+        # preceeds MutableMapping which means single dispatch will always
+        # choose MutableSequence here.
+        class D(c.defaultdict):
+            pass
+        c.MutableSequence.register(D)
+        bases = [c.MutableSequence, c.MutableMapping]
+        for haystack in permutations(bases):
+            m = mro(D, bases)
+            self.assertEqual(m, [D, c.MutableSequence, c.Sequence,
+                                 c.defaultdict, dict, c.MutableMapping,
+                                 c.Mapping, c.Sized, c.Iterable, c.Container,
+                                 object])
+
+        # Container and Callable are registered on different base classes and
+        # a generic function supporting both should always pick the Callable
+        # implementation if a C instance is passed.
+        class C(c.defaultdict):
+            def __call__(self):
+                pass
+        bases = [c.Sized, c.Callable, c.Container, c.Mapping]
+        for haystack in permutations(bases):
+            m = mro(C, haystack)
+            self.assertEqual(m, [C, c.Callable, c.defaultdict, dict, c.Mapping,
+                                 c.Sized, c.Iterable, c.Container, object])
+
+    def test_register_abc(self):
+        c = collections
+        d = {"a": "b"}
+        l = [1, 2, 3]
+        s = set([object(), None])
+        f = frozenset(s)
+        t = (1, 2, 3)
+        @functools.singledispatch
+        def g(obj):
+            return "base"
+        self.assertEqual(g(d), "base")
+        self.assertEqual(g(l), "base")
+        self.assertEqual(g(s), "base")
+        self.assertEqual(g(f), "base")
+        self.assertEqual(g(t), "base")
+        g.register(c.Sized, lambda obj: "sized")
+        self.assertEqual(g(d), "sized")
+        self.assertEqual(g(l), "sized")
+        self.assertEqual(g(s), "sized")
+        self.assertEqual(g(f), "sized")
+        self.assertEqual(g(t), "sized")
+        g.register(c.MutableMapping, lambda obj: "mutablemapping")
+        self.assertEqual(g(d), "mutablemapping")
+        self.assertEqual(g(l), "sized")
+        self.assertEqual(g(s), "sized")
+        self.assertEqual(g(f), "sized")
+        self.assertEqual(g(t), "sized")
+        g.register(c.ChainMap, lambda obj: "chainmap")
+        self.assertEqual(g(d), "mutablemapping")  # irrelevant ABCs registered
+        self.assertEqual(g(l), "sized")
+        self.assertEqual(g(s), "sized")
+        self.assertEqual(g(f), "sized")
+        self.assertEqual(g(t), "sized")
+        g.register(c.MutableSequence, lambda obj: "mutablesequence")
+        self.assertEqual(g(d), "mutablemapping")
+        self.assertEqual(g(l), "mutablesequence")
+        self.assertEqual(g(s), "sized")
+        self.assertEqual(g(f), "sized")
+        self.assertEqual(g(t), "sized")
+        g.register(c.MutableSet, lambda obj: "mutableset")
+        self.assertEqual(g(d), "mutablemapping")
+        self.assertEqual(g(l), "mutablesequence")
+        self.assertEqual(g(s), "mutableset")
+        self.assertEqual(g(f), "sized")
+        self.assertEqual(g(t), "sized")
+        g.register(c.Mapping, lambda obj: "mapping")
+        self.assertEqual(g(d), "mutablemapping")  # not specific enough
+        self.assertEqual(g(l), "mutablesequence")
+        self.assertEqual(g(s), "mutableset")
+        self.assertEqual(g(f), "sized")
+        self.assertEqual(g(t), "sized")
+        g.register(c.Sequence, lambda obj: "sequence")
+        self.assertEqual(g(d), "mutablemapping")
+        self.assertEqual(g(l), "mutablesequence")
+        self.assertEqual(g(s), "mutableset")
+        self.assertEqual(g(f), "sized")
+        self.assertEqual(g(t), "sequence")
+        g.register(c.Set, lambda obj: "set")
+        self.assertEqual(g(d), "mutablemapping")
+        self.assertEqual(g(l), "mutablesequence")
+        self.assertEqual(g(s), "mutableset")
+        self.assertEqual(g(f), "set")
+        self.assertEqual(g(t), "sequence")
+        g.register(dict, lambda obj: "dict")
+        self.assertEqual(g(d), "dict")
+        self.assertEqual(g(l), "mutablesequence")
+        self.assertEqual(g(s), "mutableset")
+        self.assertEqual(g(f), "set")
+        self.assertEqual(g(t), "sequence")
+        g.register(list, lambda obj: "list")
+        self.assertEqual(g(d), "dict")
+        self.assertEqual(g(l), "list")
+        self.assertEqual(g(s), "mutableset")
+        self.assertEqual(g(f), "set")
+        self.assertEqual(g(t), "sequence")
+        g.register(set, lambda obj: "concrete-set")
+        self.assertEqual(g(d), "dict")
+        self.assertEqual(g(l), "list")
+        self.assertEqual(g(s), "concrete-set")
+        self.assertEqual(g(f), "set")
+        self.assertEqual(g(t), "sequence")
+        g.register(frozenset, lambda obj: "frozen-set")
+        self.assertEqual(g(d), "dict")
+        self.assertEqual(g(l), "list")
+        self.assertEqual(g(s), "concrete-set")
+        self.assertEqual(g(f), "frozen-set")
+        self.assertEqual(g(t), "sequence")
+        g.register(tuple, lambda obj: "tuple")
+        self.assertEqual(g(d), "dict")
+        self.assertEqual(g(l), "list")
+        self.assertEqual(g(s), "concrete-set")
+        self.assertEqual(g(f), "frozen-set")
+        self.assertEqual(g(t), "tuple")
+
+    def test_c3_abc(self):
+        c = collections
+        mro = functools._c3_mro
+        class A(object):
+            pass
+        class B(A):
+            def __len__(self):
+                return 0   # implies Sized
+        #@c.Container.register
+        class C(object):
+            pass
+        c.Container.register(C)
+        class D(object):
+            pass   # unrelated
+        class X(D, C, B):
+            def __call__(self):
+                pass   # implies Callable
+        expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object]
+        for abcs in permutations([c.Sized, c.Callable, c.Container]):
+            self.assertEqual(mro(X, abcs=abcs), expected)
+        # unrelated ABCs don't appear in the resulting MRO
+        many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable]
+        self.assertEqual(mro(X, abcs=many_abcs), expected)
+
+    def test_mro_conflicts(self):
+        c = collections
+        @functools.singledispatch
+        def g(arg):
+            return "base"
+        class O(c.Sized):
+            def __len__(self):
+                return 0
+        o = O()
+        self.assertEqual(g(o), "base")
+        g.register(c.Iterable, lambda arg: "iterable")
+        g.register(c.Container, lambda arg: "container")
+        g.register(c.Sized, lambda arg: "sized")
+        g.register(c.Set, lambda arg: "set")
+        self.assertEqual(g(o), "sized")
+        c.Iterable.register(O)
+        self.assertEqual(g(o), "sized")   # because it's explicitly in __mro__
+        c.Container.register(O)
+        self.assertEqual(g(o), "sized")   # see above: Sized is in __mro__
+        c.Set.register(O)
+        self.assertEqual(g(o), "set")     # because c.Set is a subclass of
+                                          # c.Sized and c.Container
+        class P(object):
+            pass
+        p = P()
+        self.assertEqual(g(p), "base")
+        c.Iterable.register(P)
+        self.assertEqual(g(p), "iterable")
+        c.Container.register(P)
+        with self.assertRaises(RuntimeError) as re_one:
+            g(p)
+        self.assertIn(
+            str(re_one.exception),
+            (("Ambiguous dispatch: <class '{prefix}.Container'> "
+              "or <class '{prefix}.Iterable'>").format(prefix=abcoll_prefix),
+             ("Ambiguous dispatch: <class '{prefix}.Iterable'> "
+              "or <class '{prefix}.Container'>").format(prefix=abcoll_prefix)),
+        )
+        class Q(c.Sized):
+            def __len__(self):
+                return 0
+        q = Q()
+        self.assertEqual(g(q), "sized")
+        c.Iterable.register(Q)
+        self.assertEqual(g(q), "sized")   # because it's explicitly in __mro__
+        c.Set.register(Q)
+        self.assertEqual(g(q), "set")     # because c.Set is a subclass of
+                                          # c.Sized and c.Iterable
+        @functools.singledispatch
+        def h(arg):
+            return "base"
+        @h.register(c.Sized)
+        def _(arg):
+            return "sized"
+        @h.register(c.Container)
+        def _(arg):
+            return "container"
+        # Even though Sized and Container are explicit bases of MutableMapping,
+        # this ABC is implicitly registered on defaultdict which makes all of
+        # MutableMapping's bases implicit as well from defaultdict's
+        # perspective.
+        with self.assertRaises(RuntimeError) as re_two:
+            h(c.defaultdict(lambda: 0))
+        self.assertIn(
+            str(re_two.exception),
+            (("Ambiguous dispatch: <class '{prefix}.Container'> "
+              "or <class '{prefix}.Sized'>").format(prefix=abcoll_prefix),
+             ("Ambiguous dispatch: <class '{prefix}.Sized'> "
+              "or <class '{prefix}.Container'>").format(prefix=abcoll_prefix)),
+        )
+        class R(c.defaultdict):
+            pass
+        c.MutableSequence.register(R)
+        @functools.singledispatch
+        def i(arg):
+            return "base"
+        @i.register(c.MutableMapping)
+        def _(arg):
+            return "mapping"
+        @i.register(c.MutableSequence)
+        def _(arg):
+            return "sequence"
+        r = R()
+        self.assertEqual(i(r), "sequence")
+        class S(object):
+            pass
+        class T(S, c.Sized):
+            def __len__(self):
+                return 0
+        t = T()
+        self.assertEqual(h(t), "sized")
+        c.Container.register(T)
+        self.assertEqual(h(t), "sized")   # because it's explicitly in the MRO
+        class U(object):
+            def __len__(self):
+                return 0
+        u = U()
+        self.assertEqual(h(u), "sized")   # implicit Sized subclass inferred
+                                          # from the existence of __len__()
+        c.Container.register(U)
+        # There is no preference for registered versus inferred ABCs.
+        with self.assertRaises(RuntimeError) as re_three:
+            h(u)
+        self.assertIn(
+            str(re_three.exception),
+            (("Ambiguous dispatch: <class '{prefix}.Container'> "
+              "or <class '{prefix}.Sized'>").format(prefix=abcoll_prefix),
+             ("Ambiguous dispatch: <class '{prefix}.Sized'> "
+              "or <class '{prefix}.Container'>").format(prefix=abcoll_prefix)),
+        )
+        class V(c.Sized, S):
+            def __len__(self):
+                return 0
+        @functools.singledispatch
+        def j(arg):
+            return "base"
+        @j.register(S)
+        def _(arg):
+            return "s"
+        @j.register(c.Container)
+        def _(arg):
+            return "container"
+        v = V()
+        self.assertEqual(j(v), "s")
+        c.Container.register(V)
+        self.assertEqual(j(v), "container")   # because it ends up right after
+                                              # Sized in the MRO
+
+    def test_cache_invalidation(self):
+        try:
+            from collections import UserDict
+        except ImportError:
+            from UserDict import UserDict
+        class TracingDict(UserDict):
+            def __init__(self, *args, **kwargs):
+                UserDict.__init__(self, *args, **kwargs)
+                self.set_ops = []
+                self.get_ops = []
+            def __getitem__(self, key):
+                result = self.data[key]
+                self.get_ops.append(key)
+                return result
+            def __setitem__(self, key, value):
+                self.set_ops.append(key)
+                self.data[key] = value
+            def clear(self):
+                self.data.clear()
+        _orig_wkd = functools.WeakKeyDictionary
+        td = TracingDict()
+        functools.WeakKeyDictionary = lambda: td
+        c = collections
+        @functools.singledispatch
+        def g(arg):
+            return "base"
+        d = {}
+        l = []
+        self.assertEqual(len(td), 0)
+        self.assertEqual(g(d), "base")
+        self.assertEqual(len(td), 1)
+        self.assertEqual(td.get_ops, [])
+        self.assertEqual(td.set_ops, [dict])
+        self.assertEqual(td.data[dict], g.registry[object])
+        self.assertEqual(g(l), "base")
+        self.assertEqual(len(td), 2)
+        self.assertEqual(td.get_ops, [])
+        self.assertEqual(td.set_ops, [dict, list])
+        self.assertEqual(td.data[dict], g.registry[object])
+        self.assertEqual(td.data[list], g.registry[object])
+        self.assertEqual(td.data[dict], td.data[list])
+        self.assertEqual(g(l), "base")
+        self.assertEqual(g(d), "base")
+        self.assertEqual(td.get_ops, [list, dict])
+        self.assertEqual(td.set_ops, [dict, list])
+        g.register(list, lambda arg: "list")
+        self.assertEqual(td.get_ops, [list, dict])
+        self.assertEqual(len(td), 0)
+        self.assertEqual(g(d), "base")
+        self.assertEqual(len(td), 1)
+        self.assertEqual(td.get_ops, [list, dict])
+        self.assertEqual(td.set_ops, [dict, list, dict])
+        self.assertEqual(td.data[dict],
+                         functools._find_impl(dict, g.registry))
+        self.assertEqual(g(l), "list")
+        self.assertEqual(len(td), 2)
+        self.assertEqual(td.get_ops, [list, dict])
+        self.assertEqual(td.set_ops, [dict, list, dict, list])
+        self.assertEqual(td.data[list],
+                         functools._find_impl(list, g.registry))
+        class X(object):
+            pass
+        c.MutableMapping.register(X)   # Will not invalidate the cache,
+                                       # not using ABCs yet.
+        self.assertEqual(g(d), "base")
+        self.assertEqual(g(l), "list")
+        self.assertEqual(td.get_ops, [list, dict, dict, list])
+        self.assertEqual(td.set_ops, [dict, list, dict, list])
+        g.register(c.Sized, lambda arg: "sized")
+        self.assertEqual(len(td), 0)
+        self.assertEqual(g(d), "sized")
+        self.assertEqual(len(td), 1)
+        self.assertEqual(td.get_ops, [list, dict, dict, list])
+        self.assertEqual(td.set_ops, [dict, list, dict, list, dict])
+        self.assertEqual(g(l), "list")
+        self.assertEqual(len(td), 2)
+        self.assertEqual(td.get_ops, [list, dict, dict, list])
+        self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
+        self.assertEqual(g(l), "list")
+        self.assertEqual(g(d), "sized")
+        self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict])
+        self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
+        g.dispatch(list)
+        g.dispatch(dict)
+        self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict,
+                                      list, dict])
+        self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
+        c.MutableSet.register(X)       # Will invalidate the cache.
+        self.assertEqual(len(td), 2)   # Stale cache.
+        self.assertEqual(g(l), "list")
+        self.assertEqual(len(td), 1)
+        g.register(c.MutableMapping, lambda arg: "mutablemapping")
+        self.assertEqual(len(td), 0)
+        self.assertEqual(g(d), "mutablemapping")
+        self.assertEqual(len(td), 1)
+        self.assertEqual(g(l), "list")
+        self.assertEqual(len(td), 2)
+        g.register(dict, lambda arg: "dict")
+        self.assertEqual(g(d), "dict")
+        self.assertEqual(g(l), "list")
+        g._clear_cache()
+        self.assertEqual(len(td), 0)
+        functools.WeakKeyDictionary = _orig_wkd
+
+
+if __name__ == '__main__':
+    unittest.main()


More information about the pypy-commit mailing list