itertools.intersect?

Arnaud Delobelle arnodel at googlemail.com
Mon Jun 15 16:10:04 EDT 2009


"Andrew Henshaw" <andrew.henshaw at gtri.gatech.edu> writes:

> "Raymond Hettinger" <python at rcn.com> wrote in message 
> news:fb1feeeb-c430-4ca7-9e76-fea02ea3ef6f at v23g2000pro.googlegroups.com...
>> [David Wilson]
>>> The problem is simple: given one or more ordered sequences, return
>>> only the objects that appear in each sequence, without reading the
>>> whole set into memory. This is basically an SQL many-many join.
>>
>> FWIW, this is equivalent to the Welfare Crook problem in David Gries
>> book, The Science of Programming, http://tinyurl.com/mzoqk4 .
>>
>>
>>> I thought it could be accomplished through recursively embedded
>>> generators, but that approach failed in the end.
>>
>> Translated into Python, David Gries' solution looks like this:
>>
>> def intersect(f, g, h):
>>    i = j = k = 0
>>    try:
>>        while True:
>>            if f[i] < g[j]:
>>                i += 1
>>            elif g[j] < h[k]:
>>                j += 1
>>            elif h[k] < f[i]:
>>                k += 1
>>            else:
>>                print(f[i])
>>                i += 1
>>    except IndexError:
>>        pass
>>
>> streams = [sorted(sample(range(50), 30)) for i in range(3)]
>> for s in streams:
>>    print(s)
>> intersect(*streams)
>>
>>
>> Raymond
>
> Here's my translation of your code to support variable number of streams:
>
> def intersect(*s):
>     num_streams = len(s)
>     indices = [0]*num_streams
>     try:
>         while True:
>             for i in range(num_streams):
>                 j = (i + 1) % num_streams
>                 if s[i][indices[i]] < s[j][indices[j]]:
>                     indices[i] += 1
>                     break
>             else:
>                 print(s[0][indices[0]])
>                 indices[0] += 1
>     except IndexError:
>         pass

I posted this solution earlier on:

def intersect(iterables):
    nexts = [iter(iterable).next for iterable in iterables]
    v = [next() for next in nexts]
    while True:
        for i in xrange(1, len(v)):
            while v[0] > v[i]:
                v[i] = nexts[i]()
            if v[0] < v[i]: break
        else:
            yield v[0]
        v[0] = nexts[0]()

It's quite similar but not as clever as the solution proposed by
R. Hettinger insofar as it doesn't exploit the fact that if a, b, c are
members of a totally ordered set, then:

    if a >= b >= c >= a then a = b = c.

However it can be easily modified to do so:

def intersect(iterables):
    nexts = [iter(iterable).next for iterable in iterables]
    v = [next() for next in nexts]
    while True:
        for i in xrange(-1, len(v)-1):
            if v[i] < v[i+1]:
                v[i] = nexts[i]()
                break
        else:
            yield v[0]
            v[0] = nexts[0]()

I haven't really thought about it too much, but there may be cases where
the original version terminates faster (I guess when it is expected that
the intersection is empty).

-- 
Arnaud



More information about the Python-list mailing list