[pypy-commit] cffi default: (fijal, arigo) (early sprint)

arigo noreply at buildbot.pypy.org
Thu Jun 21 16:37:03 CEST 2012


Author: Armin Rigo <arigo at tunes.org>
Branch: 
Changeset: r481:a225a58c9067
Date: 2012-06-21 16:36 +0200
http://bitbucket.org/cffi/cffi/changeset/a225a58c9067/

Log:	(fijal, arigo) (early sprint)

	cdata pointer comparison.

diff --git a/c/_ffi_backend.c b/c/_ffi_backend.c
--- a/c/_ffi_backend.c
+++ b/c/_ffi_backend.c
@@ -1148,34 +1148,52 @@
 
 static PyObject *cdata_richcompare(PyObject *v, PyObject *w, int op)
 {
-    CDataObject *obv, *obw;
-    int equal;
-    PyObject *res;
-
-    if (op != Py_EQ && op != Py_NE)
-        goto Unimplemented;
+    int res, full_order;
+    PyObject *pyres;
+    char *v_cdata, *w_cdata;
+
+    full_order = (op != Py_EQ && op != Py_NE);
 
     assert(CData_Check(v));
-    obv = (CDataObject *)v;
+    v_cdata = ((CDataObject *)v)->c_data;
+    if (full_order &&
+        (((CDataObject *)v)->c_type->ct_flags & CT_PRIMITIVE_ANY))
+        goto Error;
 
     if (w == Py_None) {
-        equal = (obv->c_data == NULL);
+        w_cdata = NULL;
     }
     else if (CData_Check(w)) {
-        obw = (CDataObject *)w;
-        equal = (obv->c_type == obw->c_type) && (obv->c_data == obw->c_data);
+        w_cdata = ((CDataObject *)w)->c_data;
+        if (full_order &&
+            (((CDataObject *)w)->c_type->ct_flags & CT_PRIMITIVE_ANY))
+            goto Error;
     }
     else
         goto Unimplemented;
 
-    res = (equal ^ (op == Py_NE)) ? Py_True : Py_False;
+    switch (op) {
+    case Py_EQ: res = (v_cdata == w_cdata); break;
+    case Py_NE: res = (v_cdata != w_cdata); break;
+    case Py_LT: res = (v_cdata <  w_cdata); break;
+    case Py_LE: res = (v_cdata <= w_cdata); break;
+    case Py_GT: res = (v_cdata >  w_cdata); break;
+    case Py_GE: res = (v_cdata >= w_cdata); break;
+    default: res = -1;
+    }
+    pyres = res ? Py_True : Py_False;
  done:
-    Py_INCREF(res);
-    return res;
+    Py_INCREF(pyres);
+    return pyres;
 
  Unimplemented:
-    res = Py_NotImplemented;
+    pyres = Py_NotImplemented;
     goto done;
+
+ Error:
+    PyErr_SetString(PyExc_TypeError,
+                    "cannot do comparison on a primitive cdata");
+    return NULL;
 }
 
 static long cdata_hash(CDataObject *cd)
diff --git a/cffi/backend_ctypes.py b/cffi/backend_ctypes.py
--- a/cffi/backend_ctypes.py
+++ b/cffi/backend_ctypes.py
@@ -1,4 +1,4 @@
-import ctypes, ctypes.util
+import ctypes, ctypes.util, operator
 from . import model
 
 class CTypesData(object):
@@ -76,10 +76,42 @@
         raise TypeError("cdata %r does not support iteration" % (
             self._get_c_name()),)
 
+    def _make_cmp(name):
+        cmpfunc = getattr(operator, name)
+        def cmp(self, other):
+            if isinstance(other, CTypesData):
+                return cmpfunc(self._convert_to_address(None),
+                               other._convert_to_address(None))
+            elif other is None:
+                return cmpfunc(self._convert_to_address(None), 0)
+            else:
+                return NotImplemented
+        cmp.func_name = name
+        return cmp
+
+    __eq__ = _make_cmp('__eq__')
+    __ne__ = _make_cmp('__ne__')
+    __lt__ = _make_cmp('__lt__')
+    __le__ = _make_cmp('__le__')
+    __gt__ = _make_cmp('__gt__')
+    __ge__ = _make_cmp('__ge__')
+
+    def __hash__(self):
+        return hash(type(self)) ^ hash(self._convert_to_address(None))
+
 
 class CTypesGenericPrimitive(CTypesData):
     __slots__ = []
 
