[Python-checkins] r68410 - in python/branches/release30-maint: Doc/library/itertools.rst Lib/test/test_itertools.py Modules/itertoolsmodule.c

raymond.hettinger python-checkins at python.org
Thu Jan 8 22:07:00 CET 2009


Author: raymond.hettinger
Date: Thu Jan  8 22:07:00 2009
New Revision: 68410

Log:
Backport r68409 fixing itertools.permutations() and combinations().

Modified:
   python/branches/release30-maint/Doc/library/itertools.rst
   python/branches/release30-maint/Lib/test/test_itertools.py
   python/branches/release30-maint/Modules/itertoolsmodule.c

Modified: python/branches/release30-maint/Doc/library/itertools.rst
==============================================================================
--- python/branches/release30-maint/Doc/library/itertools.rst	(original)
+++ python/branches/release30-maint/Doc/library/itertools.rst	Thu Jan  8 22:07:00 2009
@@ -104,7 +104,9 @@
             # combinations(range(4), 3) --> 012 013 023 123
             pool = tuple(iterable)
             n = len(pool)
-            indices = range(r)
+            if r > n:
+                return
+            indices = list(range(r))
             yield tuple(pool[i] for i in indices)
             while 1:
                 for i in reversed(range(r)):
@@ -128,6 +130,8 @@
                 if sorted(indices) == list(indices):
                     yield tuple(pool[i] for i in indices)
 
+   The number of items returned is ``n! / r! / (n-r)!`` when ``0 <= r <= n``
+   or zero when ``r > n``.
 
 .. function:: count([n])
 
@@ -325,7 +329,9 @@
             pool = tuple(iterable)
             n = len(pool)
             r = n if r is None else r
-            indices = range(n)
+            if r > n:
+                return
+            indices = list(range(n))
             cycles = range(n, n-r, -1)
             yield tuple(pool[i] for i in indices[:r])
             while n:
@@ -354,6 +360,8 @@
                 if len(set(indices)) == r:
                     yield tuple(pool[i] for i in indices)
 
+   The number of items returned is ``n! / (n-r)!`` when ``0 <= r <= n``
+   or zero when ``r > n``.
 
 .. function:: product(*iterables[, repeat])
 
@@ -593,7 +601,8 @@
        return (d for d, s in zip(data, selectors) if s)
 
    def combinations_with_replacement(iterable, r):
-       "combinations_with_replacement('ABC', 3) --> AA AB AC BB BC CC"
+       "combinations_with_replacement('ABC', 2) --> AA AB AC BB BC CC"
+       # number items returned:  (n+r-1)! / r! / (n-1)!
        pool = tuple(iterable)
        n = len(pool)
        indices = [0] * r

Modified: python/branches/release30-maint/Lib/test/test_itertools.py
==============================================================================
--- python/branches/release30-maint/Lib/test/test_itertools.py	(original)
+++ python/branches/release30-maint/Lib/test/test_itertools.py	Thu Jan  8 22:07:00 2009
@@ -75,11 +75,11 @@
         self.assertRaises(TypeError, list, chain.from_iterable([2, 3]))
 
     def test_combinations(self):
-        self.assertRaises(TypeError, combinations, 'abc')   # missing r argument
+        self.assertRaises(TypeError, combinations, 'abc')       # missing r argument
         self.assertRaises(TypeError, combinations, 'abc', 2, 1) # too many arguments
         self.assertRaises(TypeError, combinations, None)        # pool is not iterable
         self.assertRaises(ValueError, combinations, 'abc', -2)  # r is negative
-        self.assertRaises(ValueError, combinations, 'abc', 32)  # r is too big
+        self.assertEqual(list(combinations('abc', 32)), [])     # r > n
         self.assertEqual(list(combinations(range(4), 3)),
                                            [(0,1,2), (0,1,3), (0,2,3), (1,2,3)])
 
@@ -87,6 +87,8 @@
             'Pure python version shown in the docs'
             pool = tuple(iterable)
             n = len(pool)
+            if r > n:
+                return
             indices = list(range(r))
             yield tuple(pool[i] for i in indices)
             while 1:
@@ -110,9 +112,9 @@
 
         for n in range(7):
             values = [5*x-12 for x in range(n)]
-            for r in range(n+1):
+            for r in range(n+2):
                 result = list(combinations(values, r))
-                self.assertEqual(len(result), fact(n) / fact(r) / fact(n-r)) # right number of combs
+                self.assertEqual(len(result), 0 if r>n else fact(n) / fact(r) / fact(n-r)) # right number of combs
                 self.assertEqual(len(result), len(set(result)))         # no repeats
                 self.assertEqual(result, sorted(result))                # lexicographic order
                 for c in result:
@@ -123,7 +125,7 @@
                     self.assertEqual(list(c),
                                      [e for e in values if e in c])      # comb is a subsequence of the input iterable
                 self.assertEqual(result, list(combinations1(values, r))) # matches first pure python version
-                self.assertEqual(result, list(combinations2(values, r))) # matches first pure python version
+                self.assertEqual(result, list(combinations2(values, r))) # matches second pure python version
 
         # Test implementation detail:  tuple re-use
         self.assertEqual(len(set(map(id, combinations('abcde', 3)))), 1)
