function that counts...

Bryan bryanjugglercryptographer at yahoo.com
Fri Jun 11 16:19:33 EDT 2010


Lie Ryan wrote:
> In my original post in comp.programming, I
> used this definition of factorial:
>
> def fact(n):
>     """ factorial function (i.e. n! = n * (n-1) * ... * 2 * 1) """
>     p = 1
>     for i in range(1,n+1):
>         p *= i
>     return p

Ah, much better, but partition10(M, i) gets the wrong answer when i is
1 or 2. I think you don't want to let M go negative. With that tweak,
it seems to work in general, and fact() never gets called with a
negative number.

What I really like about your partition10() is that it's adaptable to
efficiently handle bases much larger than 10. Richard Thomas's
algorithm is poly-time and efficient as long as the base is small.

I'll take the liberty of tweaking your code to handle the 1 or 2 digit
case, and write the more general form. I'll also memoize fact(), and
add prttn() and a test.

--
--Bryan


_ft = [1]
def fact(n):
    assert n >= 0 and n % 1 == 0
    if len(_ft) <= n:
        for i in range(len(_ft), n + 1):
            _ft.append(i * _ft[-1])
    return _ft[n]

def C(n, r):
    """ regular Combination (nCr) """
    return fact(n) // (fact(n - r) * fact(r))

def D(M, N):
    """ Distribution aka Partitioning """
    assert M >= 0 and N > 0
    return C(M + N - 1, M)

def partition(nballs, nbins, binmax):
    """Count ways to put identical balls into distinct bounded bins.
    """
    if nbins == 0:
        return int(nballs == 0)
    s = 0
    sign = 1
    for j in range(1 + min(nbins, nballs // binmax)):
        s += sign * D(nballs, nbins) * C(nbins, j)

        # flip the sign for inclusion-exclusion
        sign *= -1
        nballs -= binmax
    return s

def prttn(m, n):
    assert m >= 0 and n > 0
    count = 0
    dls = [int(c) for c in reversed(str(n))]
    while dls:
        msd = dls.pop()
        count += sum(partition(m - d, len(dls), 10) for
            d in range(min(msd, m + 1)))
        m -= msd
    return count


def test():
    upto = 123456
    counts = [0] * (len(str(upto)) * 9)
    for n in range(upto):
        digits = [int(c) for c in str(n)]
        counts[sum(digits)] += 1
        for m in range(9 * len(digits) + 2):
            count = prttn(m, n + 1)
            assert count == counts[m]
            if count == 0:
                break
        assert count == 0
        if n % 1000 == 0:
            print('Tested to:', n)

test()







More information about the Python-list mailing list