Hi!
We have a need for an SMTP server, and found the current implementation
a bit fragile. This patch should robustify it, and also improve RFC 2821
compliance (no guarantees though).
I didn't touch the SMTP client (apart from adding the double-dot protocol),
but I suspect I'll look over it as well.
Comments, etc are welcome. The patch is relative to twisted CVS as of a
few minutes ago.
Regards,
/Anders
--
-- Of course I'm crazy, but that doesn't mean I'm wrong.
Anders Hammarquist | iko(a)strakt.com
AB Strakt | Tel: +46 31 711 08 70
G|teborg, Sweden. RADIO: SM6XMM and N2JGL | Fax: +46 31 711 08 80
Index: twisted/protocols/basic.py
===================================================================
RCS file: /cvs/Twisted/twisted/protocols/basic.py,v
retrieving revision 1.23
diff -u -u -r1.23 basic.py
--- twisted/protocols/basic.py 23 Sep 2002 08:51:29 -0000 1.23
+++ twisted/protocols/basic.py 30 Sep 2002 14:49:24 -0000
@@ -150,12 +150,14 @@
line, self.__buffer = self.__buffer.split(self.delimiter, 1)
except ValueError:
if len(self.__buffer) > self.MAX_LENGTH:
- self.transport.loseConnection()
+ line, self.__buffer = self.__buffer, ''
+ self.lineLengthExceeded(line)
return
break
else:
if len(line) > self.MAX_LENGTH:
- self.transport.loseConnection()
+ line, self.__buffer = self.__buffer, ''
+ self.lineLengthExceeded(line)
return
self.lineReceived(line)
if self.transport.disconnecting:
@@ -197,6 +199,12 @@
"""Sends a line to the other end of the connection.
"""
self.transport.write(line + self.delimiter)
+
+ def lineLengthExceeded(self, line):
+ """Called when the maximum line length has been reached.
+ Override if it needs to be dealt with in some special way.
+ """
+ self.transport.loseConnection()
class Int32StringReceiver(protocol.Protocol):
Index: twisted/protocols/smtp.py
===================================================================
RCS file: /cvs/Twisted/twisted/protocols/smtp.py,v
retrieving revision 1.23
diff -u -u -r1.23 smtp.py
--- twisted/protocols/smtp.py 19 Aug 2002 03:21:58 -0000 1.23
+++ twisted/protocols/smtp.py 30 Sep 2002 14:49:24 -0000
@@ -18,10 +18,10 @@
"""
from twisted.protocols import basic
-from twisted.internet import protocol, defer
+from twisted.internet import protocol, defer, reactor
from twisted.python import log
-import os, time, string, operator
+import os, time, string, operator, re
class SMTPError(Exception):
pass
@@ -49,15 +49,93 @@
self.deferred.errback(arg)
self.done = 1
+class AddressError(SMTPError):
+ "Parse error in address"
+ pass
+
+# Character classes for parsing addresses
+atom = r"-A-Za-z0-9!#$%&'*+/=?^_`{|}~"
+
+class Address:
+ """Parse and hold an RFC 2821 address.
+
+ Source routes are stipped and ignored, UUCP-style bang-paths
+ and %-style routing are not parsed.
+ """
+
+ qstring = re.compile(r'((?:"[^"]*"|\\.|[' + atom + r'])+|.)')
+
+ def __init__(self, addr):
+ self.local = ''
+ self.domain = ''
+ self.addrstr = addr
+
+ # Tokenize
+ atl = filter(None,self.qstring.split(addr))
+
+ local = []
+ domain = []
+
+ while atl:
+ if atl[0] == '<':
+ if atl[-1] != '>':
+ raise AddressError, "Unbalanced <>"
+ atl = atl[1:-1]
+ elif atl[0] == '@':
+ atl = atl[1:]
+ if not local:
+ # Source route
+ while atl and atl[0] != ':':
+ # remove it
+ atl = atl[1:]
+ if not atl:
+ raise AddressError, "Malformed source route"
+ atl = atl[1:] # remove :
+ elif domain:
+ raise AddressError, "Too many @"
+ else:
+ # Now in domain
+ domain = ['']
+ elif len(atl[0]) == 1 and atl[0] not in atom + '.':
+ raise AddressError, "Parse error at " + atl[0]
+ else:
+ if not domain:
+ local.append(atl[0])
+ else:
+ domain.append(atl[0])
+ atl = atl[1:]
+
+ self.local = ''.join(local)
+ self.domain = ''.join(domain)
+
+ dequotebs = re.compile(r'\\(.)')
+ def dequote(self, addr):
+ "Remove RFC-2821 quotes from address"
+ res = []
+
+ atl = filter(None,self.qstring.split(addr))
+
+ for t in atl:
+ if t[0] == '"' and t[-1] == '"':
+ res.append(t[1:-1])
+ elif '\\' in t:
+ res.append(self.dequotebs.sub(r'\1',t))
+ else:
+ res.append(t)
+
+ return ''.join(res)
+
+ def __str__(self):
+ return '%s%s' % (self.local, self.domain and ("@" + self.domain) or "")
+
+ def __repr__(self):
+ return "%s.%s(%s)" % (self.__module__, self.__class__.__name__,
+ repr(str(self)))
-class User:
+class User(Address):
def __init__(self, destination, helo, protocol, orig):
- try:
- self.name, self.domain = string.split(destination, '@', 1)
- except ValueError:
- self.name = destination
- self.domain = ''
+ Address.__init__(self,destination)
self.helo = helo
self.protocol = protocol
self.orig = orig
@@ -83,23 +161,41 @@
class SMTP(basic.LineReceiver):
- def __init__(self):
+ def __init__(self, domain=None, timeout=600):
self.mode = COMMAND
self.__from = None
self.__helo = None
- self.__to = ()
+ self.__to = []
+ self.timeout = timeout
+ if not domain:
+ import socket
+ domain = socket.getfqdn()
+ self.host = domain
+
+ def timedout(self):
+ self.sendCode(421, '%s Timeout. Try talking faster next time!' %
+ self.host)
+ self.transport.loseConnection()
def connectionMade(self):
- self.sendCode(220, 'Spammers beware, your ass is on fire')
+ self.sendCode(220, '%s Spammers beware, your ass is on fire' %
+ self.host)
+ self.timeoutID = reactor.callLater(self.timeout, self.timedout)
def sendCode(self, code, message=''):
"Send an SMTP code with a message."
self.transport.write('%d %s\r\n' % (code, message))
def lineReceived(self, line):
+ self.timeoutID.cancel()
+ self.timeoutID = reactor.callLater(self.timeout, self.timedout)
+
if self.mode is DATA:
return self.dataLineReceived(line)
- command = string.split(line, None, 1)[0]
+ if line:
+ command = string.split(line, None, 1)[0]
+ else:
+ command = ''
method = getattr(self, 'do_'+string.upper(command), None)
if method is None:
method = self.do_UNKNOWN
@@ -107,21 +203,59 @@
line = line[len(command):]
return method(string.strip(line))
+ def lineLengthExceeded(self, line):
+ if self.mode is DATA:
+ for message in self.__messages:
+ message.connectionLost()
+ self.mode = COMMAND
+ del self.__messages
+ self.sendCode(500, 'Line too long')
+
def do_UNKNOWN(self, rest):
- self.sendCode(502, 'Command not implemented')
+ self.sendCode(500, 'Command not implemented')
def do_HELO(self, rest):
- self.__helo = rest
- self.sendCode(250, 'Nice to meet you')
+ peer = self.transport.getPeer()[1]
+ self.__helo = (rest, peer)
+ self.sendCode(250, '%s Hello %s, nice to meet you' % (self.host, peer))
def do_QUIT(self, rest):
self.sendCode(221, 'See you later')
self.transport.loseConnection()
+ qstring = r'("[^"]*"|\\.|[' + atom + r'@.])+'
+
+ path_re = re.compile(r"(<(?=.*>))?(?P<addr><>|(?<=<)" + qstring + r"(?=(?<!\\)>)|" + qstring + r")((?<![\\<])>)?$")
+ mail_re = re.compile(r'\s*FROM:\s*(?P<path><>|<' + qstring + r'>|' +
+ qstring + r')\s*(?P<opts>.*)$',re.I)
+ rcpt_re = re.compile(r'\s*TO:\s*(?P<path><' + qstring + r'>|' +
+ qstring + r')\s*(?P<opts>.*)$',re.I)
+
def do_MAIL(self, rest):
- from_ = rest[len("MAIL:<"):-len(">")]
- self.validateFrom(self.__helo, from_, self._fromValid,
- self._fromInvalid)
+ if not self.__helo:
+ self.sendCode(503,"Who are you? Say HELO first");
+ if self.__from:
+ self.sendCode(503,"Only one sender per message, please")
+ return
+ # Clear old recipient list
+ self.__to = []
+ m = self.mail_re.match(rest)
+ if not m:
+ self.sendCode(501, "Syntax error")
+ return
+ m = self.path_re.match(m.group('path'))
+ if not m:
+ self.sendCode(553, "Unparseable address")
+ return
+
+ try:
+ addr = Address(m.group('addr'))
+ except AddressError, e:
+ self.sendCode(553, str(e))
+ return
+
+ self.validateFrom(self.__helo, addr, self._fromValid,
+ self._fromInvalid)
def _fromValid(self, from_):
self.__from = from_
@@ -131,12 +265,28 @@
self.sendCode(550, 'No mail for you!')
def do_RCPT(self, rest):
- to = rest[len("TO:<"):-len(">")]
- user = User(to, self.__helo, self, self.__from)
+ if not self.__from:
+ self.sendCode(503, "Must have sender before recpient")
+ return
+ m = self.rcpt_re.match(rest)
+ if not m:
+ self.sendCode(501, "Syntax error")
+ return
+ m = self.path_re.match(m.group('path'))
+ if not m:
+ self.sendCode(553, "Unparseable address")
+ return
+
+ try:
+ user = User(m.group('addr'), self.__helo, self, self.__from)
+ except AddressError, e:
+ self.sendCode(553, str(e))
+ return
+
self.validateTo(user, self._toValid, self._toInvalid)
def _toValid(self, to):
- self.__to = self.__to + (to,)
+ self.__to.append(to)
self.sendCode(250, 'Address recognized')
def _toInvalid(self, to):
@@ -144,22 +294,29 @@
def do_DATA(self, rest):
if self.__from is None or not self.__to:
- self.sendCode(550, 'Must have valid receiver and originator')
+ self.sendCode(503, 'Must have valid receiver and originator')
return
self.mode = DATA
helo, origin, recipients = self.__helo, self.__from, self.__to
self.__from = None
- self.__to = ()
+ self.__to = []
self.__messages = self.startMessage(recipients)
+ for message in self.__messages:
+ message.lineReceived(self.receivedHeader(helo, origin, recipients))
self.sendCode(354, 'Continue')
def connectionLost(self, reason):
+ # self.sendCode(421, 'Loosing connection.') # This does nothing...
+ # Ideally, if we (rather than the other side) lose the connection,
+ # we should be able to tell the other side that we are going away.
+ # RFC-2821 requires that we try.
if self.mode is DATA:
for message in self.__messages:
message.connectionLost()
def do_RSET(self, rest):
- self.__init__()
+ self.__from = None
+ self.__to = []
self.sendCode(250, 'I remember nothing.')
def dataLineReceived(self, line):
@@ -177,6 +334,7 @@
deferred = message.eomReceived()
deferred.addCallback(ndeferred.callback)
deferred.addErrback(ndeferred.errback)
+ del self.__messages
return
line = line[1:]
for message in self.__messages:
@@ -189,6 +347,11 @@
self.sendCode(550, 'Could not send e-mail')
# overridable methods:
+ def receivedHeader(self, helo, origin, recipents):
+ return "Received: From %s ([%s]) by %s; %s" % (
+ helo[0], helo[1], self.host,
+ time.strftime("%a, %d %b %Y %H:%M:%S +0000", time.gmtime()))
+
def validateFrom(self, helo, origin, success, failure):
success(origin)
@@ -281,6 +444,7 @@
self.state = 'afterData'
chunk = string.replace(chunk, "\n", "\r\n")
+ chunk = string.replace(chunk, "\r\n.", "\r\n..")
self.transport.write(chunk)
def pauseProducing(self):