[Python-checkins] bpo-40397: Refactor typing._GenericAlias (GH-19719)

Serhiy Storchaka webhook-mailer at python.org
Wed May 6 21:09:40 EDT 2020


https://github.com/python/cpython/commit/c1c7d8ead9eb214a6149a43e31a3213c52448877
commit: c1c7d8ead9eb214a6149a43e31a3213c52448877
branch: master
author: Serhiy Storchaka <storchaka at gmail.com>
committer: GitHub <noreply at github.com>
date: 2020-05-07T04:09:33+03:00
summary:

bpo-40397: Refactor typing._GenericAlias (GH-19719)

Make the design more object-oriented.
Split _GenericAlias on two almost independent classes: for special
generic aliases like List and for parametrized generic aliases like List[int].
Add specialized subclasses for Callable, Callable[...], Tuple and Union[...].

files:
M Lib/typing.py

diff --git a/Lib/typing.py b/Lib/typing.py
index f3cd280a09e27..681ab6d21e0a3 100644
--- a/Lib/typing.py
+++ b/Lib/typing.py
@@ -181,34 +181,11 @@ def _collect_type_vars(types):
     for t in types:
         if isinstance(t, TypeVar) and t not in tvars:
             tvars.append(t)
-        if ((isinstance(t, _GenericAlias) and not t._special)
-                or isinstance(t, GenericAlias)):
+        if isinstance(t, (_GenericAlias, GenericAlias)):
             tvars.extend([t for t in t.__parameters__ if t not in tvars])
     return tuple(tvars)
 
 
-def _subs_tvars(tp, tvars, subs):
-    """Substitute type variables 'tvars' with substitutions 'subs'.
-    These two must have the same length.
-    """
-    if not isinstance(tp, (_GenericAlias, GenericAlias)):
-        return tp
-    new_args = list(tp.__args__)
-    for a, arg in enumerate(tp.__args__):
-        if isinstance(arg, TypeVar):
-            for i, tvar in enumerate(tvars):
-                if arg == tvar:
-                    new_args[a] = subs[i]
-        else:
-            new_args[a] = _subs_tvars(arg, tvars, subs)
-    if tp.__origin__ is Union:
-        return Union[tuple(new_args)]
-    if isinstance(tp, GenericAlias):
-        return GenericAlias(tp.__origin__, tuple(new_args))
-    else:
-        return tp.copy_with(tuple(new_args))
-
-
 def _check_generic(cls, parameters):
     """Check correct count for parameters of a generic cls (internal helper).
     This gives a nice error message in case of count mismatch.
@@ -229,7 +206,7 @@ def _remove_dups_flatten(parameters):
     # Flatten out Union[Union[...], ...].
     params = []
     for p in parameters:
-        if isinstance(p, _GenericAlias) and p.__origin__ is Union:
+        if isinstance(p, _UnionGenericAlias):
             params.extend(p.__args__)
         elif isinstance(p, tuple) and len(p) > 0 and p[0] is Union:
             params.extend(p[1:])
@@ -274,18 +251,14 @@ def _eval_type(t, globalns, localns):
     """
     if isinstance(t, ForwardRef):
         return t._evaluate(globalns, localns)
-    if isinstance(t, _GenericAlias):
+    if isinstance(t, (_GenericAlias, GenericAlias)):
         ev_args = tuple(_eval_type(a, globalns, localns) for a in t.__args__)
         if ev_args == t.__args__:
             return t
-        res = t.copy_with(ev_args)
-        res._special = t._special
-        return res
-    if isinstance(t, GenericAlias):
-        ev_args = tuple(_eval_type(a, globalns, localns) for a in t.__args__)
-        if ev_args == t.__args__:
-            return t
-        return GenericAlias(t.__origin__, ev_args)
+        if isinstance(t, GenericAlias):
+            return GenericAlias(t.__origin__, ev_args)
+        else:
+            return t.copy_with(ev_args)
     return t
 
 
@@ -300,6 +273,7 @@ def __init_subclass__(self, /, *args, **kwds):
 
 class _Immutable:
     """Mixin to indicate that object should not be copied."""
+    __slots__ = ()
 
     def __copy__(self):
         return self
@@ -446,7 +420,7 @@ def Union(self, parameters):
     parameters = _remove_dups_flatten(parameters)
     if len(parameters) == 1:
         return parameters[0]