+    def __eq__(self, other):
+        return self is other
+
+    def __ne__(self, other):
+        return self is not other
+
+    def __hash__(self):
+        return object.__hash__(self)
+
 
 class CTypesGenericArray(CTypesData):
     __slots__ = []
@@ -119,18 +151,6 @@
     def __nonzero__(self):
         return bool(self._address)
 
-    def __eq__(self, other):
-        if other is None:
-            return not bool(self._address)
-        return (type(self) is type(other) and
-                self._address == other._address)
-
-    def __ne__(self, other):
-        return not self.__eq__(other)
-
-    def __hash__(self):
-        return hash(type(self)) ^ hash(self._address)
-
     @classmethod
     def _to_ctypes(cls, value):
         if value is None:
diff --git a/testing/backend_tests.py b/testing/backend_tests.py
--- a/testing/backend_tests.py
+++ b/testing/backend_tests.py
@@ -65,6 +65,7 @@
         q = ffi.cast(c_decl, long(min - 1))
         assert ffi.typeof(q) is ffi.typeof(p) and int(q) == max
         assert q != p
+        assert int(q) == int(p)
         assert hash(q) != hash(p)   # unlikely
         py.test.raises(OverflowError, ffi.new, c_decl, min - 1)
         py.test.raises(OverflowError, ffi.new, c_decl, max + 1)
@@ -818,6 +819,67 @@
         assert p == s+0
         assert p+1 == s+1
 
+    def test_pointer_comparison(self):
+        ffi = FFI(backend=self.Backend())
+        s = ffi.new("short[]", range(100))
+        p = ffi.cast("short *", s)
+        assert (p <  s) is False
+        assert (p <= s) is True
+        assert (p == s) is True
+        assert (p != s) is False
+        assert (p >  s) is False
+        assert (p >= s) is True
+        assert (s <  p) is False
+        assert (s <= p) is True
+        assert (s == p) is True
+        assert (s != p) is False
+        assert (s >  p) is False
+        assert (s >= p) is True
+        q = p + 1
+        assert (q <  s) is False
+        assert (q <= s) is False
+        assert (q == s) is False
+        assert (q != s) is True
+        assert (q >  s) is True
+        assert (q >= s) is True
+        assert (s <  q) is True
+        assert (s <= q) is True
+        assert (s == q) is False
+        assert (s != q) is True
+        assert (s >  q) is False
+        assert (s >= q) is False
+        assert (q <  p) is False
+        assert (q <= p) is False
+        assert (q == p) is False
+        assert (q != p) is True
+        assert (q >  p) is True
+        assert (q >= p) is True
+        assert (p <  q) is True
+        assert (p <= q) is True
+        assert (p == q) is False
+        assert (p != q) is True
+        assert (p >  q) is False
+        assert (p >= q) is False
+        #
+        assert (None == s) is False
+        assert (None != s) is True
+        assert (s == None) is False
+        assert (s != None) is True
+        assert (None == q) is False
+        assert (None != q) is True
+        assert (q == None) is False
+        assert (q != None) is True
+
+    def test_no_integer_comparison(self):
+        ffi = FFI(backend=self.Backend())
+        x = ffi.cast("int", 123)
+        y = ffi.cast("int", 456)
+        py.test.raises(TypeError, "x < y")
+        #
+        z = ffi.cast("double", 78.9)
+        py.test.raises(TypeError, "x < z")
+        py.test.raises(TypeError, "z < y")
+
     def test_ffi_buffer_ptr(self):
         ffi = FFI(backend=self.Backend())
         a = ffi.new("short", 100)


More information about the pypy-commit mailing list