[Python-ideas] Reducing collisions in small dicts/sets

Tim Peters tim.peters at gmail.com
Sat Jun 24 17:29:28 EDT 2017


Short course:  the average number of probes needed when searching
small dicts/sets can be reduced, in both successful ("found") and
failing ("not found") cases.

But I'm not going to pursue this.  This is a brain dump for someone
who's willing to endure the interminable pain of arguing about
benchmarks ;-)

Background:

http://bugs.python.org/issue30671

raised some questions about how dict collisions are handled.  While
the analysis there didn't make sense to me, I wrote enough code to dig
into it.  As detailed in that bug report, the current implementation
appeared to meet the theoretical performance of "uniform hashing",
meaning there was no room left for improvement.

However, that missed something:  the simple expressions for expected
probes under uniform hashing are upper bounds, and while they're
excellent approximations for modest load factors in sizable tables,
for small tables they're significantly overstated.  For example, for a
table with 5 items in 8 slots, the load factor is a = 5/8 = 0.625, and

    avg probes when found = log(1/(1-a))/a = 1.57
           when not found = 1/(1-a)        = 2.67

However, exact analysis gives 1.34 and 2.25 instead.  The current code
achieves the upper bounds, but not the exact values.  As a sanity
check, a painfully slow implementation of uniform hashing does achieve
the exact values.

Code for all this is attached, in a framework that allows you to
easily plug in any probe sequence strategy.  The current strategy is
implemented by generator "current".  There are also implementations of
"linear" probing, "quadratic" probing, "pre28201" probing (the current
strategy before bug 28201 repaired an error in its coding), "uniform"
probing, and ... "double".  The last is one form of "double hashing"
that gets very close to "uniform".

Its loop guts are significantly cheaper than the current scheme, just
1 addition and 1 mask.  However, it requires a non-trivial modulus to
get started, and that's expensive.

Is there a cheaper way to get close to "uniform"?  I don't know - this
was just the best I came up with so far.

Does it matter?  See above ;-)

If you want to pursue this, take these as given:

1. The first probe must be simply the last `nbits` bits of the hash
code.  The speed of the first probe is supremely important, that's the
fastest possible first probe, and it guarantees no collisions at all
for a dict indexed by a contiguous range of integers (an important use
case).

2. The probe sequence must (at least eventually) depend on every bit
in the hash code.  Else it's waaay too easy to stumble into
quadratic-time behavior for "bad" sets of keys, even by accident.

Have fun :-)
-------------- next part --------------
MIN_ELTS = 100_000

M64 = (1 << 64) - 1
def phash(obj, M=M64):  # hash(obj) as uint64
    return hash(obj) & M

# Probers:  generate sequence of table indices to look at,
# in table of size 2**nbits, for object with uint64 hash code h.

def linear(h, nbits):
    mask = (1 << nbits) - 1
    i = h & mask
    while True:
        yield i
        i = (i + 1) & mask

# offsets of 0, 1, 3, 6, 10, 15, ...
# this permutes the index range when the table size is a power of 2
def quadratic(h, nbits):
    mask = (1 << nbits) - 1
    i = h & mask
    inc = 1
    while True:
        yield i
        i = (i + inc) & mask
        inc += 1
    
def pre28201(h, nbits):
    mask = (1 << nbits) - 1
    i = h & mask
    while True:
        yield i
        i = (5*i + h + 1) & mask
        h >>= 5

def current(h, nbits):
    mask = (1 << nbits) - 1
    i = h & mask
    while True:
        yield i
        h >>= 5
        i = (5*i + h + 1) & mask

# One version of "double hashing".  The increment between probes is
# fixed, but varies across objects.  This does very well!  Note that the
# increment needs to be relatively prime to the table size so that all
# possible indices are generated.  Because our tables have power-of-2
# sizes, we merely need to ensure the increment is odd.
# Using `h % mask` is akin to "casting out 9's" in decimal:  it's as if
# we broke the hash code into nbits-wide chunks from the right, then
# added them, then repeated that procedure until only one "digit"
# remains. All bits in the hash code affect the result.
# While mod is expensive, a successful search usual gets out on the
# first try, & then the lookup can return before the mod completes.
def double(h, nbits):
    mask = (1 << nbits) - 1
    i = h & mask
    yield i
    inc = (h % mask) | 1    # force it odd
    while True:
        i = (i + inc) & mask
        yield i

