[Python-checkins] r86370 - in python/branches/py3k/Lib: multiprocessing/pool.py test/test_multiprocessing.py

ask.solem python-checkins at python.org
Tue Nov 9 21:55:52 CET 2010


Author: ask.solem
Date: Tue Nov  9 21:55:52 2010
New Revision: 86370

Log:
Issue #9244: multiprocessing.pool: Worker crashes if result can't be encoded

Modified:
   python/branches/py3k/Lib/multiprocessing/pool.py
   python/branches/py3k/Lib/test/test_multiprocessing.py

Modified: python/branches/py3k/Lib/multiprocessing/pool.py
==============================================================================
--- python/branches/py3k/Lib/multiprocessing/pool.py	(original)
+++ python/branches/py3k/Lib/multiprocessing/pool.py	Tue Nov  9 21:55:52 2010
@@ -42,6 +42,23 @@
 # Code run by worker processes
 #
 
+class MaybeEncodingError(Exception):
+    """Wraps possible unpickleable errors, so they can be
+    safely sent through the socket."""
+
+    def __init__(self, exc, value):
+        self.exc = repr(exc)
+        self.value = repr(value)
+        super(MaybeEncodingError, self).__init__(self.exc, self.value)
+
+    def __str__(self):
+        return "Error sending result: '%s'. Reason: '%s'" % (self.value,
+                                                             self.exc)
+
+    def __repr__(self):
+        return "<MaybeEncodingError: %s>" % str(self)
+
+
 def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None):
     assert maxtasks is None or (type(maxtasks) == int and maxtasks > 0)
     put = outqueue.put
@@ -70,7 +87,13 @@
             result = (True, func(*args, **kwds))
         except Exception as e:
             result = (False, e)
-        put((job, i, result))
+        try:
+            put((job, i, result))
+        except Exception as e:
+            wrapped = MaybeEncodingError(e, result[1])
+            debug("Possible encoding error while sending result: %s" % (
+                wrapped))
+            put((job, i, (False, wrapped)))
         completed += 1
     debug('worker exiting after %d tasks' % completed)
 
@@ -235,16 +258,18 @@
                      for i, x in enumerate(task_batches)), result._set_length))
             return (item for chunk in result for item in chunk)
 
-    def apply_async(self, func, args=(), kwds={}, callback=None):
+    def apply_async(self, func, args=(), kwds={}, callback=None,
+            error_callback=None):
         '''
         Asynchronous version of `apply()` method.
         '''
         assert self._state == RUN
-        result = ApplyResult(self._cache, callback)
+        result = ApplyResult(self._cache, callback, error_callback)
         self._taskqueue.put(([(result._job, None, func, args, kwds)], None))
         return result
 
-    def map_async(self, func, iterable, chunksize=None, callback=None):
+    def map_async(self, func, iterable, chunksize=None, callback=None,
+            error_callback=None):
         '''
         Asynchronous version of `map()` method.
         '''
@@ -260,7 +285,8 @@
             chunksize = 0
 
         task_batches = Pool._get_tasks(func, iterable, chunksize)
-        result = MapResult(self._cache, chunksize, len(iterable), callback)
+        result = MapResult(self._cache, chunksize, len(iterable), callback,
+                           error_callback=error_callback)
         self._taskqueue.put((((result._job, i, mapstar, (x,), {})
                               for i, x in enumerate(task_batches)), None))
         return result
@@ -459,12 +485,13 @@
 
 class ApplyResult(object):
 
-    def __init__(self, cache, callback):
+    def __init__(self, cache, callback, error_callback):
         self._cond = threading.Condition(threading.Lock())
         self._job = next(job_counter)
         self._cache = cache
         self._ready = False
         self._callback = callback
+        self._error_callback = error_callback
         cache[self._job] = self
 
     def ready(self):
@@ -495,6 +522,8 @@
         self._success, self._value = obj
         if self._callback and self._success:
             self._callback(self._value)
+        if self._error_callback and not self._success:
+            self._error_callback(self._value)
         self._cond.acquire()
         try:
             self._ready = True
@@ -509,8 +538,9 @@
 
 class MapResult(ApplyResult):
 
-    def __init__(self, cache, chunksize, length, callback):
-        ApplyResult.__init__(self, cache, callback)
+    def __init__(self, cache, chunksize, length, callback, error_callback):
+        ApplyResult.__init__(self, cache, callback,
+                             error_callback=error_callback)
         self._success = True
         self._value = [None] * length
         self._chunksize = chunksize
@@ -535,10 +565,11 @@
                     self._cond.notify()
                 finally:
                     self._cond.release()
-
         else:
             self._success = False
             self._value = result
+            if self._error_callback:
+                self._error_callback(self._value)
             del self._cache[self._job]
             self._cond.acquire()
             try:

Modified: python/branches/py3k/Lib/test/test_multiprocessing.py
==============================================================================
--- python/branches/py3k/Lib/test/test_multiprocessing.py	(original)
+++ python/branches/py3k/Lib/test/test_multiprocessing.py	Tue Nov  9 21:55:52 2010
@@ -1011,6 +1011,7 @@
 def sqr(x, wait=0.0):
     time.sleep(wait)
     return x*x
+
 class _TestPool(BaseTestCase):
 
     def test_apply(self):
@@ -1087,9 +1088,55 @@
         join()
         self.assertTrue(join.elapsed < 0.2)
 
-class _TestPoolWorkerLifetime(BaseTestCase):
+def raising():
+    raise KeyError("key")
+
+def unpickleable_result():
+    return lambda: 42
+
+class _TestPoolWorkerErrors(BaseTestCase):
+    ALLOWED_TYPES = ('processes', )
+
+    def test_async_error_callback(self):
+        p = multiprocessing.Pool(2)
+
+        scratchpad = [None]
+        def errback(exc):
+            scratchpad[0] = exc
+
+        res = p.apply_async(raising, error_callback=errback)
+        self.assertRaises(KeyError, res.get)
+        self.assertTrue(scratchpad[0])
+        self.assertIsInstance(scratchpad[0], KeyError)
+
+        p.close()
+        p.join()
+
+    def test_unpickleable_result(self):
+        from multiprocessing.pool import MaybeEncodingError
+        p = multiprocessing.Pool(2)
+
+        # Make sure we don't lose pool processes because of encoding errors.
+        for iteration in range(20):
+
+            scratchpad = [None]
+            def errback(exc):
+                scratchpad[0] = exc
+
+            res = p.apply_async(unpickleable_result, error_callback=errback)
+            self.assertRaises(MaybeEncodingError, res.get)
+            wrapped = scratchpad[0]
+            self.assertTrue(wrapped)
+            self.assertIsInstance(scratchpad[0], MaybeEncodingError)
+            self.assertIsNotNone(wrapped.exc)
+            self.assertIsNotNone(wrapped.value)
+
+        p.close()
+        p.join()
 
+class _TestPoolWorkerLifetime(BaseTestCase):
     ALLOWED_TYPES = ('processes', )
+
     def test_pool_worker_lifetime(self):
         p = multiprocessing.Pool(3, maxtasksperchild=10)
         self.assertEqual(3, len(p._pool))


More information about the Python-checkins mailing list