[Python-checkins] cpython (merge 3.5 -> default): Issue #26494: Fixed crash on iterating exhausting iterators.

serhiy.storchaka python-checkins at python.org
Wed Mar 30 13:44:11 EDT 2016


https://hg.python.org/cpython/rev/73ce47d4a7b2
changeset:   100801:73ce47d4a7b2
parent:      100799:854018bf929f
parent:      100800:905b5944119c
user:        Serhiy Storchaka <storchaka at gmail.com>
date:        Wed Mar 30 20:41:15 2016 +0300
summary:
  Issue #26494: Fixed crash on iterating exhausting iterators.

Affected classes are generic sequence iterators, iterators of str, bytes,
bytearray, list, tuple, set, frozenset, dict, OrderedDict, corresponding
views and os.scandir() iterator.

files:
  Lib/test/seq_tests.py         |   5 +++++
  Lib/test/support/__init__.py  |  19 +++++++++++++++++++
  Lib/test/test_bytes.py        |   4 ++++
  Lib/test/test_deque.py        |   4 ++++
  Lib/test/test_dict.py         |   6 ++++++
  Lib/test/test_iter.py         |   4 ++++
  Lib/test/test_ordered_dict.py |   6 ++++++
  Lib/test/test_set.py          |   3 +++
  Lib/test/test_unicode.py      |   4 ++++
  Misc/NEWS                     |   5 +++++
  Modules/posixmodule.c         |  16 ++++++++++------
  Objects/bytearrayobject.c     |   2 +-
  Objects/bytesobject.c         |   2 +-
  Objects/dictobject.c          |   6 +++---
  Objects/iterobject.c          |   2 +-
  Objects/listobject.c          |  20 +++++++++++++-------
  Objects/setobject.c           |   2 +-
  Objects/tupleobject.c         |   2 +-
  Objects/unicodeobject.c       |   2 +-
  19 files changed, 92 insertions(+), 22 deletions(-)


diff --git a/Lib/test/seq_tests.py b/Lib/test/seq_tests.py
--- a/Lib/test/seq_tests.py
+++ b/Lib/test/seq_tests.py
@@ -5,6 +5,7 @@
 import unittest
 import sys
 import pickle
+from test import support
 
 # Various iterables
 # This is used for checking the constructor (here and in test_deque.py)
@@ -408,3 +409,7 @@
             lst2 = pickle.loads(pickle.dumps(lst, proto))
             self.assertEqual(lst2, lst)
             self.assertNotEqual(id(lst2), id(lst))
+
+    def test_free_after_iterating(self):
+        support.check_free_after_iterating(self, iter, self.type2test)
+        support.check_free_after_iterating(self, reversed, self.type2test)
diff --git a/Lib/test/support/__init__.py b/Lib/test/support/__init__.py
--- a/Lib/test/support/__init__.py
+++ b/Lib/test/support/__init__.py
@@ -2432,3 +2432,22 @@
                                      "memory allocations")
     import _testcapi
     return _testcapi.run_in_subinterp(code)
+
+
+def check_free_after_iterating(test, iter, cls, args=()):
+    class A(cls):
+        def __del__(self):
+            nonlocal done
+            done = True
+            try:
+                next(it)
+            except StopIteration:
+                pass
+
+    done = False
+    it = iter(A(*args))
+    # Issue 26494: Shouldn't crash
+    test.assertRaises(StopIteration, next, it)
+    # The sequence should be deallocated just after the end of iterating
+    gc_collect()
+    test.assertTrue(done)
diff --git a/Lib/test/test_bytes.py b/Lib/test/test_bytes.py
--- a/Lib/test/test_bytes.py
+++ b/Lib/test/test_bytes.py
@@ -761,6 +761,10 @@
         self.assertRaisesRegex(TypeError, r'\bendswith\b', b.endswith,
                                 x, None, None, None)
 
+    def test_free_after_iterating(self):
+        test.support.check_free_after_iterating(self, iter, self.type2test)
+        test.support.check_free_after_iterating(self, reversed, self.type2test)
+
 
 class BytesTest(BaseBytesTest, unittest.TestCase):
     type2test = bytes
diff --git a/Lib/test/test_deque.py b/Lib/test/test_deque.py
--- a/Lib/test/test_deque.py
+++ b/Lib/test/test_deque.py
@@ -918,6 +918,10 @@
         # For now, bypass tests that require slicing
         pass
 