-    return _GenericAlias(self, parameters)
+    return _UnionGenericAlias(self, parameters)
 
 @_SpecialForm
 def Optional(self, parameters):
@@ -579,7 +553,7 @@ def longest(x: A, y: A) -> A:
     """
 
     __slots__ = ('__name__', '__bound__', '__constraints__',
-                 '__covariant__', '__contravariant__')
+                 '__covariant__', '__contravariant__', '__dict__')
 
     def __init__(self, name, *constraints, bound=None,
                  covariant=False, contravariant=False):
@@ -629,23 +603,10 @@ def __reduce__(self):
 #   e.g., Dict[T, int].__args__ == (T, int).
 
 
-# Mapping from non-generic type names that have a generic alias in typing
-# but with a different name.
-_normalize_alias = {'list': 'List',
-                    'tuple': 'Tuple',
-                    'dict': 'Dict',
-                    'set': 'Set',
-                    'frozenset': 'FrozenSet',
-                    'deque': 'Deque',
-                    'defaultdict': 'DefaultDict',
-                    'type': 'Type',
-                    'Set': 'AbstractSet'}
-
 def _is_dunder(attr):
     return attr.startswith('__') and attr.endswith('__')
 
-
-class _GenericAlias(_Final, _root=True):
+class _BaseGenericAlias(_Final, _root=True):
     """The central part of internal API.
 
     This represents a generic version of type 'origin' with type arguments 'params'.
