[Python-checkins] bpo-21423: Add an initializer argument to {Process, Thread}PoolExecutor (#4241)

Antoine Pitrou webhook-mailer at python.org
Sat Nov 4 06:05:54 EDT 2017


https://github.com/python/cpython/commit/63ff4131af86e8a48cbedb9fbba95bd65ca90061
commit: 63ff4131af86e8a48cbedb9fbba95bd65ca90061
branch: master
author: Antoine Pitrou <pitrou at free.fr>
committer: GitHub <noreply at github.com>
date: 2017-11-04T11:05:49+01:00
summary:

bpo-21423: Add an initializer argument to {Process,Thread}PoolExecutor (#4241)

* bpo-21423: Add an initializer argument to {Process,Thread}PoolExecutor

* Fix docstring

files:
A Misc/NEWS.d/next/Library/2017-11-02-22-26-16.bpo-21423.hw5mEh.rst
M Doc/library/concurrent.futures.rst
M Lib/concurrent/futures/__init__.py
M Lib/concurrent/futures/_base.py
M Lib/concurrent/futures/process.py
M Lib/concurrent/futures/thread.py
M Lib/test/test_concurrent_futures.py

diff --git a/Doc/library/concurrent.futures.rst b/Doc/library/concurrent.futures.rst
index 30556fbb345..d4b698e1c17 100644
--- a/Doc/library/concurrent.futures.rst
+++ b/Doc/library/concurrent.futures.rst
@@ -124,11 +124,17 @@ And::
    executor.submit(wait_on_future)
 
 
-.. class:: ThreadPoolExecutor(max_workers=None, thread_name_prefix='')
+.. class:: ThreadPoolExecutor(max_workers=None, thread_name_prefix='', initializer=None, initargs=())
 
    An :class:`Executor` subclass that uses a pool of at most *max_workers*
    threads to execute calls asynchronously.
 
+   *initializer* is an optional callable that is called at the start of
+   each worker thread; *initargs* is a tuple of arguments passed to the
+   initializer.  Should *initializer* raise an exception, all currently
+   pending jobs will raise a :exc:`~concurrent.futures.thread.BrokenThreadPool`,
+   as well any attempt to submit more jobs to the pool.
+
    .. versionchanged:: 3.5
       If *max_workers* is ``None`` or
       not given, it will default to the number of processors on the machine,
@@ -142,6 +148,10 @@ And::
       control the threading.Thread names for worker threads created by
       the pool for easier debugging.
 
+   .. versionchanged:: 3.7
+      Added the *initializer* and *initargs* arguments.
+
+
 .. _threadpoolexecutor-example:
 
 ThreadPoolExecutor Example
@@ -191,7 +201,7 @@ that :class:`ProcessPoolExecutor` will not work in the interactive interpreter.
 Calling :class:`Executor` or :class:`Future` methods from a callable submitted
 to a :class:`ProcessPoolExecutor` will result in deadlock.
 
-.. class:: ProcessPoolExecutor(max_workers=None, mp_context=None)
+.. class:: ProcessPoolExecutor(max_workers=None, mp_context=None, initializer=None, initargs=())
 
    An :class:`Executor` subclass that executes calls asynchronously using a pool
    of at most *max_workers* processes.  If *max_workers* is ``None`` or not
@@ -202,6 +212,12 @@ to a :class:`ProcessPoolExecutor` will result in deadlock.
    launch the workers. If *mp_context* is ``None`` or not given, the default
    multiprocessing context is used.
 
+   *initializer* is an optional callable that is called at the start of
+   each worker process; *initargs* is a tuple of arguments passed to the
+   initializer.  Should *initializer* raise an exception, all currently
+   pending jobs will raise a :exc:`~concurrent.futures.thread.BrokenThreadPool`,
+   as well any attempt to submit more jobs to the pool.
+
    .. versionchanged:: 3.3
       When one of the worker processes terminates abruptly, a
       :exc:`BrokenProcessPool` error is now raised.  Previously, behaviour
@@ -212,6 +228,8 @@ to a :class:`ProcessPoolExecutor` will result in deadlock.
       The *mp_context* argument was added to allow users to control the
       start_method for worker processes created by the pool.
 
+      Added the *initializer* and *initargs* arguments.
+
 
 .. _processpoolexecutor-example:
 
@@ -432,13 +450,31 @@ Exception classes
 
    Raised when a future operation exceeds the given timeout.
 
+.. exception:: BrokenExecutor
+
+   Derived from :exc:`RuntimeError`, this exception class is raised
+   when an executor is broken for some reason, and cannot be used
+   to submit or execute new tasks.
+
+   .. versionadded:: 3.7
+
+.. currentmodule:: concurrent.futures.thread
+
+.. exception:: BrokenThreadPool
+
+   Derived from :exc:`~concurrent.futures.BrokenExecutor`, this exception
+   class is raised when one of the workers of a :class:`ThreadPoolExecutor`
+   has failed initializing.
+
+   .. versionadded:: 3.7
+
 .. currentmodule:: concurrent.futures.process
 
 .. exception:: BrokenProcessPool
 
-   Derived from :exc:`RuntimeError`, this exception class is raised when
-   one of the workers of a :class:`ProcessPoolExecutor` has terminated
-   in a non-clean fashion (for example, if it was killed from the outside).
+   Derived from :exc:`~concurrent.futures.BrokenExecutor` (formerly
+   :exc:`RuntimeError`), this exception class is raised when one of the
+   workers of a :class:`ProcessPoolExecutor` has terminated in a non-clean
+   fashion (for example, if it was killed from the outside).
 
    .. versionadded:: 3.3
-
diff --git a/Lib/concurrent/futures/__init__.py b/Lib/concurrent/futures/__init__.py
index b5231f8aab3..ba8de163905 100644
--- a/Lib/concurrent/futures/__init__.py
+++ b/Lib/concurrent/futures/__init__.py
@@ -10,6 +10,7 @@
                                       ALL_COMPLETED,
                                       CancelledError,
                                       TimeoutError,
+                                      BrokenExecutor,
                                       Future,
                                       Executor,
                                       wait,
diff --git a/Lib/concurrent/futures/_base.py b/Lib/concurrent/futures/_base.py
index 6bace6c7464..4f22f7ee0e6 100644
--- a/Lib/concurrent/futures/_base.py
+++ b/Lib/concurrent/futures/_base.py
@@ -610,3 +610,9 @@ def __enter__(self):
     def __exit__(self, exc_type, exc_val, exc_tb):
         self.shutdown(wait=True)
         return False
+
+
+class BrokenExecutor(RuntimeError):
+    """
+    Raised when a executor has become non-functional after a severe failure.
+    """
diff --git a/Lib/concurrent/futures/process.py b/Lib/concurrent/futures/process.py
index 67ebbf51521..35af65d0bee 100644
--- a/Lib/concurrent/futures/process.py
+++ b/Lib/concurrent/futures/process.py
@@ -131,6 +131,7 @@ def __init__(self, work_id, fn, args, kwargs):
         self.args = args
         self.kwargs = kwargs
 
+
 def _get_chunks(*iterables, chunksize):
     """ Iterates over zip()ed iterables in chunks. """
     it = zip(*iterables)
@@ -151,7 +152,7 @@ def _process_chunk(fn, chunk):
     """
     return [fn(*args) for args in chunk]
 
-def _process_worker(call_queue, result_queue):
+def _process_worker(call_queue, result_queue, initializer, initargs):
     """Evaluates calls from call_queue and places the results in result_queue.
 
     This worker is run in a separate process.
@@ -161,7 +162,17 @@ def _process_worker(call_queue, result_queue):
             evaluated by the worker.
         result_queue: A ctx.Queue of _ResultItems that will written
             to by the worker.
+        initializer: A callable initializer, or None
+        initargs: A tuple of args for the initializer
     """
+    if initializer is not None:
+        try:
+            initializer(*initargs)
+        except BaseException:
+            _base.LOGGER.critical('Exception in initializer:', exc_info=True)
+            # The parent will notice that the process stopped and
+            # mark the pool broken
+            return
     while True:
         call_item = call_queue.get(block=True)
         if call_item is None:
@@ -277,7 +288,9 @@ def shutdown_worker():
             # Mark the process pool broken so that submits fail right now.
             executor = executor_reference()
             if executor is not None:
-                executor._broken = True
+                executor._broken = ('A child process terminated '
+                                    'abruptly, the process pool is not '
+                                    'usable anymore')
                 executor._shutdown_thread = True
                 executor = None
             # All futures in flight must be marked failed
@@ -372,7 +385,7 @@ def _chain_from_iterable_of_lists(iterable):
             yield element.pop()
 
 
-class BrokenProcessPool(RuntimeError):
+class BrokenProcessPool(_base.BrokenExecutor):
     """
     Raised when a process in a ProcessPoolExecutor terminated abruptly
     while a future was in the running state.
@@ -380,7 +393,8 @@ class BrokenProcessPool(RuntimeError):
 
 
 class ProcessPoolExecutor(_base.Executor):
-    def __init__(self, max_workers=None, mp_context=None):
+    def __init__(self, max_workers=None, mp_context=None,
+                 initializer=None, initargs=()):
         """Initializes a new ProcessPoolExecutor instance.
 
         Args:
@@ -389,6 +403,8 @@ def __init__(self, max_workers=None, mp_context=None):
                 worker processes will be created as the machine has processors.
             mp_context: A multiprocessing context to launch the workers. This
                 object should provide SimpleQueue, Queue and Process.
+            initializer: An callable used to initialize worker processes.
+            initargs: A tuple of arguments to pass to the initializer.
         """
         _check_system_limits()
 
@@ -403,6 +419,11 @@ def __init__(self, max_workers=None, mp_context=None):
             mp_context = mp.get_context()
         self._mp_context = mp_context
 
+        if initializer is not None and not callable(initializer):
+            raise TypeError("initializer must be a callable")
+        self._initializer = initializer
+        self._initargs = initargs
+
         # Make the call queue slightly larger than the number of processes to
         # prevent the worker processes from idling. But don't make it too big
         # because futures in the call queue cannot be cancelled.
@@ -450,15 +471,16 @@ def _adjust_process_count(self):
             p = self._mp_context.Process(
                 target=_process_worker,
                 args=(self._call_queue,
-                      self._result_queue))
+                      self._result_queue,
+                      self._initializer,
+                      self._initargs))
             p.start()
             self._processes[p.pid] = p
 
     def submit(self, fn, *args, **kwargs):
         with self._shutdown_lock:
             if self._broken:
