[Python-checkins] cpython: Issue 21424: Apply the nlargest() optimizations to nsmallest() as well.

raymond.hettinger python-checkins at python.org
Sun May 11 23:21:30 CEST 2014


http://hg.python.org/cpython/rev/31950174f60f
changeset:   90644:31950174f60f
user:        Raymond Hettinger <python at rcn.com>
date:        Sun May 11 14:21:23 2014 -0700
summary:
  Issue 21424:  Apply the nlargest() optimizations to nsmallest() as well.

files:
  Lib/heapq.py           |  156 +++++++++++++++++++++-------
  Lib/test/test_heapq.py |    2 +-
  Misc/NEWS              |    4 +-
  Modules/_heapqmodule.c |   94 +++-------------
  4 files changed, 138 insertions(+), 118 deletions(-)


diff --git a/Lib/heapq.py b/Lib/heapq.py
--- a/Lib/heapq.py
+++ b/Lib/heapq.py
@@ -127,7 +127,7 @@
 __all__ = ['heappush', 'heappop', 'heapify', 'heapreplace', 'merge',
            'nlargest', 'nsmallest', 'heappushpop']
 
-from itertools import islice, count, tee, chain
+from itertools import islice, count
 
 def heappush(heap, item):
     """Push item onto heap, maintaining the heap invariant."""
