[Twisted-Python] Improvements to twisted.protocols.smtp
Hi! We have a need for an SMTP server, and found the current implementation a bit fragile. This patch should robustify it, and also improve RFC 2821 compliance (no guarantees though). I didn't touch the SMTP client (apart from adding the double-dot protocol), but I suspect I'll look over it as well. Comments, etc are welcome. The patch is relative to twisted CVS as of a few minutes ago. Regards, /Anders -- -- Of course I'm crazy, but that doesn't mean I'm wrong. Anders Hammarquist | iko@strakt.com AB Strakt | Tel: +46 31 711 08 70 G|teborg, Sweden. RADIO: SM6XMM and N2JGL | Fax: +46 31 711 08 80 Index: twisted/protocols/basic.py =================================================================== RCS file: /cvs/Twisted/twisted/protocols/basic.py,v retrieving revision 1.23 diff -u -u -r1.23 basic.py --- twisted/protocols/basic.py 23 Sep 2002 08:51:29 -0000 1.23 +++ twisted/protocols/basic.py 30 Sep 2002 14:49:24 -0000 @@ -150,12 +150,14 @@ line, self.__buffer = self.__buffer.split(self.delimiter, 1) except ValueError: if len(self.__buffer) > self.MAX_LENGTH: - self.transport.loseConnection() + line, self.__buffer = self.__buffer, '' + self.lineLengthExceeded(line) return break else: if len(line) > self.MAX_LENGTH: - self.transport.loseConnection() + line, self.__buffer = self.__buffer, '' + self.lineLengthExceeded(line) return self.lineReceived(line) if self.transport.disconnecting: @@ -197,6 +199,12 @@ """Sends a line to the other end of the connection. """ self.transport.write(line + self.delimiter) + + def lineLengthExceeded(self, line): + """Called when the maximum line length has been reached. + Override if it needs to be dealt with in some special way. + """ + self.transport.loseConnection() class Int32StringReceiver(protocol.Protocol): Index: twisted/protocols/smtp.py =================================================================== RCS file: /cvs/Twisted/twisted/protocols/smtp.py,v retrieving revision 1.23 diff -u -u -r1.23 smtp.py --- twisted/protocols/smtp.py 19 Aug 2002 03:21:58 -0000 1.23 +++ twisted/protocols/smtp.py 30 Sep 2002 14:49:24 -0000 @@ -18,10 +18,10 @@ """ from twisted.protocols import basic -from twisted.internet import protocol, defer +from twisted.internet import protocol, defer, reactor from twisted.python import log -import os, time, string, operator +import os, time, string, operator, re class SMTPError(Exception): pass @@ -49,15 +49,93 @@ self.deferred.errback(arg) self.done = 1 +class AddressError(SMTPError): + "Parse error in address" + pass + +# Character classes for parsing addresses +atom = r"-A-Za-z0-9!#$%&'*+/=?^_`{|}~" + +class Address: + """Parse and hold an RFC 2821 address. + + Source routes are stipped and ignored, UUCP-style bang-paths + and %-style routing are not parsed. + """ + + qstring = re.compile(r'((?:"[^"]*"|\\.|[' + atom + r'])+|.)') + + def __init__(self, addr): + self.local = '' + self.domain = '' + self.addrstr = addr + + # Tokenize + atl = filter(None,self.qstring.split(addr)) + + local = [] + domain = [] + + while atl: + if atl[0] == '<': + if atl[-1] != '>': + raise AddressError, "Unbalanced <>" + atl = atl[1:-1] + elif atl[0] == '@': + atl = atl[1:] + if not local: + # Source route + while atl and atl[0] != ':': + # remove it + atl = atl[1:] + if not atl: + raise AddressError, "Malformed source route" + atl = atl[1:] # remove : + elif domain: + raise AddressError, "Too many @" + else: + # Now in domain + domain = [''] + elif len(atl[0]) == 1 and atl[0] not in atom + '.': + raise AddressError, "Parse error at " + atl[0] + else: + if not domain: + local.append(atl[0]) + else: + domain.append(atl[0]) + atl = atl[1:] + + self.local = ''.join(local) + self.domain = ''.join(domain) + + dequotebs = re.compile(r'\\(.)') + def dequote(self, addr): + "Remove RFC-2821 quotes from address" + res = [] + + atl = filter(None,self.qstring.split(addr)) + + for t in atl: + if t[0] == '"' and t[-1] == '"': + res.append(t[1:-1]) + elif '\\' in t: + res.append(self.dequotebs.sub(r'\1',t)) + else: + res.append(t) + + return ''.join(res) + + def __str__(self): + return '%s%s' % (self.local, self.domain and ("@" + self.domain) or "") + + def __repr__(self): + return "%s.%s(%s)" % (self.__module__, self.__class__.__name__, + repr(str(self))) -class User: +class User(Address): def __init__(self, destination, helo, protocol, orig): - try: - self.name, self.domain = string.split(destination, '@', 1) - except ValueError: - self.name = destination - self.domain = '' + Address.__init__(self,destination) self.helo = helo self.protocol = protocol self.orig = orig @@ -83,23 +161,41 @@ class SMTP(basic.LineReceiver): - def __init__(self): + def __init__(self, domain=None, timeout=600): self.mode = COMMAND self.__from = None self.__helo = None - self.__to = () + self.__to = [] + self.timeout = timeout + if not domain: + import socket + domain = socket.getfqdn() + self.host = domain + + def timedout(self): + self.sendCode(421, '%s Timeout. Try talking faster next time!' % + self.host) + self.transport.loseConnection() def connectionMade(self): - self.sendCode(220, 'Spammers beware, your ass is on fire') + self.sendCode(220, '%s Spammers beware, your ass is on fire' % + self.host) + self.timeoutID = reactor.callLater(self.timeout, self.timedout) def sendCode(self, code, message=''): "Send an SMTP code with a message." self.transport.write('%d %s\r\n' % (code, message)) def lineReceived(self, line): + self.timeoutID.cancel() + self.timeoutID = reactor.callLater(self.timeout, self.timedout) + if self.mode is DATA: return self.dataLineReceived(line) - command = string.split(line, None, 1)[0] + if line: + command = string.split(line, None, 1)[0] + else: + command = '' method = getattr(self, 'do_'+string.upper(command), None) if method is None: method = self.do_UNKNOWN @@ -107,21 +203,59 @@ line = line[len(command):] return method(string.strip(line)) + def lineLengthExceeded(self, line): + if self.mode is DATA: + for message in self.__messages: + message.connectionLost() + self.mode = COMMAND + del self.__messages + self.sendCode(500, 'Line too long') + def do_UNKNOWN(self, rest): - self.sendCode(502, 'Command not implemented') + self.sendCode(500, 'Command not implemented') def do_HELO(self, rest): - self.__helo = rest - self.sendCode(250, 'Nice to meet you') + peer = self.transport.getPeer()[1] + self.__helo = (rest, peer) + self.sendCode(250, '%s Hello %s, nice to meet you' % (self.host, peer)) def do_QUIT(self, rest): self.sendCode(221, 'See you later') self.transport.loseConnection() + qstring = r'("[^"]*"|\\.|[' + atom + r'@.])+' + + path_re = re.compile(r"(<(?=.*>))?(?P<addr><>|(?<=<)" + qstring + r"(?=(?<!\\)>)|" + qstring + r")((?<![\\<])>)?$") + mail_re = re.compile(r'\s*FROM:\s*(?P<path><>|<' + qstring + r'>|' + + qstring + r')\s*(?P<opts>.*)$',re.I) + rcpt_re = re.compile(r'\s*TO:\s*(?P<path><' + qstring + r'>|' + + qstring + r')\s*(?P<opts>.*)$',re.I) + def do_MAIL(self, rest): - from_ = rest[len("MAIL:<"):-len(">")] - self.validateFrom(self.__helo, from_, self._fromValid, - self._fromInvalid) + if not self.__helo: + self.sendCode(503,"Who are you? Say HELO first"); + if self.__from: + self.sendCode(503,"Only one sender per message, please") + return + # Clear old recipient list + self.__to = [] + m = self.mail_re.match(rest) + if not m: + self.sendCode(501, "Syntax error") + return + m = self.path_re.match(m.group('path')) + if not m: + self.sendCode(553, "Unparseable address") + return + + try: + addr = Address(m.group('addr')) + except AddressError, e: + self.sendCode(553, str(e)) + return + + self.validateFrom(self.__helo, addr, self._fromValid, + self._fromInvalid) def _fromValid(self, from_): self.__from = from_ @@ -131,12 +265,28 @@ self.sendCode(550, 'No mail for you!') def do_RCPT(self, rest): - to = rest[len("TO:<"):-len(">")] - user = User(to, self.__helo, self, self.__from) + if not self.__from: + self.sendCode(503, "Must have sender before recpient") + return + m = self.rcpt_re.match(rest) + if not m: + self.sendCode(501, "Syntax error") + return + m = self.path_re.match(m.group('path')) + if not m: + self.sendCode(553, "Unparseable address") + return + + try: + user = User(m.group('addr'), self.__helo, self, self.__from) + except AddressError, e: + self.sendCode(553, str(e)) + return + self.validateTo(user, self._toValid, self._toInvalid) def _toValid(self, to): - self.__to = self.__to + (to,) + self.__to.append(to) self.sendCode(250, 'Address recognized') def _toInvalid(self, to): @@ -144,22 +294,29 @@ def do_DATA(self, rest): if self.__from is None or not self.__to: - self.sendCode(550, 'Must have valid receiver and originator') + self.sendCode(503, 'Must have valid receiver and originator') return self.mode = DATA helo, origin, recipients = self.__helo, self.__from, self.__to self.__from = None - self.__to = () + self.__to = [] self.__messages = self.startMessage(recipients) + for message in self.__messages: + message.lineReceived(self.receivedHeader(helo, origin, recipients)) self.sendCode(354, 'Continue') def connectionLost(self, reason): + # self.sendCode(421, 'Loosing connection.') # This does nothing... + # Ideally, if we (rather than the other side) lose the connection, + # we should be able to tell the other side that we are going away. + # RFC-2821 requires that we try. if self.mode is DATA: for message in self.__messages: message.connectionLost() def do_RSET(self, rest): - self.__init__() + self.__from = None + self.__to = [] self.sendCode(250, 'I remember nothing.') def dataLineReceived(self, line): @@ -177,6 +334,7 @@ deferred = message.eomReceived() deferred.addCallback(ndeferred.callback) deferred.addErrback(ndeferred.errback) + del self.__messages return line = line[1:] for message in self.__messages: @@ -189,6 +347,11 @@ self.sendCode(550, 'Could not send e-mail') # overridable methods: + def receivedHeader(self, helo, origin, recipents): + return "Received: From %s ([%s]) by %s; %s" % ( + helo[0], helo[1], self.host, + time.strftime("%a, %d %b %Y %H:%M:%S +0000", time.gmtime())) + def validateFrom(self, helo, origin, success, failure): success(origin) @@ -281,6 +444,7 @@ self.state = 'afterData' chunk = string.replace(chunk, "\n", "\r\n") + chunk = string.replace(chunk, "\r\n.", "\r\n..") self.transport.write(chunk) def pauseProducing(self):
On Mon, 30 Sep 2002 17:01:36 +0200 Anders Hammarquist <iko@strakt.com> wrote:
We have a need for an SMTP server, and found the current implementation a bit fragile. This patch should robustify it, and also improve RFC 2821 compliance (no guarantees though).
Thanks Anders! One of us will try to look it over soon. And a few test cases to go along would be even more appreciated ;) -- Itamar Shtull-Trauring http://itamarst.org/ Available for Python, Twisted, Zope and Java consulting
On Mon, 30 Sep 2002, Anders Hammarquist <iko@strakt.com> wrote:
Comments, etc are welcome. The patch is relative to twisted CVS as of a few minutes ago.
Hi! I have some issues with the code, but I'm currently a bit under the weather so I find it hard to master the concentration to detail them. I'll send another e-mail when I will feel better, just wanted to let you know you're not ignored.
I promised I'll detail my problems with the patch. I'll try to check something in this weekend, maybe these problems.
Index: twisted/protocols/basic.py
I'd hesitate to touch protocols.basic without a unit test for this feature.
Index: twisted/protocols/smtp.py
I have a general problem here: the REs are very complicated and undocumented. Please use re.X and comment the RE
-class User: +class User(Address):
I'd prefer containment rather than inclusion here.
- def __init__(self): + def __init__(self, domain=None, timeout=600): self.mode = COMMAND self.__from = None self.__helo = None - self.__to = () + self.__to = [] + self.timeout = timeout + if not domain: + import socket + domain = socket.getfqdn() + self.host = domain
this is wrong. the domain name should be in the protocol factory, not in the protocol
+ self.timeoutID = reactor.callLater(self.timeout, self.timedout)
please allow a timeout of "None" to mean "no timeouts"
+ del self.__messages + self.sendCode(500, 'Line too long')
this seems like it wouldn't work. the client would send the rest of the long line, which look like a beginning of a command
def do_MAIL(self, rest): - from_ = rest[len("MAIL:<"):-len(">")] - self.validateFrom(self.__helo, from_, self._fromValid, - self._fromInvalid) + if not self.__helo: + self.sendCode(503,"Who are you? Say HELO first");
be liberal, at least in the framework. if a user wants to allow nonusing of helo, he should be free to do this. you can enforce helo in validateFrom
- self.__init__() + self.__from = None + self.__to = [] self.sendCode(250, 'I remember nothing.') what about __messages?
participants (3)
-
Anders Hammarquist
-
Itamar Shtull-Trauring
-
Moshe Zadka