[Python-checkins] bpo-38005: Fixed comparing and creating of InterpreterID and ChannelID. (GH-15652)

Serhiy Storchaka webhook-mailer at python.org
Fri Sep 13 15:50:33 EDT 2019


https://github.com/python/cpython/commit/bf169915ecdd42329726104278eb723a7dda2736
commit: bf169915ecdd42329726104278eb723a7dda2736
branch: master
author: Serhiy Storchaka <storchaka at gmail.com>
committer: GitHub <noreply at github.com>
date: 2019-09-13T22:50:27+03:00
summary:

bpo-38005: Fixed comparing and creating of InterpreterID and ChannelID. (GH-15652)

* Fix a crash in comparing with float (and maybe other crashes).
* They are now never equal to strings and non-integer numbers.
* Comparison with a large number no longer raises OverflowError.
* Arbitrary exceptions no longer silenced in constructors and comparisons.
* TypeError raised in the constructor contains now the name of the type.
* Accept only ChannelID and int-like objects in channel functions.
* Accept only InterpreterId, int-like objects and str in the InterpreterId constructor.
* Accept int-like objects, not just int in interpreter related functions.

files:
A Misc/NEWS.d/next/Core and Builtins/2019-09-02-20-00-31.bpo-38005.e7VsTA.rst
M Include/cpython/interpreteridobject.h
M Lib/test/test__xxsubinterpreters.py
M Modules/_xxsubinterpretersmodule.c
M Objects/interpreteridobject.c

diff --git a/Include/cpython/interpreteridobject.h b/Include/cpython/interpreteridobject.h
index cb72c2b0895a..67ec5873542d 100644
--- a/Include/cpython/interpreteridobject.h
+++ b/Include/cpython/interpreteridobject.h
@@ -14,8 +14,6 @@ PyAPI_FUNC(PyObject *) _PyInterpreterID_New(int64_t);
 PyAPI_FUNC(PyObject *) _PyInterpreterState_GetIDObject(PyInterpreterState *);
 PyAPI_FUNC(PyInterpreterState *) _PyInterpreterID_LookUp(PyObject *);
 
-PyAPI_FUNC(int64_t) _Py_CoerceID(PyObject *);
-
 #ifdef __cplusplus
 }
 #endif
diff --git a/Lib/test/test__xxsubinterpreters.py b/Lib/test/test__xxsubinterpreters.py
index 78b2030a1f6d..f5474f4101d0 100644
--- a/Lib/test/test__xxsubinterpreters.py
+++ b/Lib/test/test__xxsubinterpreters.py
@@ -526,30 +526,23 @@ def test_with_int(self):
         self.assertEqual(int(id), 10)
 
     def test_coerce_id(self):
-        id = interpreters.InterpreterID('10', force=True)
-        self.assertEqual(int(id), 10)
-
-        id = interpreters.InterpreterID(10.0, force=True)
-        self.assertEqual(int(id), 10)
-
         class Int(str):
-            def __init__(self, value):
-                self._value = value
-            def __int__(self):
-                return self._value
+            def __index__(self):
+                return 10
 
-        id = interpreters.InterpreterID(Int(10), force=True)
-        self.assertEqual(int(id), 10)
+        for id in ('10', '1_0', Int()):
+            with self.subTest(id=id):
+                id = interpreters.InterpreterID(id, force=True)
+                self.assertEqual(int(id), 10)
 
     def test_bad_id(self):
-        for id in [-1, 'spam']:
-            with self.subTest(id):
-                with self.assertRaises(ValueError):
-                    interpreters.InterpreterID(id)
-        with self.assertRaises(OverflowError):
-            interpreters.InterpreterID(2**64)
-        with self.assertRaises(TypeError):
-            interpreters.InterpreterID(object())
+        self.assertRaises(TypeError, interpreters.InterpreterID, object())
+        self.assertRaises(TypeError, interpreters.InterpreterID, 10.0)
+        self.assertRaises(TypeError, interpreters.InterpreterID, b'10')
+        self.assertRaises(ValueError, interpreters.InterpreterID, -1)
+        self.assertRaises(ValueError, interpreters.InterpreterID, '-1')
+        self.assertRaises(ValueError, interpreters.InterpreterID, 'spam')
+        self.assertRaises(OverflowError, interpreters.InterpreterID, 2**64)
 
     def test_does_not_exist(self):
         id = interpreters.channel_create()
