? doc/examples/echoclient_tls.py ? doc/examples/echoserv_tls.py Index: twisted/internet/ssl.py =================================================================== RCS file: /cvs/Twisted/twisted/internet/ssl.py,v retrieving revision 1.40 diff -u -r1.40 ssl.py --- twisted/internet/ssl.py 2 Apr 2003 04:11:32 -0000 1.40 +++ twisted/internet/ssl.py 3 May 2003 06:28:42 -0000 @@ -95,116 +95,13 @@ return SSL.Context(SSL.SSLv3_METHOD) -class Connection(tcp.Connection): - """I am an SSL connection. - """ - - __implements__ = tcp.Connection.__implements__, interfaces.ISSLTransport - - writeBlockedOnRead = 0 - readBlockedOnWrite= 0 - sslShutdown = 0 - - def getPeerCertificate(self): - """Return the certificate for the peer.""" - return self.socket.get_peer_certificate() - - def _postLoseConnection(self): - """Gets called after loseConnection(), after buffered data is sent. - - We close the SSL transport layer, and if the other side hasn't - closed it yet we start reading, waiting for a ZeroReturnError - which will indicate the SSL shutdown has completed. - """ - try: - done = self.socket.shutdown() - self.sslShutdown = 1 - except SSL.Error: - return main.CONNECTION_LOST - if done: - return main.CONNECTION_DONE - else: - # we wait for other side to close SSL connection - - # this will be signaled by SSL.ZeroReturnError when reading - # from the socket - self.stopWriting() - self.startReading() - return None # don't close socket just yet - - def doRead(self): - """See tcp.Connection.doRead for details. - """ - if self.writeBlockedOnRead: - self.writeBlockedOnRead = 0 - return self.doWrite() - try: - return tcp.Connection.doRead(self) - except SSL.ZeroReturnError: - # close SSL layer, since other side has done so, if we haven't - if not self.sslShutdown: - try: - self.socket.shutdown() - self.sslShutdown = 1 - except SSL.Error: - pass - return main.CONNECTION_DONE - except SSL.WantReadError: - return - except SSL.WantWriteError: - self.readBlockedOnWrite = 1 - self.startWriting() - return - except SSL.Error: - return main.CONNECTION_LOST - - def doWrite(self): - if self.readBlockedOnWrite: - self.readBlockedOnWrite = 0 - if not self.dataBuffer: self.stopWriting() - return self.doRead() - return tcp.Connection.doWrite(self) - - def writeSomeData(self, data): - """See tcp.Connection.writeSomeData for details. - """ - if not data: - return 0 - - try: - return tcp.Connection.writeSomeData(self, data) - except SSL.WantWriteError: - return 0 - except SSL.WantReadError: - self.writeBlockedOnRead = 1 - return 0 - except SSL.Error: - return main.CONNECTION_LOST - - def _closeSocket(self): - """Called to close our socket.""" - try: - self.socket.sock_shutdown(2) - except socket.error: - try: - self.socket.close() - except socket.error: - log.deferr() - - - -class Client(Connection, tcp.Client): +class Client(tcp.Client): """I am an SSL client.""" def __init__(self, host, port, bindAddress, ctxFactory, connector, reactor=None): # tcp.Client.__init__ depends on self.ctxFactory being set self.ctxFactory = ctxFactory tcp.Client.__init__(self, host, port, bindAddress, connector, reactor) - def createInternetSocket(self): - """(internal) create an SSL socket - """ - sock = tcp.Client.createInternetSocket(self) - return SSL.Connection(self.ctxFactory.getContext(), sock) - def getHost(self): """Returns a tuple of ('SSL', hostname, port). @@ -219,16 +116,14 @@ """ return ('SSL',)+self.addr + def _finishInit(self, whenDone, skt, error, reactor): + tcp.Client._finishInit(self, whenDone, skt, error, reactor) + self.startTLS(self.ctxFactory) -class Server(Connection, tcp.Server): +class Server(tcp.Server): """I am an SSL server. """ - - def __init__(*args, **kw): - # We don't want Connection's __init__ - tcp.Server.__init__(*args, **kw) - def getHost(self): """Returns a tuple of ('SSL', hostname, port). @@ -257,33 +152,12 @@ """ sock = tcp.Port.createInternetSocket(self) return SSL.Connection(self.ctxFactory.getContext(), sock) - - def doRead(self): - """Called when my socket is ready for reading. - This accepts a connection and calls self.protocol() to handle the - wire-level protocol. - """ - try: - try: - skt, addr = self.socket.accept() - except socket.error, e: - if e.args[0] == tcp.EWOULDBLOCK: - return - raise - except SSL.Error: - log.deferr() - return - protocol = self.factory.buildProtocol(addr) - if protocol is None: - skt.close() - return - s = self.sessionno - self.sessionno = s+1 - transport = self.transport(skt, protocol, addr, self, s) - protocol.makeConnection(transport) - except: - log.deferr() + def _preMakeConnection(self, transport): + # *Don't* call startTLS here + # The transport already has the SSL.Connection object from above + transport._startTLS() + return tcp.Port._preMakeConnection(self, transport) class Connector(base.BaseConnector): Index: twisted/internet/tcp.py =================================================================== RCS file: /cvs/Twisted/twisted/internet/tcp.py,v retrieving revision 1.118 diff -u -r1.118 tcp.py --- twisted/internet/tcp.py 2 May 2003 04:31:14 -0000 1.118 +++ twisted/internet/tcp.py 3 May 2003 06:28:50 -0000 @@ -39,6 +39,11 @@ except ImportError: fcntl = None +try: + from OpenSSL import SSL +except ImportError: + SSL = None + if os.name == 'nt': # we hardcode these since windows actually wants e.g. # WSAEALREADY rather than EALREADY. Possibly we should @@ -88,14 +93,36 @@ __implements__ = abstract.FileDescriptor.__implements__, interfaces.ITCPTransport + if SSL: + writeBlockedOnRead = 0 + readBlockedOnWrite= 0 + sslShutdown = 0 + TLS = 0 + def __init__(self, skt, protocol, reactor=None): abstract.FileDescriptor.__init__(self, reactor=reactor) self.socket = skt self.socket.setblocking(0) self.fileno = skt.fileno self.protocol = protocol + + def startTLS(self, ctx): + if not SSL: + raise RuntimeException, "No SSL support available" + assert not self.TLS - def doRead(self): + self._startTLS() + self.socket = SSL.Connection(ctx.getContext(), self.socket) + self.fileno = self.socket.fileno + + def _startTLS(self): + self.TLS = 1 + self.doRead = self._TLS_doRead + self.writeSomeData = self._TLS_writeSomeData + self.doWrite = self._TLS_doWrite + self._closeSocket = self._TLS_closeSocket + + def _NOTLS_doRead(self): """Calls self.protocol.dataReceived with all available data. This reads up to self.bufferSize bytes of data from its socket, then @@ -114,7 +141,42 @@ return main.CONNECTION_LOST return self.protocol.dataReceived(data) - def writeSomeData(self, data): + doRead = _NOTLS_doRead + + def _TLS_doRead(self): + if self.writeBlockedOnRead: + self.writeBlockedOnRead = 0 + return self.doWrite() + try: + return self._NOTLS_doRead() + except SSL.ZeroReturnError: + # close SSL layer, since other side has done so, if we haven't + if not self.sslShutdown: + try: + self.socket.shutdown() + self.sslShutdown = 1 + except SSL.Error: + pass + return main.CONNECTION_DONE + except SSL.WantReadError: + return + except SSL.WantWriteError: + self.readBlockedOnWrite = 1 + self.startWriting() + return + except SSL.Error: + return main.CONNECTION_LOST + + def _TLS_doWrite(self): + if self.readBlockedOnWrite: + self.readBlockedOnWrite = 0 + # XXX - This is touching internal guts bad bad bad + if not self.dataBuffer: + self.stopWriting() + return self.doRead() + return abstract.FileDescriptor.doWrite(self) + + def _NOTLS_writeSomeData(self, data): """Connection.writeSomeData(data) -> #of bytes written | CONNECTION_LOST This writes as much data as possible to the socket and returns either the number of bytes read (which is positive) or a connection error code @@ -128,7 +190,21 @@ else: return main.CONNECTION_LOST - def _closeSocket(self): + writeSomeData = _NOTLS_writeSomeData + + def _TLS_writeSomeData(self, data): + if not data: + return 0 + try: + return self._NOTLS_writeSomeData(data) + except SSL.WantWriteError: + return 0 + except SSL.WantReadError: + self.writeBlockedOnRead = 1 + except SSL.Error: + return main.CONNECTION_LOST + + def _NOTLS_closeSocket(self): """Called to close our socket.""" # This used to close() the socket, but that doesn't *really* close if # there's another reference to it in the TCP/IP stack, e.g. if it was @@ -139,6 +215,17 @@ except socket.error: pass + _closeSocket = _NOTLS_closeSocket + + def _TLS_closeSocket(self): + try: + self.socket.sock_shutdown(2) + except: + try: + self.socket.close() + except: + pass + def connectionLost(self, reason): """See abstract.FileDescriptor.connectionLost(). """ @@ -173,6 +260,33 @@ def setTcpNoDelay(self, enabled): self.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, enabled) + + def _postLoseConnection(self): + """Gets called after loseConnection(), after buffered data is sent. + + We close the SSL transport layer, and if the other side hasn't + closed it yet we start reading, waiting for a ZeroReturnError + which will indicate the SSL shutdown has completed. + """ + if not self.TLS: + return abstract.FileDescriptor._postLoseConnection(self) + + try: + done = self.socket.shutdown() + self.sslShutdown = 1 + except SSL.Error: + return main.CONNECTION_LOST + if done: + return main.CONNECTION_DONE + else: + # we wait for other side to close SSL connection - + # this will be signaled by SSL.ZeroReturnError when reading + # from the socket + self.stopWriting() + self.startReading() + + # don't close socket just yet + return None class BaseClient(Connection): @@ -191,6 +305,11 @@ else: reactor.callLater(0, self.failIfNotConnected, error) + def startTLS(self, ctx): + holder = Connection.startTLS(self, ctx) + self.socket.set_connect_state() + return holder + def stopConnecting(self): """Stop attempt to connect.""" self.failIfNotConnected(error.UserError()) @@ -360,6 +479,11 @@ """ return self.repstr + def startTLS(self, ctx): + holder = Connection.startTLS(self, ctx) + self.socket.set_accept_state() + return holder + def getHost(self): """Returns a tuple of ('INET', hostname, port). @@ -458,6 +582,7 @@ elif e.args[0] == EPERM: continue raise + protocol = self.factory.buildProtocol(addr) if protocol is None: skt.close() @@ -465,11 +590,22 @@ s = self.sessionno self.sessionno = s+1 transport = self.transport(skt, protocol, addr, self, s) + transport = self._preMakeConnection(transport) protocol.makeConnection(transport) else: self.numberAccepts = self.numberAccepts+20 except: + # Note that in TLS mode, this will possibly catch SSL.Errors + # raised by self.socket.accept() + # + # There is no "except SSL.Error:" above because SSL may be + # None if there is no SSL support. In any case, all the + # "except SSL.Error:" suite would probably do is log.deferr() + # and return, so handling it here works just as well. log.deferr() + + def _preMakeConnection(self, transport): + return transport def loseConnection(self, connDone=failure.Failure(main.CONNECTION_DONE)): """Stop accepting connections on this port. Index: twisted/test/test_ssl.py =================================================================== RCS file: /cvs/Twisted/twisted/test/test_ssl.py,v retrieving revision 1.9 diff -u -r1.9 test_ssl.py --- twisted/test/test_ssl.py 3 May 2003 02:03:54 -0000 1.9 +++ twisted/test/test_ssl.py 3 May 2003 06:28:55 -0000 @@ -17,19 +17,23 @@ from __future__ import nested_scopes from twisted.trial import unittest from twisted.internet import protocol, reactor +from twisted.protocols import basic + try: - import OpenSSL + from OpenSSL import SSL from twisted.internet import ssl except ImportError: - OpenSSL = None + SSL = None + import os import test_tcp +certPath = os.path.join(os.path.split(test_tcp.__file__)[0], "server.pem") + class StolenTCPTestCase(test_tcp.ProperlyCloseFilesTestCase, test_tcp.WriteDataTestCase): def setUp(self): - certPath = os.path.join(os.path.split(test_tcp.__file__)[0], "server.pem") f = protocol.ServerFactory() f.protocol = protocol.Protocol self.listener = reactor.listenSSL( @@ -49,5 +53,117 @@ self.totalConnections = 0 -if not OpenSSL: - del StolenTCPTestCase +class ClientTLSContext(ssl.ClientContextFactory): + isClient = 1 + def getContext(self): + return SSL.Context(ssl.SSL.TLSv1_METHOD) + +class UnintelligentProtocol(basic.LineReceiver): + pretext = [ + "first line", + "last thing before tls starts", + "STARTTLS", + ] + + posttext = [ + "first thing after tls started", + "last thing ever", + ] + + def connectionMade(self): + for l in self.pretext: + self.sendLine(l) + + def lineReceived(self, line): + if line == "READY": + self.transport.startTLS(ClientTLSContext()) + for l in self.posttext: + self.sendLine(l) + self.transport.loseConnection() + +class ServerTLSContext(ssl.DefaultOpenSSLContextFactory): + isClient = 0 + def __init__(self, *args, **kw): + kw['sslmethod'] = SSL.TLSv1_METHOD + ssl.DefaultOpenSSLContextFactory.__init__(self, *args, **kw) + +class LineCollector(basic.LineReceiver): + def __init__(self, doTLS): + self.doTLS = doTLS + + def connectionMade(self): + self.factory.rawdata = '' + self.factory.lines = [] + + def lineReceived(self, line): + self.factory.lines.append(line) + if line == 'STARTTLS': + self.sendLine('READY') + if self.doTLS: + ctx = ServerTLSContext( + privateKeyFileName=certPath, + certificateFileName=certPath, + ) + self.transport.startTLS(ctx) + else: + self.setRawMode() + + def rawDataReceived(self, data): + self.factory.rawdata += data + self.factory.done = 1 + + def connectionLost(self, reason): + self.factory.done = 1 + +class TLSTestCase(unittest.TestCase): + def testTLS(self): + cf = protocol.ClientFactory() + cf.protocol = UnintelligentProtocol + + sf = protocol.ServerFactory() + sf.protocol = lambda: LineCollector(1) + sf.done = 0 + + port = reactor.listenTCP(0, sf) + portNo = port.getHost()[2] + + reactor.connectTCP('0.0.0.0', portNo, cf) + + i = 0 + while i < 5000 and not sf.done: + reactor.iterate(0.01) + i += 1 + + self.failUnless(sf.done, "Never finished reading all lines") + self.assertEquals( + sf.lines, + UnintelligentProtocol.pretext + UnintelligentProtocol.posttext + ) + + def testUnTLS(self): + cf = protocol.ClientFactory() + cf.protocol = UnintelligentProtocol + + sf = protocol.ServerFactory() + sf.protocol = lambda: LineCollector(0) + sf.done = 0 + + port = reactor.listenTCP(0, sf) + portNo = port.getHost()[2] + + reactor.connectTCP('0.0.0.0', portNo, cf) + + i = 0 + while i < 5000 and not sf.done: + reactor.iterate(0.01) + i += 1 + + self.failUnless(sf.done, "Never finished reading all lines") + self.assertEquals( + sf.lines, + UnintelligentProtocol.pretext + ) + self.failUnless(sf.rawdata, "No encrypted bytes received") + +if not SSL: + globals().clear()