-                raise BrokenProcessPool('A child process terminated '
-                    'abruptly, the process pool is not usable anymore')
+                raise BrokenProcessPool(self._broken)
             if self._shutdown_thread:
                 raise RuntimeError('cannot schedule new futures after shutdown')
 
diff --git a/Lib/concurrent/futures/thread.py b/Lib/concurrent/futures/thread.py
index 0b5d5373ffd..2e7100bc352 100644
--- a/Lib/concurrent/futures/thread.py
+++ b/Lib/concurrent/futures/thread.py
@@ -41,6 +41,7 @@ def _python_exit():
 
 atexit.register(_python_exit)
 
+
 class _WorkItem(object):
     def __init__(self, future, fn, args, kwargs):
         self.future = future
@@ -61,7 +62,17 @@ def run(self):
         else:
             self.future.set_result(result)
 
-def _worker(executor_reference, work_queue):
+
+def _worker(executor_reference, work_queue, initializer, initargs):
+    if initializer is not None:
+        try:
+            initializer(*initargs)
+        except BaseException:
+            _base.LOGGER.critical('Exception in initializer:', exc_info=True)
+            executor = executor_reference()
+            if executor is not None:
+                executor._initializer_failed()
+            return
     try:
         while True:
             work_item = work_queue.get(block=True)
