[Python-checkins] bpo-32604: Implement force-closing channels. (gh-6937)

Eric Snow webhook-mailer at python.org
Thu May 17 10:27:19 EDT 2018


https://github.com/python/cpython/commit/3ab0136ac5d6059ce96d4debca89c5f5ab0356f5
commit: 3ab0136ac5d6059ce96d4debca89c5f5ab0356f5
branch: master
author: Eric Snow <ericsnowcurrently at gmail.com>
committer: GitHub <noreply at github.com>
date: 2018-05-17T10:27:09-04:00
summary:

bpo-32604: Implement force-closing channels. (gh-6937)

This will make it easier to clean up channels (e.g. when used in tests).

files:
M Lib/test/test__xxsubinterpreters.py
M Modules/_xxsubinterpretersmodule.c

diff --git a/Lib/test/test__xxsubinterpreters.py b/Lib/test/test__xxsubinterpreters.py
index 118f2e4895fe..f66cc9516926 100644
--- a/Lib/test/test__xxsubinterpreters.py
+++ b/Lib/test/test__xxsubinterpreters.py
@@ -1379,12 +1379,104 @@ def test_close_multiple_times(self):
         with self.assertRaises(interpreters.ChannelClosedError):
             interpreters.channel_close(cid)
 
-    def test_close_with_unused_items(self):
+    def test_close_empty(self):
+        tests = [
+            (False, False),
+            (True, False),
+            (False, True),
+            (True, True),
+            ]
+        for send, recv in tests:
+            with self.subTest((send, recv)):
+                cid = interpreters.channel_create()
+                interpreters.channel_send(cid, b'spam')
+                interpreters.channel_recv(cid)
+                interpreters.channel_close(cid, send=send, recv=recv)
+
+                with self.assertRaises(interpreters.ChannelClosedError):
+                    interpreters.channel_send(cid, b'eggs')
+                with self.assertRaises(interpreters.ChannelClosedError):
+                    interpreters.channel_recv(cid)
+
+    def test_close_defaults_with_unused_items(self):
+        cid = interpreters.channel_create()
+        interpreters.channel_send(cid, b'spam')
+        interpreters.channel_send(cid, b'ham')
+
+        with self.assertRaises(interpreters.ChannelNotEmptyError):
+            interpreters.channel_close(cid)
+        interpreters.channel_recv(cid)
+        interpreters.channel_send(cid, b'eggs')
+
+    def test_close_recv_with_unused_items_unforced(self):
         cid = interpreters.channel_create()
         interpreters.channel_send(cid, b'spam')
         interpreters.channel_send(cid, b'ham')
-        interpreters.channel_close(cid)
+
+        with self.assertRaises(interpreters.ChannelNotEmptyError):
+            interpreters.channel_close(cid, recv=True)
+        interpreters.channel_recv(cid)
+        interpreters.channel_send(cid, b'eggs')
+        interpreters.channel_recv(cid)
+        interpreters.channel_recv(cid)
+        interpreters.channel_close(cid, recv=True)
+
+    def test_close_send_with_unused_items_unforced(self):
+        cid = interpreters.channel_create()
+        interpreters.channel_send(cid, b'spam')
+        interpreters.channel_send(cid, b'ham')
+        interpreters.channel_close(cid, send=True)
 
