[pypy-svn] r33679 - in pypy/dist/pypy/module/rsocket: . test

ac at codespeak.net ac at codespeak.net
Tue Oct 24 19:34:56 CEST 2006


Author: ac
Date: Tue Oct 24 19:34:55 2006
New Revision: 33679

Modified:
   pypy/dist/pypy/module/rsocket/ctypes_socket.py
   pypy/dist/pypy/module/rsocket/rsocket.py
   pypy/dist/pypy/module/rsocket/test/test_rsocket.py
Log:
Automagically register address families when creating a new
Address class.
Add support for AF_NETLINK.



Modified: pypy/dist/pypy/module/rsocket/ctypes_socket.py
==============================================================================
--- pypy/dist/pypy/module/rsocket/ctypes_socket.py	(original)
+++ pypy/dist/pypy/module/rsocket/ctypes_socket.py	Tue Oct 24 19:34:55 2006
@@ -25,11 +25,14 @@
             'arpa/inet.h',
             'stdint.h', 
             )
+cond_includes = [('AF_NETLINK', 'linux/netlink.h')]
 HEADER = ''.join(['#include <%s>\n' % filename for filename in includes])
+COND_HEADER = ''.join(['#ifdef %s\n#include <%s>\n#endif\n' % cond_include
+                      for cond_include in cond_includes])
 constants = {}
 
 class CConfig:
-    _header_ = HEADER
+    _header_ = HEADER + COND_HEADER
     # constants
     O_NONBLOCK = ctypes_platform.ConstantInteger('O_NONBLOCK')
     F_GETFL = ctypes_platform.ConstantInteger('F_GETFL')
@@ -167,7 +170,15 @@
 
 CConfig.sockaddr_un = ctypes_platform.Struct('struct sockaddr_un',
                                              [('sun_family', c_int),
-                                              ('sun_path', c_ubyte * 0)])
+                                              ('sun_path', c_ubyte * 0)],
+                                             ifdef='AF_UNIX')
+
+CConfig.sockaddr_nl = ctypes_platform.Struct('struct sockaddr_nl',
+                                             [('nl_family', c_int),
+                                              ('nl_pid', c_int),
+                                              ('nl_groups', c_int)],
+                                             ifdef='AF_NETLINK')
+                                             
 
 addrinfo_ptr = POINTER("addrinfo")
 CConfig.addrinfo = ctypes_platform.Struct('struct addrinfo',
@@ -258,6 +269,8 @@
 sockaddr_in = cConfig.sockaddr_in
 sockaddr_in6 = cConfig.sockaddr_in6
 sockaddr_un = cConfig.sockaddr_un
+if cConfig.sockaddr_nl is not None:
+    sockaddr_nl = cConfig.sockaddr_nl
 in_addr = cConfig.in_addr
 in_addr_size = sizeof(in_addr)
 in6_addr = cConfig.in6_addr

Modified: pypy/dist/pypy/module/rsocket/rsocket.py
==============================================================================
--- pypy/dist/pypy/module/rsocket/rsocket.py	(original)
+++ pypy/dist/pypy/module/rsocket/rsocket.py	Tue Oct 24 19:34:55 2006
@@ -27,11 +27,20 @@
 htonl = _c.htonl
 
 
+_FAMILIES = {}
 class Address(object):
     """The base class for RPython-level objects representing addresses.
     Fields:  addr    - a _c.sockaddr structure
              addrlen - size used within 'addr'
     """
+    class __metaclass__(type):
+        def __new__(cls, name, bases, dict):
+            family = dict.get('family')
+            A = type.__new__(cls, name, bases, dict)
+            if family is not None:
+                _FAMILIES[family] = A
+            return A
+
     def __init__(self, addr, addrlen):
         self.addr = addr
         self.addrlen = addrlen
@@ -355,14 +364,45 @@
         return UNIXAddress(space.str_w(w_address))
     from_object = staticmethod(from_object)
 
-# ____________________________________________________________
+if 'AF_NETLINK' in constants:
+    class NETLINKAddress(Address):
+        family = AF_NETLINK
+        struct = _c.sockaddr_nl
+        maxlen = sizeof(struct)
+
+        def __init__(self, pid, groups):
+            addr = _c.sockaddr_nl(nl_family = AF_NETLINK)
+            addr.nl_pid = pid
+            addr.nl_groups = groups
+            self.addr = cast(pointer(addr), _c.sockaddr_ptr).contents
+            self.addrlen = sizeof(addr)
+
+        def as_sockaddr_nl(self):
+            if self.addrlen != NETLINKAddress.maxlen:
+                raise RSocketError("invalid address")
+            return cast(pointer(self.addr), POINTER(_c.sockaddr_nl)).contents
 
-_FAMILIES = {}
-for klass in [INETAddress,
-              INET6Address,
-              UNIXAddress]:
-    if klass.family is not None:
-        _FAMILIES[klass.family] = klass
+        def get_pid(self):
+            return self.as_sockaddr_nl().nl_pid
+
+        def get_groups(self):
+            return self.as_sockaddr_nl().nl_groups
+
+        def __repr__(self):
+            return '<NETLINKAddress %r>' % (self.get_pid(), self.get_groups())
+        
+        def as_object(self, space):
+            return space.wrap(self.get_pid(), self.get_groups())
+
+        def from_object(space, w_address):
+            try:
+                w_pid, w_groups = space.unpackiterable(w_address, 2)
+            except ValueError:
+                raise TypeError("AF_NETLINK address must be a tuple of length 2")
+            return NETLINKAddress(space.int_w(w_pid), space.int_w(w_group))
+        from_object = staticmethod(from_object)
+
+# ____________________________________________________________
 
 def familyclass(family):
     return _FAMILIES.get(family, Address)

Modified: pypy/dist/pypy/module/rsocket/test/test_rsocket.py
==============================================================================
--- pypy/dist/pypy/module/rsocket/test/test_rsocket.py	(original)
+++ pypy/dist/pypy/module/rsocket/test/test_rsocket.py	Tue Oct 24 19:34:55 2006
@@ -1,4 +1,5 @@
 import py, errno
+from pypy.module.rsocket import rsocket
 from pypy.module.rsocket.rsocket import *
 
 def test_ipv4_addr():
@@ -19,6 +20,15 @@
     a = UNIXAddress("/tmp/socketname")
     assert a.get_path() == "/tmp/socketname"
 
+def test_netlink_addr():
+    if getattr(rsocket, 'AF_NETLINK', None) is None:
+        py.test.skip('AF_NETLINK not supported.')
+    pid = 1
+    group_mask = 64 + 32
+    a = NETLINKAddress(pid, group_mask)
+    assert a.get_pid() == pid
+    assert a.get_groups() == group_mask
+    
 def test_gethostname():
     s = gethostname()
     assert isinstance(s, str)



More information about the Pypy-commit mailing list