[Python-checkins] cpython (merge 3.6 -> default): issue23591: fix flag decomposition and repr

ethan.furman python-checkins at python.org
Sun Sep 18 16:16:29 EDT 2016


https://hg.python.org/cpython/rev/7372c042e9a1
changeset:   103929:7372c042e9a1
parent:      103926:911070065e38
parent:      103928:b56290a80ff7
user:        Ethan Furman <ethan at stoneleaf.us>
date:        Sun Sep 18 13:16:04 2016 -0700
summary:
  issue23591: fix flag decomposition and repr

files:
  Doc/library/enum.rst  |   22 ++++
  Lib/enum.py           |  152 ++++++++++++++++++-----------
  Lib/test/test_enum.py |  105 +++++++++++++++-----
  3 files changed, 193 insertions(+), 86 deletions(-)


diff --git a/Doc/library/enum.rst b/Doc/library/enum.rst
--- a/Doc/library/enum.rst
+++ b/Doc/library/enum.rst
@@ -674,6 +674,8 @@
     ...     green = auto()
     ...     white = red | blue | green
     ...
+    >>> Color.white
+    <Color.white: 7>
 
 Giving a name to the "no flags set" condition does not change its boolean
 value::
@@ -1068,3 +1070,23 @@
     >>> dir(Planet.EARTH)
     ['__class__', '__doc__', '__module__', 'name', 'surface_gravity', 'value']
 
+
+Combining members of ``Flag``
+"""""""""""""""""""""""""""""
+
+If a combination of Flag members is not named, the :func:`repr` will include
+all named flags and all named combinations of flags that are in the value::
+
+    >>> class Color(Flag):
+    ...     red = auto()
+    ...     green = auto()
+    ...     blue = auto()
+    ...     magenta = red | blue
+    ...     yellow = red | green
+    ...     cyan = green | blue
+    ...
+    >>> Color(3)  # named combination
+    <Color.yellow: 3>
+    >>> Color(7)      # not named combination
+    <Color.cyan|magenta|blue|yellow|green|red: 7>
+
diff --git a/Lib/enum.py b/Lib/enum.py
--- a/Lib/enum.py
+++ b/Lib/enum.py
@@ -1,7 +1,7 @@
 import sys
 from types import MappingProxyType, DynamicClassAttribute
 from functools import reduce
-from operator import or_ as _or_
+from operator import or_ as _or_, and_ as _and_, xor, neg
 
 # try _collections first to reduce startup cost
 try:
@@ -47,11 +47,12 @@
     cls.__reduce_ex__ = _break_on_call_reduce
     cls.__module__ = '<unknown>'
 
+_auto_null = object()
 class auto:
     """
     Instances are replaced with an appropriate value in Enum class suites.
     """
-    pass
+    value = _auto_null
 
 
 class _EnumDict(dict):
@@ -77,7 +78,7 @@
         """
         if _is_sunder(key):
             if key not in (
-                    '_order_', '_create_pseudo_member_', '_decompose_',
+                    '_order_', '_create_pseudo_member_',
                     '_generate_next_value_', '_missing_',
                     ):
                 raise ValueError('_names_ are reserved for future Enum use')
@@ -94,7 +95,9 @@
                 # enum overwriting a descriptor?
                 raise TypeError('%r already defined as: %r' % (key, self[key]))
             if isinstance(value, auto):
-                value = self._generate_next_value(key, 1, len(self._member_names), self._last_values[:])
+                if value.value == _auto_null:
+                    value.value = self._generate_next_value(key, 1, len(self._member_names), self._last_values[:])
+                value = value.value
             self._member_names.append(key)
             self._last_values.append(value)
         super().__setitem__(key, value)
@@ -658,7 +661,7 @@
             try:
                 high_bit = _high_bit(last_value)
                 break
-            except TypeError:
+            except Exception:
                 raise TypeError('Invalid Flag value: %r' % last_value) from None
         return 2 ** (high_bit+1)
 
@@ -668,61 +671,38 @@
         if value < 0:
             value = ~value
         possible_member = cls._create_pseudo_member_(value)
-        for member in possible_member._decompose_():
-            if member._name_ is None and member._value_ != 0:
-                raise ValueError('%r is not a valid %s' % (original_value, cls.__name__))
         if original_value < 0:
             possible_member = ~possible_member
         return possible_member
 
     @classmethod
     def _create_pseudo_member_(cls, value):
+        """
+        Create a composite member iff value contains only members.
+        """
         pseudo_member = cls._value2member_map_.get(value, None)
         if pseudo_member is None:
