[Python-checkins] GH-98363: Add itertools.batched() (GH-98364)

rhettinger webhook-mailer at python.org
Mon Oct 17 19:53:50 EDT 2022


https://github.com/python/cpython/commit/de3ece769a8bc10c207a648c8a446f520504fa7e
commit: de3ece769a8bc10c207a648c8a446f520504fa7e
branch: main
author: Raymond Hettinger <rhettinger at users.noreply.github.com>
committer: rhettinger <rhettinger at users.noreply.github.com>
date: 2022-10-17T18:53:45-05:00
summary:

GH-98363:  Add itertools.batched() (GH-98364)

files:
A Misc/NEWS.d/next/Library/2022-10-17-12-49-02.gh-issue-98363.aFmSP-.rst
M Doc/library/itertools.rst
M Lib/test/test_itertools.py
M Modules/clinic/itertoolsmodule.c.h
M Modules/itertoolsmodule.c

diff --git a/Doc/library/itertools.rst b/Doc/library/itertools.rst
index 9f7ec10729cd..35a71335b35f 100644
--- a/Doc/library/itertools.rst
+++ b/Doc/library/itertools.rst
@@ -48,6 +48,7 @@ Iterator            Arguments               Results
 Iterator                        Arguments                       Results                                             Example
 ============================    ============================    =================================================   =============================================================
 :func:`accumulate`              p [,func]                       p0, p0+p1, p0+p1+p2, ...                            ``accumulate([1,2,3,4,5]) --> 1 3 6 10 15``
+:func:`batched`                 p, n                            [p0, p1, ..., p_n-1], ...                           ``batched('ABCDEFG', n=3) --> ABC DEF G``
 :func:`chain`                   p, q, ...                       p0, p1, ... plast, q0, q1, ...                      ``chain('ABC', 'DEF') --> A B C D E F``
 :func:`chain.from_iterable`     iterable                        p0, p1, ... plast, q0, q1, ...                      ``chain.from_iterable(['ABC', 'DEF']) --> A B C D E F``
 :func:`compress`                data, selectors                 (d[0] if s[0]), (d[1] if s[1]), ...                 ``compress('ABCDEF', [1,0,1,0,1,1]) --> A C E F``
@@ -170,6 +171,44 @@ loops that truncate the stream.
     .. versionchanged:: 3.8
        Added the optional *initial* parameter.
 
+
+.. function:: batched(iterable, n)
+
+   Batch data from the *iterable* into lists of length *n*. The last
+   batch may be shorter than *n*.
+
+   Loops over the input iterable and accumulates data into lists up to
+   size *n*.  The input is consumed lazily, just enough to fill a list.
+   The result is yielded as soon as the batch is full or when the input
+   iterable is exhausted:
+
+   .. doctest::
+
+      >>> flattened_data = ['roses', 'red', 'violets', 'blue', 'sugar', 'sweet']
+      >>> unflattened = list(batched(flattened_data, 2))
+      >>> unflattened
+      [['roses', 'red'], ['violets', 'blue'], ['sugar', 'sweet']]
+
+      >>> for batch in batched('ABCDEFG', 3):
+      ...     print(batch)
+      ...
+      ['A', 'B', 'C']
+      ['D', 'E', 'F']
+      ['G']
+
+   Roughly equivalent to::
+
+      def batched(iterable, n):
+          # batched('ABCDEFG', 3) --> ABC DEF G
+          if n < 1:
+              raise ValueError('n must be at least one')
+          it = iter(iterable)
+          while (batch := list(islice(it, n))):
+              yield batch
+
+    .. versionadded:: 3.12
+
+
 .. function:: chain(*iterables)
 
    Make an iterator that returns elements from the first iterable until it is
@@ -858,13 +897,6 @@ which incur interpreter overhead.
        else:
            raise ValueError('Expected fill, strict, or ignore')
 
