[pypy-svn] r46777 - in pypy/dist/pypy/rlib: . test

arigo at codespeak.net arigo at codespeak.net
Fri Sep 21 11:21:07 CEST 2007


Author: arigo
Date: Fri Sep 21 11:21:06 2007
New Revision: 46777

Modified:
   pypy/dist/pypy/rlib/rsocket_rffi.py
   pypy/dist/pypy/rlib/test/test_rsocket_rffi.py
Log:
Trying to do the right thing about the lifetime of the 'addr' of Address
instances.  See comments before the lock()/unlock() methods.

It fixes an issue where lltype and ll2ctypes complain about a structures
being used after they are freed, so I guess it also fixes a real issue
in the generated C code.


Modified: pypy/dist/pypy/rlib/rsocket_rffi.py
==============================================================================
--- pypy/dist/pypy/rlib/rsocket_rffi.py	(original)
+++ pypy/dist/pypy/rlib/rsocket_rffi.py	Fri Sep 21 11:21:06 2007
@@ -14,7 +14,7 @@
 # It's unclear if makefile() and SSL support belong here or only as
 # app-level code for PyPy.
 
-from pypy.rlib.objectmodel import instantiate
+from pypy.rlib.objectmodel import instantiate, keepalive_until_here
 from pypy.rlib import _rsocket_rffi as _c
 from pypy.rlib.rarithmetic import intmask
 from pypy.rpython.lltypesystem import lltype, rffi
@@ -58,32 +58,58 @@
             return A
 
     # default uninitialized value: NULL ptr
-    addr = lltype.nullptr(_c.sockaddr_ptr.TO)
+    addr_p = lltype.nullptr(_c.sockaddr_ptr.TO)
 
     def __init__(self, addr, addrlen):
-        self.addr = addr
+        self.addr_p = addr
         self.addrlen = addrlen
 
     def __del__(self):
-        addr = self.addr
+        addr = self.addr_p
         if addr:
             lltype.free(addr, flavor='raw')
 
     def setdata(self, addr, addrlen):
         # initialize self.addr and self.addrlen.  'addr' can be a different
         # pointer type than exactly sockaddr_ptr, and we cast it for you.
-        assert not self.addr
-        self.addr = rffi.cast(_c.sockaddr_ptr, addr)
+        assert not self.addr_p
+        self.addr_p = rffi.cast(_c.sockaddr_ptr, addr)
         self.addrlen = addrlen
     setdata._annspecialcase_ = 'specialize:ll'
 
+    # the following slightly strange interface is needed to manipulate
+    # what self.addr_p points to in a safe way.  The problem is that
+    # after inlining we might end up with operations that looks like:
+    #    addr = self.addr_p
+    #    <self is freed here, and its __del__ calls lltype.free()>
+    #    read from addr
+    # To prevent this we have to insert a keepalive after the last
+    # use of 'addr'.  The interface to do that is called lock()/unlock()
+    # because it strongly reminds callers not to forget unlock().
+    #
+    def lock(self, TYPE=_c.sockaddr):
+        """Return self.addr_p, cast as a pointer to TYPE.  Must call unlock()!
+        """
+        if not (self.minlen <= self.addrlen <= self.maxlen):
+            raise RSocketError("invalid address")
+        return rffi.cast(lltype.Ptr(TYPE), self.addr_p)
+    lock._annspecialcase_ = 'specialize:ll'
+
+    def unlock(self):
+        """To call after we're done with the pointer returned by lock().
+        Note that locking and unlocking costs nothing at run-time.
+        """
+        keepalive_until_here(self)
+
     def as_object(self, space):
         """Convert the address to an app-level object."""
         # If we don't know the address family, don't raise an
         # exception -- return it as a tuple.
-        family = rffi.cast(lltype.Signed, self.addr.c_sa_family)
+        addr = self.lock()
+        family = rffi.cast(lltype.Signed, addr.c_sa_family)
         datalen = self.addrlen - offsetof(_c.sockaddr, 'c_sa_data')