@@ -572,6 +565,14 @@ def test_equality(self):
         self.assertTrue(id1 == id1)
         self.assertTrue(id1 == id2)
         self.assertTrue(id1 == int(id1))
+        self.assertTrue(int(id1) == id1)
+        self.assertTrue(id1 == float(int(id1)))
+        self.assertTrue(float(int(id1)) == id1)
+        self.assertFalse(id1 == float(int(id1)) + 0.1)
+        self.assertFalse(id1 == str(int(id1)))
+        self.assertFalse(id1 == 2**1000)
+        self.assertFalse(id1 == float('inf'))
+        self.assertFalse(id1 == 'spam')
         self.assertFalse(id1 == id3)
 
         self.assertFalse(id1 != id1)
@@ -1105,30 +1106,20 @@ def test_with_kwargs(self):
         self.assertEqual(cid.end, 'both')
 
     def test_coerce_id(self):
-        cid = interpreters._channel_id('10', force=True)
-        self.assertEqual(int(cid), 10)
-
-        cid = interpreters._channel_id(10.0, force=True)
-        self.assertEqual(int(cid), 10)
-
         class Int(str):
-            def __init__(self, value):
-                self._value = value
-            def __int__(self):
-                return self._value
+            def __index__(self):
+                return 10
 
-        cid = interpreters._channel_id(Int(10), force=True)
+        cid = interpreters._channel_id(Int(), force=True)
         self.assertEqual(int(cid), 10)
 
     def test_bad_id(self):
-        for cid in [-1, 'spam']:
-            with self.subTest(cid):
-                with self.assertRaises(ValueError):
-                    interpreters._channel_id(cid)
-        with self.assertRaises(OverflowError):
-            interpreters._channel_id(2**64)
-        with self.assertRaises(TypeError):
-            interpreters._channel_id(object())
+        self.assertRaises(TypeError, interpreters._channel_id, object())
+        self.assertRaises(TypeError, interpreters._channel_id, 10.0)
+        self.assertRaises(TypeError, interpreters._channel_id, '10')
+        self.assertRaises(TypeError, interpreters._channel_id, b'10')
+        self.assertRaises(ValueError, interpreters._channel_id, -1)
+        self.assertRaises(OverflowError, interpreters._channel_id, 2**64)
 
     def test_bad_kwargs(self):
         with self.assertRaises(ValueError):
@@ -1164,6 +1155,14 @@ def test_equality(self):
         self.assertTrue(cid1 == cid1)
         self.assertTrue(cid1 == cid2)
         self.assertTrue(cid1 == int(cid1))
+        self.assertTrue(int(cid1) == cid1)
+        self.assertTrue(cid1 == float(int(cid1)))
+        self.assertTrue(float(int(cid1)) == cid1)
+        self.assertFalse(cid1 == float(int(cid1)) + 0.1)
+        self.assertFalse(cid1 == str(int(cid1)))
+        self.assertFalse(cid1 == 2**1000)
+        self.assertFalse(cid1 == float('inf'))
+        self.assertFalse(cid1 == 'spam')
         self.assertFalse(cid1 == cid3)
 
         self.assertFalse(cid1 != cid1)
diff --git a/Misc/NEWS.d/next/Core and Builtins/2019-09-02-20-00-31.bpo-38005.e7VsTA.rst b/Misc/NEWS.d/next/Core and Builtins/2019-09-02-20-00-31.bpo-38005.e7VsTA.rst
new file mode 100644
index 000000000000..706abf587b90
--- /dev/null
+++ b/Misc/NEWS.d/next/Core and Builtins/2019-09-02-20-00-31.bpo-38005.e7VsTA.rst	
@@ -0,0 +1 @@
+Fixed comparing and creating of InterpreterID and ChannelID.
diff --git a/Modules/_xxsubinterpretersmodule.c b/Modules/_xxsubinterpretersmodule.c
index 19d98fd96934..7842947e54ac 100644
--- a/Modules/_xxsubinterpretersmodule.c
+++ b/Modules/_xxsubinterpretersmodule.c
@@ -1405,6 +1405,34 @@ typedef struct channelid {
     _channels *channels;
 } channelid;
 
