[Python-checkins] r61224 - in python/trunk: Doc/library/itertools.rst Lib/test/test_itertools.py Modules/itertoolsmodule.c

raymond.hettinger python-checkins at python.org
Tue Mar 4 05:17:09 CET 2008


Author: raymond.hettinger
Date: Tue Mar  4 05:17:08 2008
New Revision: 61224

Modified:
   python/trunk/Doc/library/itertools.rst
   python/trunk/Lib/test/test_itertools.py
   python/trunk/Modules/itertoolsmodule.c
Log:
Beef-up docs and tests for itertools.  Fix-up end-case for product().

Modified: python/trunk/Doc/library/itertools.rst
==============================================================================
--- python/trunk/Doc/library/itertools.rst	(original)
+++ python/trunk/Doc/library/itertools.rst	Tue Mar  4 05:17:08 2008
@@ -89,6 +89,7 @@
 
    .. versionadded:: 2.6
 
+
 .. function:: combinations(iterable, r)
 
    Return successive *r* length combinations of elements in the *iterable*.
@@ -123,6 +124,17 @@
                     indices[j] = indices[j-1] + 1
                 yield tuple(pool[i] for i in indices)
 
+   The code for :func:`combinations` can be also expressed as a subsequence
+   of :func:`permutations` after filtering entries where the elements are not
+   in sorted order (according to their position in the input pool)::
+
+        def combinations(iterable, r):
+            pool = tuple(iterable)
+            n = len(pool)
+            for indices in permutations(range(n), r):
+                if sorted(indices) == list(indices):
+                    yield tuple(pool[i] for i in indices)
+
    .. versionadded:: 2.6
 
 .. function:: count([n])
@@ -391,6 +403,18 @@
                 else:
                     return
 
+   The code for :func:`permutations` can be also expressed as a subsequence of 
+   :func:`product`, filtered to exclude entries with repeated elements (those
+   from the same position in the input pool)::
+
+        def permutations(iterable, r=None):
+            pool = tuple(iterable)
+            n = len(pool)
+            r = n if r is None else r
+            for indices in product(range(n), repeat=r):
+                if len(set(indices)) == r:
+                    yield tuple(pool[i] for i in indices)
+
    .. versionadded:: 2.6
 
 .. function:: product(*iterables[, repeat])
@@ -401,9 +425,9 @@
    ``product(A, B)`` returns the same as ``((x,y) for x in A for y in B)``.
 
    The leftmost iterators are in the outermost for-loop, so the output tuples
-   cycle in a manner similar to an odometer (with the rightmost element
-   changing on every iteration).  This results in a lexicographic ordering
-   so that if the inputs iterables are sorted, the product tuples are emitted
+   cycle like an odometer (with the rightmost element changing on every 
+   iteration).  This results in a lexicographic ordering so that if the 
+   inputs iterables are sorted, the product tuples are emitted
    in sorted order.
 
    To compute the product of an iterable with itself, specify the number of
@@ -415,12 +439,11 @@
 
        def product(*args, **kwds):
            pools = map(tuple, args) * kwds.get('repeat', 1)
-           if pools:            
-               result = [[]]
-               for pool in pools:
-                   result = [x+[y] for x in result for y in pool]
-               for prod in result:
-                   yield tuple(prod)
+           result = [[]]
+           for pool in pools:
+               result = [x+[y] for x in result for y in pool]
+           for prod in result:
+               yield tuple(prod)
 
    .. versionadded:: 2.6
 

Modified: python/trunk/Lib/test/test_itertools.py
==============================================================================
--- python/trunk/Lib/test/test_itertools.py	(original)
+++ python/trunk/Lib/test/test_itertools.py	Tue Mar  4 05:17:08 2008
@@ -40,9 +40,21 @@
     'Convenience function for partially consuming a long of infinite iterable'
     return list(islice(seq, n))
 
