[Python-checkins] cpython: Issue #11707: Fast C version of functools.cmp_to_key()

raymond.hettinger python-checkins at python.org
Tue Apr 5 11:34:06 CEST 2011


http://hg.python.org/cpython/rev/a03fb2fc3ed8
changeset:   69150:a03fb2fc3ed8
user:        Raymond Hettinger <python at rcn.com>
date:        Tue Apr 05 02:33:54 2011 -0700
summary:
  Issue #11707: Fast C version of functools.cmp_to_key()

files:
  Lib/functools.py           |    7 +-
  Lib/test/test_functools.py |   66 ++++++++++-
  Misc/NEWS                  |    3 +
  Modules/_functoolsmodule.c |  161 +++++++++++++++++++++++++
  4 files changed, 235 insertions(+), 2 deletions(-)


diff --git a/Lib/functools.py b/Lib/functools.py
--- a/Lib/functools.py
+++ b/Lib/functools.py
@@ -97,7 +97,7 @@
     """Convert a cmp= function into a key= function"""
     class K(object):
         __slots__ = ['obj']
-        def __init__(self, obj, *args):
+        def __init__(self, obj):
             self.obj = obj
         def __lt__(self, other):
             return mycmp(self.obj, other.obj) < 0
@@ -115,6 +115,11 @@
             raise TypeError('hash not implemented')
     return K
 
+try:
+    from _functools import cmp_to_key
+except ImportError:
+    pass
+
 _CacheInfo = namedtuple("CacheInfo", "hits misses maxsize currsize")
 
 def lru_cache(maxsize=100):
diff --git a/Lib/test/test_functools.py b/Lib/test/test_functools.py
--- a/Lib/test/test_functools.py
+++ b/Lib/test/test_functools.py
@@ -435,18 +435,81 @@
         self.assertEqual(self.func(add, d), "".join(d.keys()))
 
 class TestCmpToKey(unittest.TestCase):
+
     def test_cmp_to_key(self):
+        def cmp1(x, y):
+            return (x > y) - (x < y)
+        key = functools.cmp_to_key(cmp1)
+        self.assertEqual(key(3), key(3))
+        self.assertGreater(key(3), key(1))
+        def cmp2(x, y):
+            return int(x) - int(y)
+        key = functools.cmp_to_key(cmp2)
+        self.assertEqual(key(4.0), key('4'))
+        self.assertLess(key(2), key('35'))
+
+    def test_cmp_to_key_arguments(self):
+        def cmp1(x, y):
+            return (x > y) - (x < y)
+        key = functools.cmp_to_key(mycmp=cmp1)
+        self.assertEqual(key(obj=3), key(obj=3))
+        self.assertGreater(key(obj=3), key(obj=1))
+        with self.assertRaises((TypeError, AttributeError)):
+            key(3) > 1    # rhs is not a K object
+        with self.assertRaises((TypeError, AttributeError)):
+            1 < key(3)    # lhs is not a K object
+        with self.assertRaises(TypeError):
+            key = functools.cmp_to_key()             # too few args
+        with self.assertRaises(TypeError):
+            key = functools.cmp_to_key(cmp1, None)   # too many args
+        key = functools.cmp_to_key(cmp1)
+        with self.assertRaises(TypeError):
+            key()                                    # too few args
+        with self.assertRaises(TypeError):
+            key(None, None)                          # too many args
+
+    def test_bad_cmp(self):
+        def cmp1(x, y):
+            raise ZeroDivisionError
+        key = functools.cmp_to_key(cmp1)
+        with self.assertRaises(ZeroDivisionError):
+            key(3) > key(1)
+
+        class BadCmp:
+            def __lt__(self, other):
+                raise ZeroDivisionError
+        def cmp1(x, y):
+            return BadCmp()
+        with self.assertRaises(ZeroDivisionError):
+            key(3) > key(1)
+
+    def test_obj_field(self):
+        def cmp1(x, y):
+            return (x > y) - (x < y)
+        key = functools.cmp_to_key(mycmp=cmp1)
+        self.assertEqual(key(50).obj, 50)
+
+    def test_sort_int(self):
         def mycmp(x, y):
             return y - x
         self.assertEqual(sorted(range(5), key=functools.cmp_to_key(mycmp)),
                          [4, 3, 2, 1, 0])
 