+static int
+channel_id_converter(PyObject *arg, void *ptr)
+{
+    int64_t cid;
+    if (PyObject_TypeCheck(arg, &ChannelIDtype)) {
+        cid = ((channelid *)arg)->id;
+    }
+    else if (PyIndex_Check(arg)) {
+        cid = PyLong_AsLongLong(arg);
+        if (cid == -1 && PyErr_Occurred()) {
+            return 0;
+        }
+        if (cid < 0) {
+            PyErr_Format(PyExc_ValueError,
+                        "channel ID must be a non-negative int, got %R", arg);
+            return 0;
+        }
+    }
+    else {
+        PyErr_Format(PyExc_TypeError,
+                     "channel ID must be an int, got %.100s",
+                     arg->ob_type->tp_name);
+        return 0;
+    }
+    *(int64_t *)ptr = cid;
+    return 1;
+}
+
 static channelid *
 newchannelid(PyTypeObject *cls, int64_t cid, int end, _channels *channels,
              int force, int resolve)
@@ -1437,28 +1465,16 @@ static PyObject *
 channelid_new(PyTypeObject *cls, PyObject *args, PyObject *kwds)
 {
     static char *kwlist[] = {"id", "send", "recv", "force", "_resolve", NULL};
-    PyObject *id;
+    int64_t cid;
     int send = -1;
     int recv = -1;
     int force = 0;
     int resolve = 0;
     if (!PyArg_ParseTupleAndKeywords(args, kwds,
-                                     "O|$pppp:ChannelID.__new__", kwlist,
-                                     &id, &send, &recv, &force, &resolve))
+                                     "O&|$pppp:ChannelID.__new__", kwlist,
+                                     channel_id_converter, &cid, &send, &recv, &force, &resolve))
         return NULL;
 
