[pypy-svn] pypy default: Add support for tp_compare

amauryfa commits-noreply at bitbucket.org
Fri Apr 8 23:07:34 CEST 2011


Author: Amaury Forgeot d'Arc <amauryfa at gmail.com>
Branch: 
Changeset: r43238:48571fedd5a4
Date: 2011-04-08 16:28 +0200
http://bitbucket.org/pypy/pypy/changeset/48571fedd5a4/

Log:	Add support for tp_compare

diff --git a/pypy/module/cpyext/slotdefs.py b/pypy/module/cpyext/slotdefs.py
--- a/pypy/module/cpyext/slotdefs.py
+++ b/pypy/module/cpyext/slotdefs.py
@@ -6,7 +6,7 @@
     unaryfunc, wrapperfunc, ternaryfunc, PyTypeObjectPtr, binaryfunc,
     getattrfunc, getattrofunc, setattrofunc, lenfunc, ssizeargfunc,
     ssizessizeargfunc, ssizeobjargproc, iternextfunc, initproc, richcmpfunc,
-    hashfunc, descrgetfunc, descrsetfunc, objobjproc)
+    cmpfunc, hashfunc, descrgetfunc, descrsetfunc, objobjproc)
 from pypy.module.cpyext.pyobject import from_ref
 from pypy.module.cpyext.pyerrors import PyErr_Occurred
 from pypy.module.cpyext.state import State
@@ -197,10 +197,9 @@
     def inner(space, w_self, w_args, func):
         func_target = rffi.cast(richcmpfunc, func)
         check_num_args(space, w_args, 1)
-        args_w = space.fixedview(w_args)
-        other_w = args_w[0]
+        w_other, = space.fixedview(w_args)
         return generic_cpy_call(space, func_target,
-            w_self, other_w, rffi.cast(rffi.INT_real, OP_CONST))
+            w_self, w_other, rffi.cast(rffi.INT_real, OP_CONST))
     return inner
 
 richcmp_eq = get_richcmp_func(Py_EQ)
@@ -210,6 +209,21 @@
 richcmp_gt = get_richcmp_func(Py_GT)
 richcmp_ge = get_richcmp_func(Py_GE)
 
+def wrap_cmpfunc(space, w_self, w_args, func):
+    func_target = rffi.cast(cmpfunc, func)
+    check_num_args(space, w_args, 1)
+    w_other, = space.fixedview(w_args)
+
+    if not space.is_true(space.issubtype(space.type(w_self),
+                                         space.type(w_other))):
+        raise OperationError(space.w_TypeError, space.wrap(
+            "%s.__cmp__(x,y) requires y to be a '%s', not a '%s'" %
+            space.type(w_self).getname(space),
+            space.type(w_self).getname(space),
+            space.type(w_other).getname(space)))
+
+    return space.wrap(generic_cpy_call(space, func_target, w_self, w_other))
+
 @cpython_api([PyTypeObjectPtr, PyObject, PyObject], PyObject, external=False)
 def slot_tp_new(space, type, w_args, w_kwds):
     from pypy.module.cpyext.tupleobject import PyTuple_Check

diff --git a/pypy/module/cpyext/test/comparisons.c b/pypy/module/cpyext/test/comparisons.c
--- a/pypy/module/cpyext/test/comparisons.c
+++ b/pypy/module/cpyext/test/comparisons.c
@@ -69,12 +69,31 @@
 };
 
 
+static int cmp_compare(PyObject *self, PyObject *other) {
+    return -1;
+}
+
+PyTypeObject OldCmpType = {
+    PyVarObject_HEAD_INIT(NULL, 0)
+    "comparisons.OldCmpType",                       /* tp_name */
+    sizeof(CmpObject),                              /* tp_basicsize */
+    0,                                              /* tp_itemsize */
+    0,                                              /* tp_dealloc */
+    0,                                              /* tp_print */
+    0,                                              /* tp_getattr */
+    0,                                              /* tp_setattr */
+    (cmpfunc)cmp_compare,                           /* tp_compare */
+};
+
+
 void initcomparisons(void)
 {
     PyObject *m, *d;
 
     if (PyType_Ready(&CmpType) < 0)
         return;
+    if (PyType_Ready(&OldCmpType) < 0)
+        return;
     m = Py_InitModule("comparisons", NULL);
     if (m == NULL)
         return;
@@ -83,4 +102,6 @@
         return;
     if (PyDict_SetItemString(d, "CmpType", (PyObject *)&CmpType) < 0)
         return;
+    if (PyDict_SetItemString(d, "OldCmpType", (PyObject *)&OldCmpType) < 0)
+        return;
 }

diff --git a/pypy/module/cpyext/test/test_typeobject.py b/pypy/module/cpyext/test/test_typeobject.py
--- a/pypy/module/cpyext/test/test_typeobject.py
+++ b/pypy/module/cpyext/test/test_typeobject.py
@@ -213,6 +213,11 @@
 
         assert cmpr.__le__(4) is NotImplemented
 
+    def test_tpcompare(self):
+        module = self.import_module("comparisons")
+        cmpr = module.OldCmpType()
+        assert cmpr < cmpr
+
     def test_hash(self):
         module = self.import_module("comparisons")
         cmpr = module.CmpType()


More information about the Pypy-commit mailing list