@@ -179,12 +179,12 @@
     for i in reversed(range(n//2)):
         _siftup(x, i)
 
-def _heappushpop_max(heap, item):
-    """Maxheap version of a heappush followed by a heappop."""
-    if heap and item < heap[0]:
-        item, heap[0] = heap[0], item
-        _siftup_max(heap, 0)
-    return item
+def _heapreplace_max(heap, item):
+    """Maxheap version of a heappop followed by a heappush."""
+    returnitem = heap[0]    # raises appropriate IndexError if heap is empty
+    heap[0] = item
+    _siftup_max(heap, 0)
+    return returnitem
 
 def _heapify_max(x):
     """Transform list into a maxheap, in-place, in O(len(x)) time."""
@@ -192,24 +192,6 @@
     for i in reversed(range(n//2)):
         _siftup_max(x, i)
 
-def nsmallest(n, iterable):
-    """Find the n smallest elements in a dataset.
-
-    Equivalent to:  sorted(iterable)[:n]
-    """
-    if n <= 0:
-        return []
-    it = iter(iterable)
-    result = list(islice(it, n))
-    if not result:
-        return result
-    _heapify_max(result)
-    _heappushpop = _heappushpop_max
-    for elem in it:
-        _heappushpop(result, elem)
-    result.sort()
-    return result
-
 # 'heap' is a heap at all indices >= startpos, except possibly for pos.  pos
 # is the index of a leaf with a possibly out-of-order value.  Restore the
 # heap invariant.
@@ -327,6 +309,10 @@
     from _heapq import *
 except ImportError:
     pass
+try:
+    from _heapq import _heapreplace_max
+except ImportError:
+    pass
 
 def merge(*iterables):
     '''Merge multiple sorted inputs into a single sorted output.
@@ -367,22 +353,86 @@
         yield v
         yield from next.__self__
 
-# Extend the implementations of nsmallest and nlargest to use a key= argument
-_nsmallest = nsmallest
+
+# Algorithm notes for nlargest() and nsmallest()
+# ==============================================
+#
+# Makes just a single pass over the data while keeping the k most extreme values
+# in a heap.  Memory consumption is limited to keeping k values in a list.
+#
+# Measured performance for random inputs:
+#
+#                                   number of comparisons
+#    n inputs     k-extreme values  (average of 5 trials)   % more than min()
+# -------------   ----------------  - -------------------   -----------------
+#      1,000           100                  3,317               133.2%
+#     10,000           100                 14,046                40.5%
+#    100,000           100                105,749                 5.7%
+#  1,000,000           100              1,007,751                 0.8%
+# 10,000,000           100             10,009,401                 0.1%
+#
+# Theoretical number of comparisons for k smallest of n random inputs:
+#
+# Step   Comparisons                  Action
+# ----   --------------------------   ---------------------------
+#  1     1.66 * k                     heapify the first k-inputs
+#  2     n - k                        compare remaining elements to top of heap
+#  3     k * (1 + lg2(k)) * ln(n/k)   replace the topmost value on the heap
+#  4     k * lg2(k) - (k/2)           final sort of the k most extreme values
+# Combining and simplifying for a rough estimate gives:
+#        comparisons = n + k * (1 + log(n/k)) * (1 + log(k, 2))
+#
+# Computing the number of comparisons for step 3:
+# -----------------------------------------------
+# * For the i-th new value from the iterable, the probability of being in the
+#   k most extreme values is k/i.  For example, the probability of the 101st
+#   value seen being in the 100 most extreme values is 100/101.
+# * If the value is a new extreme value, the cost of inserting it into the
+#   heap is 1 + log(k, 2).
+# * The probabilty times the cost gives:
+#            (k/i) * (1 + log(k, 2))
+# * Summing across the remaining n-k elements gives:
+#            sum((k/i) * (1 + log(k, 2)) for xrange(k+1, n+1))
+# * This reduces to:
+#            (H(n) - H(k)) * k * (1 + log(k, 2))
+# * Where H(n) is the n-th harmonic number estimated by:
+#            gamma = 0.5772156649
+#            H(n) = log(n, e) + gamma + 1.0 / (2.0 * n)
+#   http://en.wikipedia.org/wiki/Harmonic_series_(mathematics)#Rate_of_divergence
+# * Substituting the H(n) formula:
+#            comparisons = k * (1 + log(k, 2)) * (log(n/k, e) + (1/n - 1/k) / 2)
+#
+# Worst-case for step 3:
+# ----------------------
+# In the worst case, the input data is reversed sorted so that every new element
+# must be inserted in the heap:
+#
+#             comparisons = 1.66 * k + log(k, 2) * (n - k)
+#
+# Alternative Algorithms
+# ----------------------
+# Other algorithms were not used because they:
+# 1) Took much more auxiliary memory,
+# 2) Made multiple passes over the data.
+# 3) Made more comparisons in common cases (small k, large n, semi-random input).
+# See the more detailed comparison of approach at:
+# http://code.activestate.com/recipes/577573-compare-algorithms-for-heapqsmallest
+
 def nsmallest(n, iterable, key=None):
     """Find the n smallest elements in a dataset.
 
     Equivalent to:  sorted(iterable, key=key)[:n]
     """
+
     # Short-cut for n==1 is to use min() when len(iterable)>0
     if n == 1:
         it = iter(iterable)
-        head = list(islice(it, 1))
-        if not head:
-            return []
+        sentinel = object()
         if key is None:
-            return [min(chain(head, it))]
-        return [min(chain(head, it), key=key)]
+            result = min(it, default=sentinel)
+        else:
+            result = min(it, default=sentinel, key=key)
+        return [] if result is sentinel else [result]
 
     # When n>=size, it's faster to use sorted()
     try:
@@ -395,15 +445,39 @@
 
     # When key is none, use simpler decoration
     if key is None:
-        it = zip(iterable, count())                         # decorate
-        result = _nsmallest(n, it)
-        return [r[0] for r in result]                       # undecorate
+        it = iter(iterable)
+        result = list(islice(zip(it, count()), n))
+        if not result:
+            return result
+        _heapify_max(result)
+        order = n
+        top = result[0][0]
+        _heapreplace = _heapreplace_max
+        for elem in it:
+            if elem < top:
+                _heapreplace(result, (elem, order))
+                top = result[0][0]
+                order += 1
+        result.sort()
+        return [r[0] for r in result]
 
     # General case, slowest method
-    in1, in2 = tee(iterable)
-    it = zip(map(key, in1), count(), in2)                   # decorate
-    result = _nsmallest(n, it)
-    return [r[2] for r in result]                           # undecorate
+    it = iter(iterable)
+    result = [(key(elem), i, elem) for i, elem in zip(range(n), it)]
+    if not result:
+        return result
+    _heapify_max(result)
+    order = n
+    top = result[0][0]
+    _heapreplace = _heapreplace_max
+    for elem in it:
+        k = key(elem)
+        if k < top:
+            _heapreplace(result, (k, order, elem))
+            top = result[0][0]
+            order += 1
+    result.sort()
+    return [r[2] for r in result]
 
 def nlargest(n, iterable, key=None):
     """Find the n largest elements in a dataset.
@@ -442,9 +516,9 @@
         _heapreplace = heapreplace
         for elem in it:
             if top < elem:
-                order -= 1
                 _heapreplace(result, (elem, order))
                 top = result[0][0]
+                order -= 1
         result.sort(reverse=True)
         return [r[0] for r in result]
 
@@ -460,9 +534,9 @@
     for elem in it:
         k = key(elem)
         if top < k:
-            order -= 1
             _heapreplace(result, (k, order, elem))
             top = result[0][0]
+            order -= 1
     result.sort(reverse=True)
     return [r[2] for r in result]
 
diff --git a/Lib/test/test_heapq.py b/Lib/test/test_heapq.py
--- a/Lib/test/test_heapq.py
+++ b/Lib/test/test_heapq.py
@@ -13,7 +13,7 @@
 # _heapq.nlargest/nsmallest are saved in heapq._nlargest/_smallest when
 # _heapq is imported, so check them there
 func_names = ['heapify', 'heappop', 'heappush', 'heappushpop',
-              'heapreplace', '_nsmallest']
+              'heapreplace', '_heapreplace_max']
 
 class TestModules(TestCase):
     def test_py_functions(self):
diff --git a/Misc/NEWS b/Misc/NEWS
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -84,8 +84,8 @@
 - Issue #21156: importlib.abc.InspectLoader.source_to_code() is now a
   staticmethod.
 
-- Issue #21424: Simplified and optimized heaqp.nlargest() to make fewer
-  tuple comparisons.
+- Issue #21424: Simplified and optimized heaqp.nlargest() and nmsmallest()
+  to make fewer tuple comparisons.
 
 - Issue #21396: Fix TextIOWrapper(..., write_through=True) to not force a
   flush() on the underlying binary stream.  Patch by akira.
diff --git a/Modules/_heapqmodule.c b/Modules/_heapqmodule.c
--- a/Modules/_heapqmodule.c
+++ b/Modules/_heapqmodule.c
@@ -354,88 +354,34 @@
 }
 
 static PyObject *
-nsmallest(PyObject *self, PyObject *args)
+_heapreplace_max(PyObject *self, PyObject *args)
 {
-    PyObject *heap=NULL, *elem, *iterable, *los, *it, *oldelem;
-    Py_ssize_t i, n;
-    int cmp;
+    PyObject *heap, *item, *returnitem;
 
-    if (!PyArg_ParseTuple(args, "nO:nsmallest", &n, &iterable))
+    if (!PyArg_UnpackTuple(args, "_heapreplace_max", 2, 2, &heap, &item))
         return NULL;
 
-    it = PyObject_GetIter(iterable);
-    if (it == NULL)
+    if (!PyList_Check(heap)) {
+        PyErr_SetString(PyExc_TypeError, "heap argument must be a list");
         return NULL;
-
-    heap = PyList_New(0);
-    if (heap == NULL)
-        goto fail;
-
-    for (i=0 ; i<n ; i++ ){
-        elem = PyIter_Next(it);
-        if (elem == NULL) {
-            if (PyErr_Occurred())
-                goto fail;
-            else
-                goto sortit;
-        }
-        if (PyList_Append(heap, elem) == -1) {
-            Py_DECREF(elem);
-            goto fail;
-        }
-        Py_DECREF(elem);
-    }
-    n = PyList_GET_SIZE(heap);
-    if (n == 0)
-        goto sortit;
-
-    for (i=n/2-1 ; i>=0 ; i--)
-        if(_siftupmax((PyListObject *)heap, i) == -1)
-            goto fail;
-
-    los = PyList_GET_ITEM(heap, 0);
-    while (1) {
-        elem = PyIter_Next(it);
-        if (elem == NULL) {
-            if (PyErr_Occurred())
-                goto fail;
-            else
-                goto sortit;
-        }
-        cmp = PyObject_RichCompareBool(elem, los, Py_LT);
-        if (cmp == -1) {
-            Py_DECREF(elem);
-            goto fail;
-        }
-        if (cmp == 0) {
-            Py_DECREF(elem);
-            continue;
-        }
-
-        oldelem = PyList_GET_ITEM(heap, 0);
-        PyList_SET_ITEM(heap, 0, elem);
-        Py_DECREF(oldelem);
-        if (_siftupmax((PyListObject *)heap, 0) == -1)
-            goto fail;
-        los = PyList_GET_ITEM(heap, 0);
     }
 
-sortit:
-    if (PyList_Sort(heap) == -1)
-        goto fail;
-    Py_DECREF(it);
-    return heap;
+    if (PyList_GET_SIZE(heap) < 1) {
+        PyErr_SetString(PyExc_IndexError, "index out of range");
+        return NULL;
+    }
 
-fail:
-    Py_DECREF(it);
-    Py_XDECREF(heap);
-    return NULL;
+    returnitem = PyList_GET_ITEM(heap, 0);
+    Py_INCREF(item);
+    PyList_SET_ITEM(heap, 0, item);
+    if (_siftupmax((PyListObject *)heap, 0) == -1) {
+        Py_DECREF(returnitem);
+        return NULL;
+    }
+    return returnitem;
 }
 
-PyDoc_STRVAR(nsmallest_doc,
-"Find the n smallest elements in a dataset.\n\
-\n\
-Equivalent to:  sorted(iterable)[:n]\n");
+PyDoc_STRVAR(heapreplace_max_doc, "Maxheap variant of heapreplace");
 
 static PyMethodDef heapq_methods[] = {
     {"heappush",        (PyCFunction)heappush,
@@ -448,8 +394,8 @@
         METH_VARARGS,           heapreplace_doc},
     {"heapify",         (PyCFunction)heapify,
         METH_O,                 heapify_doc},
-    {"nsmallest",       (PyCFunction)nsmallest,
-        METH_VARARGS,           nsmallest_doc},
+    {"_heapreplace_max",(PyCFunction)_heapreplace_max,
+        METH_VARARGS,           heapreplace_max_doc},
     {NULL,              NULL}           /* sentinel */
 };
 

-- 
Repository URL: http://hg.python.org/cpython


More information about the Python-checkins mailing list