# The theoretical "gold standard":  generate a random permutation of the
# table indices for each object.  We can't actually do that, but
# Python's PRNG gets close enough that there's no practical difference.
def uniform(h, nbits):
    from random import seed, randrange
    seed(h)
    n = 1 << nbits
    seen = set()
    while True:
        assert len(seen) < n
        while True:
            i = randrange(n)
            if i not in seen:
                break
        seen.add(i)
        yield i

def spray(nbits, objs, cs, prober, *, used=None, shift=5):
    building = used is None
    nslots = 1 << nbits
    mask = nslots - 1
    if building:
        used = [0] * nslots
    assert len(used) == nslots
    for o in objs:
        n = 1
        for i in prober(phash(o), nbits):
            if used[i]:
                n += 1
            else:
                break
        if building:
            used[i] = 1
        cs[n] += 1
    return used

def objgen(i=1):
    while True:
        yield str(i)
        i += 1

# Average probes for a failing search; e.g.,
# 100 slots; 3 occupied
# 1:                        97/100
# 2:  3/100 *               97/99
# 3:  3/100 * 2/99 *        97/98
# 4:  3/100 * 2/99 * 1/98 * 97/97
#
# `total` slots, `filled` occupied
# probability `p` probes will be needed, 1 <= p <= filled+1
# p-1 collisions followed by success:
#     ff(filled, p-1) / ff(total, p-1) * (total - filled) / (total - (p-1))
# where `ff` is the falling factorial.
def avgf(total, filled):
    assert 0 <= filled < total
    ffn = float(filled)
    ffd = float(total)
    tmf = ffd - ffn
    result = 0.0
    ffpartial = 1.0
    ppartial = 0.0
    for p in range(1, filled + 2):
        thisp = ffpartial * tmf / (total - (p-1))
        ppartial += thisp
        result += thisp * p
        ffpartial *= ffn / ffd
        ffn -= 1.0
        ffd -= 1.0
    assert abs(ppartial - 1.0) < 1e-14, ppartial
    return result

# Average probes for a successful search.  Alas, this takes time
# quadratic in `filled`.
def avgs(total, filled):
    assert 0 < filled < total
    return sum(avgf(total, f) for f in range(filled)) / filled

def pstats(ns):
    total = sum(ns.values())
    small = min(ns)
    print(f"min {small}:{ns[small]/total:.2%} "
          f"max {max(ns)} "
          f"mean {sum(i * j for i, j in ns.items())/total:.2f} ")

def drive(nbits):
    from collections import defaultdict
    from itertools import islice
    import math
    import sys

    nslots = 1 << nbits
    dlen = nslots * 2 // 3
    assert (sys.getsizeof({i: i for i in range(dlen)}) <
            sys.getsizeof({i: i for i in range(dlen + 1)}))
    alpha = dlen / nslots # actual load factor of max dict
    ntodo = (MIN_ELTS + dlen - 1) // dlen
    print()
    print("bits", nbits,
          f"nslots {nslots:,} dlen {dlen:,} alpha {alpha:.2f} "
          f"# built {ntodo:,}")
    print(f"theoretical avg probes for uniform hashing "
          f"when found {math.log(1/(1-alpha))/alpha:.2f} "
          f"not found {1/(1-alpha):.2f}")
    print("                                     crisp ", end="")
    if nbits > 12:
        print("... skipping (slow!)")
    else:
        print(f"when found {avgs(nslots, dlen):.2f} "
              f"not found {avgf(nslots, dlen):.2f}")

    for prober in (linear, quadratic, pre28201, current,
                   double, uniform):
        print("    prober", prober.__name__)
        objs = objgen()
        good = defaultdict(int)
        bad = defaultdict(int)
        for _ in range(ntodo):
            used = spray(nbits, islice(objs, dlen), good, prober)
            assert sum(used) == dlen
            spray(nbits, islice(objs, dlen), bad, prober, used=used)
        print(" " * 8 + "found ", end="")
        pstats(good)
        print(" " * 8 + "fail  ", end="")
        pstats(bad)

for bits in range(3, 23):
    drive(bits)


More information about the Python-ideas mailing list