cpython: Add optional *func* argument to itertools.accumulate().
http://hg.python.org/cpython/rev/79ccefd30a37 changeset: 69021:79ccefd30a37 user: Raymond Hettinger <python@rcn.com> date: Sun Mar 27 18:52:10 2011 -0700 summary: Add optional *func* argument to itertools.accumulate(). files: Doc/library/itertools.rst | 33 ++++++++++++++++++++++--- Lib/test/test_itertools.py | 12 ++++++++- Misc/NEWS | 3 ++ Modules/itertoolsmodule.c | 18 +++++++++++--- 4 files changed, 56 insertions(+), 10 deletions(-) diff --git a/Doc/library/itertools.rst b/Doc/library/itertools.rst --- a/Doc/library/itertools.rst +++ b/Doc/library/itertools.rst @@ -46,7 +46,7 @@ ==================== ============================ ================================================= ============================================================= Iterator Arguments Results Example ==================== ============================ ================================================= ============================================================= -:func:`accumulate` p p0, p0+p1, p0+p1+p2, ... ``accumulate([1,2,3,4,5]) --> 1 3 6 10 15`` +:func:`accumulate` p [,func] p0, p0+p1, p0+p1+p2, ... ``accumulate([1,2,3,4,5]) --> 1 3 6 10 15`` :func:`chain` p, q, ... p0, p1, ... plast, q0, q1, ... ``chain('ABC', 'DEF') --> A B C D E F`` :func:`compress` data, selectors (d[0] if s[0]), (d[1] if s[1]), ... ``compress('ABCDEF', [1,0,1,0,1,1]) --> A C E F`` :func:`dropwhile` pred, seq seq[n], seq[n+1], starting when pred fails ``dropwhile(lambda x: x<5, [1,4,6,4,1]) --> 6 4 1`` @@ -84,23 +84,46 @@ streams of infinite length, so they should only be accessed by functions or loops that truncate the stream. -.. function:: accumulate(iterable) +.. function:: accumulate(iterable[, func]) Make an iterator that returns accumulated sums. Elements may be any addable - type including :class:`Decimal` or :class:`Fraction`. Equivalent to:: + type including :class:`Decimal` or :class:`Fraction`. If the optional + *func* argument is supplied, it should be a function of two arguments + and it will be used instead of addition. - def accumulate(iterable): + Equivalent to:: + + def accumulate(iterable, func=operator.add): 'Return running totals' # accumulate([1,2,3,4,5]) --> 1 3 6 10 15 + # accumulate([1,2,3,4,5], operator.mul) --> 1 2 6 24 120 it = iter(iterable) total = next(it) yield total for element in it: - total = total + element + total = func(total, element) yield total + Uses for the *func* argument include :func:`min` for a running minimum, + :func:`max` for a running maximum, and :func:`operator.mul` for a running + product:: + + >>> data = [3, 4, 6, 2, 1, 9, 0, 7, 5, 8] + >>> list(accumulate(data, operator.mul)) # running product + [3, 12, 72, 144, 144, 1296, 0, 0, 0, 0] + >>> list(accumulate(data, max)) # running maximum + [3, 4, 6, 6, 6, 9, 9, 9, 9, 9] + + # Amortize a 5% loan of 1000 with 4 annual payments of 90 + >>> cashflows = [1000, -90, -90, -90, -90] + >>> list(accumulate(cashflows, lambda bal, pmt: bal*1.05 + pmt)) + [1000, 960.0, 918.0, 873.9000000000001, 827.5950000000001] + .. versionadded:: 3.2 + .. versionchanged:: 3.3 + Added the optional *func* parameter. + .. function:: chain(*iterables) Make an iterator that returns elements from the first iterable until it is diff --git a/Lib/test/test_itertools.py b/Lib/test/test_itertools.py --- a/Lib/test/test_itertools.py +++ b/Lib/test/test_itertools.py @@ -69,11 +69,21 @@ self.assertEqual(list(accumulate('abc')), ['a', 'ab', 'abc']) # works with non-numeric self.assertEqual(list(accumulate([])), []) # empty iterable self.assertEqual(list(accumulate([7])), [7]) # iterable of length one - self.assertRaises(TypeError, accumulate, range(10), 5) # too many args + self.assertRaises(TypeError, accumulate, range(10), 5, 6) # too many args self.assertRaises(TypeError, accumulate) # too few args self.assertRaises(TypeError, accumulate, x=range(10)) # unexpected kwd arg self.assertRaises(TypeError, list, accumulate([1, []])) # args that don't add + s = [2, 8, 9, 5, 7, 0, 3, 4, 1, 6] + self.assertEqual(list(accumulate(s, min)), + [2, 2, 2, 2, 2, 0, 0, 0, 0, 0]) + self.assertEqual(list(accumulate(s, max)), + [2, 8, 9, 9, 9, 9, 9, 9, 9, 9]) + self.assertEqual(list(accumulate(s, operator.mul)), + [2, 16, 144, 720, 5040, 0, 0, 0, 0, 0]) + with self.assertRaises(TypeError): + list(accumulate(s, chr)) # unary-operation + def test_chain(self): def chain2(*iterables): diff --git a/Misc/NEWS b/Misc/NEWS --- a/Misc/NEWS +++ b/Misc/NEWS @@ -89,6 +89,9 @@ - Issue #11696: Fix ID generation in msilib. +- itertools.accumulate now supports an optional *func* argument for + a user-supplied binary function. + - Issue #11692: Remove unnecessary demo functions in subprocess module. - Issue #9696: Fix exception incorrectly raised by xdrlib.Packer.pack_int when diff --git a/Modules/itertoolsmodule.c b/Modules/itertoolsmodule.c --- a/Modules/itertoolsmodule.c +++ b/Modules/itertoolsmodule.c @@ -2590,6 +2590,7 @@ PyObject_HEAD PyObject *total; PyObject *it; + PyObject *binop; } accumulateobject; static PyTypeObject accumulate_type; @@ -2597,12 +2598,14 @@ static PyObject * accumulate_new(PyTypeObject *type, PyObject *args, PyObject *kwds) { - static char *kwargs[] = {"iterable", NULL}; + static char *kwargs[] = {"iterable", "func", NULL}; PyObject *iterable; PyObject *it; + PyObject *binop = NULL; accumulateobject *lz; - if (!PyArg_ParseTupleAndKeywords(args, kwds, "O:accumulate", kwargs, &iterable)) + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|O:accumulate", + kwargs, &iterable, &binop)) return NULL; /* Get iterator. */ @@ -2617,6 +2620,8 @@ return NULL; } + Py_XINCREF(binop); + lz->binop = binop; lz->total = NULL; lz->it = it; return (PyObject *)lz; @@ -2626,6 +2631,7 @@ accumulate_dealloc(accumulateobject *lz) { PyObject_GC_UnTrack(lz); + Py_XDECREF(lz->binop); Py_XDECREF(lz->total); Py_XDECREF(lz->it); Py_TYPE(lz)->tp_free(lz); @@ -2634,6 +2640,7 @@ static int accumulate_traverse(accumulateobject *lz, visitproc visit, void *arg) { + Py_VISIT(lz->binop); Py_VISIT(lz->it); Py_VISIT(lz->total); return 0; @@ -2653,8 +2660,11 @@ lz->total = val; return lz->total; } - - newtotal = PyNumber_Add(lz->total, val); + + if (lz->binop == NULL) + newtotal = PyNumber_Add(lz->total, val); + else + newtotal = PyObject_CallFunctionObjArgs(lz->binop, lz->total, val, NULL); Py_DECREF(val); if (newtotal == NULL) return NULL; -- Repository URL: http://hg.python.org/cpython
participants (1)
-
raymond.hettinger