@@ -83,18 +94,28 @@ def _worker(executor_reference, work_queue):
     except BaseException:
         _base.LOGGER.critical('Exception in worker', exc_info=True)
 
+
+class BrokenThreadPool(_base.BrokenExecutor):
+    """
+    Raised when a worker thread in a ThreadPoolExecutor failed initializing.
+    """
+
+
 class ThreadPoolExecutor(_base.Executor):
 
     # Used to assign unique thread names when thread_name_prefix is not supplied.
     _counter = itertools.count().__next__
 
-    def __init__(self, max_workers=None, thread_name_prefix=''):
+    def __init__(self, max_workers=None, thread_name_prefix='',
+                 initializer=None, initargs=()):
         """Initializes a new ThreadPoolExecutor instance.
 
         Args:
             max_workers: The maximum number of threads that can be used to
                 execute the given calls.
             thread_name_prefix: An optional name prefix to give our threads.
+            initializer: An callable used to initialize worker threads.
+            initargs: A tuple of arguments to pass to the initializer.
         """
         if max_workers is None:
             # Use this number because ThreadPoolExecutor is often
@@ -103,16 +124,25 @@ def __init__(self, max_workers=None, thread_name_prefix=''):
         if max_workers <= 0:
             raise ValueError("max_workers must be greater than 0")
 