+def prod(iterable):
+    return reduce(operator.mul, iterable, 1)
+
 def fact(n):
     'Factorial'
-    return reduce(operator.mul, range(1, n+1), 1)
+    return prod(range(1, n+1))
+
+def permutations(iterable, r=None):
+    # XXX use this until real permutations code is added
+    pool = tuple(iterable)
+    n = len(pool)
+    r = n if r is None else r
+    for indices in product(range(n), repeat=r):
+        if len(set(indices)) == r:
+            yield tuple(pool[i] for i in indices)
 
 class TestBasicOps(unittest.TestCase):
     def test_chain(self):
@@ -62,11 +74,38 @@
     def test_combinations(self):
         self.assertRaises(TypeError, combinations, 'abc')   # missing r argument
         self.assertRaises(TypeError, combinations, 'abc', 2, 1) # too many arguments
+        self.assertRaises(TypeError, combinations, None)        # pool is not iterable
         self.assertRaises(ValueError, combinations, 'abc', -2)  # r is negative
         self.assertRaises(ValueError, combinations, 'abc', 32)  # r is too big
         self.assertEqual(list(combinations(range(4), 3)),
                                            [(0,1,2), (0,1,3), (0,2,3), (1,2,3)])
-        for n in range(8):
+
+        def combinations1(iterable, r):
+            'Pure python version shown in the docs'
+            pool = tuple(iterable)
+            n = len(pool)
+            indices = range(r)
+            yield tuple(pool[i] for i in indices)
+            while 1:
+                for i in reversed(range(r)):
+                    if indices[i] != i + n - r:
+                        break
+                else:
+                    return
+                indices[i] += 1
+                for j in range(i+1, r):
+                    indices[j] = indices[j-1] + 1
+                yield tuple(pool[i] for i in indices)
+
+        def combinations2(iterable, r):
+            'Pure python version shown in the docs'
+            pool = tuple(iterable)
+            n = len(pool)
+            for indices in permutations(range(n), r):
+                if sorted(indices) == list(indices):
+                    yield tuple(pool[i] for i in indices)
+
+        for n in range(7):
             values = [5*x-12 for x in range(n)]
             for r in range(n+1):
                 result = list(combinations(values, r))
@@ -78,6 +117,73 @@
                     self.assertEqual(len(set(c)), r)                    # no duplicate elements
                     self.assertEqual(list(c), sorted(c))                # keep original ordering
                     self.assert_(all(e in values for e in c))           # elements taken from input iterable