-        rawdata = ''.join([self.addr.c_sa_data[i] for i in range(datalen)])
+        rawdata = ''.join([addr.c_sa_data[i] for i in range(datalen)])
+        self.unlock()
         return space.newtuple([space.wrap(family),
                                space.wrap(rawdata)])
 
@@ -160,17 +186,13 @@
 class INETAddress(IPAddress):
     family = AF_INET
     struct = _c.sockaddr_in
-    maxlen = sizeof(struct)
+    maxlen = minlen = sizeof(struct)
 
     def __init__(self, host, port):
         makeipaddr(host, self)
-        a = self.as_sockaddr_in()
+        a = self.lock(_c.sockaddr_in)
         a.c_sin_port = htons(port)
-
-    def as_sockaddr_in(self):
-        if self.addrlen != INETAddress.maxlen:
-            raise RSocketError("invalid address")
-        return rffi.cast(lltype.Ptr(_c.sockaddr_in), self.addr)
+        self.unlock()
 
     def __repr__(self):
         try:
@@ -179,8 +201,10 @@
             return '<INETAddress ?>'
 
     def get_port(self):
-        a = self.as_sockaddr_in()
-        return ntohs(a.c_sin_port)
+        a = self.lock(_c.sockaddr_in)
+        port = ntohs(a.c_sin_port)
+        self.unlock()
+        return port
 
     def eq(self, other):   # __eq__() is not called by RPython :-/
         return (isinstance(other, INETAddress) and
@@ -206,8 +230,9 @@
         # XXX a bit of code duplication
         _, w_port = space.unpackiterable(w_address, 2)
         port = space.int_w(w_port)
-        a = self.as_sockaddr_in()
+        a = self.lock(_c.sockaddr_in)
         a.c_sin_port = htons(port)
+        self.unlock()
 
     def from_in_addr(in_addr):
         result = instantiate(INETAddress)
@@ -221,8 +246,9 @@
         return result
     from_in_addr = staticmethod(from_in_addr)
 
-    def extract_in_addr(self):
-        p = rffi.cast(rffi.VOIDP, self.as_sockaddr_in().sin_addr)
+    def lock_in_addr(self):
+        a = self.lock(_c.sockaddr_in)
+        p = rffi.cast(rffi.VOIDP, a.c_sin_addr)
         return p, sizeof(_c.in_addr)
 
 # ____________________________________________________________
@@ -230,19 +256,15 @@
 class INET6Address(IPAddress):
     family = AF_INET6
     struct = _c.sockaddr_in6
-    maxlen = sizeof(struct)
+    maxlen = minlen = sizeof(struct)
 
     def __init__(self, host, port, flowinfo=0, scope_id=0):
         makeipaddr(host, self)
-        a = self.as_sockaddr_in6()
+        a = self.lock(_c.sockaddr_in6)
         a.c_sin6_port = htons(port)
         a.c_sin6_flowinfo = flowinfo
         a.c_sin6_scope_id = scope_id
-
-    def as_sockaddr_in6(self):
-        if self.addrlen != INET6Address.maxlen:
-            raise RSocketError("invalid address")
-        return rffi.cast(lltype.Ptr(_c.sockaddr_in6), self.addr)
+        self.unlock()
 
     def __repr__(self):
         try:
@@ -254,16 +276,22 @@
             return '<INET6Address ?>'
 
     def get_port(self):
-        a = self.as_sockaddr_in6()
-        return ntohs(a.c_sin6_port)
+        a = self.lock(_c.sockaddr_in6)
+        port = ntohs(a.c_sin6_port)
+        self.unlock()
+        return port
 
     def get_flowinfo(self):
-        a = self.as_sockaddr_in6()
-        return a.c_sin6_flowinfo
+        a = self.lock(_c.sockaddr_in6)
+        flowinfo = a.c_sin6_flowinfo
+        self.unlock()
+        return flowinfo
 
     def get_scope_id(self):
-        a = self.as_sockaddr_in6()
-        return a.c_sin6_scope_id
+        a = self.lock(_c.sockaddr_in6)
+        scope_id = a.c_sin6_scope_id
+        self.unlock()
+        return scope_id
 
     def eq(self, other):   # __eq__() is not called by RPython :-/
         return (isinstance(other, INET6Address) and
@@ -303,10 +331,11 @@
         else:                 flowinfo = 0
         if len(pieces_w) > 3: scope_id = space.int_w(pieces_w[3])
         else:                 scope_id = 0
-        a = self.as_sockaddr_in6()
+        a = self.lock(_c.sockaddr_in6)
         a.c_sin6_port = htons(port)
         a.c_sin6_flowinfo = flowinfo
         a.c_sin6_scope_id = scope_id
+        self.unlock()
 
     def from_in6_addr(in6_addr):
         result = instantiate(INET6Address)
@@ -319,8 +348,9 @@
         return result
     from_in6_addr = staticmethod(from_in6_addr)
 
-    def extract_in_addr(self):
-        p = rffi.cast(rffi.VOIDP, self.as_sockaddr_in6().sin6_addr)
+    def lock_in_addr(self):
+        a = self.lock(_c.sockaddr_in6)
+        p = rffi.cast(rffi.VOIDP, a.c_sin6_addr)
         return p, sizeof(_c.in6_addr)
 
 # ____________________________________________________________
@@ -329,6 +359,7 @@
     class UNIXAddress(Address):
         family = AF_UNIX
         struct = _c.sockaddr_un
+        minlen = offsetof(_c.sockaddr_un, 'c_sun_path') + 1
         maxlen = sizeof(struct)
 
         def __init__(self, path):
@@ -348,11 +379,6 @@
             for i in range(len(path)):
                 sun.c_sun_path[i] = path[i]
 
-        def as_sockaddr_un(self):
-            if self.addrlen <= offsetof(_c.sockaddr_un, 'c_sun_path'):
-                raise RSocketError("invalid address")
-            return rffi.cast(lltype.Ptr(_c.sockaddr_un), self.addr)
-
         def __repr__(self):
             try:
                 return '<UNIXAddress %r>' % (self.get_path(),)
@@ -360,7 +386,7 @@
                 return '<UNIXAddress ?>'
 
         def get_path(self):
-            a = self.as_sockaddr_un()
+            a = self.lock(_c.sockaddr_un)
             maxlength = self.addrlen - offsetof(_c.sockaddr_un, 'c_sun_path')
             if _c.linux and a.c_sun_path[0] == '\x00':
                 # Linux abstract namespace
@@ -370,7 +396,9 @@
                 length = 0
                 while length < maxlength and a.c_sun_path[length] != '\x00':
                     length += 1
-            return ''.join([a.c_sun_path[i] for i in range(length)])
+            result = ''.join([a.c_sun_path[i] for i in range(length)])
+            self.unlock()
+            return result
 
         def eq(self, other):   # __eq__() is not called by RPython :-/
             return (isinstance(other, UNIXAddress) and
@@ -387,7 +415,7 @@
     class NETLINKAddress(Address):
         family = AF_NETLINK
         struct = _c.sockaddr_nl
-        maxlen = sizeof(struct)
+        maxlen = minlen = sizeof(struct)
 
         def __init__(self, pid, groups):
             addr = rffi.make(_c.sockaddr_nl)
@@ -396,16 +424,17 @@
             rffi.setintfield(addr, 'c_nl_pid', pid)
             rffi.setintfield(addr, 'c_nl_groups', groups)
 
-        def as_sockaddr_nl(self):
-            if self.addrlen != NETLINKAddress.maxlen:
-                raise RSocketError("invalid address")
-            return rffi.cast(lltype.Ptr(_c.sockaddr_nl), self.addr)
-
         def get_pid(self):
-            return self.as_sockaddr_nl().c_nl_pid
+            a = self.lock(_c.sockaddr_nl)
+            pid = a.c_nl_pid
+            self.unlock()
+            return pid
 
         def get_groups(self):
-            return self.as_sockaddr_nl().c_nl_groups
+            a = self.lock(_c.sockaddr_nl)
+            groups = a.c_nl_groups
+            self.unlock()
+            return groups
 
         def __repr__(self):
             return '<NETLINKAddress %r>' % (self.get_pid(), self.get_groups())
@@ -562,11 +591,14 @@
     def addr_from_object(self, space, w_address):
         return af_get(self.family).from_object(space, w_address)
 
+    # build a null address object, ready to be used as output argument to
+    # C functions that return an address.  It must be unlock()ed after you
+    # are done using addr_p.
     def _addrbuf(self):
         addr, maxlen = make_null_address(self.family)
         addrlen_p = lltype.malloc(_c.socklen_t_ptr.TO, flavor='raw')
         addrlen_p[0] = rffi.cast(_c.socklen_t, maxlen)
-        return addr, addrlen_p
+        return addr, addr.addr_p, addrlen_p
 
     def accept(self, SocketClass=None):
         """Wait for an incoming connection.
@@ -575,12 +607,13 @@
             SocketClass = RSocket
         if self._select(False) == 1:
             raise SocketTimeout
-        address, addrlen_p = self._addrbuf()
+        address, addr_p, addrlen_p = self._addrbuf()
         try:
-            newfd = _c.socketaccept(self.fd, address.addr, addrlen_p)
+            newfd = _c.socketaccept(self.fd, addr_p, addrlen_p)
             addrlen = addrlen_p[0]
         finally:
             lltype.free(addrlen_p, flavor='raw')
+            address.unlock()
         if _c.invalid_socket(newfd):
             raise self.error_handler()
         address.addrlen = addrlen
@@ -590,7 +623,9 @@
 
     def bind(self, address):
         """Bind the socket to a local address."""
-        res = _c.socketbind(self.fd, address.addr, address.addrlen)
+        addr = address.lock()
+        res = _c.socketbind(self.fd, addr, address.addrlen)
+        address.unlock()
         if res < 0:
             raise self.error_handler()
 
@@ -605,14 +640,17 @@
 
     def connect(self, address):
         """Connect the socket to a remote address."""
-        res = _c.socketconnect(self.fd, address.addr, address.addrlen)
+        addr = address.lock()
+        res = _c.socketconnect(self.fd, addr, address.addrlen)
+        address.unlock()
         if self.timeout > 0.0:
             errno = _c.geterrno()
             if res < 0 and errno == _c.EINPROGRESS:
                 timeout = self._select(True)
                 if timeout == 0:
-                    res = _c.socketconnect(self.fd, address.addr,
-                                           address.addrlen)
+                    addr = address.lock()
+                    res = _c.socketconnect(self.fd, addr, address.addrlen)
+                    address.unlock()
                 elif timeout == -1:
                     raise self.error_handler()
                 else:
@@ -624,14 +662,17 @@
     def connect_ex(self, address):
         """This is like connect(address), but returns an error code (the errno
         value) instead of raising an exception when an error occurs."""
-        res = _c.socketconnect(self.fd, address.addr, address.addrlen)
+        addr = address.lock()
+        res = _c.socketconnect(self.fd, addr, address.addrlen)
+        address.unlock()
         if self.timeout > 0.0:
             errno = _c.geterrno()
             if res < 0 and errno == _c.EINPROGRESS:
                 timeout = self._select(True)
                 if timeout == 0:
-                    res = _c.socketconnect(self.fd, address.addr,
-                                           address.addrlen)
+                    addr = address.lock()
+                    res = _c.socketconnect(self.fd, addr, address.addrlen)
+                    address.unlock()
                 elif timeout == -1:
                     return _c.geterrno()
                 else:
@@ -659,12 +700,13 @@
 
     def getpeername(self):
         """Return the address of the remote endpoint."""
-        address, addrlen_p = self._addrbuf()
+        address, addr_p, addrlen_p = self._addrbuf()
         try:
-            res = _c.socketgetpeername(self.fd, address.addr, addrlen_p)
+            res = _c.socketgetpeername(self.fd, addr_p, addrlen_p)
             addrlen = addrlen_p[0]
         finally:
             lltype.free(addrlen_p, flavor='raw')
+            address.unlock()
         if res < 0:
             raise self.error_handler()
         address.addrlen = addrlen
@@ -672,12 +714,13 @@
 
     def getsockname(self):
         """Return the address of the local endpoint."""
-        address, addrlen_p = self._addrbuf()
+        address, addr_p, addrlen_p = self._addrbuf()
         try:
-            res = _c.socketgetsockname(self.fd, address.addr, addrlen_p)
+            res = _c.socketgetsockname(self.fd, addr_p, addrlen_p)
             addrlen = addrlen_p[0]
         finally:
             lltype.free(addrlen_p, flavor='raw')
+            address.unlock()
         if res < 0:
             raise self.error_handler()
         address.addrlen = addrlen
@@ -765,13 +808,14 @@
         elif timeout == 0:
             buf = mallocbuf(buffersize)
             try:
-                address, addrlen_p = self._addrbuf()
+                address, addr_p, addrlen_p = self._addrbuf()
                 try:
                     read_bytes = _c.recvfrom(self.fd, buf, buffersize, flags,
-                                             address.addr, addrlen_p)
+                                             addr_p, addrlen_p)
                     addrlen = addrlen_p[0]
                 finally:
                     lltype.free(addrlen_p, flavor='raw')
+                    address.unlock()
                 if read_bytes >= 0:
                     if addrlen:
                         address.addrlen = addrlen
@@ -814,8 +858,10 @@
         if timeout == 1:
             raise SocketTimeout
         elif timeout == 0:
+            addr = address.lock()
             res = _c.sendto(self.fd, data, len(data), flags,
-                            address.addr, address.addrlen)
+                            addr, address.addrlen)
+            address.unlock()
         if res < 0:
             raise self.error_handler()
         return res
@@ -1018,8 +1064,11 @@
 def gethostbyaddr(ip):
     # XXX use gethostbyaddr_r() if available, and/or use locks if not
     addr = makeipaddr(ip)
-    p, size = addr.extract_in_addr()
-    hostent =_c.gethostbyaddr(p, size, addr.family)
+    p, size = addr.lock_in_addr()
+    try:
+        hostent = _c.gethostbyaddr(p, size, addr.family)
+    finally:
+        addr.unlock()
     return gethost_common(ip, hostent, addr)
 
 def getaddrinfo(host, port_or_service,
@@ -1078,14 +1127,16 @@
         raise RSocketError("protocol not found")
     return protoent.contents.p_proto
 
-def getnameinfo(addr, flags):
+def getnameinfo(address, flags):
     host = lltype.malloc(rffi.CCHARP.TO, NI_MAXHOST, flavor='raw')
     try:
         serv = lltype.malloc(rffi.CCHARP.TO, NI_MAXSERV, flavor='raw')
         try:
-            error =_c.getnameinfo(addr.addr, addr.addrlen,
+            addr = address.lock()
+            error =_c.getnameinfo(addr, address.addrlen,
                                   host, NI_MAXHOST,
                                   serv, NI_MAXSERV, flags)
+            address.unlock()
             if error:
                 raise GAIError(error)
             return rffi.charp2str(host), rffi.charp2str(serv)

Modified: pypy/dist/pypy/rlib/test/test_rsocket_rffi.py
==============================================================================
--- pypy/dist/pypy/rlib/test/test_rsocket_rffi.py	(original)
+++ pypy/dist/pypy/rlib/test/test_rsocket_rffi.py	Fri Sep 21 11:21:06 2007
@@ -56,6 +56,19 @@
         py.test.fail("could not find the 127.0.0.1 IPv4 address in %r"
                      % (address_list,))
 
+def test_gethostbyaddr():
+    name, aliases, address_list = gethostbyaddr('127.0.0.1')
+    allnames = [name] + aliases
+    for n in allnames:
+        assert isinstance(n, str)
+    assert 'localhost' in allnames
+    for a in address_list:
+        if isinstance(a, INETAddress) and a.get_host() == "127.0.0.1":
+            break  # ok
+    else:
+        py.test.fail("could not find the 127.0.0.1 IPv4 address in %r"
+                     % (address_list,))
+
 def test_socketpair():
     if sys.platform == "win32":
         py.test.skip('No socketpair on Windows')



More information about the Pypy-commit mailing list