@@ -654,12 +615,8 @@ class _GenericAlias(_Final, _root=True):
     have 'name' always set. If 'inst' is False, then the alias can't be instantiated,
     this is used by e.g. typing.List and typing.Dict.
     """
-    def __init__(self, origin, params, *, inst=True, special=False, name=None):
+    def __init__(self, origin, params, *, inst=True, name=None):
         self._inst = inst
-        self._special = special
-        if special and name is None:
-            orig_name = origin.__name__
-            name = _normalize_alias.get(orig_name, orig_name)
         self._name = name
         if not isinstance(params, tuple):
             params = (params,)
@@ -671,68 +628,20 @@ def __init__(self, origin, params, *, inst=True, special=False, name=None):
         self.__slots__ = None  # This is not documented.
         if not name:
             self.__module__ = origin.__module__
-        if special:
-            self.__doc__ = f'A generic version of {origin.__module__}.{origin.__qualname__}'
-
-    @_tp_cache
-    def __getitem__(self, params):
-        if self.__origin__ in (Generic, Protocol):
-            # Can't subscript Generic[...] or Protocol[...].
-            raise TypeError(f"Cannot subscript already-subscripted {self}")
-        if not isinstance(params, tuple):
-            params = (params,)
-        msg = "Parameters to generic types must be types."
-        params = tuple(_type_check(p, msg) for p in params)
-        _check_generic(self, params)
-        return _subs_tvars(self, self.__parameters__, params)
-
-    def copy_with(self, params):
-        # We don't copy self._special.
-        return _GenericAlias(self.__origin__, params, name=self._name, inst=self._inst)
-
-    def __repr__(self):
-        if (self.__origin__ == Union and len(self.__args__) == 2
-                and type(None) in self.__args__):
-            if self.__args__[0] is not type(None):
-                arg = self.__args__[0]
-            else:
-                arg = self.__args__[1]
-            return (f'typing.Optional[{_type_repr(arg)}]')
-        if (self._name != 'Callable' or
-                len(self.__args__) == 2 and self.__args__[0] is Ellipsis):
-            if self._name:
-                name = 'typing.' + self._name
-            else:
-                name = _type_repr(self.__origin__)
-            if not self._special:
-                args = f'[{", ".join([_type_repr(a) for a in self.__args__])}]'
-            else:
-                args = ''
-            return (f'{name}{args}')
-        if self._special:
-            return 'typing.Callable'
-        return (f'typing.Callable'
-                f'[[{", ".join([_type_repr(a) for a in self.__args__[:-1]])}], '
-                f'{_type_repr(self.__args__[-1])}]')
 
     def __eq__(self, other):
-        if not isinstance(other, _GenericAlias):
+        if not isinstance(other, _BaseGenericAlias):
             return NotImplemented
-        if self.__origin__ != other.__origin__:
-            return False
-        if self.__origin__ is Union and other.__origin__ is Union:
-            return frozenset(self.__args__) == frozenset(other.__args__)
-        return self.__args__ == other.__args__
+        return (self.__origin__ == other.__origin__
+                and self.__args__ == other.__args__)
 
     def __hash__(self):
-        if self.__origin__ is Union:
-            return hash((Union, frozenset(self.__args__)))
         return hash((self.__origin__, self.__args__))
 
     def __call__(self, *args, **kwargs):
         if not self._inst:
             raise TypeError(f"Type {self._name} cannot be instantiated; "
-                            f"use {self._name.lower()}() instead")
+                            f"use {self.__origin__.__name__}() instead")
         result = self.__origin__(*args, **kwargs)
         try:
             result.__orig_class__ = self
@@ -741,23 +650,16 @@ def __call__(self, *args, **kwargs):
         return result
 
     def __mro_entries__(self, bases):
-        if self._name:  # generic version of an ABC or built-in class
-            res = []
-            if self.__origin__ not in bases:
-                res.append(self.__origin__)
-            i = bases.index(self)
-            if not any(isinstance(b, _GenericAlias) or issubclass(b, Generic)
-                       for b in bases[i+1:]):
-                res.append(Generic)
-            return tuple(res)
-        if self.__origin__ is Generic:
-            if Protocol in bases:
-                return ()
-            i = bases.index(self)
-            for b in bases[i+1:]:
-                if isinstance(b, _GenericAlias) and b is not self:
-                    return ()
-        return (self.__origin__,)
+        res = []
+        if self.__origin__ not in bases:
+            res.append(self.__origin__)
+        i = bases.index(self)
+        for b in bases[i+1:]:
+            if isinstance(b, _BaseGenericAlias) or issubclass(b, Generic):
+                break
+        else:
+            res.append(Generic)
+        return tuple(res)
 
     def __getattr__(self, attr):
         # We are careful for copy and pickle.
@@ -767,7 +669,7 @@ def __getattr__(self, attr):
         raise AttributeError(attr)
 
     def __setattr__(self, attr, val):
-        if _is_dunder(attr) or attr in ('_name', '_inst', '_special'):
+        if _is_dunder(attr) or attr in ('_name', '_inst'):
             super().__setattr__(attr, val)
         else:
             setattr(self.__origin__, attr, val)
@@ -776,39 +678,124 @@ def __instancecheck__(self, obj):
         return self.__subclasscheck__(type(obj))
 
     def __subclasscheck__(self, cls):
-        if self._special:
-            if not isinstance(cls, _GenericAlias):
-                return issubclass(cls, self.__origin__)
-            if cls._special:
-                return issubclass(cls.__origin__, self.__origin__)
         raise TypeError("Subscripted generics cannot be used with"
                         " class and instance checks")
 
-    def __reduce__(self):
-        if self._special:
-            return self._name
 
+class _GenericAlias(_BaseGenericAlias, _root=True):
+    @_tp_cache
+    def __getitem__(self, params):
+        if self.__origin__ in (Generic, Protocol):
+            # Can't subscript Generic[...] or Protocol[...].
+            raise TypeError(f"Cannot subscript already-subscripted {self}")
+        if not isinstance(params, tuple):
+            params = (params,)
+        msg = "Parameters to generic types must be types."
+        params = tuple(_type_check(p, msg) for p in params)
+        _check_generic(self, params)
+
+        subst = dict(zip(self.__parameters__, params))
+        new_args = []
+        for arg in self.__args__:
+            if isinstance(arg, TypeVar):
+                arg = subst[arg]
+            elif isinstance(arg, (_BaseGenericAlias, GenericAlias)):
+                subargs = tuple(subst[x] for x in arg.__parameters__)
+                arg = arg[subargs]
+            new_args.append(arg)
+        return self.copy_with(tuple(new_args))
+
+    def copy_with(self, params):
+        return self.__class__(self.__origin__, params, name=self._name, inst=self._inst)
+
+    def __repr__(self):
+        if self._name:
+            name = 'typing.' + self._name
+        else:
+            name = _type_repr(self.__origin__)
+        args = ", ".join([_type_repr(a) for a in self.__args__])
+        return f'{name}[{args}]'
+
+    def __reduce__(self):
         if self._name:
             origin = globals()[self._name]
         else:
             origin = self.__origin__
-        if (origin is Callable and
-            not (len(self.__args__) == 2 and self.__args__[0] is Ellipsis)):
-            args = list(self.__args__[:-1]), self.__args__[-1]
-        else:
-            args = tuple(self.__args__)
-            if len(args) == 1 and not isinstance(args[0], tuple):
-                args, = args
+        args = tuple(self.__args__)
+        if len(args) == 1 and not isinstance(args[0], tuple):
+            args, = args
         return operator.getitem, (origin, args)
 
+    def __mro_entries__(self, bases):
+        if self._name:  # generic version of an ABC or built-in class
+            return super().__mro_entries__(bases)
+        if self.__origin__ is Generic:
+            if Protocol in bases:
+                return ()
+            i = bases.index(self)
+            for b in bases[i+1:]:
+                if isinstance(b, _BaseGenericAlias) and b is not self:
+                    return ()
+        return (self.__origin__,)
+
+
+class _SpecialGenericAlias(_BaseGenericAlias, _root=True):
+    def __init__(self, origin, params, *, inst=True, name=None):
+        if name is None:
+            name = origin.__name__
+        super().__init__(origin, params, inst=inst, name=name)
+        self.__doc__ = f'A generic version of {origin.__module__}.{origin.__qualname__}'
+
+    @_tp_cache
+    def __getitem__(self, params):
+        if not isinstance(params, tuple):
+            params = (params,)
+        msg = "Parameters to generic types must be types."
+        params = tuple(_type_check(p, msg) for p in params)
+        _check_generic(self, params)
+        assert self.__args__ == self.__parameters__
+        return self.copy_with(params)
+
+    def copy_with(self, params):
+        return _GenericAlias(self.__origin__, params,
+                             name=self._name, inst=self._inst)
+
+    def __repr__(self):
+        return 'typing.' + self._name
+
+    def __subclasscheck__(self, cls):
+        if isinstance(cls, _SpecialGenericAlias):
+            return issubclass(cls.__origin__, self.__origin__)
+        if not isinstance(cls, _GenericAlias):
+            return issubclass(cls, self.__origin__)
+        return super().__subclasscheck__(cls)
+
+    def __reduce__(self):
+        return self._name
+
+
+class _CallableGenericAlias(_GenericAlias, _root=True):
+    def __repr__(self):
+        assert self._name == 'Callable'
+        if len(self.__args__) == 2 and self.__args__[0] is Ellipsis:
+            return super().__repr__()
+        return (f'typing.Callable'
+                f'[[{", ".join([_type_repr(a) for a in self.__args__[:-1]])}], '
+                f'{_type_repr(self.__args__[-1])}]')
+
+    def __reduce__(self):
+        args = self.__args__
+        if not (len(args) == 2 and args[0] is ...):
+            args = list(args[:-1]), args[-1]
+        return operator.getitem, (Callable, args)
+
+
+class _CallableType(_SpecialGenericAlias, _root=True):
+    def copy_with(self, params):
+        return _CallableGenericAlias(self.__origin__, params,
+                                     name=self._name, inst=self._inst)
 
-class _VariadicGenericAlias(_GenericAlias, _root=True):
-    """Same as _GenericAlias above but for variadic aliases. Currently,
-    this is used only by special internal aliases: Tuple and Callable.
-    """
     def __getitem__(self, params):
-        if self._name != 'Callable' or not self._special:
-            return self.__getitem_inner__(params)
         if not isinstance(params, tuple) or len(params) != 2:
             raise TypeError("Callable must be used as "
                             "Callable[[arg, ...], result].")
@@ -824,29 +811,53 @@ def __getitem__(self, params):
 
     @_tp_cache
     def __getitem_inner__(self, params):
-        if self.__origin__ is tuple and self._special:
-            if params == ():
-                return self.copy_with((_TypingEmpty,))
-            if not isinstance(params, tuple):
-                params = (params,)
-            if len(params) == 2 and params[1] is ...:
-                msg = "Tuple[t, ...]: t must be a type."
-                p = _type_check(params[0], msg)
-                return self.copy_with((p, _TypingEllipsis))
-            msg = "Tuple[t0, t1, ...]: each t must be a type."
-            params = tuple(_type_check(p, msg) for p in params)
-            return self.copy_with(params)
-        if self.__origin__ is collections.abc.Callable and self._special:
-            args, result = params
-            msg = "Callable[args, result]: result must be a type."
-            result = _type_check(result, msg)
-            if args is Ellipsis:
-                return self.copy_with((_TypingEllipsis, result))
-            msg = "Callable[[arg, ...], result]: each arg must be a type."
-            args = tuple(_type_check(arg, msg) for arg in args)
-            params = args + (result,)
-            return self.copy_with(params)
-        return super().__getitem__(params)
+        args, result = params
+        msg = "Callable[args, result]: result must be a type."
+        result = _type_check(result, msg)
+        if args is Ellipsis:
+            return self.copy_with((_TypingEllipsis, result))
+        msg = "Callable[[arg, ...], result]: each arg must be a type."
+        args = tuple(_type_check(arg, msg) for arg in args)
+        params = args + (result,)
+        return self.copy_with(params)
+
+
+class _TupleType(_SpecialGenericAlias, _root=True):
+    @_tp_cache
+    def __getitem__(self, params):
+        if params == ():
+            return self.copy_with((_TypingEmpty,))
+        if not isinstance(params, tuple):
+            params = (params,)
+        if len(params) == 2 and params[1] is ...:
+            msg = "Tuple[t, ...]: t must be a type."
+            p = _type_check(params[0], msg)
+            return self.copy_with((p, _TypingEllipsis))
+        msg = "Tuple[t0, t1, ...]: each t must be a type."
+        params = tuple(_type_check(p, msg) for p in params)
+        return self.copy_with(params)
+
+
+class _UnionGenericAlias(_GenericAlias, _root=True):
+    def copy_with(self, params):
+        return Union[params]
+
+    def __eq__(self, other):
+        if not isinstance(other, _UnionGenericAlias):
+            return NotImplemented
+        return set(self.__args__) == set(other.__args__)
+
+    def __hash__(self):
+        return hash(frozenset(self.__args__))
+
+    def __repr__(self):
+        args = self.__args__
+        if len(args) == 2:
+            if args[0] is type(None):
+                return f'typing.Optional[{_type_repr(args[1])}]'
+            elif args[1] is type(None):
+                return f'typing.Optional[{_type_repr(args[0])}]'
+        return super().__repr__()
 
 
 class Generic:
@@ -1162,9 +1173,8 @@ def __reduce__(self):
     def __eq__(self, other):
         if not isinstance(other, _AnnotatedAlias):
             return NotImplemented
-        if self.__origin__ != other.__origin__:
-            return False
-        return self.__metadata__ == other.__metadata__
+        return (self.__origin__ == other.__origin__
+                and self.__metadata__ == other.__metadata__)
 
     def __hash__(self):
         return hash((self.__origin__, self.__metadata__))
@@ -1380,9 +1390,7 @@ def _strip_annotations(t):
         stripped_args = tuple(_strip_annotations(a) for a in t.__args__)
         if stripped_args == t.__args__:
             return t
-        res = t.copy_with(stripped_args)
-        res._special = t._special
-        return res
+        return t.copy_with(stripped_args)
     if isinstance(t, GenericAlias):
         stripped_args = tuple(_strip_annotations(a) for a in t.__args__)
         if stripped_args == t.__args__:
@@ -1407,7 +1415,7 @@ def get_origin(tp):
     """
     if isinstance(tp, _AnnotatedAlias):
         return Annotated