@@ -134,7 +136,7 @@
         self.assertRaises(TypeError, permutations, 'abc', 2, 1) # too many arguments
         self.assertRaises(TypeError, permutations, None)        # pool is not iterable
         self.assertRaises(ValueError, permutations, 'abc', -2)  # r is negative
-        self.assertRaises(ValueError, permutations, 'abc', 32)  # r is too big
+        self.assertEqual(list(permutations('abc', 32)), [])     # r > n
         self.assertRaises(TypeError, permutations, 'abc', 's')  # r is not an int or None
         self.assertEqual(list(permutations(range(3), 2)),
                                            [(0,1), (0,2), (1,0), (1,2), (2,0), (2,1)])
@@ -144,6 +146,8 @@
             pool = tuple(iterable)
             n = len(pool)
             r = n if r is None else r
+            if r > n:
+                return
             indices = list(range(n))
             cycles = list(range(n-r+1, n+1))[::-1]
             yield tuple(pool[i] for i in indices[:r])
@@ -172,9 +176,9 @@
 
         for n in range(7):
             values = [5*x-12 for x in range(n)]
-            for r in range(n+1):
+            for r in range(n+2):
                 result = list(permutations(values, r))
-                self.assertEqual(len(result), fact(n) / fact(n-r))      # right number of perms
+                self.assertEqual(len(result), 0 if r>n else fact(n) / fact(n-r))      # right number of perms
                 self.assertEqual(len(result), len(set(result)))         # no repeats
                 self.assertEqual(result, sorted(result))                # lexicographic order
                 for p in result:
@@ -182,7 +186,7 @@
                     self.assertEqual(len(set(p)), r)                    # no duplicate elements
                     self.assert_(all(e in values for e in p))           # elements taken from input iterable
                 self.assertEqual(result, list(permutations1(values, r))) # matches first pure python version
-                self.assertEqual(result, list(permutations2(values, r))) # matches first pure python version
+                self.assertEqual(result, list(permutations2(values, r))) # matches second pure python version
                 if r == n:
                     self.assertEqual(result, list(permutations(values, None))) # test r as None
                     self.assertEqual(result, list(permutations(values)))       # test default r
@@ -1384,6 +1388,26 @@
 >>> list(combinations_with_replacement('abc', 2))
 [('a', 'a'), ('a', 'b'), ('a', 'c'), ('b', 'b'), ('b', 'c'), ('c', 'c')]
 
+>>> list(combinations_with_replacement('01', 3))
+[('0', '0', '0'), ('0', '0', '1'), ('0', '1', '1'), ('1', '1', '1')]
+
+>>> def combinations_with_replacement2(iterable, r):
+...     'Alternate version that filters from product()'
+...     pool = tuple(iterable)
+...     n = len(pool)
+...     for indices in product(range(n), repeat=r):
+...         if sorted(indices) == list(indices):
+...             yield tuple(pool[i] for i in indices)
+
+>>> list(combinations_with_replacement('abc', 2)) == list(combinations_with_replacement2('abc', 2))
+True
+
+>>> list(combinations_with_replacement('01', 3)) == list(combinations_with_replacement2('01', 3))
+True
+
+>>> list(combinations_with_replacement('2310', 6)) == list(combinations_with_replacement2('2310', 6))
+True
+
 >>> list(unique_everseen('AAAABBBCCDAABBB'))
 ['A', 'B', 'C', 'D']
 

Modified: python/branches/release30-maint/Modules/itertoolsmodule.c
==============================================================================
--- python/branches/release30-maint/Modules/itertoolsmodule.c	(original)
+++ python/branches/release30-maint/Modules/itertoolsmodule.c	Thu Jan  8 22:07:00 2009
@@ -1880,10 +1880,6 @@
 		PyErr_SetString(PyExc_ValueError, "r must be non-negative");
 		goto error;
 	}
-	if (r > n) {
-		PyErr_SetString(PyExc_ValueError, "r cannot be bigger than the iterable");
-		goto error;
-	}
 
 	indices = PyMem_Malloc(r * sizeof(Py_ssize_t));
 	if (indices == NULL) {
@@ -1903,7 +1899,7 @@
 	co->indices = indices;
 	co->result = NULL;
 	co->r = r;
-	co->stopped = 0;
+	co->stopped = r > n ? 1 : 0;
 
 	return (PyObject *)co;
 
@@ -2143,10 +2139,6 @@
 		PyErr_SetString(PyExc_ValueError, "r must be non-negative");
 		goto error;
 	}
-	if (r > n) {
-		PyErr_SetString(PyExc_ValueError, "r cannot be bigger than the iterable");
-		goto error;
-	}
 
 	indices = PyMem_Malloc(n * sizeof(Py_ssize_t));
 	cycles = PyMem_Malloc(r * sizeof(Py_ssize_t));
@@ -2170,7 +2162,7 @@
 	po->cycles = cycles;
 	po->result = NULL;
 	po->r = r;
-	po->stopped = 0;
+	po->stopped = r > n ? 1 : 0;
 
 	return (PyObject *)po;
 


More information about the Python-checkins mailing list