+        if initializer is not None and not callable(initializer):
+            raise TypeError("initializer must be a callable")
+
         self._max_workers = max_workers
         self._work_queue = queue.Queue()
         self._threads = set()
+        self._broken = False
         self._shutdown = False
         self._shutdown_lock = threading.Lock()
         self._thread_name_prefix = (thread_name_prefix or
                                     ("ThreadPoolExecutor-%d" % self._counter()))
+        self._initializer = initializer
+        self._initargs = initargs
 
     def submit(self, fn, *args, **kwargs):
         with self._shutdown_lock:
+            if self._broken:
+                raise BrokenThreadPool(self._broken)
+
             if self._shutdown:
                 raise RuntimeError('cannot schedule new futures after shutdown')
 
@@ -137,12 +167,27 @@ def weakref_cb(_, q=self._work_queue):
                                      num_threads)
             t = threading.Thread(name=thread_name, target=_worker,
                                  args=(weakref.ref(self, weakref_cb),
-                                       self._work_queue))
+                                       self._work_queue,
+                                       self._initializer,
+                                       self._initargs))
             t.daemon = True
             t.start()
             self._threads.add(t)
             _threads_queues[t] = self._work_queue
 
+    def _initializer_failed(self):
+        with self._shutdown_lock:
+            self._broken = ('A thread initializer failed, the thread pool '
+                            'is not usable anymore')
+            # Drain work queue and mark pending futures failed
+            while True:
+                try:
+                    work_item = self._work_queue.get_nowait()
+                except queue.Empty:
+                    break
+                if work_item is not None:
+                    work_item.future.set_exception(BrokenThreadPool(self._broken))
+
     def shutdown(self, wait=True):
         with self._shutdown_lock:
             self._shutdown = True
diff --git a/Lib/test/test_concurrent_futures.py b/Lib/test/test_concurrent_futures.py
index ed8ad41f8e6..296398f0d94 100644
--- a/Lib/test/test_concurrent_futures.py
+++ b/Lib/test/test_concurrent_futures.py
@@ -7,6 +7,7 @@
 
 from test.support.script_helper import assert_python_ok
 
+import contextlib
 import itertools
 import os
 import sys
@@ -17,7 +18,8 @@
 
 from concurrent import futures
 from concurrent.futures._base import (
-    PENDING, RUNNING, CANCELLED, CANCELLED_AND_NOTIFIED, FINISHED, Future)
+    PENDING, RUNNING, CANCELLED, CANCELLED_AND_NOTIFIED, FINISHED, Future,
+    BrokenExecutor)
 from concurrent.futures.process import BrokenProcessPool
 from multiprocessing import get_context
 
@@ -37,11 +39,12 @@ def create_future(state=PENDING, exception=None, result=None):
 EXCEPTION_FUTURE = create_future(state=FINISHED, exception=OSError())
 SUCCESSFUL_FUTURE = create_future(state=FINISHED, result=42)
 
+INITIALIZER_STATUS = 'uninitialized'
+
 
 def mul(x, y):
     return x * y
 
-
 def sleep_and_raise(t):
     time.sleep(t)
     raise Exception('this is an exception')
@@ -51,6 +54,17 @@ def sleep_and_print(t, msg):
     print(msg)
     sys.stdout.flush()
 
+def init(x):
+    global INITIALIZER_STATUS
+    INITIALIZER_STATUS = x
+
+def get_init_status():
+    return INITIALIZER_STATUS
+
+def init_fail():
+    time.sleep(0.1)  # let some futures be scheduled
+    raise ValueError('error in initializer')
+
 
 class MyObject(object):
     def my_method(self):