+    def test_free_after_iterating(self):
+        # For now, bypass tests that require slicing
+        self.skipTest("Exhausted deque iterator doesn't free a deque")
+
 #==============================================================================
 
 libreftest = """
diff --git a/Lib/test/test_dict.py b/Lib/test/test_dict.py
--- a/Lib/test/test_dict.py
+++ b/Lib/test/test_dict.py
@@ -954,6 +954,12 @@
         d = {X(): 0, 1: 1}
         self.assertRaises(RuntimeError, d.update, other)
 
+    def test_free_after_iterating(self):
+        support.check_free_after_iterating(self, iter, dict)
+        support.check_free_after_iterating(self, lambda d: iter(d.keys()), dict)
+        support.check_free_after_iterating(self, lambda d: iter(d.values()), dict)
+        support.check_free_after_iterating(self, lambda d: iter(d.items()), dict)
+
 from test import mapping_tests
 
 class GeneralMappingTests(mapping_tests.BasicTestMappingProtocol):
diff --git a/Lib/test/test_iter.py b/Lib/test/test_iter.py
--- a/Lib/test/test_iter.py
+++ b/Lib/test/test_iter.py
@@ -3,6 +3,7 @@
 import sys
 import unittest
 from test.support import run_unittest, TESTFN, unlink, cpython_only
+from test.support import check_free_after_iterating
 import pickle
 import collections.abc
 
@@ -980,6 +981,9 @@
         self.assertEqual(next(it), 0)
         self.assertEqual(next(it), 1)
 
+    def test_free_after_iterating(self):
+        check_free_after_iterating(self, iter, SequenceClass, (0,))
+
 
 def test_main():
     run_unittest(TestCase)
diff --git a/Lib/test/test_ordered_dict.py b/Lib/test/test_ordered_dict.py
--- a/Lib/test/test_ordered_dict.py
+++ b/Lib/test/test_ordered_dict.py
@@ -608,6 +608,12 @@
         gc.collect()
         self.assertIsNone(r())
 
+    def test_free_after_iterating(self):
+        support.check_free_after_iterating(self, iter, self.OrderedDict)
+        support.check_free_after_iterating(self, lambda d: iter(d.keys()), self.OrderedDict)
+        support.check_free_after_iterating(self, lambda d: iter(d.values()), self.OrderedDict)
+        support.check_free_after_iterating(self, lambda d: iter(d.items()), self.OrderedDict)
+
 
 class PurePythonOrderedDictTests(OrderedDictTests, unittest.TestCase):
 
diff --git a/Lib/test/test_set.py b/Lib/test/test_set.py
--- a/Lib/test/test_set.py
+++ b/Lib/test/test_set.py
@@ -364,6 +364,9 @@
         gc.collect()
         self.assertTrue(ref() is None, "Cycle was not collected")
 
+    def test_free_after_iterating(self):
+        support.check_free_after_iterating(self, iter, self.thetype)
+
 class TestSet(TestJointOps, unittest.TestCase):
     thetype = set
     basetype = set
diff --git a/Lib/test/test_unicode.py b/Lib/test/test_unicode.py
--- a/Lib/test/test_unicode.py
+++ b/Lib/test/test_unicode.py
@@ -2729,6 +2729,10 @@
                 # Check that the second call returns the same result
                 self.assertEqual(getargs_s_hash(s), chr(k).encode() * (i + 1))
 
+    def test_free_after_iterating(self):
+        support.check_free_after_iterating(self, iter, str)
+        support.check_free_after_iterating(self, reversed, str)
+
 
 class StringModuleTest(unittest.TestCase):
     def test_formatter_parser(self):
diff --git a/Misc/NEWS b/Misc/NEWS
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -10,6 +10,11 @@
 Core and Builtins
 -----------------
 
+- Issue #26494: Fixed crash on iterating exhausting iterators.
+  Affected classes are generic sequence iterators, iterators of str, bytes,
+  bytearray, list, tuple, set, frozenset, dict, OrderedDict, corresponding
+  views and os.scandir() iterator.
+
 - Issue #26574: Optimize ``bytes.replace(b'', b'.')`` and
   ``bytearray.replace(b'', b'.')``. Patch written by Josh Snider.
 
diff --git a/Modules/posixmodule.c b/Modules/posixmodule.c
--- a/Modules/posixmodule.c
+++ b/Modules/posixmodule.c
@@ -11956,13 +11956,15 @@
 static void
 ScandirIterator_closedir(ScandirIterator *iterator)
 {
-    if (iterator->handle == INVALID_HANDLE_VALUE)
+    HANDLE handle = iterator->handle;
+
+    if (handle == INVALID_HANDLE_VALUE)
         return;
 
+    iterator->handle = INVALID_HANDLE_VALUE;
     Py_BEGIN_ALLOW_THREADS
-    FindClose(iterator->handle);
+    FindClose(handle);
     Py_END_ALLOW_THREADS
-    iterator->handle = INVALID_HANDLE_VALUE;
 }
 
 static PyObject *
@@ -12018,13 +12020,15 @@
 static void
 ScandirIterator_closedir(ScandirIterator *iterator)
 {
-    if (!iterator->dirp)
+    DIR *dirp = iterator->dirp;
+
+    if (!dirp)
         return;
 
+    iterator->dirp = NULL;
     Py_BEGIN_ALLOW_THREADS
-    closedir(iterator->dirp);
+    closedir(dirp);
     Py_END_ALLOW_THREADS
-    iterator->dirp = NULL;
     return;
 }
 
diff --git a/Objects/bytearrayobject.c b/Objects/bytearrayobject.c
--- a/Objects/bytearrayobject.c
+++ b/Objects/bytearrayobject.c
@@ -3126,8 +3126,8 @@
         return item;
     }
 
+    it->it_seq = NULL;
     Py_DECREF(seq);
-    it->it_seq = NULL;
     return NULL;
 }
 
diff --git a/Objects/bytesobject.c b/Objects/bytesobject.c
--- a/Objects/bytesobject.c
+++ b/Objects/bytesobject.c
@@ -3806,8 +3806,8 @@
         return item;
     }
 
+    it->it_seq = NULL;
     Py_DECREF(seq);
-    it->it_seq = NULL;
     return NULL;
 }
 
diff --git a/Objects/dictobject.c b/Objects/dictobject.c
--- a/Objects/dictobject.c
+++ b/Objects/dictobject.c
@@ -2988,8 +2988,8 @@
     return key;
 
 fail:
+    di->di_dict = NULL;
     Py_DECREF(d);
-    di->di_dict = NULL;
     return NULL;
 }
 
@@ -3069,8 +3069,8 @@
     return value;
 
 fail:
+    di->di_dict = NULL;
     Py_DECREF(d);
-    di->di_dict = NULL;
     return NULL;
 }
 
@@ -3164,8 +3164,8 @@
     return result;
 
 fail:
+    di->di_dict = NULL;
     Py_DECREF(d);
-    di->di_dict = NULL;
     return NULL;
 }
 
diff --git a/Objects/iterobject.c b/Objects/iterobject.c
--- a/Objects/iterobject.c
+++ b/Objects/iterobject.c
@@ -69,8 +69,8 @@
         PyErr_ExceptionMatches(PyExc_StopIteration))
     {
         PyErr_Clear();
+        it->it_seq = NULL;
         Py_DECREF(seq);
-        it->it_seq = NULL;
     }
     return NULL;
 }
diff --git a/Objects/listobject.c b/Objects/listobject.c
--- a/Objects/listobject.c
+++ b/Objects/listobject.c
@@ -2776,8 +2776,8 @@
         return item;
     }
 
+    it->it_seq = NULL;
     Py_DECREF(seq);
-    it->it_seq = NULL;
     return NULL;
 }
 
@@ -2906,9 +2906,17 @@
 listreviter_next(listreviterobject *it)
 {
     PyObject *item;
-    Py_ssize_t index = it->it_index;
-    PyListObject *seq = it->it_seq;
+    Py_ssize_t index;
+    PyListObject *seq;
 
+    assert(it != NULL);
+    seq = it->it_seq;
+    if (seq == NULL) {
+        return NULL;
+    }
+    assert(PyList_Check(seq));
+
+    index = it->it_index;
     if (index>=0 && index < PyList_GET_SIZE(seq)) {
         item = PyList_GET_ITEM(seq, index);
         it->it_index--;
@@ -2916,10 +2924,8 @@
         return item;
     }
     it->it_index = -1;
-    if (seq != NULL) {
-        it->it_seq = NULL;
-        Py_DECREF(seq);
-    }
+    it->it_seq = NULL;
+    Py_DECREF(seq);
     return NULL;
 }
 
diff --git a/Objects/setobject.c b/Objects/setobject.c
--- a/Objects/setobject.c
+++ b/Objects/setobject.c
@@ -916,8 +916,8 @@
     return key;
 
 fail:
+    si->si_set = NULL;
     Py_DECREF(so);
-    si->si_set = NULL;
     return NULL;
 }
 
diff --git a/Objects/tupleobject.c b/Objects/tupleobject.c
--- a/Objects/tupleobject.c
+++ b/Objects/tupleobject.c
@@ -961,8 +961,8 @@
         return item;
     }
 
+    it->it_seq = NULL;
     Py_DECREF(seq);
-    it->it_seq = NULL;
     return NULL;
 }
 
diff --git a/Objects/unicodeobject.c b/Objects/unicodeobject.c
--- a/Objects/unicodeobject.c
+++ b/Objects/unicodeobject.c
@@ -15401,8 +15401,8 @@
         return item;
     }
 
+    it->it_seq = NULL;
     Py_DECREF(seq);
-    it->it_seq = NULL;
     return NULL;
 }
 

-- 
Repository URL: https://hg.python.org/cpython


More information about the Python-checkins mailing list