+    def test_sort_int_str(self):
+        def mycmp(x, y):
+            x, y = int(x), int(y)
+            return (x > y) - (x < y)
+        values = [5, '3', 7, 2, '0', '1', 4, '10', 1]
+        values = sorted(values, key=functools.cmp_to_key(mycmp))
+        self.assertEqual([int(value) for value in values],
+                         [0, 1, 1, 2, 3, 4, 5, 7, 10])
+
     def test_hash(self):
         def mycmp(x, y):
             return y - x
         key = functools.cmp_to_key(mycmp)
         k = key(10)
-        self.assertRaises(TypeError, hash(k))
+        self.assertRaises(TypeError, hash, k)
 
 class TestTotalOrdering(unittest.TestCase):
 
@@ -655,6 +718,7 @@
 
 def test_main(verbose=None):
     test_classes = (
+        TestCmpToKey,
         TestPartial,
         TestPartialSubclass,
         TestPythonPartial,
diff --git a/Misc/NEWS b/Misc/NEWS
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -97,6 +97,9 @@
 - Issue #10791: Implement missing method GzipFile.read1(), allowing GzipFile
   to be wrapped in a TextIOWrapper.  Patch by Nadeem Vawda.
 
+- Issue #11707: Added a fast C version of functools.cmp_to_key().
+  Patch by Filip Gruszczyński.
+
 - Issue #11688: Add sqlite3.Connection.set_trace_callback().  Patch by
   Torsten Landschoff.
 
diff --git a/Modules/_functoolsmodule.c b/Modules/_functoolsmodule.c
--- a/Modules/_functoolsmodule.c
+++ b/Modules/_functoolsmodule.c
@@ -330,6 +330,165 @@
 };
 
 