@@ -81,6 +95,7 @@ def tearDown(self):
 
 class ExecutorMixin:
     worker_count = 5
+    executor_kwargs = {}
 
     def setUp(self):
         super().setUp()
@@ -90,10 +105,12 @@ def setUp(self):
             if hasattr(self, "ctx"):
                 self.executor = self.executor_type(
                     max_workers=self.worker_count,
-                    mp_context=get_context(self.ctx))
+                    mp_context=get_context(self.ctx),
+                    **self.executor_kwargs)
             else:
                 self.executor = self.executor_type(
-                    max_workers=self.worker_count)
+                    max_workers=self.worker_count,
+                    **self.executor_kwargs)
         except NotImplementedError as e:
             self.skipTest(str(e))
         self._prime_executor()
@@ -114,7 +131,6 @@ def _prime_executor(self):
         # tests. This should reduce the probability of timeouts in the tests.
         futures = [self.executor.submit(time.sleep, 0.1)
                    for _ in range(self.worker_count)]
-
         for f in futures:
             f.result()
 
@@ -148,6 +164,90 @@ def setUp(self):
         super().setUp()
 
 
+def create_executor_tests(mixin, bases=(BaseTestCase,),
+                          executor_mixins=(ThreadPoolMixin,
+                                           ProcessPoolForkMixin,
+                                           ProcessPoolForkserverMixin,
+                                           ProcessPoolSpawnMixin)):
+    def strip_mixin(name):
+        if name.endswith(('Mixin', 'Tests')):
+            return name[:-5]
+        elif name.endswith('Test'):
+            return name[:-4]
+        else:
+            return name
+
+    for exe in executor_mixins:
+        name = ("%s%sTest"
+                % (strip_mixin(exe.__name__), strip_mixin(mixin.__name__)))
+        cls = type(name, (mixin,) + (exe,) + bases, {})
+        globals()[name] = cls
+
+
+class InitializerMixin(ExecutorMixin):
+    worker_count = 2
+
+    def setUp(self):
+        global INITIALIZER_STATUS
+        INITIALIZER_STATUS = 'uninitialized'
+        self.executor_kwargs = dict(initializer=init,
+                                    initargs=('initialized',))
+        super().setUp()
+
+    def test_initializer(self):
+        futures = [self.executor.submit(get_init_status)
+                   for _ in range(self.worker_count)]
+
+        for f in futures:
+            self.assertEqual(f.result(), 'initialized')
+
+
+class FailingInitializerMixin(ExecutorMixin):
+    worker_count = 2
+
+    def setUp(self):
+        self.executor_kwargs = dict(initializer=init_fail)
+        super().setUp()
+
+    def test_initializer(self):
+        with self._assert_logged('ValueError: error in initializer'):
+            try:
+                future = self.executor.submit(get_init_status)
+            except BrokenExecutor:
+                # Perhaps the executor is already broken
+                pass
+            else:
+                with self.assertRaises(BrokenExecutor):
+                    future.result()
+            # At some point, the executor should break
+            t1 = time.time()
+            while not self.executor._broken:
+                if time.time() - t1 > 5:
+                    self.fail("executor not broken after 5 s.")
+                time.sleep(0.01)
+            # ... and from this point submit() is guaranteed to fail
+            with self.assertRaises(BrokenExecutor):
+                self.executor.submit(get_init_status)
+
+    def _prime_executor(self):
+        pass
+
+    @contextlib.contextmanager
+    def _assert_logged(self, msg):
+        if self.executor_type is futures.ProcessPoolExecutor:
+            # No easy way to catch the child processes' stderr
+            yield
+        else:
+            with self.assertLogs('concurrent.futures', 'CRITICAL') as cm:
+                yield
+            self.assertTrue(any(msg in line for line in cm.output),
+                            cm.output)
+
+
+create_executor_tests(InitializerMixin)
+create_executor_tests(FailingInitializerMixin)
+
+
 class ExecutorShutdownTest:
     def test_run_after_shutdown(self):
         self.executor.shutdown()
@@ -278,20 +378,11 @@ def test_del_shutdown(self):
         call_queue.join_thread()
 
 
