[Python-checkins] bpo-44863: Allow generic typing.TypedDict (#27663)

JelleZijlstra webhook-mailer at python.org
Tue May 3 09:21:50 EDT 2022


https://github.com/python/cpython/commit/f6f36cc26978e3036a6c5c068fca5b8135f27ef3
commit: f6f36cc26978e3036a6c5c068fca5b8135f27ef3
branch: main
author: Samodya Abey <379594+sransara at users.noreply.github.com>
committer: JelleZijlstra <jelle.zijlstra at gmail.com>
date: 2022-05-03T07:21:42-06:00
summary:

bpo-44863: Allow generic typing.TypedDict (#27663)

Co-authored-by: Ken Jin <28750310+Fidget-Spinner at users.noreply.github.com>
Co-authored-by: Yurii Karabas <1998uriyyo at gmail.com>
Co-authored-by: Jelle Zijlstra <jelle.zijlstra at gmail.com>
Co-authored-by: Serhiy Storchaka <storchaka at gmail.com>

files:
A Misc/NEWS.d/next/Library/2021-09-03-07-56-48.bpo-44863.udgz95.rst
M Doc/library/typing.rst
M Doc/whatsnew/3.11.rst
M Lib/test/_typed_dict_helper.py
M Lib/test/test_typing.py
M Lib/typing.py

diff --git a/Doc/library/typing.rst b/Doc/library/typing.rst
index 05ac05767f32e..c9fc944fdeb56 100644
--- a/Doc/library/typing.rst
+++ b/Doc/library/typing.rst
@@ -1738,7 +1738,7 @@ These are not used in annotations. They are building blocks for declaring types.
           z: int
 
    A ``TypedDict`` cannot inherit from a non-TypedDict class,
-   notably including :class:`Generic`. For example::
+   except for :class:`Generic`. For example::
 
       class X(TypedDict):
           x: int
@@ -1755,6 +1755,12 @@ These are not used in annotations. They are building blocks for declaring types.
       T = TypeVar('T')
       class XT(X, Generic[T]): pass  # raises TypeError
 
+   A ``TypedDict`` can be generic::
+
+      class Group(TypedDict, Generic[T]):
+          key: T
+          group: list[T]
+
    A ``TypedDict`` can be introspected via annotations dicts
    (see :ref:`annotations-howto` for more information on annotations best practices),
    :attr:`__total__`, :attr:`__required_keys__`, and :attr:`__optional_keys__`.
@@ -1802,6 +1808,9 @@ These are not used in annotations. They are building blocks for declaring types.
 
    .. versionadded:: 3.8
 
+   .. versionchanged:: 3.11
+      Added support for generic ``TypedDict``\ s.
+
 Generic concrete collections
 ----------------------------
 
diff --git a/Doc/whatsnew/3.11.rst b/Doc/whatsnew/3.11.rst
index c19f158f57a31..2f32b56423de7 100644
--- a/Doc/whatsnew/3.11.rst
+++ b/Doc/whatsnew/3.11.rst
@@ -715,7 +715,10 @@ For major changes, see :ref:`new-feat-related-type-hints-311`.
   to clear all registered overloads of a function.
   (Contributed by Jelle Zijlstra in :gh:`89263`.)
 
-* :class:`~typing.NamedTuple` subclasses can be generic.
+* :data:`typing.TypedDict` subclasses can now be generic. (Contributed by
+  Samodya Abey in :gh:`89026`.)
+
+* :class:`~typing.NamedTuple` subclasses can now be generic.
   (Contributed by Serhiy Storchaka in :issue:`43923`.)
 
 
diff --git a/Lib/test/_typed_dict_helper.py b/Lib/test/_typed_dict_helper.py
index 3328330c995b3..9df0ede7d40ee 100644
--- a/Lib/test/_typed_dict_helper.py
+++ b/Lib/test/_typed_dict_helper.py
@@ -13,12 +13,18 @@ class Bar(_typed_dict_helper.Foo, total=False):
 
 from __future__ import annotations
 
-from typing import Annotated, Optional, Required, TypedDict
+from typing import Annotated, Generic, Optional, Required, TypedDict, TypeVar
+
 
 OptionalIntType = Optional[int]
 
 class Foo(TypedDict):
     a: OptionalIntType
 
+T = TypeVar("T")
+
+class FooGeneric(TypedDict, Generic[T]):
+    a: Optional[T]
+
 class VeryAnnotated(TypedDict, total=False):
     a: Annotated[Annotated[Annotated[Required[int], "a"], "b"], "c"]
diff --git a/Lib/test/test_typing.py b/Lib/test/test_typing.py
index 08f7d0211eafb..55e18c08537df 100644
--- a/Lib/test/test_typing.py
+++ b/Lib/test/test_typing.py
@@ -4530,9 +4530,16 @@ class Point2D(TypedDict):
     x: int
     y: int
 
+class Point2DGeneric(Generic[T], TypedDict):
+    a: T
+    b: T
+
 class Bar(_typed_dict_helper.Foo, total=False):
     b: int
 
+class BarGeneric(_typed_dict_helper.FooGeneric[T], total=False):
+    b: int
+
 class LabelPoint2D(Point2D, Label): ...
 
 class Options(TypedDict, total=False):
@@ -5890,6 +5897,17 @@ def test_pickle(self):
             EmpDnew = pickle.loads(ZZ)
             self.assertEqual(EmpDnew({'name': 'jane', 'id': 37}), jane)
 
+    def test_pickle_generic(self):
+        point = Point2DGeneric(a=5.0, b=3.0)
+        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
+            z = pickle.dumps(point, proto)
+            point2 = pickle.loads(z)
+            self.assertEqual(point2, point)
+            self.assertEqual(point2, {'a': 5.0, 'b': 3.0})
+            ZZ = pickle.dumps(Point2DGeneric, proto)
+            Point2DGenericNew = pickle.loads(ZZ)
+            self.assertEqual(Point2DGenericNew({'a': 5.0, 'b': 3.0}), point)
+
     def test_optional(self):
         EmpD = TypedDict('EmpD', {'name': str, 'id': int})
 
@@ -6074,6 +6092,124 @@ def test_get_type_hints(self):
             {'a': typing.Optional[int], 'b': int}
         )
 
+    def test_get_type_hints_generic(self):
+        self.assertEqual(
+            get_type_hints(BarGeneric),
+            {'a': typing.Optional[T], 'b': int}
+        )
+
+        class FooBarGeneric(BarGeneric[int]):
+            c: str
+
+        self.assertEqual(
+            get_type_hints(FooBarGeneric),
+            {'a': typing.Optional[T], 'b': int, 'c': str}
+        )
+
+    def test_generic_inheritance(self):
+        class A(TypedDict, Generic[T]):
+            a: T
+
+        self.assertEqual(A.__bases__, (Generic, dict))
+        self.assertEqual(A.__orig_bases__, (TypedDict, Generic[T]))
+        self.assertEqual(A.__mro__, (A, Generic, dict, object))
+        self.assertEqual(A.__parameters__, (T,))
+        self.assertEqual(A[str].__parameters__, ())
+        self.assertEqual(A[str].__args__, (str,))
+
+        class A2(Generic[T], TypedDict):
+            a: T
+
+        self.assertEqual(A2.__bases__, (Generic, dict))
+        self.assertEqual(A2.__orig_bases__, (Generic[T], TypedDict))
+        self.assertEqual(A2.__mro__, (A2, Generic, dict, object))
+        self.assertEqual(A2.__parameters__, (T,))
+        self.assertEqual(A2[str].__parameters__, ())
+        self.assertEqual(A2[str].__args__, (str,))
+
+        class B(A[KT], total=False):
+            b: KT
+
+        self.assertEqual(B.__bases__, (Generic, dict))
+        self.assertEqual(B.__orig_bases__, (A[KT],))
+        self.assertEqual(B.__mro__, (B, Generic, dict, object))
+        self.assertEqual(B.__parameters__, (KT,))
+        self.assertEqual(B.__total__, False)
+        self.assertEqual(B.__optional_keys__, frozenset(['b']))
+        self.assertEqual(B.__required_keys__, frozenset(['a']))
+
+        self.assertEqual(B[str].__parameters__, ())
+        self.assertEqual(B[str].__args__, (str,))
+        self.assertEqual(B[str].__origin__, B)
+
+        class C(B[int]):
+            c: int
+
+        self.assertEqual(C.__bases__, (Generic, dict))
+        self.assertEqual(C.__orig_bases__, (B[int],))
+        self.assertEqual(C.__mro__, (C, Generic, dict, object))
+        self.assertEqual(C.__parameters__, ())
+        self.assertEqual(C.__total__, True)
+        self.assertEqual(C.__optional_keys__, frozenset(['b']))
+        self.assertEqual(C.__required_keys__, frozenset(['a', 'c']))
+        assert C.__annotations__ == {
+            'a': T,
+            'b': KT,
+            'c': int,
+        }
+        with self.assertRaises(TypeError):
+            C[str]
+
+
+        class Point3D(Point2DGeneric[T], Generic[T, KT]):
+            c: KT
+
+        self.assertEqual(Point3D.__bases__, (Generic, dict))
+        self.assertEqual(Point3D.__orig_bases__, (Point2DGeneric[T], Generic[T, KT]))
+        self.assertEqual(Point3D.__mro__, (Point3D, Generic, dict, object))
+        self.assertEqual(Point3D.__parameters__, (T, KT))
+        self.assertEqual(Point3D.__total__, True)
+        self.assertEqual(Point3D.__optional_keys__, frozenset())
+        self.assertEqual(Point3D.__required_keys__, frozenset(['a', 'b', 'c']))
+        assert Point3D.__annotations__ == {
+            'a': T,
+            'b': T,
+            'c': KT,
+        }
+        self.assertEqual(Point3D[int, str].__origin__, Point3D)
+
+        with self.assertRaises(TypeError):
+            Point3D[int]
+
+        with self.assertRaises(TypeError):
+            class Point3D(Point2DGeneric[T], Generic[KT]):
+                c: KT
+
+    def test_implicit_any_inheritance(self):
+        class A(TypedDict, Generic[T]):
+            a: T
+
+        class B(A[KT], total=False):
+            b: KT
+
+        class WithImplicitAny(B):
+            c: int
+
+        self.assertEqual(WithImplicitAny.__bases__, (Generic, dict,))
+        self.assertEqual(WithImplicitAny.__mro__, (WithImplicitAny, Generic, dict, object))
+        # Consistent with GenericTests.test_implicit_any
+        self.assertEqual(WithImplicitAny.__parameters__, ())
+        self.assertEqual(WithImplicitAny.__total__, True)
+        self.assertEqual(WithImplicitAny.__optional_keys__, frozenset(['b']))
+        self.assertEqual(WithImplicitAny.__required_keys__, frozenset(['a', 'c']))
+        assert WithImplicitAny.__annotations__ == {
+            'a': T,
+            'b': KT,
+            'c': int,
+        }
+        with self.assertRaises(TypeError):
+            WithImplicitAny[str]
+
     def test_non_generic_subscript(self):
         # For backward compatibility, subscription works
         # on arbitrary TypedDict types.
diff --git a/Lib/typing.py b/Lib/typing.py
index 84f0fd1a8a946..bdc14e39033dc 100644
--- a/Lib/typing.py
+++ b/Lib/typing.py
@@ -1796,7 +1796,9 @@ def __init_subclass__(cls, *args, **kwargs):
         if '__orig_bases__' in cls.__dict__:
             error = Generic in cls.__orig_bases__
         else:
-            error = Generic in cls.__bases__ and cls.__name__ != 'Protocol'
+            error = (Generic in cls.__bases__ and
+                        cls.__name__ != 'Protocol' and
+                        type(cls) != _TypedDictMeta)
         if error:
             raise TypeError("Cannot inherit from plain Generic")
         if '__orig_bases__' in cls.__dict__:
@@ -2868,14 +2870,19 @@ def __new__(cls, name, bases, ns, total=True):
         Subclasses and instances of TypedDict return actual dictionaries.
         """
         for base in bases:
-            if type(base) is not _TypedDictMeta:
+            if type(base) is not _TypedDictMeta and base is not Generic:
                 raise TypeError('cannot inherit from both a TypedDict type '
                                 'and a non-TypedDict base class')
-        tp_dict = type.__new__(_TypedDictMeta, name, (dict,), ns)
+
+        if any(issubclass(b, Generic) for b in bases):
+            generic_base = (Generic,)
+        else:
+            generic_base = ()
+
+        tp_dict = type.__new__(_TypedDictMeta, name, (*generic_base, dict), ns)
 
         annotations = {}
         own_annotations = ns.get('__annotations__', {})
-        own_annotation_keys = set(own_annotations.keys())
         msg = "TypedDict('Name', {f0: t0, f1: t1, ...}); each t must be a type"
         own_annotations = {
             n: _type_check(tp, msg, module=tp_dict.__module__)
diff --git a/Misc/NEWS.d/next/Library/2021-09-03-07-56-48.bpo-44863.udgz95.rst b/Misc/NEWS.d/next/Library/2021-09-03-07-56-48.bpo-44863.udgz95.rst
new file mode 100644
index 0000000000000..130856587fd91
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2021-09-03-07-56-48.bpo-44863.udgz95.rst
@@ -0,0 +1,4 @@
+Allow :class:`~typing.TypedDict` subclasses to also include
+:class:`~typing.Generic` as a base class in class based syntax. Thereby allowing
+the user to define a generic ``TypedDict``, just like a user-defined generic but
+with ``TypedDict`` semantics.



More information about the Python-checkins mailing list