+        with self.assertRaises(interpreters.ChannelClosedError):
+            interpreters.channel_send(cid, b'eggs')
+        interpreters.channel_recv(cid)
+        interpreters.channel_recv(cid)
+        with self.assertRaises(interpreters.ChannelClosedError):
+            interpreters.channel_recv(cid)
+
+    def test_close_both_with_unused_items_unforced(self):
+        cid = interpreters.channel_create()
+        interpreters.channel_send(cid, b'spam')
+        interpreters.channel_send(cid, b'ham')
+
+        with self.assertRaises(interpreters.ChannelNotEmptyError):
+            interpreters.channel_close(cid, recv=True, send=True)
+        interpreters.channel_recv(cid)
+        interpreters.channel_send(cid, b'eggs')
+        interpreters.channel_recv(cid)
+        interpreters.channel_recv(cid)
+        interpreters.channel_close(cid, recv=True)
+
+    def test_close_recv_with_unused_items_forced(self):
+        cid = interpreters.channel_create()
+        interpreters.channel_send(cid, b'spam')
+        interpreters.channel_send(cid, b'ham')
+        interpreters.channel_close(cid, recv=True, force=True)
+
+        with self.assertRaises(interpreters.ChannelClosedError):
+            interpreters.channel_send(cid, b'eggs')
+        with self.assertRaises(interpreters.ChannelClosedError):
+            interpreters.channel_recv(cid)
+
+    def test_close_send_with_unused_items_forced(self):
+        cid = interpreters.channel_create()
+        interpreters.channel_send(cid, b'spam')
+        interpreters.channel_send(cid, b'ham')
+        interpreters.channel_close(cid, send=True, force=True)
+
+        with self.assertRaises(interpreters.ChannelClosedError):
+            interpreters.channel_send(cid, b'eggs')
+        with self.assertRaises(interpreters.ChannelClosedError):
+            interpreters.channel_recv(cid)
+
+    def test_close_both_with_unused_items_forced(self):
+        cid = interpreters.channel_create()
+        interpreters.channel_send(cid, b'spam')
+        interpreters.channel_send(cid, b'ham')
+        interpreters.channel_close(cid, send=True, recv=True, force=True)
+
+        with self.assertRaises(interpreters.ChannelClosedError):
+            interpreters.channel_send(cid, b'eggs')
         with self.assertRaises(interpreters.ChannelClosedError):
             interpreters.channel_recv(cid)
 
@@ -1403,7 +1495,7 @@ def test_close_by_unassociated_interp(self):
         interp = interpreters.create()
         interpreters.run_string(interp, dedent(f"""
             import _xxsubinterpreters as _interpreters
-            _interpreters.channel_close({cid})
+            _interpreters.channel_close({cid}, force=True)
             """))
         with self.assertRaises(interpreters.ChannelClosedError):
             interpreters.channel_recv(cid)
@@ -1416,7 +1508,7 @@ def test_close_used_multiple_times_by_single_user(self):
         interpreters.channel_send(cid, b'spam')
         interpreters.channel_send(cid, b'spam')
         interpreters.channel_recv(cid)
-        interpreters.channel_close(cid)
+        interpreters.channel_close(cid, force=True)
 
         with self.assertRaises(interpreters.ChannelClosedError):
             interpreters.channel_send(cid, b'eggs')
diff --git a/Modules/_xxsubinterpretersmodule.c b/Modules/_xxsubinterpretersmodule.c
index 5184f6593db1..72387d8da56b 100644
--- a/Modules/_xxsubinterpretersmodule.c
+++ b/Modules/_xxsubinterpretersmodule.c
@@ -306,10 +306,15 @@ _sharedexception_apply(_sharedexception *exc, PyObject *wrapperclass)
 
 /* channel-specific code ****************************************************/
 
+#define CHANNEL_SEND 1
+#define CHANNEL_BOTH 0
+#define CHANNEL_RECV -1
+
 static PyObject *ChannelError;
 static PyObject *ChannelNotFoundError;
 static PyObject *ChannelClosedError;
 static PyObject *ChannelEmptyError;
+static PyObject *ChannelNotEmptyError;
 
 static int
 channel_exceptions_init(PyObject *ns)
@@ -356,6 +361,16 @@ channel_exceptions_init(PyObject *ns)
         return -1;
     }
 
+    // An operation tried to close a non-empty channel.
+    ChannelNotEmptyError = PyErr_NewException(
+            "_xxsubinterpreters.ChannelNotEmptyError", ChannelError, NULL);
+    if (ChannelNotEmptyError == NULL) {
+        return -1;
+    }
+    if (PyDict_SetItemString(ns, "ChannelNotEmptyError", ChannelNotEmptyError) != 0) {
+        return -1;
+    }
+
     return 0;
 }
 
@@ -696,8 +711,11 @@ _channelends_close_interpreter(_channelends *ends, int64_t interp, int which)
 }
 
 static void