+                self.assertEqual(result, list(combinations1(values, r))) # matches first pure python version
+                self.assertEqual(result, list(combinations2(values, r))) # matches first pure python version
+
+        # Test implementation detail:  tuple re-use
+        self.assertEqual(len(set(map(id, combinations('abcde', 3)))), 1)
+        self.assertNotEqual(len(set(map(id, list(combinations('abcde', 3))))), 1)
+
+    def test_permutations(self):
+        self.assertRaises(TypeError, permutations)              # too few arguments
+        self.assertRaises(TypeError, permutations, 'abc', 2, 1) # too many arguments
+##        self.assertRaises(TypeError, permutations, None)        # pool is not iterable
+##        self.assertRaises(ValueError, permutations, 'abc', -2)  # r is negative
+##        self.assertRaises(ValueError, permutations, 'abc', 32)  # r is too big
+        self.assertEqual(list(permutations(range(3), 2)),
+                                           [(0,1), (0,2), (1,0), (1,2), (2,0), (2,1)])
+
+        def permutations1(iterable, r=None):
+            'Pure python version shown in the docs'
+            pool = tuple(iterable)
+            n = len(pool)
+            r = n if r is None else r
+            indices = range(n)
+            cycles = range(n-r+1, n+1)[::-1]
+            yield tuple(pool[i] for i in indices[:r])
+            while n:
+                for i in reversed(range(r)):
+                    cycles[i] -= 1
+                    if cycles[i] == 0:
+                        indices[i:] = indices[i+1:] + indices[i:i+1]
+                        cycles[i] = n - i
+                    else:
+                        j = cycles[i]
+                        indices[i], indices[-j] = indices[-j], indices[i]
+                        yield tuple(pool[i] for i in indices[:r])
+                        break
+                else:
+                    return
+
+        def permutations2(iterable, r=None):
+            'Pure python version shown in the docs'
+            pool = tuple(iterable)
+            n = len(pool)
+            r = n if r is None else r
+            for indices in product(range(n), repeat=r):
+                if len(set(indices)) == r:
+                    yield tuple(pool[i] for i in indices)
+
+        for n in range(7):
+            values = [5*x-12 for x in range(n)]
+            for r in range(n+1):
+                result = list(permutations(values, r))
+                self.assertEqual(len(result), fact(n) / fact(n-r))      # right number of perms
+                self.assertEqual(len(result), len(set(result)))         # no repeats
+                self.assertEqual(result, sorted(result))                # lexicographic order
+                for p in result:
+                    self.assertEqual(len(p), r)                         # r-length permutations
+                    self.assertEqual(len(set(p)), r)                    # no duplicate elements
+                    self.assert_(all(e in values for e in p))           # elements taken from input iterable
+                self.assertEqual(result, list(permutations1(values, r))) # matches first pure python version
+                self.assertEqual(result, list(permutations2(values, r))) # matches first pure python version
+                if r == n:
+                    self.assertEqual(result, list(permutations(values, None))) # test r as None
+                    self.assertEqual(result, list(permutations(values)))       # test default r
+
+        # Test implementation detail:  tuple re-use
+##        self.assertEqual(len(set(map(id, permutations('abcde', 3)))), 1)
+        self.assertNotEqual(len(set(map(id, list(permutations('abcde', 3))))), 1)
 
     def test_count(self):
         self.assertEqual(zip('abc',count()), [('a', 0), ('b', 1), ('c', 2)])
@@ -288,7 +394,7 @@
 
     def test_product(self):
         for args, result in [
-            ([], []),                       # zero iterables   ??? is this correct
+            ([], [()]),                     # zero iterables
             (['ab'], [('a',), ('b',)]),     # one iterable
             ([range(2), range(3)], [(0,0), (0,1), (0,2), (1,0), (1,1), (1,2)]),     # two iterables
             ([range(0), range(2), range(3)], []),           # first iterable with zero length
@@ -305,10 +411,10 @@
                     set('abcdefg'), range(11), tuple(range(13))]
         for i in range(100):
             args = [random.choice(argtypes) for j in range(random.randrange(5))]
-            n = reduce(operator.mul, map(len, args), 1) if args else 0
-            self.assertEqual(len(list(product(*args))), n)
+            expected_len = prod(map(len, args))
+            self.assertEqual(len(list(product(*args))), expected_len)
             args = map(iter, args)
-            self.assertEqual(len(list(product(*args))), n)
+            self.assertEqual(len(list(product(*args))), expected_len)
 
         # Test implementation detail:  tuple re-use
         self.assertEqual(len(set(map(id, product('abc', 'def')))), 1)

Modified: python/trunk/Modules/itertoolsmodule.c
==============================================================================
--- python/trunk/Modules/itertoolsmodule.c	(original)
+++ python/trunk/Modules/itertoolsmodule.c	Tue Mar  4 05:17:08 2008
@@ -1885,10 +1885,7 @@
 
 	if (result == NULL) {
                 /* On the first pass, return an initial tuple filled with the 
-                   first element from each pool.  If any pool is empty, then 
-                   whole product is empty and we're already done */
-		if (npools == 0)
-			goto empty;
+                   first element from each pool. */
 		result = PyTuple_New(npools);
 		if (result == NULL)
             		goto empty;


More information about the Python-checkins mailing list