-    if isinstance(tp, (_GenericAlias, GenericAlias)):
+    if isinstance(tp, (_BaseGenericAlias, GenericAlias)):
         return tp.__origin__
     if tp is Generic:
         return Generic
@@ -1427,7 +1435,7 @@ def get_args(tp):
     """
     if isinstance(tp, _AnnotatedAlias):
         return (tp.__origin__,) + tp.__metadata__
-    if isinstance(tp, _GenericAlias) and not tp._special:
+    if isinstance(tp, _GenericAlias):
         res = tp.__args__
         if tp.__origin__ is collections.abc.Callable and res[0] is not Ellipsis:
             res = (list(res[:-1]), res[-1])
@@ -1561,8 +1569,7 @@ class Other(Leaf):  # Error reported by type checker
 
 
 # Various ABCs mimicking those in collections.abc.
-def _alias(origin, params, inst=True):
-    return _GenericAlias(origin, params, special=True, inst=inst)
+_alias = _SpecialGenericAlias
 
 Hashable = _alias(collections.abc.Hashable, ())  # Not generic.
 Awaitable = _alias(collections.abc.Awaitable, T_co)
@@ -1575,7 +1582,7 @@ def _alias(origin, params, inst=True):
 Sized = _alias(collections.abc.Sized, ())  # Not generic.
 Container = _alias(collections.abc.Container, T_co)
 Collection = _alias(collections.abc.Collection, T_co)
-Callable = _VariadicGenericAlias(collections.abc.Callable, (), special=True)
+Callable = _CallableType(collections.abc.Callable, ())
 Callable.__doc__ = \
     """Callable type; Callable[[int], str] is a function of (int) -> str.
 