-   def batched(iterable, n):
-       "Batch data into lists of length n. The last batch may be shorter."
-       # batched('ABCDEFG', 3) --> ABC DEF G
-       it = iter(iterable)
-       while (batch := list(islice(it, n))):
-           yield batch
-
    def triplewise(iterable):
        "Return overlapping triplets from an iterable"
        # triplewise('ABCDEFG') --> ABC BCD CDE DEF EFG
@@ -1211,36 +1243,6 @@ which incur interpreter overhead.
     >>> list(grouper('abcdefg', n=3, incomplete='ignore'))
     [('a', 'b', 'c'), ('d', 'e', 'f')]
 
-    >>> list(batched('ABCDEFG', 3))
-    [['A', 'B', 'C'], ['D', 'E', 'F'], ['G']]
-    >>> list(batched('ABCDEF', 3))
-    [['A', 'B', 'C'], ['D', 'E', 'F']]
-    >>> list(batched('ABCDE', 3))
-    [['A', 'B', 'C'], ['D', 'E']]
-    >>> list(batched('ABCD', 3))
-    [['A', 'B', 'C'], ['D']]
-    >>> list(batched('ABC', 3))
-    [['A', 'B', 'C']]
-    >>> list(batched('AB', 3))
-    [['A', 'B']]
-    >>> list(batched('A', 3))
-    [['A']]
-    >>> list(batched('', 3))
-    []
-    >>> list(batched('ABCDEFG', 2))
-    [['A', 'B'], ['C', 'D'], ['E', 'F'], ['G']]
-    >>> list(batched('ABCDEFG', 1))
-    [['A'], ['B'], ['C'], ['D'], ['E'], ['F'], ['G']]
-    >>> list(batched('ABCDEFG', 0))
-    []
-    >>> list(batched('ABCDEFG', -1))
-    Traceback (most recent call last):
-      ...
-    ValueError: Stop argument for islice() must be None or an integer: 0 <= x <= sys.maxsize.
-    >>> s = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
-    >>> all(list(flatten(batched(s[:n], 5))) == list(s[:n]) for n in range(len(s)))
-    True
-
     >>> list(triplewise('ABCDEFG'))
     [('A', 'B', 'C'), ('B', 'C', 'D'), ('C', 'D', 'E'), ('D', 'E', 'F'), ('E', 'F', 'G')]
 
diff --git a/Lib/test/test_itertools.py b/Lib/test/test_itertools.py
index f469bfe185e6..c0e35711a2b3 100644
--- a/Lib/test/test_itertools.py
+++ b/Lib/test/test_itertools.py
@@ -159,6 +159,44 @@ def test_accumulate(self):
         with self.assertRaises(TypeError):
             list(accumulate([10, 20], 100))
 
+    def test_batched(self):
+        self.assertEqual(list(batched('ABCDEFG', 3)),
+                             [['A', 'B', 'C'], ['D', 'E', 'F'], ['G']])
+        self.assertEqual(list(batched('ABCDEFG', 2)),
+                             [['A', 'B'], ['C', 'D'], ['E', 'F'], ['G']])
+        self.assertEqual(list(batched('ABCDEFG', 1)),
+                             [['A'], ['B'], ['C'], ['D'], ['E'], ['F'], ['G']])
+
+        with self.assertRaises(TypeError):          # Too few arguments
+            list(batched('ABCDEFG'))
+        with self.assertRaises(TypeError):
+            list(batched('ABCDEFG', 3, None))       # Too many arguments
+        with self.assertRaises(TypeError):
+            list(batched(None, 3))                  # Non-iterable input
+        with self.assertRaises(TypeError):
+            list(batched('ABCDEFG', 'hello'))       # n is a string
+        with self.assertRaises(ValueError):
+            list(batched('ABCDEFG', 0))             # n is zero
+        with self.assertRaises(ValueError):
+            list(batched('ABCDEFG', -1))            # n is negative
+
+        data = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
+        for n in range(1, 6):
+            for i in range(len(data)):
+                s = data[:i]
+                batches = list(batched(s, n))
+                with self.subTest(s=s, n=n, batches=batches):
+                    # Order is preserved and no data is lost
+                    self.assertEqual(''.join(chain(*batches)), s)
+                    # Each batch is an exact list
+                    self.assertTrue(all(type(batch) is list for batch in batches))
+                    # All but the last batch is of size n
+                    if batches:
+                        last_batch = batches.pop()
+                        self.assertTrue(all(len(batch) == n for batch in batches))
+                        self.assertTrue(len(last_batch) <= n)
+                        batches.append(last_batch)
+
     def test_chain(self):
 
         def chain2(*iterables):
