[Python-checkins] bpo-46477: [Enum] ensure Flag subclasses have correct bitwise methods (GH-30816)

ethanfurman webhook-mailer at python.org
Sat Jan 22 21:27:58 EST 2022


https://github.com/python/cpython/commit/353e3b2820bed38da16140276786eef9ba33d3bd
commit: 353e3b2820bed38da16140276786eef9ba33d3bd
branch: main
author: Ethan Furman <ethan at stoneleaf.us>
committer: ethanfurman <ethan at stoneleaf.us>
date: 2022-01-22T18:27:52-08:00
summary:

bpo-46477: [Enum] ensure Flag subclasses have correct bitwise methods (GH-30816)

files:
M Lib/enum.py
M Lib/test/test_enum.py

diff --git a/Lib/enum.py b/Lib/enum.py
index b510467731293..85245c95f9a9c 100644
--- a/Lib/enum.py
+++ b/Lib/enum.py
@@ -618,6 +618,18 @@ def __new__(metacls, cls, bases, classdict, *, boundary=None, _simple=False, **k
             if name not in classdict:
                 setattr(enum_class, name, getattr(first_enum, name))
         #
+        # for Flag, add __or__, __and__, __xor__, and __invert__
+        if Flag is not None and issubclass(enum_class, Flag):
+            for name in (
+                    '__or__', '__and__', '__xor__',
+                    '__ror__', '__rand__', '__rxor__',
+                    '__invert__'
+                ):
+                if name not in classdict:
+                    enum_method = getattr(Flag, name)
+                    setattr(enum_class, name, enum_method)
+                    classdict[name] = enum_method
+        #
         # replace any other __new__ with our own (as long as Enum is not None,
         # anyway) -- again, this is to support pickle
         if Enum is not None:
@@ -1466,44 +1478,10 @@ def __str__(self):
     def __bool__(self):
         return bool(self._value_)
 
-    def __or__(self, other):
-        if not isinstance(other, self.__class__):
-            return NotImplemented
-        return self.__class__(self._value_ | other._value_)
-
-    def __and__(self, other):
-        if not isinstance(other, self.__class__):
-            return NotImplemented
-        return self.__class__(self._value_ & other._value_)
-
-    def __xor__(self, other):
-        if not isinstance(other, self.__class__):
-            return NotImplemented
-        return self.__class__(self._value_ ^ other._value_)
-
-    def __invert__(self):
-        if self._inverted_ is None:
-            if self._boundary_ is KEEP:
-                # use all bits
-                self._inverted_ = self.__class__(~self._value_)
-            else:
-                # calculate flags not in this member
-                self._inverted_ = self.__class__(self._flag_mask_ ^ self._value_)
-            if isinstance(self._inverted_, self.__class__):
-                self._inverted_._inverted_ = self
-        return self._inverted_
-
-
-class IntFlag(int, ReprEnum, Flag, boundary=EJECT):
-    """
-    Support for integer-based Flags
-    """
-
-
     def __or__(self, other):
         if isinstance(other, self.__class__):
             other = other._value_
-        elif isinstance(other, int):
+        elif self._member_type_ is not object and isinstance(other, self._member_type_):
             other = other
         else:
             return NotImplemented
@@ -1513,7 +1491,7 @@ def __or__(self, other):
     def __and__(self, other):
         if isinstance(other, self.__class__):
             other = other._value_
-        elif isinstance(other, int):
+        elif self._member_type_ is not object and isinstance(other, self._member_type_):
             other = other
         else:
             return NotImplemented
@@ -1523,17 +1501,34 @@ def __and__(self, other):
     def __xor__(self, other):
         if isinstance(other, self.__class__):
             other = other._value_
-        elif isinstance(other, int):
+        elif self._member_type_ is not object and isinstance(other, self._member_type_):
             other = other
         else:
             return NotImplemented
         value = self._value_
         return self.__class__(value ^ other)
 
-    __ror__ = __or__
+    def __invert__(self):
+        if self._inverted_ is None:
+            if self._boundary_ is KEEP:
+                # use all bits
+                self._inverted_ = self.__class__(~self._value_)
+            else:
+                # calculate flags not in this member
+                self._inverted_ = self.__class__(self._flag_mask_ ^ self._value_)
+            if isinstance(self._inverted_, self.__class__):
+                self._inverted_._inverted_ = self
+        return self._inverted_
+
     __rand__ = __and__
+    __ror__ = __or__
     __rxor__ = __xor__
-    __invert__ = Flag.__invert__
+
+
+class IntFlag(int, ReprEnum, Flag, boundary=EJECT):
+    """
+    Support for integer-based Flags
+    """
 
 
 def _high_bit(value):
@@ -1662,6 +1657,13 @@ def convert_class(cls):
             body['_flag_mask_'] = None
             body['_all_bits_'] = None
             body['_inverted_'] = None
+            body['__or__'] = Flag.__or__
+            body['__xor__'] = Flag.__xor__
+            body['__and__'] = Flag.__and__
+            body['__ror__'] = Flag.__ror__
+            body['__rxor__'] = Flag.__rxor__
+            body['__rand__'] = Flag.__rand__
+            body['__invert__'] = Flag.__invert__
         for name, obj in cls.__dict__.items():
             if name in ('__dict__', '__weakref__'):
                 continue
diff --git a/Lib/test/test_enum.py b/Lib/test/test_enum.py
index d7ce8add78715..b8a7914355c53 100644
--- a/Lib/test/test_enum.py
+++ b/Lib/test/test_enum.py
@@ -2496,6 +2496,13 @@ def __new__(cls, val):
         self.assertEqual(Some.x.value, 1)
         self.assertEqual(Some.y.value, 2)
 
+    def test_custom_flag_bitwise(self):
+        class MyIntFlag(int, Flag):
+            ONE = 1
+            TWO = 2
+            FOUR = 4
+        self.assertTrue(isinstance(MyIntFlag.ONE | MyIntFlag.TWO, MyIntFlag), MyIntFlag.ONE | MyIntFlag.TWO)
+        self.assertTrue(isinstance(MyIntFlag.ONE | 2, MyIntFlag))
 
 class TestOrder(unittest.TestCase):
     "test usage of the `_order_` attribute"



More information about the Python-checkins mailing list