-            # construct a non-singleton enum pseudo-member
+            # verify all bits are accounted for
+            _, extra_flags = _decompose(cls, value)
+            if extra_flags:
+                raise ValueError("%r is not a valid %s" % (value, cls.__name__))
+            # construct a singleton enum pseudo-member
             pseudo_member = object.__new__(cls)
             pseudo_member._name_ = None
             pseudo_member._value_ = value
             cls._value2member_map_[value] = pseudo_member
         return pseudo_member
 
-    def _decompose_(self):
-        """Extract all members from the value."""
-        value = self._value_
-        members = []
-        cls = self.__class__
-        for member in sorted(cls, key=lambda m: m._value_, reverse=True):
-            while _high_bit(value) > _high_bit(member._value_):
-                unknown = self._create_pseudo_member_(2 ** _high_bit(value))
-                members.append(unknown)
-                value &= ~unknown._value_
-            if (
-                    (value & member._value_ == member._value_)
-                    and (member._value_ or not members)
-                    ):
-                value &= ~member._value_
-                members.append(member)
-        if not members or value:
-            members.append(self._create_pseudo_member_(value))
-        members = list(members)
-        return members
-
     def __contains__(self, other):
         if not isinstance(other, self.__class__):
             return NotImplemented
         return other._value_ & self._value_ == other._value_
 
-    def __iter__(self):
-        if self.value == 0:
-            return iter([])
-        else:
-            return iter(self._decompose_())
-
     def __repr__(self):
         cls = self.__class__
         if self._name_ is not None:
             return '<%s.%s: %r>' % (cls.__name__, self._name_, self._value_)