@@ -1737,6 +1775,31 @@ def test_takewhile(self):
 
 class TestPurePythonRoughEquivalents(unittest.TestCase):
 
+    def test_batched_recipe(self):
+        def batched_recipe(iterable, n):
+            "Batch data into lists of length n. The last batch may be shorter."
+            # batched('ABCDEFG', 3) --> ABC DEF G
+            if n < 1:
+                raise ValueError('n must be at least one')
+            it = iter(iterable)
+            while (batch := list(islice(it, n))):
+                yield batch
+
+        for iterable, n in product(
+                ['', 'a', 'ab', 'abc', 'abcd', 'abcde', 'abcdef', 'abcdefg', None],
+                [-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, None]):
+            with self.subTest(iterable=iterable, n=n):
+                try:
+                    e1, r1 = None, list(batched(iterable, n))
+                except Exception as e:
+                    e1, r1 = type(e), None
+                try:
+                    e2, r2 = None, list(batched_recipe(iterable, n))
+                except Exception as e:
+                    e2, r2 = type(e), None
+                self.assertEqual(r1, r2)
+                self.assertEqual(e1, e2)
+
     @staticmethod
     def islice(iterable, *args):
         s = slice(*args)
@@ -1788,6 +1851,10 @@ def test_accumulate(self):
         a = []
         self.makecycle(accumulate([1,2,a,3]), a)
 
+    def test_batched(self):
+        a = []
+        self.makecycle(batched([1,2,a,3], 2), a)
+
     def test_chain(self):
         a = []
         self.makecycle(chain(a), a)
@@ -1972,6 +2039,18 @@ def test_accumulate(self):
         self.assertRaises(TypeError, accumulate, N(s))
         self.assertRaises(ZeroDivisionError, list, accumulate(E(s)))
 
+    def test_batched(self):
+        s = 'abcde'
+        r = [['a', 'b'], ['c', 'd'], ['e']]
+        n = 2
+        for g in (G, I, Ig, L, R):
+            with self.subTest(g=g):
+                self.assertEqual(list(batched(g(s), n)), r)
+        self.assertEqual(list(batched(S(s), 2)), [])
+        self.assertRaises(TypeError, batched, X(s), 2)
+        self.assertRaises(TypeError, batched, N(s), 2)
+        self.assertRaises(ZeroDivisionError, list, batched(E(s), 2))
+
     def test_chain(self):
         for s in ("123", "", range(1000), ('do', 1.2), range(2000,2200,5)):
             for g in (G, I, Ig, S, L, R):
diff --git a/Misc/NEWS.d/next/Library/2022-10-17-12-49-02.gh-issue-98363.aFmSP-.rst b/Misc/NEWS.d/next/Library/2022-10-17-12-49-02.gh-issue-98363.aFmSP-.rst
new file mode 100644
index 000000000000..9c6e7552a3f4
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2022-10-17-12-49-02.gh-issue-98363.aFmSP-.rst
@@ -0,0 +1,2 @@
+Added itertools.batched() to batch data into lists of a given length with
+the last list possibly being shorter than the others.
diff --git a/Modules/clinic/itertoolsmodule.c.h b/Modules/clinic/itertoolsmodule.c.h
index 8806606d85be..17f9ebb24939 100644
--- a/Modules/clinic/itertoolsmodule.c.h
+++ b/Modules/clinic/itertoolsmodule.c.h
@@ -8,6 +8,85 @@ preserve
 #endif
 
 