-    // Coerce and check the ID.
-    int64_t cid;
-    if (PyObject_TypeCheck(id, &ChannelIDtype)) {
-        cid = ((channelid *)id)->id;
-    }
-    else {
-        cid = _Py_CoerceID(id);
-        if (cid < 0) {
-            return NULL;
-        }
-    }
-
     // Handle "send" and "recv".
     if (send == 0 && recv == 0) {
         PyErr_SetString(PyExc_ValueError,
@@ -1592,30 +1608,28 @@ channelid_richcompare(PyObject *self, PyObject *other, int op)
     int equal;
     if (PyObject_TypeCheck(other, &ChannelIDtype)) {
         channelid *othercid = (channelid *)other;
-        if (cid->end != othercid->end) {
-            equal = 0;
-        }
-        else {
-            equal = (cid->id == othercid->id);
-        }
+        equal = (cid->end == othercid->end) && (cid->id == othercid->id);
     }
-    else {
-        other = PyNumber_Long(other);
-        if (other == NULL) {
-            PyErr_Clear();
-            Py_RETURN_NOTIMPLEMENTED;
-        }
-        int64_t othercid = PyLong_AsLongLong(other);
-        Py_DECREF(other);
-        if (othercid == -1 && PyErr_Occurred() != NULL) {
+    else if (PyLong_Check(other)) {
+        /* Fast path */
+        int overflow;
+        long long othercid = PyLong_AsLongLongAndOverflow(other, &overflow);
+        if (othercid == -1 && PyErr_Occurred()) {
             return NULL;
         }
-        if (othercid < 0) {
-            equal = 0;
-        }
-        else {
-            equal = (cid->id == othercid);
+        equal = !overflow && (othercid >= 0) && (cid->id == othercid);
+    }
+    else if (PyNumber_Check(other)) {
+        PyObject *pyid = PyLong_FromLongLong(cid->id);
+        if (pyid == NULL) {
+            return NULL;
         }
+        PyObject *res = PyObject_RichCompare(pyid, other, op);
+        Py_DECREF(pyid);
+        return res;
+    }
+    else {
+        Py_RETURN_NOTIMPLEMENTED;
     }
 
     if ((op == Py_EQ && equal) || (op == Py_NE && !equal)) {
@@ -1754,8 +1768,7 @@ static PyTypeObject ChannelIDtype = {
     0,                              /* tp_getattro */
     0,                              /* tp_setattro */
     0,                              /* tp_as_buffer */
-    Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE |
-        Py_TPFLAGS_LONG_SUBCLASS,   /* tp_flags */
+    Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
     channelid_doc,                  /* tp_doc */
     0,                              /* tp_traverse */
     0,                              /* tp_clear */
@@ -2017,10 +2030,6 @@ interp_destroy(PyObject *self, PyObject *args, PyObject *kwds)
                                      "O:destroy", kwlist, &id)) {
         return NULL;
     }
-    if (!PyLong_Check(id)) {
-        PyErr_SetString(PyExc_TypeError, "ID must be an int");
-        return NULL;
-    }
 
     // Look up the interpreter.
     PyInterpreterState *interp = _PyInterpreterID_LookUp(id);
@@ -2145,10 +2154,6 @@ interp_run_string(PyObject *self, PyObject *args, PyObject *kwds)
                                      &id, &code, &shared)) {
         return NULL;
     }
-    if (!PyLong_Check(id)) {
-        PyErr_SetString(PyExc_TypeError, "first arg (ID) must be an int");
-        return NULL;
-    }
 
     // Look up the interpreter.
     PyInterpreterState *interp = _PyInterpreterID_LookUp(id);
@@ -2216,10 +2221,6 @@ interp_is_running(PyObject *self, PyObject *args, PyObject *kwds)
                                      "O:is_running", kwlist, &id)) {
         return NULL;
     }
-    if (!PyLong_Check(id)) {
-        PyErr_SetString(PyExc_TypeError, "ID must be an int");
-        return NULL;
-    }
 
     PyInterpreterState *interp = _PyInterpreterID_LookUp(id);
     if (interp == NULL) {
@@ -2268,13 +2269,9 @@ static PyObject *
 channel_destroy(PyObject *self, PyObject *args, PyObject *kwds)
 {
     static char *kwlist[] = {"cid", NULL};
-    PyObject *id;
-    if (!PyArg_ParseTupleAndKeywords(args, kwds,
-                                     "O:channel_destroy", kwlist, &id)) {
-        return NULL;
-    }
-    int64_t cid = _Py_CoerceID(id);
-    if (cid < 0) {
+    int64_t cid;
+    if (!PyArg_ParseTupleAndKeywords(args, kwds, "O&:channel_destroy", kwlist,
+                                     channel_id_converter, &cid)) {
         return NULL;
     }
 
@@ -2331,14 +2328,10 @@ static PyObject *
 channel_send(PyObject *self, PyObject *args, PyObject *kwds)
 {
     static char *kwlist[] = {"cid", "obj", NULL};
-    PyObject *id;
+    int64_t cid;
     PyObject *obj;
-    if (!PyArg_ParseTupleAndKeywords(args, kwds,
-                                     "OO:channel_send", kwlist, &id, &obj)) {
-        return NULL;
-    }
-    int64_t cid = _Py_CoerceID(id);
-    if (cid < 0) {
+    if (!PyArg_ParseTupleAndKeywords(args, kwds, "O&O:channel_send", kwlist,
+                                     channel_id_converter, &cid, &obj)) {
         return NULL;
     }
 
@@ -2357,13 +2350,9 @@ static PyObject *
 channel_recv(PyObject *self, PyObject *args, PyObject *kwds)
 {
     static char *kwlist[] = {"cid", NULL};
-    PyObject *id;
-    if (!PyArg_ParseTupleAndKeywords(args, kwds,
-                                     "O:channel_recv", kwlist, &id)) {
-        return NULL;
-    }
-    int64_t cid = _Py_CoerceID(id);
-    if (cid < 0) {
+    int64_t cid;
+    if (!PyArg_ParseTupleAndKeywords(args, kwds, "O&:channel_recv", kwlist,
+                                     channel_id_converter, &cid)) {
         return NULL;
     }
 
@@ -2379,17 +2368,13 @@ static PyObject *
 channel_close(PyObject *self, PyObject *args, PyObject *kwds)
 {
     static char *kwlist[] = {"cid", "send", "recv", "force", NULL};
-    PyObject *id;
+    int64_t cid;
     int send = 0;
     int recv = 0;
     int force = 0;
     if (!PyArg_ParseTupleAndKeywords(args, kwds,
-                                     "O|$ppp:channel_close", kwlist,
-                                     &id, &send, &recv, &force)) {
-        return NULL;
-    }
-    int64_t cid = _Py_CoerceID(id);
-    if (cid < 0) {
+                                     "O&|$ppp:channel_close", kwlist,
+                                     channel_id_converter, &cid, &send, &recv, &force)) {
         return NULL;
     }
 
@@ -2431,17 +2416,13 @@ channel_release(PyObject *self, PyObject *args, PyObject *kwds)
 {
     // Note that only the current interpreter is affected.
     static char *kwlist[] = {"cid", "send", "recv", "force", NULL};
-    PyObject *id;
+    int64_t cid;
     int send = 0;
     int recv = 0;
     int force = 0;
     if (!PyArg_ParseTupleAndKeywords(args, kwds,
-                                     "O|$ppp:channel_release", kwlist,
-                                     &id, &send, &recv, &force)) {
-        return NULL;
-    }
-    int64_t cid = _Py_CoerceID(id);
-    if (cid < 0) {
+                                     "O&|$ppp:channel_release", kwlist,
+                                     channel_id_converter, &cid, &send, &recv, &force)) {
         return NULL;
     }
     if (send == 0 && recv == 0) {
@@ -2538,7 +2519,6 @@ PyInit__xxsubinterpreters(void)
     }
 
     /* Initialize types */
-    ChannelIDtype.tp_base = &PyLong_Type;
     if (PyType_Ready(&ChannelIDtype) != 0) {
         return NULL;
     }
diff --git a/Objects/interpreteridobject.c b/Objects/interpreteridobject.c
index 0a1dfa25795f..3edbb85e6ac3 100644
--- a/Objects/interpreteridobject.c
+++ b/Objects/interpreteridobject.c
@@ -5,38 +5,6 @@
 #include "interpreteridobject.h"
 
 
-int64_t
-_Py_CoerceID(PyObject *orig)
-{
-    PyObject *pyid = PyNumber_Long(orig);
-    if (pyid == NULL) {
-        if (PyErr_ExceptionMatches(PyExc_TypeError)) {
-            PyErr_Format(PyExc_TypeError,
-                         "'id' must be a non-negative int, got %R", orig);
-        }
-        else {
-            PyErr_Format(PyExc_ValueError,
-                         "'id' must be a non-negative int, got %R", orig);
-        }
-        return -1;
-    }
-    int64_t id = PyLong_AsLongLong(pyid);
-    Py_DECREF(pyid);
-    if (id == -1 && PyErr_Occurred() != NULL) {
-        if (!PyErr_ExceptionMatches(PyExc_OverflowError)) {
-            PyErr_Format(PyExc_ValueError,
-                         "'id' must be a non-negative int, got %R", orig);
-        }
-        return -1;
-    }
-    if (id < 0) {
-        PyErr_Format(PyExc_ValueError,
-                     "'id' must be a non-negative int, got %R", orig);
-        return -1;
-    }
-    return id;
-}
-
 typedef struct interpid {
     PyObject_HEAD
     int64_t id;
@@ -85,8 +53,31 @@ interpid_new(PyTypeObject *cls, PyObject *args, PyObject *kwds)
         id = ((interpid *)idobj)->id;
     }
     else {
-        id = _Py_CoerceID(idobj);
+        PyObject *pyid;
+        if (PyIndex_Check(idobj)) {
+            pyid = idobj;
+            Py_INCREF(pyid);
+        }
+        else if (PyUnicode_Check(idobj)) {
+            pyid = PyNumber_Long(idobj);
+            if (pyid == NULL) {
+                return NULL;
+            }
+        }
+        else {
+            PyErr_Format(PyExc_TypeError,
+                         "interpreter ID must be an int, got %.100s",
+                         idobj->ob_type->tp_name);
+            return NULL;
+        }
+        id = PyLong_AsLongLong(pyid);
+        Py_DECREF(pyid);
+        if (id == -1 && PyErr_Occurred()) {
+            return NULL;
+        }
         if (id < 0) {
+            PyErr_Format(PyExc_ValueError,
+                         "interpreter ID must be a non-negative int, got %R", idobj);
             return NULL;
         }
     }
@@ -202,23 +193,26 @@ interpid_richcompare(PyObject *self, PyObject *other, int op)
         interpid *otherid = (interpid *)other;
         equal = (id->id == otherid->id);
     }
-    else {
-        other = PyNumber_Long(other);
-        if (other == NULL) {
-            PyErr_Clear();
-            Py_RETURN_NOTIMPLEMENTED;
-        }
-        int64_t otherid = PyLong_AsLongLong(other);
-        Py_DECREF(other);
-        if (otherid == -1 && PyErr_Occurred() != NULL) {
+    else if (PyLong_CheckExact(other)) {
+        /* Fast path */
+        int overflow;
+        long long otherid = PyLong_AsLongLongAndOverflow(other, &overflow);
+        if (otherid == -1 && PyErr_Occurred()) {
             return NULL;
         }
-        if (otherid < 0) {
-            equal = 0;
-        }
-        else {
-            equal = (id->id == otherid);
+        equal = !overflow && (otherid >= 0) && (id->id == otherid);
+    }
+    else if (PyNumber_Check(other)) {
+        PyObject *pyid = PyLong_FromLongLong(id->id);
+        if (pyid == NULL) {
+            return NULL;
         }
+        PyObject *res = PyObject_RichCompare(pyid, other, op);
+        Py_DECREF(pyid);
+        return res;
+    }
+    else {
+        Py_RETURN_NOTIMPLEMENTED;
     }
 
     if ((op == Py_EQ && equal) || (op == Py_NE && !equal)) {
@@ -250,8 +244,7 @@ PyTypeObject _PyInterpreterID_Type = {
     0,                              /* tp_getattro */
     0,                              /* tp_setattro */
     0,                              /* tp_as_buffer */
-    Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE |
-        Py_TPFLAGS_LONG_SUBCLASS,   /* tp_flags */
+    Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
     interpid_doc,                   /* tp_doc */
     0,                              /* tp_traverse */
     0,                              /* tp_clear */
@@ -262,7 +255,7 @@ PyTypeObject _PyInterpreterID_Type = {
     0,                              /* tp_methods */
     0,                              /* tp_members */
     0,                              /* tp_getset */
-    &PyLong_Type,                   /* tp_base */
+    0,                              /* tp_base */
     0,                              /* tp_dict */
     0,                              /* tp_descr_get */
     0,                              /* tp_descr_set */
@@ -297,12 +290,17 @@ _PyInterpreterID_LookUp(PyObject *requested_id)
     if (PyObject_TypeCheck(requested_id, &_PyInterpreterID_Type)) {
         id = ((interpid *)requested_id)->id;
     }
-    else {
+    else if (PyIndex_Check(requested_id)) {
         id = PyLong_AsLongLong(requested_id);
         if (id == -1 && PyErr_Occurred() != NULL) {
             return NULL;
         }
         assert(id <= INT64_MAX);
     }
+    else {
+        PyErr_Format(PyExc_TypeError, "interpreter ID must be an int, got %.100s",
+                     requested_id->ob_type->tp_name);
+        return NULL;
+    }
     return _PyInterpreterState_LookUpID(id);
 }



More information about the Python-checkins mailing list