-_channelends_close_all(_channelends *ends)
+_channelends_close_all(_channelends *ends, int which, int force)
 {
+    // XXX Handle the ends.
+    // XXX Handle force is True.
+
     // Ensure all the "send"-associated interpreters are closed.
     _channelend *end;
     for (end = ends->send; end != NULL; end = end->next) {
@@ -713,12 +731,16 @@ _channelends_close_all(_channelends *ends)
 /* channels */
 
 struct _channel;
+struct _channel_closing;
+static void _channel_clear_closing(struct _channel *);
+static void _channel_finish_closing(struct _channel *);
 
 typedef struct _channel {
     PyThread_type_lock mutex;
     _channelqueue *queue;
     _channelends *ends;
     int open;
+    struct _channel_closing *closing;
 } _PyChannelState;
 
 static _PyChannelState *
@@ -747,12 +769,14 @@ _channel_new(void)
         return NULL;
     }
     chan->open = 1;
+    chan->closing = NULL;
     return chan;
 }
 
 static void
 _channel_free(_PyChannelState *chan)
 {
+    _channel_clear_closing(chan);
     PyThread_acquire_lock(chan->mutex, WAIT_LOCK);
     _channelqueue_free(chan->queue);
     _channelends_free(chan->ends);
@@ -802,13 +826,20 @@ _channel_next(_PyChannelState *chan, int64_t interp)
     }
 
     data = _channelqueue_get(chan->queue);
+    if (data == NULL && !PyErr_Occurred() && chan->closing != NULL) {
+        chan->open = 0;
+    }
+
 done:
     PyThread_release_lock(chan->mutex);
+    if (chan->queue->count == 0) {
+        _channel_finish_closing(chan);
+    }
     return data;
 }
 
 static int
-_channel_close_interpreter(_PyChannelState *chan, int64_t interp, int which)
+_channel_close_interpreter(_PyChannelState *chan, int64_t interp, int end)
 {
     PyThread_acquire_lock(chan->mutex, WAIT_LOCK);
 
@@ -818,7 +849,7 @@ _channel_close_interpreter(_PyChannelState *chan, int64_t interp, int which)
         goto done;
     }
 
-    if (_channelends_close_interpreter(chan->ends, interp, which) != 0) {
+    if (_channelends_close_interpreter(chan->ends, interp, end) != 0) {
         goto done;
     }
     chan->open = _channelends_is_open(chan->ends);
@@ -830,7 +861,7 @@ _channel_close_interpreter(_PyChannelState *chan, int64_t interp, int which)
 }
 
 static int
-_channel_close_all(_PyChannelState *chan)
+_channel_close_all(_PyChannelState *chan, int end, int force)
 {
     int res = -1;
     PyThread_acquire_lock(chan->mutex, WAIT_LOCK);
@@ -840,11 +871,17 @@ _channel_close_all(_PyChannelState *chan)
         goto done;
     }
 
+    if (!force && chan->queue->count > 0) {
+        PyErr_SetString(ChannelNotEmptyError,
+                        "may not be closed if not empty (try force=True)");
+        goto done;
+    }
+
     chan->open = 0;
 
     // We *could* also just leave these in place, since we've marked
     // the channel as closed already.
-    _channelends_close_all(chan->ends);
+    _channelends_close_all(chan->ends, end, force);
 
     res = 0;
 done:
@@ -889,6 +926,9 @@ _channelref_new(int64_t id, _PyChannelState *chan)
 static void
 _channelref_free(_channelref *ref)
 {
+    if (ref->chan != NULL) {
+        _channel_clear_closing(ref->chan);
+    }
     //_channelref_clear(ref);
     PyMem_Free(ref);
 }
@@ -1009,8 +1049,12 @@ _channels_add(_channels *channels, _PyChannelState *chan)
     return cid;
 }
 
+/* forward */
+static int _channel_set_closing(struct _channelref *, PyThread_type_lock);
+
 static int
-_channels_close(_channels *channels, int64_t cid, _PyChannelState **pchan)
+_channels_close(_channels *channels, int64_t cid, _PyChannelState **pchan,
+                int end, int force)
 {
     int res = -1;
     PyThread_acquire_lock(channels->mutex, WAIT_LOCK);
@@ -1028,14 +1072,35 @@ _channels_close(_channels *channels, int64_t cid, _PyChannelState **pchan)
         PyErr_Format(ChannelClosedError, "channel %d closed", cid);
         goto done;
     }