+PyDoc_STRVAR(batched_new__doc__,
+"batched(iterable, n)\n"
+"--\n"
+"\n"
+"Batch data into lists of length n. The last batch may be shorter than n.\n"
+"\n"
+"Loops over the input iterable and accumulates data into lists\n"
+"up to size n.  The input is consumed lazily, just enough to\n"
+"fill a list.  The result is yielded as soon as a batch is full\n"
+"or when the input iterable is exhausted.\n"
+"\n"
+"    >>> for batch in batched(\'ABCDEFG\', 3):\n"
+"    ...     print(batch)\n"
+"    ...\n"
+"    [\'A\', \'B\', \'C\']\n"
+"    [\'D\', \'E\', \'F\']\n"
+"    [\'G\']");
+
+static PyObject *
+batched_new_impl(PyTypeObject *type, PyObject *iterable, Py_ssize_t n);
+
+static PyObject *
+batched_new(PyTypeObject *type, PyObject *args, PyObject *kwargs)
+{
+    PyObject *return_value = NULL;
+    #if defined(Py_BUILD_CORE) && !defined(Py_BUILD_CORE_MODULE)
+
+    #define NUM_KEYWORDS 2
+    static struct {
+        PyGC_Head _this_is_not_used;
+        PyObject_VAR_HEAD
+        PyObject *ob_item[NUM_KEYWORDS];
+    } _kwtuple = {
+        .ob_base = PyVarObject_HEAD_INIT(&PyTuple_Type, NUM_KEYWORDS)
+        .ob_item = { &_Py_ID(iterable), &_Py_ID(n), },
+    };
+    #undef NUM_KEYWORDS
+    #define KWTUPLE (&_kwtuple.ob_base.ob_base)
+
+    #else  // !Py_BUILD_CORE
+    #  define KWTUPLE NULL
+    #endif  // !Py_BUILD_CORE
+
+    static const char * const _keywords[] = {"iterable", "n", NULL};
+    static _PyArg_Parser _parser = {
+        .keywords = _keywords,
+        .fname = "batched",
+        .kwtuple = KWTUPLE,
+    };
+    #undef KWTUPLE
+    PyObject *argsbuf[2];
+    PyObject * const *fastargs;
+    Py_ssize_t nargs = PyTuple_GET_SIZE(args);
+    PyObject *iterable;
+    Py_ssize_t n;
+
+    fastargs = _PyArg_UnpackKeywords(_PyTuple_CAST(args)->ob_item, nargs, kwargs, NULL, &_parser, 2, 2, 0, argsbuf);
+    if (!fastargs) {
+        goto exit;
+    }
+    iterable = fastargs[0];
+    {
+        Py_ssize_t ival = -1;
+        PyObject *iobj = _PyNumber_Index(fastargs[1]);
+        if (iobj != NULL) {
+            ival = PyLong_AsSsize_t(iobj);
+            Py_DECREF(iobj);
+        }
+        if (ival == -1 && PyErr_Occurred()) {
+            goto exit;
+        }
+        n = ival;
+    }
+    return_value = batched_new_impl(type, iterable, n);
+
+exit:
+    return return_value;
+}
+
 PyDoc_STRVAR(pairwise_new__doc__,
 "pairwise(iterable, /)\n"
 "--\n"
@@ -834,4 +913,4 @@ itertools_count(PyTypeObject *type, PyObject *args, PyObject *kwargs)
 exit:
     return return_value;
 }
-/*[clinic end generated code: output=b1056d63f68a9059 input=a9049054013a1b77]*/
+/*[clinic end generated code: output=efea8cd1e647bd17 input=a9049054013a1b77]*/
diff --git a/Modules/itertoolsmodule.c b/Modules/itertoolsmodule.c
index 4a7a95730395..99dc30eb412c 100644
--- a/Modules/itertoolsmodule.c
+++ b/Modules/itertoolsmodule.c
@@ -16,6 +16,7 @@ class itertools.groupby "groupbyobject *" "&groupby_type"
 class itertools._grouper "_grouperobject *" "&_grouper_type"
 class itertools.teedataobject "teedataobject *" "&teedataobject_type"
 class itertools._tee "teeobject *" "&tee_type"