@@ -1586,7 +1593,7 @@ def _alias(origin, params, inst=True):
     There is no syntax to indicate optional or keyword arguments,
     such function types are rarely used as callback types.
     """
-AbstractSet = _alias(collections.abc.Set, T_co)
+AbstractSet = _alias(collections.abc.Set, T_co, name='AbstractSet')
 MutableSet = _alias(collections.abc.MutableSet, T)
 # NOTE: Mapping is only covariant in the value type.
 Mapping = _alias(collections.abc.Mapping, (KT, VT_co))
@@ -1594,7 +1601,7 @@ def _alias(origin, params, inst=True):
 Sequence = _alias(collections.abc.Sequence, T_co)
 MutableSequence = _alias(collections.abc.MutableSequence, T)
 ByteString = _alias(collections.abc.ByteString, ())  # Not generic
-Tuple = _VariadicGenericAlias(tuple, (), inst=False, special=True)
+Tuple = _TupleType(tuple, (), inst=False, name='Tuple')
 Tuple.__doc__ = \
     """Tuple type; Tuple[X, Y] is the cross-product type of X and Y.
 
@@ -1604,24 +1611,24 @@ def _alias(origin, params, inst=True):
 
     To specify a variable-length tuple of homogeneous type, use Tuple[T, ...].
     """
-List = _alias(list, T, inst=False)
-Deque = _alias(collections.deque, T)
-Set = _alias(set, T, inst=False)
-FrozenSet = _alias(frozenset, T_co, inst=False)
+List = _alias(list, T, inst=False, name='List')
+Deque = _alias(collections.deque, T, name='Deque')
+Set = _alias(set, T, inst=False, name='Set')
+FrozenSet = _alias(frozenset, T_co, inst=False, name='FrozenSet')
 MappingView = _alias(collections.abc.MappingView, T_co)
 KeysView = _alias(collections.abc.KeysView, KT)
 ItemsView = _alias(collections.abc.ItemsView, (KT, VT_co))
 ValuesView = _alias(collections.abc.ValuesView, VT_co)
-ContextManager = _alias(contextlib.AbstractContextManager, T_co)
-AsyncContextManager = _alias(contextlib.AbstractAsyncContextManager, T_co)
-Dict = _alias(dict, (KT, VT), inst=False)
-DefaultDict = _alias(collections.defaultdict, (KT, VT))
+ContextManager = _alias(contextlib.AbstractContextManager, T_co, name='ContextManager')
+AsyncContextManager = _alias(contextlib.AbstractAsyncContextManager, T_co, name='AsyncContextManager')
+Dict = _alias(dict, (KT, VT), inst=False, name='Dict')
+DefaultDict = _alias(collections.defaultdict, (KT, VT), name='DefaultDict')
 OrderedDict = _alias(collections.OrderedDict, (KT, VT))
 Counter = _alias(collections.Counter, T)
 ChainMap = _alias(collections.ChainMap, (KT, VT))
 Generator = _alias(collections.abc.Generator, (T_co, T_contra, V_co))
 AsyncGenerator = _alias(collections.abc.AsyncGenerator, (T_co, T_contra))
-Type = _alias(type, CT_co, inst=False)
+Type = _alias(type, CT_co, inst=False, name='Type')
 Type.__doc__ = \
     """A special construct usable to annotate class objects.
 



More information about the Python-checkins mailing list