-class ProcessPoolForkShutdownTest(ProcessPoolForkMixin, BaseTestCase,
-                                  ProcessPoolShutdownTest):
-    pass
-
-
-class ProcessPoolForkserverShutdownTest(ProcessPoolForkserverMixin,
-                                        BaseTestCase,
-                                        ProcessPoolShutdownTest):
-    pass
-
 
-class ProcessPoolSpawnShutdownTest(ProcessPoolSpawnMixin, BaseTestCase,
-                                   ProcessPoolShutdownTest):
-    pass
+create_executor_tests(ProcessPoolShutdownTest,
+                      executor_mixins=(ProcessPoolForkMixin,
+                                       ProcessPoolForkserverMixin,
+                                       ProcessPoolSpawnMixin))
 
 
 class WaitTests:
@@ -413,18 +504,10 @@ def future_func():
             sys.setswitchinterval(oldswitchinterval)
 
 
-class ProcessPoolForkWaitTests(ProcessPoolForkMixin, WaitTests, BaseTestCase):
-    pass
-
-
-class ProcessPoolForkserverWaitTests(ProcessPoolForkserverMixin, WaitTests,
-                                     BaseTestCase):
-    pass
-
-
-class ProcessPoolSpawnWaitTests(ProcessPoolSpawnMixin, BaseTestCase,
-                                WaitTests):
-    pass
+create_executor_tests(WaitTests,
+                      executor_mixins=(ProcessPoolForkMixin,
+                                       ProcessPoolForkserverMixin,
+                                       ProcessPoolSpawnMixin))
 
 
 class AsCompletedTests:
@@ -507,24 +590,7 @@ def test_correct_timeout_exception_msg(self):
         self.assertEqual(str(cm.exception), '2 (of 4) futures unfinished')
 
 
-class ThreadPoolAsCompletedTests(ThreadPoolMixin, AsCompletedTests, BaseTestCase):
-    pass
-
-
-class ProcessPoolForkAsCompletedTests(ProcessPoolForkMixin, AsCompletedTests,
-                                      BaseTestCase):
-    pass
-
-
-class ProcessPoolForkserverAsCompletedTests(ProcessPoolForkserverMixin,
-                                            AsCompletedTests,
-                                            BaseTestCase):
-    pass
-
-
-class ProcessPoolSpawnAsCompletedTests(ProcessPoolSpawnMixin, AsCompletedTests,
-                                       BaseTestCase):
-    pass
+create_executor_tests(AsCompletedTests)
 
 
 class ExecutorTest:
@@ -688,23 +754,10 @@ def test_ressources_gced_in_workers(self):
         self.assertTrue(obj.event.wait(timeout=1))
 
 
-class ProcessPoolForkExecutorTest(ProcessPoolForkMixin,
-                                  ProcessPoolExecutorTest,
-                                  BaseTestCase):
-    pass
-
-
-class ProcessPoolForkserverExecutorTest(ProcessPoolForkserverMixin,
-                                        ProcessPoolExecutorTest,
-                                        BaseTestCase):
-    pass
-
-
-class ProcessPoolSpawnExecutorTest(ProcessPoolSpawnMixin,
-                                   ProcessPoolExecutorTest,
-                                   BaseTestCase):
-    pass
-
+create_executor_tests(ProcessPoolExecutorTest,
+                      executor_mixins=(ProcessPoolForkMixin,
+                                       ProcessPoolForkserverMixin,
+                                       ProcessPoolSpawnMixin))
 
 
 class FutureTests(BaseTestCase):
@@ -932,6 +985,7 @@ def notification():
         self.assertTrue(isinstance(f1.exception(timeout=5), OSError))
         t.join()
 
+
 @test.support.reap_threads
 def test_main():
     try:
diff --git a/Misc/NEWS.d/next/Library/2017-11-02-22-26-16.bpo-21423.hw5mEh.rst b/Misc/NEWS.d/next/Library/2017-11-02-22-26-16.bpo-21423.hw5mEh.rst
new file mode 100644
index 00000000000..e11f15f90f1
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2017-11-02-22-26-16.bpo-21423.hw5mEh.rst
@@ -0,0 +1 @@
+Add an initializer argument to {Process,Thread}PoolExecutor



More information about the Python-checkins mailing list