+class itertools.batched "batchedobject *" "&batched_type"
 class itertools.cycle "cycleobject *" "&cycle_type"
 class itertools.dropwhile "dropwhileobject *" "&dropwhile_type"
 class itertools.takewhile "takewhileobject *" "&takewhile_type"
@@ -30,12 +31,13 @@ class itertools.filterfalse "filterfalseobject *" "&filterfalse_type"
 class itertools.count "countobject *" "&count_type"
 class itertools.pairwise "pairwiseobject *" "&pairwise_type"
 [clinic start generated code]*/
-/*[clinic end generated code: output=da39a3ee5e6b4b0d input=6498ed21fbe1bf94]*/
+/*[clinic end generated code: output=da39a3ee5e6b4b0d input=1168b274011ce21b]*/
 
 static PyTypeObject groupby_type;
 static PyTypeObject _grouper_type;
 static PyTypeObject teedataobject_type;
 static PyTypeObject tee_type;
+static PyTypeObject batched_type;
 static PyTypeObject cycle_type;
 static PyTypeObject dropwhile_type;
 static PyTypeObject takewhile_type;
@@ -51,6 +53,171 @@ static PyTypeObject pairwise_type;
 
 #include "clinic/itertoolsmodule.c.h"
 
+/* batched object ************************************************************/
+
+/* Note:  The built-in zip() function includes a "strict" argument
+   that is needed because that function can silently truncate data
+   and there is no easy way for a user to detect that condition.
+   The same reasoning does not apply to batches() which never drops
+   data.  Instead, it produces a shorter list which can be handled
+   as the user sees fit.
+ */
+
+typedef struct {
+    PyObject_HEAD
+    PyObject *it;
+    Py_ssize_t batch_size;
+} batchedobject;
+
+/*[clinic input]
+ at classmethod
+itertools.batched.__new__ as batched_new
+    iterable: object
+    n: Py_ssize_t
+Batch data into lists of length n. The last batch may be shorter than n.
+
+Loops over the input iterable and accumulates data into lists
+up to size n.  The input is consumed lazily, just enough to
+fill a list.  The result is yielded as soon as a batch is full
+or when the input iterable is exhausted.
+
+    >>> for batch in batched('ABCDEFG', 3):
+    ...     print(batch)
+    ...
+    ['A', 'B', 'C']
+    ['D', 'E', 'F']
+    ['G']
+
+[clinic start generated code]*/
+
+static PyObject *
+batched_new_impl(PyTypeObject *type, PyObject *iterable, Py_ssize_t n)
+/*[clinic end generated code: output=7ebc954d655371b6 input=f28fd12cb52365f0]*/
+{
+    PyObject *it;
+    batchedobject *bo;
+
+    if (n < 1) {
+        /* We could define the n==0 case to return an empty iterator
+           but that is add odds with the idea that batching should
+           never throw-away input data.
+        */
+        PyErr_SetString(PyExc_ValueError, "n must be at least one");
+        return NULL;
+    }
+    it = PyObject_GetIter(iterable);
+    if (it == NULL) {
+        return NULL;
+    }
+
+    /* create batchedobject structure */
+    bo = (batchedobject *)type->tp_alloc(type, 0);
+    if (bo == NULL) {
+        Py_DECREF(it);
+        return NULL;
+    }
+    bo->batch_size = n;
+    bo->it = it;
+    return (PyObject *)bo;
+}
+
+static void
+batched_dealloc(batchedobject *bo)
+{
+    PyObject_GC_UnTrack(bo);
+    Py_XDECREF(bo->it);
+    Py_TYPE(bo)->tp_free(bo);
+}
+
+static int
+batched_traverse(batchedobject *bo, visitproc visit, void *arg)
+{
+    if (bo->it != NULL) {
+        Py_VISIT(bo->it);
+    }
+    return 0;
+}
+
+static PyObject *
+batched_next(batchedobject *bo)
+{
+    Py_ssize_t i;
+    PyObject *it = bo->it;
+    PyObject *item;
+    PyObject *result;
+
+    if (it == NULL) {
+        return NULL;
+    }
+    result = PyList_New(0);
+    if (result == NULL) {
+        return NULL;
+    }
+    for (i=0 ; i < bo->batch_size ; i++) {
+        item = PyIter_Next(it);
+        if (item == NULL) {
+            break;
+        }
+        if (PyList_Append(result, item) < 0) {
+            Py_DECREF(item);
+            Py_DECREF(result);
+            return NULL;
+        }
+        Py_DECREF(item);
+    }
+    if (PyList_GET_SIZE(result) > 0) {
+        return result;
+    }
+    Py_CLEAR(bo->it);
+    Py_DECREF(result);
+    return NULL;
+}
+
+static PyTypeObject batched_type = {
+    PyVarObject_HEAD_INIT(&PyType_Type, 0)
+    "itertools.batched",            /* tp_name */
+    sizeof(batchedobject),          /* tp_basicsize */
+    0,                              /* tp_itemsize */
+    /* methods */
+    (destructor)batched_dealloc,    /* tp_dealloc */
+    0,                              /* tp_vectorcall_offset */
+    0,                              /* tp_getattr */
+    0,                              /* tp_setattr */
+    0,                              /* tp_as_async */
+    0,                              /* tp_repr */
+    0,                              /* tp_as_number */
+    0,                              /* tp_as_sequence */
+    0,                              /* tp_as_mapping */
+    0,                              /* tp_hash */
+    0,                              /* tp_call */
+    0,                              /* tp_str */
+    PyObject_GenericGetAttr,        /* tp_getattro */
+    0,                              /* tp_setattro */
+    0,                              /* tp_as_buffer */
+    Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC |
+        Py_TPFLAGS_BASETYPE,        /* tp_flags */
+    batched_new__doc__,             /* tp_doc */
+    (traverseproc)batched_traverse, /* tp_traverse */
+    0,                              /* tp_clear */
+    0,                              /* tp_richcompare */
+    0,                              /* tp_weaklistoffset */
+    PyObject_SelfIter,              /* tp_iter */
+    (iternextfunc)batched_next,     /* tp_iternext */
+    0,                              /* tp_methods */
+    0,                              /* tp_members */
+    0,                              /* tp_getset */
+    0,                              /* tp_base */
+    0,                              /* tp_dict */
+    0,                              /* tp_descr_get */
+    0,                              /* tp_descr_set */
+    0,                              /* tp_dictoffset */
+    0,                              /* tp_init */
+    PyType_GenericAlloc,            /* tp_alloc */
+    batched_new,                    /* tp_new */
+    PyObject_GC_Del,                /* tp_free */
+};
+
+
 /* pairwise object ***********************************************************/
 
 typedef struct {
@@ -4815,6 +4982,7 @@ repeat(elem [,n]) --> elem, elem, elem, ... endlessly or up to n times\n\
 \n\
 Iterators terminating on the shortest input sequence:\n\
 accumulate(p[, func]) --> p0, p0+p1, p0+p1+p2\n\
+batched(p, n) --> [p0, p1, ..., p_n-1], [p_n, p_n+1, ..., p_2n-1], ...\n\
 chain(p, q, ...) --> p0, p1, ... plast, q0, q1, ...\n\
 chain.from_iterable([p, q, ...]) --> p0, p1, ... plast, q0, q1, ...\n\
 compress(data, selectors) --> (d[0] if s[0]), (d[1] if s[1]), ...\n\
@@ -4841,6 +5009,7 @@ itertoolsmodule_exec(PyObject *m)
 {
     PyTypeObject *typelist[] = {
         &accumulate_type,
+        &batched_type,
         &combinations_type,
         &cwr_type,
         &cycle_type,



More information about the Python-checkins mailing list