[Python-checkins] bpo-35378: Fix multiprocessing.Pool references (GH-11627)

Pablo Galindo webhook-mailer at python.org
Mon Feb 11 12:29:07 EST 2019


https://github.com/python/cpython/commit/3766f18c524c57784eea7c0001602017d2122156
commit: 3766f18c524c57784eea7c0001602017d2122156
branch: master
author: Pablo Galindo <Pablogsal at gmail.com>
committer: GitHub <noreply at github.com>
date: 2019-02-11T17:29:00Z
summary:

bpo-35378: Fix multiprocessing.Pool references (GH-11627)

Changes in this commit:

1. Use a _strong_ reference between the Pool and associated iterators
2. Rework PR #8450 to eliminate a cycle in the Pool.

There is no test in this commit because any test that automatically tests this behaviour needs to eliminate the pool before joining the pool to check that the pool object is garbaged collected/does not hang. But doing this will potentially leak threads and processes (see https://bugs.python.org/issue35413).

files:
A Misc/NEWS.d/next/Library/2019-01-21-02-15-20.bpo-35378.4oF03i.rst
M Lib/multiprocessing/pool.py
M Lib/test/_test_multiprocessing.py

diff --git a/Lib/multiprocessing/pool.py b/Lib/multiprocessing/pool.py
index bfb2769ba6ec..18a56f8524b4 100644
--- a/Lib/multiprocessing/pool.py
+++ b/Lib/multiprocessing/pool.py
@@ -151,8 +151,9 @@ class Pool(object):
     '''
     _wrap_exception = True
 
-    def Process(self, *args, **kwds):
-        return self._ctx.Process(*args, **kwds)
+    @staticmethod
+    def Process(ctx, *args, **kwds):
+        return ctx.Process(*args, **kwds)
 
     def __init__(self, processes=None, initializer=None, initargs=(),
                  maxtasksperchild=None, context=None):
@@ -190,7 +191,10 @@ def __init__(self, processes=None, initializer=None, initargs=(),
 
         self._worker_handler = threading.Thread(
             target=Pool._handle_workers,
-            args=(self, )
+            args=(self._cache, self._taskqueue, self._ctx, self.Process,
+                  self._processes, self._pool, self._inqueue, self._outqueue,
+                  self._initializer, self._initargs, self._maxtasksperchild,
+                  self._wrap_exception)
             )
         self._worker_handler.daemon = True
         self._worker_handler._state = RUN
@@ -236,43 +240,61 @@ def __repr__(self):
                 f'state={self._state} '
                 f'pool_size={len(self._pool)}>')
 
-    def _join_exited_workers(self):
+    @staticmethod
+    def _join_exited_workers(pool):
         """Cleanup after any worker processes which have exited due to reaching
         their specified lifetime.  Returns True if any workers were cleaned up.
         """
         cleaned = False
-        for i in reversed(range(len(self._pool))):
-            worker = self._pool[i]
+        for i in reversed(range(len(pool))):
+            worker = pool[i]
             if worker.exitcode is not None:
                 # worker exited
                 util.debug('cleaning up worker %d' % i)
                 worker.join()
                 cleaned = True
-                del self._pool[i]
+                del pool[i]
         return cleaned
 
     def _repopulate_pool(self):
+        return self._repopulate_pool_static(self._ctx, self.Process,
+                                            self._processes,
+                                            self._pool, self._inqueue,
+                                            self._outqueue, self._initializer,
+                                            self._initargs,
+                                            self._maxtasksperchild,
+                                            self._wrap_exception)
+
+    @staticmethod
+    def _repopulate_pool_static(ctx, Process, processes, pool, inqueue,
+                                outqueue, initializer, initargs,
+                                maxtasksperchild, wrap_exception):
         """Bring the number of pool processes up to the specified number,
         for use after reaping workers which have exited.
         """
-        for i in range(self._processes - len(self._pool)):
-            w = self.Process(target=worker,
-                             args=(self._inqueue, self._outqueue,
-                                   self._initializer,
-                                   self._initargs, self._maxtasksperchild,
-                                   self._wrap_exception)
-                            )
+        for i in range(processes - len(pool)):
+            w = Process(ctx, target=worker,
+                        args=(inqueue, outqueue,
+                              initializer,
+                              initargs, maxtasksperchild,
+                              wrap_exception))
             w.name = w.name.replace('Process', 'PoolWorker')
             w.daemon = True
             w.start()
-            self._pool.append(w)
+            pool.append(w)
             util.debug('added worker')
 
-    def _maintain_pool(self):
+    @staticmethod
+    def _maintain_pool(ctx, Process, processes, pool, inqueue, outqueue,
+                       initializer, initargs, maxtasksperchild,
+                       wrap_exception):
         """Clean up any exited workers and start replacements for them.
         """
-        if self._join_exited_workers():
-            self._repopulate_pool()
+        if Pool._join_exited_workers(pool):
+            Pool._repopulate_pool_static(ctx, Process, processes, pool,
+                                         inqueue, outqueue, initializer,
+                                         initargs, maxtasksperchild,
+                                         wrap_exception)
 
     def _setup_queues(self):
         self._inqueue = self._ctx.SimpleQueue()
@@ -331,7 +353,7 @@ def imap(self, func, iterable, chunksize=1):
         '''
         self._check_running()
         if chunksize == 1:
-            result = IMapIterator(self._cache)
+            result = IMapIterator(self)
             self._taskqueue.put(
                 (
                     self._guarded_task_generation(result._job, func, iterable),
@@ -344,7 +366,7 @@ def imap(self, func, iterable, chunksize=1):
                     "Chunksize must be 1+, not {0:n}".format(
                         chunksize))
             task_batches = Pool._get_tasks(func, iterable, chunksize)
-            result = IMapIterator(self._cache)
+            result = IMapIterator(self)
             self._taskqueue.put(
                 (
                     self._guarded_task_generation(result._job,
@@ -360,7 +382,7 @@ def imap_unordered(self, func, iterable, chunksize=1):
         '''
         self._check_running()
         if chunksize == 1:
-            result = IMapUnorderedIterator(self._cache)
+            result = IMapUnorderedIterator(self)
             self._taskqueue.put(
                 (
                     self._guarded_task_generation(result._job, func, iterable),
@@ -372,7 +394,7 @@ def imap_unordered(self, func, iterable, chunksize=1):
                 raise ValueError(
                     "Chunksize must be 1+, not {0!r}".format(chunksize))
             task_batches = Pool._get_tasks(func, iterable, chunksize)
-            result = IMapUnorderedIterator(self._cache)
+            result = IMapUnorderedIterator(self)
             self._taskqueue.put(
                 (
                     self._guarded_task_generation(result._job,
@@ -388,7 +410,7 @@ def apply_async(self, func, args=(), kwds={}, callback=None,
         Asynchronous version of `apply()` method.
         '''
         self._check_running()
-        result = ApplyResult(self._cache, callback, error_callback)
+        result = ApplyResult(self, callback, error_callback)
         self._taskqueue.put(([(result._job, 0, func, args, kwds)], None))
         return result
 
@@ -417,7 +439,7 @@ def _map_async(self, func, iterable, mapper, chunksize=None, callback=None,
             chunksize = 0
 
         task_batches = Pool._get_tasks(func, iterable, chunksize)
-        result = MapResult(self._cache, chunksize, len(iterable), callback,
+        result = MapResult(self, chunksize, len(iterable), callback,
                            error_callback=error_callback)
         self._taskqueue.put(
             (
@@ -430,16 +452,20 @@ def _map_async(self, func, iterable, mapper, chunksize=None, callback=None,
         return result
 
     @staticmethod
-    def _handle_workers(pool):
+    def _handle_workers(cache, taskqueue, ctx, Process, processes, pool,
+                        inqueue, outqueue, initializer, initargs,
+                        maxtasksperchild, wrap_exception):
         thread = threading.current_thread()
 
         # Keep maintaining workers until the cache gets drained, unless the pool
         # is terminated.
-        while thread._state == RUN or (pool._cache and thread._state != TERMINATE):
-            pool._maintain_pool()
+        while thread._state == RUN or (cache and thread._state != TERMINATE):
+            Pool._maintain_pool(ctx, Process, processes, pool, inqueue,
+                                outqueue, initializer, initargs,
+                                maxtasksperchild, wrap_exception)
             time.sleep(0.1)
         # send sentinel to stop workers
-        pool._taskqueue.put(None)
+        taskqueue.put(None)
         util.debug('worker handler exiting')
 
     @staticmethod
@@ -656,13 +682,14 @@ def __exit__(self, exc_type, exc_val, exc_tb):
 
 class ApplyResult(object):
 
-    def __init__(self, cache, callback, error_callback):
+    def __init__(self, pool, callback, error_callback):
+        self._pool = pool
         self._event = threading.Event()
         self._job = next(job_counter)
-        self._cache = cache
+        self._cache = pool._cache
         self._callback = callback
         self._error_callback = error_callback
-        cache[self._job] = self
+        self._cache[self._job] = self
 
     def ready(self):
         return self._event.is_set()
@@ -692,6 +719,7 @@ def _set(self, i, obj):
             self._error_callback(self._value)
         self._event.set()
         del self._cache[self._job]
+        self._pool = None
 
 AsyncResult = ApplyResult       # create alias -- see #17805
 
@@ -701,8 +729,8 @@ def _set(self, i, obj):
 
 class MapResult(ApplyResult):
 
-    def __init__(self, cache, chunksize, length, callback, error_callback):
-        ApplyResult.__init__(self, cache, callback,
+    def __init__(self, pool, chunksize, length, callback, error_callback):
+        ApplyResult.__init__(self, pool, callback,
                              error_callback=error_callback)
         self._success = True
         self._value = [None] * length
@@ -710,7 +738,7 @@ def __init__(self, cache, chunksize, length, callback, error_callback):
         if chunksize <= 0:
             self._number_left = 0
             self._event.set()
-            del cache[self._job]
+            del self._cache[self._job]
         else:
             self._number_left = length//chunksize + bool(length % chunksize)
 
@@ -724,6 +752,7 @@ def _set(self, i, success_result):
                     self._callback(self._value)
                 del self._cache[self._job]
                 self._event.set()
+                self._pool = None
         else:
             if not success and self._success:
                 # only store first exception
@@ -735,6 +764,7 @@ def _set(self, i, success_result):
                     self._error_callback(self._value)
                 del self._cache[self._job]
                 self._event.set()
+                self._pool = None
 
 #
 # Class whose instances are returned by `Pool.imap()`
@@ -742,15 +772,16 @@ def _set(self, i, success_result):
 
 class IMapIterator(object):
 
-    def __init__(self, cache):
+    def __init__(self, pool):
+        self._pool = pool
         self._cond = threading.Condition(threading.Lock())
         self._job = next(job_counter)
-        self._cache = cache
+        self._cache = pool._cache
         self._items = collections.deque()
         self._index = 0
         self._length = None
         self._unsorted = {}
-        cache[self._job] = self
+        self._cache[self._job] = self
 
     def __iter__(self):
         return self
@@ -761,12 +792,14 @@ def next(self, timeout=None):
                 item = self._items.popleft()
             except IndexError:
                 if self._index == self._length:
+                    self._pool = None
                     raise StopIteration from None
                 self._cond.wait(timeout)
                 try:
                     item = self._items.popleft()
                 except IndexError:
                     if self._index == self._length:
+                        self._pool = None
                         raise StopIteration from None
                     raise TimeoutError from None
 
@@ -792,6 +825,7 @@ def _set(self, i, obj):
 
             if self._index == self._length:
                 del self._cache[self._job]
+                self._pool = None
 
     def _set_length(self, length):
         with self._cond:
@@ -799,6 +833,7 @@ def _set_length(self, length):
             if self._index == self._length:
                 self._cond.notify()
                 del self._cache[self._job]
+                self._pool = None
 
 #
 # Class whose instances are returned by `Pool.imap_unordered()`
@@ -813,6 +848,7 @@ def _set(self, i, obj):
             self._cond.notify()
             if self._index == self._length:
                 del self._cache[self._job]
+                self._pool = None
 
 #
 #
@@ -822,7 +858,7 @@ class ThreadPool(Pool):
     _wrap_exception = False
 
     @staticmethod
-    def Process(*args, **kwds):
+    def Process(ctx, *args, **kwds):
         from .dummy import Process
         return Process(*args, **kwds)
 
diff --git a/Lib/test/_test_multiprocessing.py b/Lib/test/_test_multiprocessing.py
index bc1072d4b30d..d93303bb9cca 100644
--- a/Lib/test/_test_multiprocessing.py
+++ b/Lib/test/_test_multiprocessing.py
@@ -2593,7 +2593,6 @@ def test_resource_warning(self):
             pool = None
             support.gc_collect()
 
-
 def raising():
     raise KeyError("key")
 
diff --git a/Misc/NEWS.d/next/Library/2019-01-21-02-15-20.bpo-35378.4oF03i.rst b/Misc/NEWS.d/next/Library/2019-01-21-02-15-20.bpo-35378.4oF03i.rst
new file mode 100644
index 000000000000..bb57f7115991
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2019-01-21-02-15-20.bpo-35378.4oF03i.rst
@@ -0,0 +1,6 @@
+Fix a reference issue inside :class:`multiprocessing.Pool` that caused
+the pool to remain alive if it was deleted without being closed or
+terminated explicitly. A new strong reference is added to the pool
+iterators to link the lifetime of the pool to the lifetime of its
+iterators so the pool does not get destroyed if a pool iterator is
+still alive.



More information about the Python-checkins mailing list