[Python-checkins] bpo-46477: [Enum] ensure Flag subclasses have correct bitwise methods (GH-30816)
ethanfurman
webhook-mailer at python.org
Sat Jan 22 21:27:58 EST 2022
https://github.com/python/cpython/commit/353e3b2820bed38da16140276786eef9ba33d3bd
commit: 353e3b2820bed38da16140276786eef9ba33d3bd
branch: main
author: Ethan Furman <ethan at stoneleaf.us>
committer: ethanfurman <ethan at stoneleaf.us>
date: 2022-01-22T18:27:52-08:00
summary:
bpo-46477: [Enum] ensure Flag subclasses have correct bitwise methods (GH-30816)
files:
M Lib/enum.py
M Lib/test/test_enum.py
diff --git a/Lib/enum.py b/Lib/enum.py
index b510467731293..85245c95f9a9c 100644
--- a/Lib/enum.py
+++ b/Lib/enum.py
@@ -618,6 +618,18 @@ def __new__(metacls, cls, bases, classdict, *, boundary=None, _simple=False, **k
if name not in classdict:
setattr(enum_class, name, getattr(first_enum, name))
#
+ # for Flag, add __or__, __and__, __xor__, and __invert__
+ if Flag is not None and issubclass(enum_class, Flag):
+ for name in (
+ '__or__', '__and__', '__xor__',
+ '__ror__', '__rand__', '__rxor__',
+ '__invert__'
+ ):
+ if name not in classdict:
+ enum_method = getattr(Flag, name)
+ setattr(enum_class, name, enum_method)
+ classdict[name] = enum_method
+ #
# replace any other __new__ with our own (as long as Enum is not None,
# anyway) -- again, this is to support pickle
if Enum is not None:
@@ -1466,44 +1478,10 @@ def __str__(self):
def __bool__(self):
return bool(self._value_)
- def __or__(self, other):
- if not isinstance(other, self.__class__):
- return NotImplemented
- return self.__class__(self._value_ | other._value_)
-
- def __and__(self, other):
- if not isinstance(other, self.__class__):
- return NotImplemented
- return self.__class__(self._value_ & other._value_)
-
- def __xor__(self, other):
- if not isinstance(other, self.__class__):
- return NotImplemented
- return self.__class__(self._value_ ^ other._value_)
-
- def __invert__(self):
- if self._inverted_ is None:
- if self._boundary_ is KEEP:
- # use all bits
- self._inverted_ = self.__class__(~self._value_)
- else:
- # calculate flags not in this member
- self._inverted_ = self.__class__(self._flag_mask_ ^ self._value_)
- if isinstance(self._inverted_, self.__class__):
- self._inverted_._inverted_ = self
- return self._inverted_
-
-
-class IntFlag(int, ReprEnum, Flag, boundary=EJECT):
- """
- Support for integer-based Flags
- """
-
-
def __or__(self, other):
if isinstance(other, self.__class__):
other = other._value_
- elif isinstance(other, int):
+ elif self._member_type_ is not object and isinstance(other, self._member_type_):
other = other
else:
return NotImplemented
@@ -1513,7 +1491,7 @@ def __or__(self, other):
def __and__(self, other):
if isinstance(other, self.__class__):
other = other._value_
- elif isinstance(other, int):
+ elif self._member_type_ is not object and isinstance(other, self._member_type_):
other = other
else:
return NotImplemented
@@ -1523,17 +1501,34 @@ def __and__(self, other):
def __xor__(self, other):
if isinstance(other, self.__class__):
other = other._value_
- elif isinstance(other, int):
+ elif self._member_type_ is not object and isinstance(other, self._member_type_):
other = other
else:
return NotImplemented
value = self._value_
return self.__class__(value ^ other)
- __ror__ = __or__
+ def __invert__(self):
+ if self._inverted_ is None:
+ if self._boundary_ is KEEP:
+ # use all bits
+ self._inverted_ = self.__class__(~self._value_)
+ else:
+ # calculate flags not in this member
+ self._inverted_ = self.__class__(self._flag_mask_ ^ self._value_)
+ if isinstance(self._inverted_, self.__class__):
+ self._inverted_._inverted_ = self
+ return self._inverted_
+
__rand__ = __and__
+ __ror__ = __or__
__rxor__ = __xor__
- __invert__ = Flag.__invert__
+
+
+class IntFlag(int, ReprEnum, Flag, boundary=EJECT):
+ """
+ Support for integer-based Flags
+ """
def _high_bit(value):
@@ -1662,6 +1657,13 @@ def convert_class(cls):
body['_flag_mask_'] = None
body['_all_bits_'] = None
body['_inverted_'] = None
+ body['__or__'] = Flag.__or__
+ body['__xor__'] = Flag.__xor__
+ body['__and__'] = Flag.__and__
+ body['__ror__'] = Flag.__ror__
+ body['__rxor__'] = Flag.__rxor__
+ body['__rand__'] = Flag.__rand__
+ body['__invert__'] = Flag.__invert__
for name, obj in cls.__dict__.items():
if name in ('__dict__', '__weakref__'):
continue
diff --git a/Lib/test/test_enum.py b/Lib/test/test_enum.py
index d7ce8add78715..b8a7914355c53 100644
--- a/Lib/test/test_enum.py
+++ b/Lib/test/test_enum.py
@@ -2496,6 +2496,13 @@ def __new__(cls, val):
self.assertEqual(Some.x.value, 1)
self.assertEqual(Some.y.value, 2)
+ def test_custom_flag_bitwise(self):
+ class MyIntFlag(int, Flag):
+ ONE = 1
+ TWO = 2
+ FOUR = 4
+ self.assertTrue(isinstance(MyIntFlag.ONE | MyIntFlag.TWO, MyIntFlag), MyIntFlag.ONE | MyIntFlag.TWO)
+ self.assertTrue(isinstance(MyIntFlag.ONE | 2, MyIntFlag))
class TestOrder(unittest.TestCase):
"test usage of the `_order_` attribute"
More information about the Python-checkins
mailing list