function that counts...

Bryan bryanjugglercryptographer at yahoo.com
Sat May 22 10:18:33 EDT 2010


I wrote:
> I came up with a recursive memo-izing algorithm that
> handles 100-digit n's.
[...]

I made a couple improvements. Code below.

-Bryan

#---------------------

_nds = {}
def ndsums(m, d):
    """ Count d-digit ints with digits suming to m.
    """
    assert m >= 0 and d >= 0
    m = min(m, d * 9 - m)   # exploit symmetry
    if m < 0:
        return 0
    if m == 0 or d == 1:
        return 1
    if (m, d) not in _nds:
        _nds[(m, d)] = sum(ndsums(m - i, d - 1)
               for i in range(min(10, m + 1)))
    return _nds[(m, d)]


def prttn(m, n):
    assert m >= 0 and n > 0
    count = 0
    dls = [int(c) for c in reversed(str(n))]
    while dls:
        msd = dls.pop()
        count += sum(ndsums(m - d, len(dls)) for
            d in range(min(msd, m + 1)))
        m -= msd
    return count


#----------------------
# Testing

from bisect import bisect_right

def slow_prttn(m, n):
    return sum(1 for k in range(m % 9, n, 9)
            if sum(int(i) for i in str(k)) == m)

_sums = [0, {}]
def tab_prttn(m, n):
    upto, sums = _sums
    if n >= upto:
        for i in range(upto, n):
            dsum = sum(int(c) for c in str(i))
            sums.setdefault(dsum, []).append(i)
        _sums[0] = n
    if m not in sums:
        return 0
    return bisect_right(sums[m], n - 1)

for n in range(1, 1234567):
    digits = [int(c) for c in str(n)]
    for m in range(9 * len(digits)):
        count = tab_prttn(m, n)
        assert prttn(m, n) == count
        if n < 500:
            assert slow_prttn(m, n) == count
        if count == 0:
            break
    if n % 1000 == 0:
        print('Tested to:', n)



More information about the Python-list mailing list