[Python-checkins] bpo-42567: [Enum] call __init_subclass__ after members are added (GH-23714)

ethanfurman webhook-mailer at python.org
Wed Dec 9 19:41:31 EST 2020


https://github.com/python/cpython/commit/6bd94de168b58ac9358277ed6f200490ab26c174
commit: 6bd94de168b58ac9358277ed6f200490ab26c174
branch: master
author: Ethan Furman <ethan at stoneleaf.us>
committer: ethanfurman <ethan at stoneleaf.us>
date: 2020-12-09T16:41:22-08:00
summary:

bpo-42567: [Enum] call __init_subclass__ after members are added (GH-23714)

When creating an Enum, type.__new__ calls __init_subclass__, but at that point the members have not been added.

This patch suppresses the initial call, then manually calls the ancestor __init_subclass__ before returning the new Enum class.

files:
A Misc/NEWS.d/next/Library/2020-12-08-22-43-35.bpo-42678.ba9ktU.rst
M Lib/enum.py
M Lib/test/test_enum.py

diff --git a/Lib/enum.py b/Lib/enum.py
index f6c7e8b233413..83e050e1e41cd 100644
--- a/Lib/enum.py
+++ b/Lib/enum.py
@@ -9,6 +9,14 @@
         ]
 
 
+class _NoInitSubclass:
+    """
+    temporary base class to suppress calling __init_subclass__
+    """
+    @classmethod
+    def __init_subclass__(cls, **kwds):
+        pass
+
 def _is_descriptor(obj):
     """
     Returns True if obj is a descriptor, False otherwise.
@@ -157,7 +165,7 @@ def __prepare__(metacls, cls, bases):
                     )
         return enum_dict
 
-    def __new__(metacls, cls, bases, classdict):
+    def __new__(metacls, cls, bases, classdict, **kwds):
         # an Enum class is final once enumeration items have been defined; it
         # cannot be mixed with other types (int, float, etc.) if it has an
         # inherited __new__ unless a new __new__ is defined (or the resulting
@@ -192,8 +200,22 @@ def __new__(metacls, cls, bases, classdict):
         if '__doc__' not in classdict:
             classdict['__doc__'] = 'An enumeration.'
 
+        # postpone calling __init_subclass__
+        if '__init_subclass__' in classdict and classdict['__init_subclass__'] is None:
+            raise TypeError('%s.__init_subclass__ cannot be None')
+        # remove current __init_subclass__ so previous one can be found with getattr
+        new_init_subclass = classdict.pop('__init_subclass__', None)
         # create our new Enum type
-        enum_class = super().__new__(metacls, cls, bases, classdict)
+        if bases:
+            bases = (_NoInitSubclass, ) + bases
+            enum_class = type.__new__(metacls, cls, bases, classdict)
+            enum_class.__bases__ = enum_class.__bases__[1:] #or (object, )
+        else:
+            enum_class = type.__new__(metacls, cls, bases, classdict)
+        old_init_subclass = getattr(enum_class, '__init_subclass__', None)
+        # and restore the new one (if there was one)
+        if new_init_subclass is not None:
+            enum_class.__init_subclass__ = classmethod(new_init_subclass)
         enum_class._member_names_ = []               # names in definition order
         enum_class._member_map_ = {}                 # name->value map
         enum_class._member_type_ = member_type
@@ -305,6 +327,9 @@ def __new__(metacls, cls, bases, classdict):
             if _order_ != enum_class._member_names_:
                 raise TypeError('member order does not match _order_')
 
+        # finally, call parents' __init_subclass__
+        if Enum is not None and old_init_subclass is not None:
+            old_init_subclass(**kwds)
         return enum_class
 
     def __bool__(self):
@@ -682,6 +707,9 @@ def _generate_next_value_(name, start, count, last_values):
         else:
             return start
 
+    def __init_subclass__(cls, **kwds):
+        super().__init_subclass__(**kwds)
+
     @classmethod
     def _missing_(cls, value):
         return None
diff --git a/Lib/test/test_enum.py b/Lib/test/test_enum.py
index ab4b52f785273..20bc5b3c750d6 100644
--- a/Lib/test/test_enum.py
+++ b/Lib/test/test_enum.py
@@ -2117,6 +2117,43 @@ class ThirdFailedStrEnum(StrEnum):
             class ThirdFailedStrEnum(StrEnum):
                 one = '1'
                 two = b'2', 'ascii', 9
+                
+    def test_init_subclass(self):
+        class MyEnum(Enum):
+            def __init_subclass__(cls, **kwds):
+                super(MyEnum, cls).__init_subclass__(**kwds)
+                self.assertFalse(cls.__dict__.get('_test', False))
+                cls._test1 = 'MyEnum'
+        #
+        class TheirEnum(MyEnum):
+            def __init_subclass__(cls, **kwds):
+                super().__init_subclass__(**kwds)
+                cls._test2 = 'TheirEnum'
+        class WhoseEnum(TheirEnum):
+            def __init_subclass__(cls, **kwds):
+                pass
+        class NoEnum(WhoseEnum):
+            ONE = 1
+        self.assertEqual(TheirEnum.__dict__['_test1'], 'MyEnum')
+        self.assertEqual(WhoseEnum.__dict__['_test1'], 'MyEnum')
+        self.assertEqual(WhoseEnum.__dict__['_test2'], 'TheirEnum')
+        self.assertFalse(NoEnum.__dict__.get('_test1', False))
+        self.assertFalse(NoEnum.__dict__.get('_test2', False))
+        #
+        class OurEnum(MyEnum):
+            def __init_subclass__(cls, **kwds):
+                cls._test2 = 'OurEnum'
+        class WhereEnum(OurEnum):
+            def __init_subclass__(cls, **kwds):
+                pass
+        class NeverEnum(WhereEnum):
+            ONE = 'one'
+        self.assertEqual(OurEnum.__dict__['_test1'], 'MyEnum')
+        self.assertFalse(WhereEnum.__dict__.get('_test1', False))
+        self.assertEqual(WhereEnum.__dict__['_test2'], 'OurEnum')
+        self.assertFalse(NeverEnum.__dict__.get('_test1', False))
+        self.assertFalse(NeverEnum.__dict__.get('_test2', False))
+
 
 class TestOrder(unittest.TestCase):
 
@@ -2573,6 +2610,42 @@ def cycle_enum():
                 'at least one thread failed while creating composite members')
         self.assertEqual(256, len(seen), 'too many composite members created')
 
+    def test_init_subclass(self):
+        class MyEnum(Flag):
+            def __init_subclass__(cls, **kwds):
+                super().__init_subclass__(**kwds)
+                self.assertFalse(cls.__dict__.get('_test', False))
+                cls._test1 = 'MyEnum'
+        #
+        class TheirEnum(MyEnum):
+            def __init_subclass__(cls, **kwds):
+                super(TheirEnum, cls).__init_subclass__(**kwds)
+                cls._test2 = 'TheirEnum'
+        class WhoseEnum(TheirEnum):
+            def __init_subclass__(cls, **kwds):
+                pass
+        class NoEnum(WhoseEnum):
+            ONE = 1
+        self.assertEqual(TheirEnum.__dict__['_test1'], 'MyEnum')
+        self.assertEqual(WhoseEnum.__dict__['_test1'], 'MyEnum')
+        self.assertEqual(WhoseEnum.__dict__['_test2'], 'TheirEnum')
+        self.assertFalse(NoEnum.__dict__.get('_test1', False))
+        self.assertFalse(NoEnum.__dict__.get('_test2', False))
+        #
+        class OurEnum(MyEnum):
+            def __init_subclass__(cls, **kwds):
+                cls._test2 = 'OurEnum'
+        class WhereEnum(OurEnum):
+            def __init_subclass__(cls, **kwds):
+                pass
+        class NeverEnum(WhereEnum):
+            ONE = 1
+        self.assertEqual(OurEnum.__dict__['_test1'], 'MyEnum')
+        self.assertFalse(WhereEnum.__dict__.get('_test1', False))
+        self.assertEqual(WhereEnum.__dict__['_test2'], 'OurEnum')
+        self.assertFalse(NeverEnum.__dict__.get('_test1', False))
+        self.assertFalse(NeverEnum.__dict__.get('_test2', False))
+
 
 class TestIntFlag(unittest.TestCase):
     """Tests of the IntFlags."""
diff --git a/Misc/NEWS.d/next/Library/2020-12-08-22-43-35.bpo-42678.ba9ktU.rst b/Misc/NEWS.d/next/Library/2020-12-08-22-43-35.bpo-42678.ba9ktU.rst
new file mode 100644
index 0000000000000..7c94cdf40dd4c
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2020-12-08-22-43-35.bpo-42678.ba9ktU.rst
@@ -0,0 +1 @@
+`Enum`: call `__init_subclass__` after members have been added



More information about the Python-checkins mailing list