+    else if (!force && end == CHANNEL_SEND && ref->chan->closing != NULL) {
+        PyErr_Format(ChannelClosedError, "channel %d closed", cid);
+        goto done;
+    }
     else {
-        if (_channel_close_all(ref->chan) != 0) {
+        if (_channel_close_all(ref->chan, end, force) != 0) {
+            if (end == CHANNEL_SEND &&
+                    PyErr_ExceptionMatches(ChannelNotEmptyError)) {
+                if (ref->chan->closing != NULL) {
+                    PyErr_Format(ChannelClosedError, "channel %d closed", cid);
+                    goto done;
+                }
+                // Mark the channel as closing and return.  The channel
+                // will be cleaned up in _channel_next().
+                PyErr_Clear();
+                if (_channel_set_closing(ref, channels->mutex) != 0) {
+                    goto done;
+                }
+                if (pchan != NULL) {
+                    *pchan = ref->chan;
+                }
+                res = 0;
+            }
             goto done;
         }
         if (pchan != NULL) {
             *pchan = ref->chan;
         }
-        else {
+        else  {
             _channel_free(ref->chan);
         }
         ref->chan = NULL;
@@ -1161,6 +1226,60 @@ _channels_list_all(_channels *channels, int64_t *count)
     return cids;
 }
 
+/* support for closing non-empty channels */
+
+struct _channel_closing {
+    struct _channelref *ref;
+};
+
+static int
+_channel_set_closing(struct _channelref *ref, PyThread_type_lock mutex) {
+    struct _channel *chan = ref->chan;
+    if (chan == NULL) {
+        // already closed
+        return 0;
+    }
+    int res = -1;
+    PyThread_acquire_lock(chan->mutex, WAIT_LOCK);
+    if (chan->closing != NULL) {
+        PyErr_SetString(ChannelClosedError, "channel closed");
+        goto done;
+    }
+    chan->closing = PyMem_NEW(struct _channel_closing, 1);
+    if (chan->closing == NULL) {
+        goto done;
+    }
+    chan->closing->ref = ref;
+
+    res = 0;
+done:
+    PyThread_release_lock(chan->mutex);
+    return res;
+}
+
+static void
+_channel_clear_closing(struct _channel *chan) {
+    PyThread_acquire_lock(chan->mutex, WAIT_LOCK);
+    if (chan->closing != NULL) {
+        PyMem_Free(chan->closing);
+        chan->closing = NULL;
+    }
+    PyThread_release_lock(chan->mutex);
+}
+
+static void
+_channel_finish_closing(struct _channel *chan) {
+    struct _channel_closing *closing = chan->closing;
+    if (closing == NULL) {
+        return;
+    }
+    _channelref *ref = closing->ref;
+    _channel_clear_closing(chan);
+    // Do the things that would have been done in _channels_close().
+    ref->chan = NULL;
+    _channel_free(chan);
+};
+
 /* "high"-level channel-related functions */
 
 static int64_t
@@ -1207,6 +1326,12 @@ _channel_send(_channels *channels, int64_t id, PyObject *obj)
     }
     // Past this point we are responsible for releasing the mutex.
 
+    if (chan->closing != NULL) {
+        PyErr_Format(ChannelClosedError, "channel %d closed", id);
+        PyThread_release_lock(mutex);
+        return -1;
+    }
+
     // Convert the object to cross-interpreter data.
     _PyCrossInterpreterData *data = PyMem_NEW(_PyCrossInterpreterData, 1);
     if (data == NULL) {
@@ -1290,16 +1415,13 @@ _channel_drop(_channels *channels, int64_t id, int send, int recv)
 }
 
 static int
-_channel_close(_channels *channels, int64_t id)
+_channel_close(_channels *channels, int64_t id, int end, int force)
 {
-    return _channels_close(channels, id, NULL);
+    return _channels_close(channels, id, NULL, end, force);
 }
 
 /* ChannelID class */
 
-#define CHANNEL_SEND 1
-#define CHANNEL_RECV -1
-
 static PyTypeObject ChannelIDtype;
 
 typedef struct channelid {
@@ -2555,15 +2677,8 @@ channel_close(PyObject *self, PyObject *args, PyObject *kwds)
     if (cid < 0) {
         return NULL;
     }
-    if (send == 0 && recv == 0) {
-        send = 1;
-        recv = 1;
-    }
-
-    // XXX Handle the ends.
-    // XXX Handle force is True.
 
-    if (_channel_close(&_globals.channels, cid) != 0) {
+    if (_channel_close(&_globals.channels, cid, send-recv, force) != 0) {
         return NULL;
     }
     Py_RETURN_NONE;



More information about the Python-checkins mailing list