-        members = self._decompose_()
+        members, uncovered = _decompose(cls, self._value_)
         return '<%s.%s: %r>' % (
                 cls.__name__,
                 '|'.join([str(m._name_ or m._value_) for m in members]),
@@ -733,7 +713,7 @@
         cls = self.__class__
         if self._name_ is not None:
             return '%s.%s' % (cls.__name__, self._name_)
-        members = self._decompose_()
+        members, uncovered = _decompose(cls, self._value_)
         if len(members) == 1 and members[0]._name_ is None:
             return '%s.%r' % (cls.__name__, members[0]._value_)
         else:
@@ -761,8 +741,11 @@
         return self.__class__(self._value_ ^ other._value_)
 
     def __invert__(self):
-        members = self._decompose_()
-        inverted_members = [m for m in self.__class__ if m not in members and not m._value_ & self._value_]
+        members, uncovered = _decompose(self.__class__, self._value_)
+        inverted_members = [
+                m for m in self.__class__
+                if m not in members and not m._value_ & self._value_
+                ]
         inverted = reduce(_or_, inverted_members, self.__class__(0))
         return self.__class__(inverted)
 
@@ -771,25 +754,45 @@
     """Support for integer-based Flags"""
 
     @classmethod
+    def _missing_(cls, value):
+        if not isinstance(value, int):
+            raise ValueError("%r is not a valid %s" % (value, cls.__name__))
+        new_member = cls._create_pseudo_member_(value)
+        return new_member
+
+    @classmethod
     def _create_pseudo_member_(cls, value):
         pseudo_member = cls._value2member_map_.get(value, None)
         if pseudo_member is None:
-            # construct a non-singleton enum pseudo-member
-            pseudo_member = int.__new__(cls, value)
-            pseudo_member._name_ = None
-            pseudo_member._value_ = value
-            cls._value2member_map_[value] = pseudo_member
+            need_to_create = [value]
+            # get unaccounted for bits
+            _, extra_flags = _decompose(cls, value)
+            # timer = 10
+            while extra_flags:
+                # timer -= 1
+                bit = _high_bit(extra_flags)
+                flag_value = 2 ** bit
+                if (flag_value not in cls._value2member_map_ and
+                        flag_value not in need_to_create
+                        ):
+                    need_to_create.append(flag_value)
+                if extra_flags == -flag_value:
+                    extra_flags = 0
+                else:
+                    extra_flags ^= flag_value
+            for value in reversed(need_to_create):
+                # construct singleton pseudo-members
+                pseudo_member = int.__new__(cls, value)
+                pseudo_member._name_ = None
+                pseudo_member._value_ = value
+                cls._value2member_map_[value] = pseudo_member
         return pseudo_member
 
-    @classmethod
-    def _missing_(cls, value):
-        possible_member = cls._create_pseudo_member_(value)
-        return possible_member
-
     def __or__(self, other):
         if not isinstance(other, (self.__class__, int)):
             return NotImplemented
-        return self.__class__(self._value_ | self.__class__(other)._value_)
+        result = self.__class__(self._value_ | self.__class__(other)._value_)
+        return result
 
     def __and__(self, other):
         if not isinstance(other, (self.__class__, int)):
@@ -806,17 +809,13 @@
     __rxor__ = __xor__
 
     def __invert__(self):
-        # members = self._decompose_()
-        # inverted_members = [m for m in self.__class__ if m not in members and not m._value_ & self._value_]
-        # inverted = reduce(_or_, inverted_members, self.__class__(0))
-        return self.__class__(~self._value_)
-
-
+        result = self.__class__(~self._value_)
+        return result
 
 
 def _high_bit(value):
     """returns index of highest bit, or -1 if value is zero or negative"""
-    return value.bit_length() - 1 if value > 0 else -1
+    return value.bit_length() - 1
 
 def unique(enumeration):
     """Class decorator for enumerations ensuring unique member values."""
@@ -830,3 +829,40 @@
         raise ValueError('duplicate values found in %r: %s' %
                 (enumeration, alias_details))
     return enumeration
+
+def _decompose(flag, value):
+    """Extract all members from the value."""
+    # _decompose is only called if the value is not named
+    not_covered = value
+    negative = value < 0
+    if negative:
+        # only check for named flags
+        flags_to_check = [
+                (m, v)
+                for v, m in flag._value2member_map_.items()
+                if m.name is not None
+                ]
+    else:
+        # check for named flags and powers-of-two flags
+        flags_to_check = [
+                (m, v)
+                for v, m in flag._value2member_map_.items()
+                if m.name is not None or _power_of_two(v)
+                ]
+    members = []
+    for member, member_value in flags_to_check:
+        if member_value and member_value & value == member_value:
+            members.append(member)
+            not_covered &= ~member_value
+    if not members and value in flag._value2member_map_:
+        members.append(flag._value2member_map_[value])
+    members.sort(key=lambda m: m._value_, reverse=True)
+    if len(members) > 1 and members[0].value == value:
+        # we have the breakdown, don't need the value member itself
+        members.pop(0)
+    return members, not_covered
+
+def _power_of_two(value):
+    if value < 1:
+        return False
+    return value == 2 ** _high_bit(value)
diff --git a/Lib/test/test_enum.py b/Lib/test/test_enum.py
--- a/Lib/test/test_enum.py
+++ b/Lib/test/test_enum.py
@@ -1634,6 +1634,13 @@
         self.assertEqual(Color.blue.value, 2)
         self.assertEqual(Color.green.value, 3)
 
+    def test_duplicate_auto(self):
+        class Dupes(Enum):
+            first = primero = auto()
+            second = auto()
+            third = auto()
+        self.assertEqual([Dupes.first, Dupes.second, Dupes.third], list(Dupes))
+
 
 class TestOrder(unittest.TestCase):
 
@@ -1731,7 +1738,7 @@
         self.assertEqual(str(Open.AC), 'Open.AC')
         self.assertEqual(str(Open.RO | Open.CE), 'Open.CE')
         self.assertEqual(str(Open.WO | Open.CE), 'Open.CE|WO')
-        self.assertEqual(str(~Open.RO), 'Open.CE|AC')
+        self.assertEqual(str(~Open.RO), 'Open.CE|AC|RW|WO')
         self.assertEqual(str(~Open.WO), 'Open.CE|RW')
         self.assertEqual(str(~Open.AC), 'Open.CE')
         self.assertEqual(str(~(Open.RO | Open.CE)), 'Open.AC')
@@ -1758,7 +1765,7 @@
         self.assertEqual(repr(Open.AC), '<Open.AC: 3>')
         self.assertEqual(repr(Open.RO | Open.CE), '<Open.CE: 524288>')
         self.assertEqual(repr(Open.WO | Open.CE), '<Open.CE|WO: 524289>')
-        self.assertEqual(repr(~Open.RO), '<Open.CE|AC: 524291>')
+        self.assertEqual(repr(~Open.RO), '<Open.CE|AC|RW|WO: 524291>')
         self.assertEqual(repr(~Open.WO), '<Open.CE|RW: 524290>')
         self.assertEqual(repr(~Open.AC), '<Open.CE: 524288>')
         self.assertEqual(repr(~(Open.RO | Open.CE)), '<Open.AC: 3>')
@@ -1949,6 +1956,33 @@
                 red = 'not an int'
                 blue = auto()
 
+    def test_cascading_failure(self):
+        class Bizarre(Flag):
+            c = 3
+            d = 4
+            f = 6
+        # Bizarre.c | Bizarre.d
+        self.assertRaisesRegex(ValueError, "5 is not a valid Bizarre", Bizarre, 5)
+        self.assertRaisesRegex(ValueError, "5 is not a valid Bizarre", Bizarre, 5)
+        self.assertRaisesRegex(ValueError, "2 is not a valid Bizarre", Bizarre, 2)
+        self.assertRaisesRegex(ValueError, "2 is not a valid Bizarre", Bizarre, 2)
+        self.assertRaisesRegex(ValueError, "1 is not a valid Bizarre", Bizarre, 1)
+        self.assertRaisesRegex(ValueError, "1 is not a valid Bizarre", Bizarre, 1)
+
+    def test_duplicate_auto(self):
+        class Dupes(Enum):
+            first = primero = auto()
+            second = auto()
+            third = auto()
+        self.assertEqual([Dupes.first, Dupes.second, Dupes.third], list(Dupes))
+
+    def test_bizarre(self):
+        class Bizarre(Flag):
+            b = 3
+            c = 4
+            d = 6
+        self.assertEqual(repr(Bizarre(7)), '<Bizarre.d|c|b: 7>')
+
 
 class TestIntFlag(unittest.TestCase):
     """Tests of the IntFlags."""
@@ -1965,6 +1999,21 @@
         AC = 3
         CE = 1<<19
 
+    def test_type(self):
+        Perm = self.Perm
+        Open = self.Open
+        for f in Perm:
+            self.assertTrue(isinstance(f, Perm))
+            self.assertEqual(f, f.value)
+        self.assertTrue(isinstance(Perm.W | Perm.X, Perm))
+        self.assertEqual(Perm.W | Perm.X, 3)
+        for f in Open:
+            self.assertTrue(isinstance(f, Open))
+            self.assertEqual(f, f.value)
+        self.assertTrue(isinstance(Open.WO | Open.RW, Open))
+        self.assertEqual(Open.WO | Open.RW, 3)
+
+
     def test_str(self):
         Perm = self.Perm
         self.assertEqual(str(Perm.R), 'Perm.R')
@@ -1975,14 +2024,14 @@
         self.assertEqual(str(Perm.R | 8), 'Perm.8|R')
         self.assertEqual(str(Perm(0)), 'Perm.0')
         self.assertEqual(str(Perm(8)), 'Perm.8')
-        self.assertEqual(str(~Perm.R), 'Perm.W|X|-8')
-        self.assertEqual(str(~Perm.W), 'Perm.R|X|-8')
-        self.assertEqual(str(~Perm.X), 'Perm.R|W|-8')
-        self.assertEqual(str(~(Perm.R | Perm.W)), 'Perm.X|-8')
+        self.assertEqual(str(~Perm.R), 'Perm.W|X')
+        self.assertEqual(str(~Perm.W), 'Perm.R|X')
+        self.assertEqual(str(~Perm.X), 'Perm.R|W')
+        self.assertEqual(str(~(Perm.R | Perm.W)), 'Perm.X')
         self.assertEqual(str(~(Perm.R | Perm.W | Perm.X)), 'Perm.-8')
-        self.assertEqual(str(~(Perm.R | 8)), 'Perm.W|X|-16')
-        self.assertEqual(str(Perm(~0)), 'Perm.R|W|X|-8')
-        self.assertEqual(str(Perm(~8)), 'Perm.R|W|X|-16')
+        self.assertEqual(str(~(Perm.R | 8)), 'Perm.W|X')
+        self.assertEqual(str(Perm(~0)), 'Perm.R|W|X')
+        self.assertEqual(str(Perm(~8)), 'Perm.R|W|X')
 
         Open = self.Open
         self.assertEqual(str(Open.RO), 'Open.RO')
@@ -1991,12 +2040,12 @@
         self.assertEqual(str(Open.RO | Open.CE), 'Open.CE')
         self.assertEqual(str(Open.WO | Open.CE), 'Open.CE|WO')
         self.assertEqual(str(Open(4)), 'Open.4')
-        self.assertEqual(str(~Open.RO), 'Open.CE|AC|-524292')
-        self.assertEqual(str(~Open.WO), 'Open.CE|RW|-524292')
-        self.assertEqual(str(~Open.AC), 'Open.CE|-524292')
-        self.assertEqual(str(~(Open.RO | Open.CE)), 'Open.AC|-524292')
-        self.assertEqual(str(~(Open.WO | Open.CE)), 'Open.RW|-524292')
-        self.assertEqual(str(Open(~4)), 'Open.CE|AC|-524296')
+        self.assertEqual(str(~Open.RO), 'Open.CE|AC|RW|WO')
+        self.assertEqual(str(~Open.WO), 'Open.CE|RW')
+        self.assertEqual(str(~Open.AC), 'Open.CE')
+        self.assertEqual(str(~(Open.RO | Open.CE)), 'Open.AC|RW|WO')
+        self.assertEqual(str(~(Open.WO | Open.CE)), 'Open.RW')
+        self.assertEqual(str(Open(~4)), 'Open.CE|AC|RW|WO')
 
     def test_repr(self):
         Perm = self.Perm
@@ -2008,14 +2057,14 @@
         self.assertEqual(repr(Perm.R | 8), '<Perm.8|R: 12>')
         self.assertEqual(repr(Perm(0)), '<Perm.0: 0>')
         self.assertEqual(repr(Perm(8)), '<Perm.8: 8>')
-        self.assertEqual(repr(~Perm.R), '<Perm.W|X|-8: -5>')
-        self.assertEqual(repr(~Perm.W), '<Perm.R|X|-8: -3>')
-        self.assertEqual(repr(~Perm.X), '<Perm.R|W|-8: -2>')
-        self.assertEqual(repr(~(Perm.R | Perm.W)), '<Perm.X|-8: -7>')
+        self.assertEqual(repr(~Perm.R), '<Perm.W|X: -5>')
+        self.assertEqual(repr(~Perm.W), '<Perm.R|X: -3>')
+        self.assertEqual(repr(~Perm.X), '<Perm.R|W: -2>')
+        self.assertEqual(repr(~(Perm.R | Perm.W)), '<Perm.X: -7>')
         self.assertEqual(repr(~(Perm.R | Perm.W | Perm.X)), '<Perm.-8: -8>')
-        self.assertEqual(repr(~(Perm.R | 8)), '<Perm.W|X|-16: -13>')
-        self.assertEqual(repr(Perm(~0)), '<Perm.R|W|X|-8: -1>')
-        self.assertEqual(repr(Perm(~8)), '<Perm.R|W|X|-16: -9>')
+        self.assertEqual(repr(~(Perm.R | 8)), '<Perm.W|X: -13>')
+        self.assertEqual(repr(Perm(~0)), '<Perm.R|W|X: -1>')
+        self.assertEqual(repr(Perm(~8)), '<Perm.R|W|X: -9>')
 
         Open = self.Open
         self.assertEqual(repr(Open.RO), '<Open.RO: 0>')
@@ -2024,12 +2073,12 @@
         self.assertEqual(repr(Open.RO | Open.CE), '<Open.CE: 524288>')
         self.assertEqual(repr(Open.WO | Open.CE), '<Open.CE|WO: 524289>')
         self.assertEqual(repr(Open(4)), '<Open.4: 4>')
-        self.assertEqual(repr(~Open.RO), '<Open.CE|AC|-524292: -1>')
-        self.assertEqual(repr(~Open.WO), '<Open.CE|RW|-524292: -2>')
-        self.assertEqual(repr(~Open.AC), '<Open.CE|-524292: -4>')
-        self.assertEqual(repr(~(Open.RO | Open.CE)), '<Open.AC|-524292: -524289>')
-        self.assertEqual(repr(~(Open.WO | Open.CE)), '<Open.RW|-524292: -524290>')
-        self.assertEqual(repr(Open(~4)), '<Open.CE|AC|-524296: -5>')
+        self.assertEqual(repr(~Open.RO), '<Open.CE|AC|RW|WO: -1>')
+        self.assertEqual(repr(~Open.WO), '<Open.CE|RW: -2>')
+        self.assertEqual(repr(~Open.AC), '<Open.CE: -4>')
+        self.assertEqual(repr(~(Open.RO | Open.CE)), '<Open.AC|RW|WO: -524289>')
+        self.assertEqual(repr(~(Open.WO | Open.CE)), '<Open.RW: -524290>')
+        self.assertEqual(repr(Open(~4)), '<Open.CE|AC|RW|WO: -5>')
 
     def test_or(self):
         Perm = self.Perm

-- 
Repository URL: https://hg.python.org/cpython


More information about the Python-checkins mailing list