Most pythonic way of weighted random selection
Steven D'Aprano
steve at REMOVE-THIS-cybersource.com.au
Sat Aug 30 22:02:17 EDT 2008
On Sat, 30 Aug 2008 17:41:27 +0200, Manuel Ebert wrote:
> Dear list,
>
> who's got aesthetic advice for the following problem?
...
[ugly code removed]
> Now that looks plain ugly, and I wonder whether you might find a
> slightly more elegant way of doing it without using numpy and the like.
Never be afraid to factor out pieces of code into small functions.
Writing a single huge while loop that does everything is not only hard to
read, hard to write and hard to debug, but it can also run slower. (It
depends on a number of factors.)
Anyway, here's my attempt to solve the problem, as best as I can
understand it:
import random
def eq(x, y, tol=1e-10):
# floating point equality within some tolerance
return abs(x-y) <= tol
M = [[0.2, 0.4, 0.05], [0.1, 0.05, 0.2]]
# the sums of each row must sum to 1.0
assert eq(1.0, sum([sum(row) for row in M]))
# build a cumulative probability matrix
CM = []
for row in M:
for p in row:
CM.append(p) # initialize with the raw probabilities
for i in range(1, len(CM)):
CM[i] += CM[i-1] # and turn into cumulative probabilities
assert CM[0] >= 0.0
assert eq(CM[-1], 1.0)
def index(data, p):
"""Return the index of the item in data
which is no smaller than float p.
"""
# Note: this uses a linear search. If it is too slow,
# you can re-write it using the bisect module.
for i, x in enumerate(data):
if x >= p:
return i
return len(data-1)
def index_to_rowcolumn(i, num_columns):
"""Convert a linear index number i into a (row, column) tuple."""
# When converting [ [a, b, c, ...], [...] ] into a single
# array [a, b, c, ... z] we have the identity:
# index number = row number * number of columns + column number
return divmod(i, num_columns)
# Now with these two helper functions, we can find the row and column
# number of the first entry in M where the cumulative probability
# exceeds some given value.
# You will need to define your own fulfills_criterion_a and
# fulfills_criterion_b, but here's a couple of mock functions
# for testing with:
def fulfills_criterion_a(row, column):
return random.random() < 0.5
fulfills_criterion_b = fulfills_criterion_a
def find_match(p=0.2):
while True:
r = random.random()
i = index(CM, r)
row, column = index_to_rowcolumn(i, len(M[0]))
if fulfills_criterion_a(row, column) or \
fulfills_criterion_b(row, column):
return row, column
else:
if random.random() <= p:
return row, column
And here's my test:
>>> find_match()
(1, 2)
Hope this helps.
--
Steven
More information about the Python-list
mailing list