[Python-checkins] bpo-24275: Don't downgrade unicode-only dicts to mixed on lookups (GH-25186)

methane webhook-mailer at python.org
Wed Apr 28 22:06:14 EDT 2021


https://github.com/python/cpython/commit/8557edbfa8f74514de82feea4c62f5963e4e0aa7
commit: 8557edbfa8f74514de82feea4c62f5963e4e0aa7
branch: master
author: Hristo Venev <hristo at venev.name>
committer: methane <songofacandy at gmail.com>
date: 2021-04-29T11:06:03+09:00
summary:

bpo-24275: Don't downgrade unicode-only dicts to mixed on lookups (GH-25186)

files:
M Lib/test/test_dict.py
M Objects/dictobject.c

diff --git a/Lib/test/test_dict.py b/Lib/test/test_dict.py
index 4b31cdc79415f..666cd81e68d81 100644
--- a/Lib/test/test_dict.py
+++ b/Lib/test/test_dict.py
@@ -1471,6 +1471,106 @@ def test_dict_items_result_gc(self):
         gc.collect()
         self.assertTrue(gc.is_tracked(next(it)))
 
+    def test_str_nonstr(self):
+        # cpython uses a different lookup function if the dict only contains
+        # `str` keys. Make sure the unoptimized path is used when a non-`str`
+        # key appears.
+
+        class StrSub(str):
+            pass
+
+        eq_count = 0
+        # This class compares equal to the string 'key3'
+        class Key3:
+            def __hash__(self):
+                return hash('key3')
+
+            def __eq__(self, other):
+                nonlocal eq_count
+                if isinstance(other, Key3) or isinstance(other, str) and other == 'key3':
+                    eq_count += 1
+                    return True
+                return False
+
+        key3_1 = StrSub('key3')
+        key3_2 = Key3()
+        key3_3 = Key3()
+
+        dicts = []
+
+        # Create dicts of the form `{'key1': 42, 'key2': 43, key3: 44}` in a
+        # bunch of different ways. In all cases, `key3` is not of type `str`.
+        # `key3_1` is a `str` subclass and `key3_2` is a completely unrelated
+        # type.
+        for key3 in (key3_1, key3_2):
+            # A literal
+            dicts.append({'key1': 42, 'key2': 43, key3: 44})
+
+            # key3 inserted via `dict.__setitem__`
+            d = {'key1': 42, 'key2': 43}
+            d[key3] = 44
+            dicts.append(d)
+
+            # key3 inserted via `dict.setdefault`
+            d = {'key1': 42, 'key2': 43}
+            self.assertEqual(d.setdefault(key3, 44), 44)
+            dicts.append(d)
+
+            # key3 inserted via `dict.update`
+            d = {'key1': 42, 'key2': 43}
+            d.update({key3: 44})
+            dicts.append(d)
+
+            # key3 inserted via `dict.__ior__`
+            d = {'key1': 42, 'key2': 43}
+            d |= {key3: 44}
+            dicts.append(d)
+
+            # `dict(iterable)`
+            def make_pairs():
+                yield ('key1', 42)
+                yield ('key2', 43)
+                yield (key3, 44)
+            d = dict(make_pairs())
+            dicts.append(d)
+
+            # `dict.copy`
+            d = d.copy()
+            dicts.append(d)
+
+            # dict comprehension
+            d = {key: 42 + i for i,key in enumerate(['key1', 'key2', key3])}
+            dicts.append(d)
+
+        for d in dicts:
+            with self.subTest(d=d):
+                self.assertEqual(d.get('key1'), 42)
+
+                # Try to make an object that is of type `str` and is equal to
+                # `'key1'`, but (at least on cpython) is a different object.
+                noninterned_key1 = 'ke'
+                noninterned_key1 += 'y1'
+                if support.check_impl_detail(cpython=True):
+                    # suppress a SyntaxWarning
+                    interned_key1 = 'key1'
+                    self.assertFalse(noninterned_key1 is interned_key1)
+                self.assertEqual(d.get(noninterned_key1), 42)
+
+                self.assertEqual(d.get('key3'), 44)
+                self.assertEqual(d.get(key3_1), 44)
+                self.assertEqual(d.get(key3_2), 44)
+
+                # `key3_3` itself is definitely not a dict key, so make sure
+                # that `__eq__` gets called.
+                #
+                # Note that this might not hold for `key3_1` and `key3_2`
+                # because they might be the same object as one of the dict keys,
+                # in which case implementations are allowed to skip the call to
+                # `__eq__`.
+                eq_count = 0
+                self.assertEqual(d.get(key3_3), 44)
+                self.assertGreaterEqual(eq_count, 1)
+
 
 class CAPITest(unittest.TestCase):
 
diff --git a/Objects/dictobject.c b/Objects/dictobject.c
index 44796a6066728..0aeee7011e844 100644
--- a/Objects/dictobject.c
+++ b/Objects/dictobject.c
@@ -857,7 +857,6 @@ lookdict_unicode(PyDictObject *mp, PyObject *key,
        unicodes is to override __eq__, and for speed we don't cater to
        that here. */
     if (!PyUnicode_CheckExact(key)) {
-        mp->ma_keys->dk_lookup = lookdict;
         return lookdict(mp, key, hash, value_addr);
     }
 
@@ -900,7 +899,6 @@ lookdict_unicode_nodummy(PyDictObject *mp, PyObject *key,
        unicodes is to override __eq__, and for speed we don't cater to
        that here. */
     if (!PyUnicode_CheckExact(key)) {
-        mp->ma_keys->dk_lookup = lookdict;
         return lookdict(mp, key, hash, value_addr);
     }
 
@@ -1084,7 +1082,6 @@ insertdict(PyDictObject *mp, PyObject *key, Py_hash_t hash, PyObject *value)
     if (ix == DKIX_ERROR)
         goto Fail;
 
-    assert(PyUnicode_CheckExact(key) || mp->ma_keys->dk_lookup == lookdict);
     MAINTAIN_TRACKING(mp, key, value);
 
     /* When insertion order is different from shared key, we can't share
@@ -1106,6 +1103,9 @@ insertdict(PyDictObject *mp, PyObject *key, Py_hash_t hash, PyObject *value)
             if (insertion_resize(mp) < 0)
                 goto Fail;
         }
+        if (!PyUnicode_CheckExact(key) && mp->ma_keys->dk_lookup != lookdict) {
+            mp->ma_keys->dk_lookup = lookdict;
+        }
         Py_ssize_t hashpos = find_empty_slot(mp->ma_keys, hash);
         ep = &DK_ENTRIES(mp->ma_keys)[mp->ma_keys->dk_nentries];
         dictkeys_set_index(mp->ma_keys, hashpos, mp->ma_keys->dk_nentries);
@@ -3068,6 +3068,9 @@ PyDict_SetDefault(PyObject *d, PyObject *key, PyObject *defaultobj)
                 return NULL;
             }
         }
+        if (!PyUnicode_CheckExact(key) && mp->ma_keys->dk_lookup != lookdict) {
+            mp->ma_keys->dk_lookup = lookdict;
+        }
         Py_ssize_t hashpos = find_empty_slot(mp->ma_keys, hash);
         ep0 = DK_ENTRIES(mp->ma_keys);
         ep = &ep0[mp->ma_keys->dk_nentries];



More information about the Python-checkins mailing list