+/* cmp_to_key ***************************************************************/
+
+typedef struct {
+    PyObject_HEAD;
+    PyObject *cmp;
+    PyObject *object;
+} keyobject;
+
+static void
+keyobject_dealloc(keyobject *ko)
+{
+    Py_DECREF(ko->cmp);
+    Py_XDECREF(ko->object);
+    PyObject_FREE(ko);
+}
+
+static int
+keyobject_traverse(keyobject *ko, visitproc visit, void *arg)
+{
+    Py_VISIT(ko->cmp);
+    if (ko->object)
+        Py_VISIT(ko->object);
+    return 0;
+}
+
+static PyMemberDef keyobject_members[] = {
+    {"obj", T_OBJECT,
+     offsetof(keyobject, object), 0,
+     PyDoc_STR("Value wrapped by a key function.")},
+    {NULL}
+};
+
+static PyObject *
+keyobject_call(keyobject *ko, PyObject *args, PyObject *kw);
+
+static PyObject *
+keyobject_richcompare(PyObject *ko, PyObject *other, int op);
+
+static PyTypeObject keyobject_type = {
+    PyVarObject_HEAD_INIT(&PyType_Type, 0)
+    "functools.KeyWrapper",             /* tp_name */
+    sizeof(keyobject),                  /* tp_basicsize */
+    0,                                  /* tp_itemsize */
+    /* methods */
+    (destructor)keyobject_dealloc,      /* tp_dealloc */
+    0,                                  /* tp_print */
+    0,                                  /* tp_getattr */
+    0,                                  /* tp_setattr */
+    0,                                  /* tp_reserved */
+    0,                                  /* tp_repr */
+    0,                                  /* tp_as_number */
+    0,                                  /* tp_as_sequence */
+    0,                                  /* tp_as_mapping */
+    0,                                  /* tp_hash */
+    (ternaryfunc)keyobject_call,        /* tp_call */
+    0,                                  /* tp_str */
+    PyObject_GenericGetAttr,            /* tp_getattro */
+    0,                                  /* tp_setattro */
+    0,                                  /* tp_as_buffer */
+    Py_TPFLAGS_DEFAULT,                 /* tp_flags */
+    0,                                  /* tp_doc */
+    (traverseproc)keyobject_traverse,   /* tp_traverse */
+    0,                                  /* tp_clear */
+    keyobject_richcompare,              /* tp_richcompare */
+    0,                                  /* tp_weaklistoffset */
+    0,                                  /* tp_iter */
+    0,                                  /* tp_iternext */
+    0,                                  /* tp_methods */
+    keyobject_members,                  /* tp_members */
+    0,                                  /* tp_getset */
+};
+
+static PyObject *
+keyobject_call(keyobject *ko, PyObject *args, PyObject *kwds)
+{
+    PyObject *object;
+    keyobject *result;
+    static char *kwargs[] = {"obj", NULL};
+
+    if (!PyArg_ParseTupleAndKeywords(args, kwds, "O:K", kwargs, &object))
+        return NULL;
+    result = PyObject_New(keyobject, &keyobject_type);
+    if (!result)
+        return NULL;
+    Py_INCREF(ko->cmp);
+    result->cmp = ko->cmp;
+    Py_INCREF(object);
+    result->object = object;
+    return (PyObject *)result;
+}
+
+static PyObject *
+keyobject_richcompare(PyObject *ko, PyObject *other, int op)
+{
+    PyObject *res;
+    PyObject *args;
+    PyObject *x;
+    PyObject *y;
+    PyObject *compare;
+    PyObject *answer;
+    static PyObject *zero;
+
+    if (zero == NULL) {
+        zero = PyLong_FromLong(0);
+        if (!zero)
+            return NULL;
+    }
+
+    if (Py_TYPE(other) != &keyobject_type){
+        PyErr_Format(PyExc_TypeError, "other argument must be K instance");
+        return NULL;
+    }
+    compare = ((keyobject *) ko)->cmp;
+    assert(compare != NULL);
+    x = ((keyobject *) ko)->object;
+    y = ((keyobject *) other)->object;
+    if (!x || !y){
+        PyErr_Format(PyExc_AttributeError, "object");
+        return NULL;
+    }
+
+    /* Call the user's comparison function and translate the 3-way
+     * result into true or false (or error).
+     */
+    args = PyTuple_New(2);
+    if (args == NULL)
+        return NULL;
+    Py_INCREF(x);
+    Py_INCREF(y);
+    PyTuple_SET_ITEM(args, 0, x);
+    PyTuple_SET_ITEM(args, 1, y);
+    res = PyObject_Call(compare, args, NULL);
+    Py_DECREF(args);
+    if (res == NULL)
+        return NULL;
+    answer = PyObject_RichCompare(res, zero, op);
+    Py_DECREF(res);
+    return answer;
+}
+
+static PyObject *
+functools_cmp_to_key(PyObject *self, PyObject *args, PyObject *kwds){
+  PyObject *cmp;
+    static char *kwargs[] = {"mycmp", NULL};
+
+    if (!PyArg_ParseTupleAndKeywords(args, kwds, "O:cmp_to_key", kwargs, &cmp))
+        return NULL;
+    keyobject *object = PyObject_New(keyobject, &keyobject_type);
+    if (!object)
+        return NULL;
+    Py_INCREF(cmp);
+    object->cmp = cmp;
+    object->object = NULL;
+    return (PyObject *)object;
+}
+
+PyDoc_STRVAR(functools_cmp_to_key_doc,
+"Convert a cmp= function into a key= function.");
+
 /* reduce (used to be a builtin) ********************************************/
 
 static PyObject *
@@ -413,6 +572,8 @@
 
 static PyMethodDef module_methods[] = {
     {"reduce",          functools_reduce,     METH_VARARGS, functools_reduce_doc},
+    {"cmp_to_key",      functools_cmp_to_key, METH_VARARGS | METH_KEYWORDS,
+     functools_cmp_to_key_doc},
     {NULL,              NULL}           /* sentinel */
 };
 

-- 
Repository URL: http://hg.python.org/cpython


More information about the Python-checkins mailing list