[Jython-checkins] jython (merge default -> default): Merge
jeff.allen
jython-checkins at python.org
Wed Dec 31 02:41:08 CET 2014
https://hg.python.org/jython/rev/a6d27d1c2c79
changeset: 7492:a6d27d1c2c79
parent: 7490:7bae6be634fb
parent: 7491:1656013e2455
user: Jeff Allen <ja.py at farowl.co.uk>
date: Wed Dec 31 01:39:13 2014 +0000
summary:
Merge
files:
Lib/_socket.py | 37 +-
Lib/_sslcerts.py | 24 -
Lib/email/test/test_email.py | 3561 ++++++++++
Lib/email/test/test_email_renamed.py | 3332 +++++++++
Lib/gzip.py | 106 +-
Lib/json/tests/test_recursion.py | 112 +
Lib/json/tests/test_tool.py | 75 +
Lib/pkgutil.py | 7 +-
Lib/tarfile.py | 43 +-
Lib/test/list_tests.py | 41 +-
Lib/test/regrtest.py | 4 +-
Lib/test/seq_tests.py | 409 +
Lib/test/test_builtin.py | 316 +-
Lib/test/test_cgi.py | 395 -
Lib/test/test_cmp_jy.py | 43 +-
Lib/test/test_codeop_jy.py | 1 +
Lib/test/test_descr.py | 2 +-
Lib/test/test_dict_jy.py | 69 +-
Lib/test/test_fileio.py | 14 +-
Lib/test/test_funcattrs.py | 15 +-
Lib/test/test_hmac.py | 318 -
Lib/test/test_import.py | 662 +
Lib/test/test_inspect.py | 11 +-
Lib/test/test_itertools.py | 112 +-
Lib/test/test_json.py | 20 -
Lib/test/test_jy_internals.py | 26 +-
Lib/test/test_list.py | 26 +-
Lib/test/test_list_jy.py | 30 +-
Lib/test/test_module.py | 142 +-
Lib/test/test_pdb.py | 316 -
Lib/test/test_pkgutil.py | 142 -
Lib/test/test_profilehooks.py | 2 -
Lib/test/test_runpy.py | 402 +
Lib/test/test_select_new.py | 18 +-
Lib/test/test_set_jy.py | 50 +-
Lib/test/test_slice.py | 137 -
Lib/test/test_socket.py | 31 +-
Lib/test/test_sort.py | 17 +-
Lib/test/test_ssl.py | 1395 +++
Lib/test/test_support.py | 16 +
Lib/test/test_tarfile.py | 1568 ----
Lib/test/test_time.py | 1 -
Lib/test/test_types.py | 2 +-
Lib/test/test_weakref.py | 12 +-
Lib/test/test_weakset.py | 8 +-
Lib/test/test_xmlrpc.py | 14 +-
src/org/python/compiler/ClassFile.java | 5 +-
src/org/python/compiler/Filename.java | 9 +
src/org/python/core/AnnotationReader.java | 14 +-
src/org/python/core/BaseSet.java | 4 +
src/org/python/core/JavaIterator.java | 20 +
src/org/python/core/JavaProxyList.java | 637 +
src/org/python/core/JavaProxyMap.java | 478 +
src/org/python/core/JavaProxySet.java | 574 +
src/org/python/core/Py.java | 13 +
src/org/python/core/PyBaseCode.java | 49 +-
src/org/python/core/PyByteArray.java | 4 +-
src/org/python/core/PyDictionary.java | 6 +-
src/org/python/core/PyIterator.java | 21 +
src/org/python/core/PyJavaType.java | 723 +-
src/org/python/core/PyList.java | 20 +
src/org/python/core/PyModule.java | 24 +-
src/org/python/core/PySequence.java | 12 +-
src/org/python/core/PySlice.java | 86 +-
src/org/python/core/PyString.java | 15 +
src/org/python/core/PySystemState.java | 34 +
src/org/python/core/PyType.java | 4 +-
src/org/python/core/PyUnicode.java | 14 +-
src/org/python/core/PyXRange.java | 23 +
src/org/python/core/__builtin__.java | 72 +-
src/org/python/core/imp.java | 198 +-
src/org/python/modules/Setup.java | 1 +
src/org/python/modules/_codecs.java | 40 +-
src/org/python/modules/_imp.java | 6 +-
src/org/python/modules/_json/Encoder.java | 204 +
src/org/python/modules/_json/Scanner.java | 343 +
src/org/python/modules/_json/_json.java | 422 +
src/org/python/modules/bz2/PyBZ2Decompressor.java | 11 +-
src/org/python/modules/itertools/PyTeeIterator.java | 2 +-
src/org/python/modules/itertools/count.java | 92 +-
src/org/python/modules/itertools/repeat.java | 5 +
src/org/python/modules/time/Time.java | 15 +-
82 files changed, 14100 insertions(+), 4184 deletions(-)
diff --git a/Lib/_socket.py b/Lib/_socket.py
--- a/Lib/_socket.py
+++ b/Lib/_socket.py
@@ -580,7 +580,7 @@
def initChannel(self, child_channel):
child = ChildSocket(self.parent_socket)
- log.debug("Initializing child %s", extra={"sock": self.parent_socket})
+ log.debug("Initializing child %s", child, extra={"sock": self.parent_socket})
child.proto = IPPROTO_TCP
child._init_client_mode(child_channel)
@@ -615,6 +615,8 @@
# thread pool
child_channel.closeFuture().addListener(unlatch_child)
+ if self.parent_socket.timeout is None:
+ child._ensure_post_connect()
child._wait_on_latch()
log.debug("Socket initChannel completed waiting on latch", extra={"sock": child})
@@ -732,7 +734,10 @@
self.selectors.addIfAbsent(selector)
def _unregister_selector(self, selector):
- return self.selectors.remove(selector)
+ try:
+ return self.selectors.remove(selector)
+ except ValueError:
+ return None
def _notify_selectors(self, exception=None, hangup=False):
for selector in self.selectors:
@@ -839,8 +844,8 @@
log.debug("Connect to %s", addr, extra={"sock": self})
self.channel = bootstrap.channel()
- connect_future = self.channel.connect(addr)
- self._handle_channel_future(connect_future, "connect")
+ self.connect_future = self.channel.connect(addr)
+ self._handle_channel_future(self.connect_future, "connect")
self.bind_timestamp = time.time()
def _post_connect(self):
@@ -868,12 +873,17 @@
log.debug("Completed connection to %s", addr, extra={"sock": self})
def connect_ex(self, addr):
- self.connect(addr)
- if self.timeout is None:
+ if not self.connected:
+ try:
+ self.connect(addr)
+ except error as e:
+ return e.errno
+ if not self.connect_future.isDone():
+ return errno.EINPROGRESS
+ elif self.connect_future.isSuccess():
return errno.EISCONN
else:
- return errno.EINPROGRESS
-
+ return errno.ENOTCONN
# SERVER METHODS
# Calling listen means this is a server socket
@@ -1011,6 +1021,12 @@
break
log.debug("Closed child socket %s not yet accepted", child, extra={"sock": self})
child.close()
+ else:
+ msgs = []
+ self.incoming.drainTo(msgs)
+ for msg in msgs:
+ if msg is not _PEER_CLOSED:
+ msg.release()
log.debug("Closed socket", extra={"sock": self})
@@ -1024,11 +1040,11 @@
pass # already removed, can safely ignore (presumably)
if how & SHUT_WR:
self._can_write = False
-
+
def _readable(self):
if self.socket_type == CLIENT_SOCKET or self.socket_type == DATAGRAM_SOCKET:
log.debug("Incoming head=%s queue=%s", self.incoming_head, self.incoming, extra={"sock": self})
- return (
+ return bool(
(self.incoming_head is not None and self.incoming_head.readableBytes()) or
self.incoming.peek())
elif self.socket_type == SERVER_SOCKET:
@@ -1329,6 +1345,7 @@
self.active = AtomicBoolean()
self.active_latch = CountDownLatch(1)
self.accepted = False
+ self.timeout = parent_socket.timeout
def _ensure_post_connect(self):
do_post_connect = not self.active.getAndSet(True)
diff --git a/Lib/_sslcerts.py b/Lib/_sslcerts.py
--- a/Lib/_sslcerts.py
+++ b/Lib/_sslcerts.py
@@ -31,33 +31,9 @@
log = logging.getLogger("ssl")
-# FIXME what happens if reloaded?
Security.addProvider(BouncyCastleProvider())
-# build the necessary certificate with a CertificateFactory; this can take the pem format:
-# http://docs.oracle.com/javase/7/docs/api/java/security/cert/CertificateFactory.html#generateCertificate(java.io.InputStream)
-
-# not certain if we can include a private key in the pem file; see
-# http://stackoverflow.com/questions/7216969/getting-rsa-private-key-from-pem-base64-encoded-private-key-file
-
-
-# helpful advice for being able to manage ca_certs outside of Java's keystore
-# specifically the example ReloadableX509TrustManager
-# http://jcalcote.wordpress.com/2010/06/22/managing-a-dynamic-java-trust-store/
-
-# in the case of http://docs.python.org/2/library/ssl.html#ssl.CERT_REQUIRED
-
-# http://docs.python.org/2/library/ssl.html#ssl.CERT_NONE
-# https://github.com/rackerlabs/romper/blob/master/romper/trust.py#L15
-#
-# it looks like CERT_OPTIONAL simply validates certificates if
-# provided, probably something in checkServerTrusted - maybe a None
-# arg? need to verify as usual with a real system... :)
-
-# http://alesaudate.wordpress.com/2010/08/09/how-to-dynamically-select-a-certificate-alias-when-invoking-web-services/
-# is somewhat relevant for managing the keyfile, certfile
-
def _get_ca_certs_trust_manager(ca_certs):
trust_store = KeyStore.getInstance(KeyStore.getDefaultType())
diff --git a/Lib/email/test/test_email.py b/Lib/email/test/test_email.py
new file mode 100644
--- /dev/null
+++ b/Lib/email/test/test_email.py
@@ -0,0 +1,3561 @@
+# Copyright (C) 2001-2010 Python Software Foundation
+# Contact: email-sig at python.org
+# email package unit tests
+
+import os
+import sys
+import time
+import base64
+import difflib
+import unittest
+import warnings
+import textwrap
+from cStringIO import StringIO
+
+import email
+
+from email.Charset import Charset
+from email.Header import Header, decode_header, make_header
+from email.Parser import Parser, HeaderParser
+from email.Generator import Generator, DecodedGenerator
+from email.Message import Message
+from email.MIMEAudio import MIMEAudio
+from email.MIMEText import MIMEText
+from email.MIMEImage import MIMEImage
+from email.MIMEBase import MIMEBase
+from email.MIMEMessage import MIMEMessage
+from email.MIMEMultipart import MIMEMultipart
+from email import Utils
+from email import Errors
+from email import Encoders
+from email import Iterators
+from email import base64MIME
+from email import quopriMIME
+
+from test.test_support import findfile, run_unittest
+from email.test import __file__ as landmark
+from test.test_support import is_jython
+
+NL = '\n'
+EMPTYSTRING = ''
+SPACE = ' '
+
+
+
+def openfile(filename, mode='r'):
+ path = os.path.join(os.path.dirname(landmark), 'data', filename)
+ return open(path, mode)
+
+
+
+# Base test class
+class TestEmailBase(unittest.TestCase):
+ def ndiffAssertEqual(self, first, second):
+ """Like assertEqual except use ndiff for readable output."""
+ if first != second:
+ sfirst = str(first)
+ ssecond = str(second)
+ diff = difflib.ndiff(sfirst.splitlines(), ssecond.splitlines())
+ fp = StringIO()
+ print >> fp, NL, NL.join(diff)
+ raise self.failureException, fp.getvalue()
+
+ def _msgobj(self, filename):
+ fp = openfile(findfile(filename))
+ try:
+ msg = email.message_from_file(fp)
+ finally:
+ fp.close()
+ return msg
+
+
+
+# Test various aspects of the Message class's API
+class TestMessageAPI(TestEmailBase):
+ def test_get_all(self):
+ eq = self.assertEqual
+ msg = self._msgobj('msg_20.txt')
+ eq(msg.get_all('cc'), ['ccc at zzz.org', 'ddd at zzz.org', 'eee at zzz.org'])
+ eq(msg.get_all('xx', 'n/a'), 'n/a')
+
+ def test_getset_charset(self):
+ eq = self.assertEqual
+ msg = Message()
+ eq(msg.get_charset(), None)
+ charset = Charset('iso-8859-1')
+ msg.set_charset(charset)
+ eq(msg['mime-version'], '1.0')
+ eq(msg.get_content_type(), 'text/plain')
+ eq(msg['content-type'], 'text/plain; charset="iso-8859-1"')
+ eq(msg.get_param('charset'), 'iso-8859-1')
+ eq(msg['content-transfer-encoding'], 'quoted-printable')
+ eq(msg.get_charset().input_charset, 'iso-8859-1')
+ # Remove the charset
+ msg.set_charset(None)
+ eq(msg.get_charset(), None)
+ eq(msg['content-type'], 'text/plain')
+ # Try adding a charset when there's already MIME headers present
+ msg = Message()
+ msg['MIME-Version'] = '2.0'
+ msg['Content-Type'] = 'text/x-weird'
+ msg['Content-Transfer-Encoding'] = 'quinted-puntable'
+ msg.set_charset(charset)
+ eq(msg['mime-version'], '2.0')
+ eq(msg['content-type'], 'text/x-weird; charset="iso-8859-1"')
+ eq(msg['content-transfer-encoding'], 'quinted-puntable')
+
+ def test_set_charset_from_string(self):
+ eq = self.assertEqual
+ msg = Message()
+ msg.set_charset('us-ascii')
+ eq(msg.get_charset().input_charset, 'us-ascii')
+ eq(msg['content-type'], 'text/plain; charset="us-ascii"')
+
+ def test_set_payload_with_charset(self):
+ msg = Message()
+ charset = Charset('iso-8859-1')
+ msg.set_payload('This is a string payload', charset)
+ self.assertEqual(msg.get_charset().input_charset, 'iso-8859-1')
+
+ def test_get_charsets(self):
+ eq = self.assertEqual
+
+ msg = self._msgobj('msg_08.txt')
+ charsets = msg.get_charsets()
+ eq(charsets, [None, 'us-ascii', 'iso-8859-1', 'iso-8859-2', 'koi8-r'])
+
+ msg = self._msgobj('msg_09.txt')
+ charsets = msg.get_charsets('dingbat')
+ eq(charsets, ['dingbat', 'us-ascii', 'iso-8859-1', 'dingbat',
+ 'koi8-r'])
+
+ msg = self._msgobj('msg_12.txt')
+ charsets = msg.get_charsets()
+ eq(charsets, [None, 'us-ascii', 'iso-8859-1', None, 'iso-8859-2',
+ 'iso-8859-3', 'us-ascii', 'koi8-r'])
+
+ def test_get_filename(self):
+ eq = self.assertEqual
+
+ msg = self._msgobj('msg_04.txt')
+ filenames = [p.get_filename() for p in msg.get_payload()]
+ eq(filenames, ['msg.txt', 'msg.txt'])
+
+ msg = self._msgobj('msg_07.txt')
+ subpart = msg.get_payload(1)
+ eq(subpart.get_filename(), 'dingusfish.gif')
+
+ def test_get_filename_with_name_parameter(self):
+ eq = self.assertEqual
+
+ msg = self._msgobj('msg_44.txt')
+ filenames = [p.get_filename() for p in msg.get_payload()]
+ eq(filenames, ['msg.txt', 'msg.txt'])
+
+ def test_get_boundary(self):
+ eq = self.assertEqual
+ msg = self._msgobj('msg_07.txt')
+ # No quotes!
+ eq(msg.get_boundary(), 'BOUNDARY')
+
+ def test_set_boundary(self):
+ eq = self.assertEqual
+ # This one has no existing boundary parameter, but the Content-Type:
+ # header appears fifth.
+ msg = self._msgobj('msg_01.txt')
+ msg.set_boundary('BOUNDARY')
+ header, value = msg.items()[4]
+ eq(header.lower(), 'content-type')
+ eq(value, 'text/plain; charset="us-ascii"; boundary="BOUNDARY"')
+ # This one has a Content-Type: header, with a boundary, stuck in the
+ # middle of its headers. Make sure the order is preserved; it should
+ # be fifth.
+ msg = self._msgobj('msg_04.txt')
+ msg.set_boundary('BOUNDARY')
+ header, value = msg.items()[4]
+ eq(header.lower(), 'content-type')
+ eq(value, 'multipart/mixed; boundary="BOUNDARY"')
+ # And this one has no Content-Type: header at all.
+ msg = self._msgobj('msg_03.txt')
+ self.assertRaises(Errors.HeaderParseError,
+ msg.set_boundary, 'BOUNDARY')
+
+ def test_make_boundary(self):
+ msg = MIMEMultipart('form-data')
+ # Note that when the boundary gets created is an implementation
+ # detail and might change.
+ self.assertEqual(msg.items()[0][1], 'multipart/form-data')
+ # Trigger creation of boundary
+ msg.as_string()
+ self.assertEqual(msg.items()[0][1][:33],
+ 'multipart/form-data; boundary="==')
+ # XXX: there ought to be tests of the uniqueness of the boundary, too.
+
+ def test_message_rfc822_only(self):
+ # Issue 7970: message/rfc822 not in multipart parsed by
+ # HeaderParser caused an exception when flattened.
+ fp = openfile(findfile('msg_46.txt'))
+ msgdata = fp.read()
+ parser = email.Parser.HeaderParser()
+ msg = parser.parsestr(msgdata)
+ out = StringIO()
+ gen = email.Generator.Generator(out, True, 0)
+ gen.flatten(msg, False)
+ self.assertEqual(out.getvalue(), msgdata)
+
+ def test_get_decoded_payload(self):
+ eq = self.assertEqual
+ msg = self._msgobj('msg_10.txt')
+ # The outer message is a multipart
+ eq(msg.get_payload(decode=True), None)
+ # Subpart 1 is 7bit encoded
+ eq(msg.get_payload(0).get_payload(decode=True),
+ 'This is a 7bit encoded message.\n')
+ # Subpart 2 is quopri
+ eq(msg.get_payload(1).get_payload(decode=True),
+ '\xa1This is a Quoted Printable encoded message!\n')
+ # Subpart 3 is base64
+ eq(msg.get_payload(2).get_payload(decode=True),
+ 'This is a Base64 encoded message.')
+ # Subpart 4 is base64 with a trailing newline, which
+ # used to be stripped (issue 7143).
+ eq(msg.get_payload(3).get_payload(decode=True),
+ 'This is a Base64 encoded message.\n')
+ # Subpart 5 has no Content-Transfer-Encoding: header.
+ eq(msg.get_payload(4).get_payload(decode=True),
+ 'This has no Content-Transfer-Encoding: header.\n')
+
+ def test_get_decoded_uu_payload(self):
+ eq = self.assertEqual
+ msg = Message()
+ msg.set_payload('begin 666 -\n+:&5L;&\\@=V]R;&0 \n \nend\n')
+ for cte in ('x-uuencode', 'uuencode', 'uue', 'x-uue'):
+ msg['content-transfer-encoding'] = cte
+ eq(msg.get_payload(decode=True), 'hello world')
+ # Now try some bogus data
+ msg.set_payload('foo')
+ eq(msg.get_payload(decode=True), 'foo')
+
+ def test_decode_bogus_uu_payload_quietly(self):
+ msg = Message()
+ msg.set_payload('begin 664 foo.txt\n%<W1F=0000H \n \nend\n')
+ msg['Content-Transfer-Encoding'] = 'x-uuencode'
+ old_stderr = sys.stderr
+ try:
+ sys.stderr = sfp = StringIO()
+ # We don't care about the payload
+ msg.get_payload(decode=True)
+ finally:
+ sys.stderr = old_stderr
+ self.assertEqual(sfp.getvalue(), '')
+
+ def test_decoded_generator(self):
+ eq = self.assertEqual
+ msg = self._msgobj('msg_07.txt')
+ fp = openfile('msg_17.txt')
+ try:
+ text = fp.read()
+ finally:
+ fp.close()
+ s = StringIO()
+ g = DecodedGenerator(s)
+ g.flatten(msg)
+ eq(s.getvalue(), text)
+
+ def test__contains__(self):
+ msg = Message()
+ msg['From'] = 'Me'
+ msg['to'] = 'You'
+ # Check for case insensitivity
+ self.assertTrue('from' in msg)
+ self.assertTrue('From' in msg)
+ self.assertTrue('FROM' in msg)
+ self.assertTrue('to' in msg)
+ self.assertTrue('To' in msg)
+ self.assertTrue('TO' in msg)
+
+ def test_as_string(self):
+ eq = self.assertEqual
+ msg = self._msgobj('msg_01.txt')
+ fp = openfile('msg_01.txt')
+ try:
+ # BAW 30-Mar-2009 Evil be here. So, the generator is broken with
+ # respect to long line breaking. It's also not idempotent when a
+ # header from a parsed message is continued with tabs rather than
+ # spaces. Before we fixed bug 1974 it was reversedly broken,
+ # i.e. headers that were continued with spaces got continued with
+ # tabs. For Python 2.x there's really no good fix and in Python
+ # 3.x all this stuff is re-written to be right(er). Chris Withers
+ # convinced me that using space as the default continuation
+ # character is less bad for more applications.
+ text = fp.read().replace('\t', ' ')
+ finally:
+ fp.close()
+ eq(text, msg.as_string())
+ fullrepr = str(msg)
+ lines = fullrepr.split('\n')
+ self.assertTrue(lines[0].startswith('From '))
+ eq(text, NL.join(lines[1:]))
+
+ def test_bad_param(self):
+ msg = email.message_from_string("Content-Type: blarg; baz; boo\n")
+ self.assertEqual(msg.get_param('baz'), '')
+
+ def test_missing_filename(self):
+ msg = email.message_from_string("From: foo\n")
+ self.assertEqual(msg.get_filename(), None)
+
+ def test_bogus_filename(self):
+ msg = email.message_from_string(
+ "Content-Disposition: blarg; filename\n")
+ self.assertEqual(msg.get_filename(), '')
+
+ def test_missing_boundary(self):
+ msg = email.message_from_string("From: foo\n")
+ self.assertEqual(msg.get_boundary(), None)
+
+ def test_get_params(self):
+ eq = self.assertEqual
+ msg = email.message_from_string(
+ 'X-Header: foo=one; bar=two; baz=three\n')
+ eq(msg.get_params(header='x-header'),
+ [('foo', 'one'), ('bar', 'two'), ('baz', 'three')])
+ msg = email.message_from_string(
+ 'X-Header: foo; bar=one; baz=two\n')
+ eq(msg.get_params(header='x-header'),
+ [('foo', ''), ('bar', 'one'), ('baz', 'two')])
+ eq(msg.get_params(), None)
+ msg = email.message_from_string(
+ 'X-Header: foo; bar="one"; baz=two\n')
+ eq(msg.get_params(header='x-header'),
+ [('foo', ''), ('bar', 'one'), ('baz', 'two')])
+
+ def test_get_param_liberal(self):
+ msg = Message()
+ msg['Content-Type'] = 'Content-Type: Multipart/mixed; boundary = "CPIMSSMTPC06p5f3tG"'
+ self.assertEqual(msg.get_param('boundary'), 'CPIMSSMTPC06p5f3tG')
+
+ def test_get_param(self):
+ eq = self.assertEqual
+ msg = email.message_from_string(
+ "X-Header: foo=one; bar=two; baz=three\n")
+ eq(msg.get_param('bar', header='x-header'), 'two')
+ eq(msg.get_param('quuz', header='x-header'), None)
+ eq(msg.get_param('quuz'), None)
+ msg = email.message_from_string(
+ 'X-Header: foo; bar="one"; baz=two\n')
+ eq(msg.get_param('foo', header='x-header'), '')
+ eq(msg.get_param('bar', header='x-header'), 'one')
+ eq(msg.get_param('baz', header='x-header'), 'two')
+ # XXX: We are not RFC-2045 compliant! We cannot parse:
+ # msg["Content-Type"] = 'text/plain; weird="hey; dolly? [you] @ <\\"home\\">?"'
+ # msg.get_param("weird")
+ # yet.
+
+ def test_get_param_funky_continuation_lines(self):
+ msg = self._msgobj('msg_22.txt')
+ self.assertEqual(msg.get_payload(1).get_param('name'), 'wibble.JPG')
+
+ def test_get_param_with_semis_in_quotes(self):
+ msg = email.message_from_string(
+ 'Content-Type: image/pjpeg; name="Jim&&Jill"\n')
+ self.assertEqual(msg.get_param('name'), 'Jim&&Jill')
+ self.assertEqual(msg.get_param('name', unquote=False),
+ '"Jim&&Jill"')
+
+ def test_get_param_with_quotes(self):
+ msg = email.message_from_string(
+ 'Content-Type: foo; bar*0="baz\\"foobar"; bar*1="\\"baz"')
+ self.assertEqual(msg.get_param('bar'), 'baz"foobar"baz')
+ msg = email.message_from_string(
+ "Content-Type: foo; bar*0=\"baz\\\"foobar\"; bar*1=\"\\\"baz\"")
+ self.assertEqual(msg.get_param('bar'), 'baz"foobar"baz')
+
+ def test_has_key(self):
+ msg = email.message_from_string('Header: exists')
+ self.assertTrue(msg.has_key('header'))
+ self.assertTrue(msg.has_key('Header'))
+ self.assertTrue(msg.has_key('HEADER'))
+ self.assertFalse(msg.has_key('headeri'))
+
+ def test_set_param(self):
+ eq = self.assertEqual
+ msg = Message()
+ msg.set_param('charset', 'iso-2022-jp')
+ eq(msg.get_param('charset'), 'iso-2022-jp')
+ msg.set_param('importance', 'high value')
+ eq(msg.get_param('importance'), 'high value')
+ eq(msg.get_param('importance', unquote=False), '"high value"')
+ eq(msg.get_params(), [('text/plain', ''),
+ ('charset', 'iso-2022-jp'),
+ ('importance', 'high value')])
+ eq(msg.get_params(unquote=False), [('text/plain', ''),
+ ('charset', '"iso-2022-jp"'),
+ ('importance', '"high value"')])
+ msg.set_param('charset', 'iso-9999-xx', header='X-Jimmy')
+ eq(msg.get_param('charset', header='X-Jimmy'), 'iso-9999-xx')
+
+ def test_del_param(self):
+ eq = self.assertEqual
+ msg = self._msgobj('msg_05.txt')
+ eq(msg.get_params(),
+ [('multipart/report', ''), ('report-type', 'delivery-status'),
+ ('boundary', 'D1690A7AC1.996856090/mail.example.com')])
+ old_val = msg.get_param("report-type")
+ msg.del_param("report-type")
+ eq(msg.get_params(),
+ [('multipart/report', ''),
+ ('boundary', 'D1690A7AC1.996856090/mail.example.com')])
+ msg.set_param("report-type", old_val)
+ eq(msg.get_params(),
+ [('multipart/report', ''),
+ ('boundary', 'D1690A7AC1.996856090/mail.example.com'),
+ ('report-type', old_val)])
+
+ def test_del_param_on_other_header(self):
+ msg = Message()
+ msg.add_header('Content-Disposition', 'attachment', filename='bud.gif')
+ msg.del_param('filename', 'content-disposition')
+ self.assertEqual(msg['content-disposition'], 'attachment')
+
+ def test_set_type(self):
+ eq = self.assertEqual
+ msg = Message()
+ self.assertRaises(ValueError, msg.set_type, 'text')
+ msg.set_type('text/plain')
+ eq(msg['content-type'], 'text/plain')
+ msg.set_param('charset', 'us-ascii')
+ eq(msg['content-type'], 'text/plain; charset="us-ascii"')
+ msg.set_type('text/html')
+ eq(msg['content-type'], 'text/html; charset="us-ascii"')
+
+ def test_set_type_on_other_header(self):
+ msg = Message()
+ msg['X-Content-Type'] = 'text/plain'
+ msg.set_type('application/octet-stream', 'X-Content-Type')
+ self.assertEqual(msg['x-content-type'], 'application/octet-stream')
+
+ def test_get_content_type_missing(self):
+ msg = Message()
+ self.assertEqual(msg.get_content_type(), 'text/plain')
+
+ def test_get_content_type_missing_with_default_type(self):
+ msg = Message()
+ msg.set_default_type('message/rfc822')
+ self.assertEqual(msg.get_content_type(), 'message/rfc822')
+
+ def test_get_content_type_from_message_implicit(self):
+ msg = self._msgobj('msg_30.txt')
+ self.assertEqual(msg.get_payload(0).get_content_type(),
+ 'message/rfc822')
+
+ def test_get_content_type_from_message_explicit(self):
+ msg = self._msgobj('msg_28.txt')
+ self.assertEqual(msg.get_payload(0).get_content_type(),
+ 'message/rfc822')
+
+ def test_get_content_type_from_message_text_plain_implicit(self):
+ msg = self._msgobj('msg_03.txt')
+ self.assertEqual(msg.get_content_type(), 'text/plain')
+
+ def test_get_content_type_from_message_text_plain_explicit(self):
+ msg = self._msgobj('msg_01.txt')
+ self.assertEqual(msg.get_content_type(), 'text/plain')
+
+ def test_get_content_maintype_missing(self):
+ msg = Message()
+ self.assertEqual(msg.get_content_maintype(), 'text')
+
+ def test_get_content_maintype_missing_with_default_type(self):
+ msg = Message()
+ msg.set_default_type('message/rfc822')
+ self.assertEqual(msg.get_content_maintype(), 'message')
+
+ def test_get_content_maintype_from_message_implicit(self):
+ msg = self._msgobj('msg_30.txt')
+ self.assertEqual(msg.get_payload(0).get_content_maintype(), 'message')
+
+ def test_get_content_maintype_from_message_explicit(self):
+ msg = self._msgobj('msg_28.txt')
+ self.assertEqual(msg.get_payload(0).get_content_maintype(), 'message')
+
+ def test_get_content_maintype_from_message_text_plain_implicit(self):
+ msg = self._msgobj('msg_03.txt')
+ self.assertEqual(msg.get_content_maintype(), 'text')
+
+ def test_get_content_maintype_from_message_text_plain_explicit(self):
+ msg = self._msgobj('msg_01.txt')
+ self.assertEqual(msg.get_content_maintype(), 'text')
+
+ def test_get_content_subtype_missing(self):
+ msg = Message()
+ self.assertEqual(msg.get_content_subtype(), 'plain')
+
+ def test_get_content_subtype_missing_with_default_type(self):
+ msg = Message()
+ msg.set_default_type('message/rfc822')
+ self.assertEqual(msg.get_content_subtype(), 'rfc822')
+
+ def test_get_content_subtype_from_message_implicit(self):
+ msg = self._msgobj('msg_30.txt')
+ self.assertEqual(msg.get_payload(0).get_content_subtype(), 'rfc822')
+
+ def test_get_content_subtype_from_message_explicit(self):
+ msg = self._msgobj('msg_28.txt')
+ self.assertEqual(msg.get_payload(0).get_content_subtype(), 'rfc822')
+
+ def test_get_content_subtype_from_message_text_plain_implicit(self):
+ msg = self._msgobj('msg_03.txt')
+ self.assertEqual(msg.get_content_subtype(), 'plain')
+
+ def test_get_content_subtype_from_message_text_plain_explicit(self):
+ msg = self._msgobj('msg_01.txt')
+ self.assertEqual(msg.get_content_subtype(), 'plain')
+
+ def test_get_content_maintype_error(self):
+ msg = Message()
+ msg['Content-Type'] = 'no-slash-in-this-string'
+ self.assertEqual(msg.get_content_maintype(), 'text')
+
+ def test_get_content_subtype_error(self):
+ msg = Message()
+ msg['Content-Type'] = 'no-slash-in-this-string'
+ self.assertEqual(msg.get_content_subtype(), 'plain')
+
+ def test_replace_header(self):
+ eq = self.assertEqual
+ msg = Message()
+ msg.add_header('First', 'One')
+ msg.add_header('Second', 'Two')
+ msg.add_header('Third', 'Three')
+ eq(msg.keys(), ['First', 'Second', 'Third'])
+ eq(msg.values(), ['One', 'Two', 'Three'])
+ msg.replace_header('Second', 'Twenty')
+ eq(msg.keys(), ['First', 'Second', 'Third'])
+ eq(msg.values(), ['One', 'Twenty', 'Three'])
+ msg.add_header('First', 'Eleven')
+ msg.replace_header('First', 'One Hundred')
+ eq(msg.keys(), ['First', 'Second', 'Third', 'First'])
+ eq(msg.values(), ['One Hundred', 'Twenty', 'Three', 'Eleven'])
+ self.assertRaises(KeyError, msg.replace_header, 'Fourth', 'Missing')
+
+ def test_broken_base64_payload(self):
+ x = 'AwDp0P7//y6LwKEAcPa/6Q=9'
+ msg = Message()
+ msg['content-type'] = 'audio/x-midi'
+ msg['content-transfer-encoding'] = 'base64'
+ msg.set_payload(x)
+ self.assertEqual(msg.get_payload(decode=True), x)
+
+ def test_get_content_charset(self):
+ msg = Message()
+ msg.set_charset('us-ascii')
+ self.assertEqual('us-ascii', msg.get_content_charset())
+ msg.set_charset(u'us-ascii')
+ self.assertEqual('us-ascii', msg.get_content_charset())
+
+ # Issue 5871: reject an attempt to embed a header inside a header value
+ # (header injection attack).
+ def test_embeded_header_via_Header_rejected(self):
+ msg = Message()
+ msg['Dummy'] = Header('dummy\nX-Injected-Header: test')
+ self.assertRaises(Errors.HeaderParseError, msg.as_string)
+
+ def test_embeded_header_via_string_rejected(self):
+ msg = Message()
+ msg['Dummy'] = 'dummy\nX-Injected-Header: test'
+ self.assertRaises(Errors.HeaderParseError, msg.as_string)
+
+
+# Test the email.Encoders module
+class TestEncoders(unittest.TestCase):
+ def test_encode_empty_payload(self):
+ eq = self.assertEqual
+ msg = Message()
+ msg.set_charset('us-ascii')
+ eq(msg['content-transfer-encoding'], '7bit')
+
+ def test_default_cte(self):
+ eq = self.assertEqual
+ # 7bit data and the default us-ascii _charset
+ msg = MIMEText('hello world')
+ eq(msg['content-transfer-encoding'], '7bit')
+ # Similar, but with 8bit data
+ msg = MIMEText('hello \xf8 world')
+ eq(msg['content-transfer-encoding'], '8bit')
+ # And now with a different charset
+ msg = MIMEText('hello \xf8 world', _charset='iso-8859-1')
+ eq(msg['content-transfer-encoding'], 'quoted-printable')
+
+ def test_encode7or8bit(self):
+ # Make sure a charset whose input character set is 8bit but
+ # whose output character set is 7bit gets a transfer-encoding
+ # of 7bit.
+ eq = self.assertEqual
+ msg = email.MIMEText.MIMEText('\xca\xb8', _charset='euc-jp')
+ eq(msg['content-transfer-encoding'], '7bit')
+
+
+# Test long header wrapping
+class TestLongHeaders(TestEmailBase):
+ def test_split_long_continuation(self):
+ eq = self.ndiffAssertEqual
+ msg = email.message_from_string("""\
+Subject: bug demonstration
+\t12345678911234567892123456789312345678941234567895123456789612345678971234567898112345678911234567892123456789112345678911234567892123456789
+\tmore text
+
+test
+""")
+ sfp = StringIO()
+ g = Generator(sfp)
+ g.flatten(msg)
+ eq(sfp.getvalue(), """\
+Subject: bug demonstration
+ 12345678911234567892123456789312345678941234567895123456789612345678971234567898112345678911234567892123456789112345678911234567892123456789
+ more text
+
+test
+""")
+
+ def test_another_long_almost_unsplittable_header(self):
+ eq = self.ndiffAssertEqual
+ hstr = """\
+bug demonstration
+\t12345678911234567892123456789312345678941234567895123456789612345678971234567898112345678911234567892123456789112345678911234567892123456789
+\tmore text"""
+ h = Header(hstr, continuation_ws='\t')
+ eq(h.encode(), """\
+bug demonstration
+\t12345678911234567892123456789312345678941234567895123456789612345678971234567898112345678911234567892123456789112345678911234567892123456789
+\tmore text""")
+ h = Header(hstr)
+ eq(h.encode(), """\
+bug demonstration
+ 12345678911234567892123456789312345678941234567895123456789612345678971234567898112345678911234567892123456789112345678911234567892123456789
+ more text""")
+
+ def test_long_nonstring(self):
+ eq = self.ndiffAssertEqual
+ g = Charset("iso-8859-1")
+ cz = Charset("iso-8859-2")
+ utf8 = Charset("utf-8")
+ g_head = "Die Mieter treten hier ein werden mit einem Foerderband komfortabel den Korridor entlang, an s\xfcdl\xfcndischen Wandgem\xe4lden vorbei, gegen die rotierenden Klingen bef\xf6rdert. "
+ cz_head = "Finan\xe8ni metropole se hroutily pod tlakem jejich d\xf9vtipu.. "
+ utf8_head = u"\u6b63\u78ba\u306b\u8a00\u3046\u3068\u7ffb\u8a33\u306f\u3055\u308c\u3066\u3044\u307e\u305b\u3093\u3002\u4e00\u90e8\u306f\u30c9\u30a4\u30c4\u8a9e\u3067\u3059\u304c\u3001\u3042\u3068\u306f\u3067\u305f\u3089\u3081\u3067\u3059\u3002\u5b9f\u969b\u306b\u306f\u300cWenn ist das Nunstuck git und Slotermeyer? Ja! Beiherhund das Oder die Flipperwaldt gersput.\u300d\u3068\u8a00\u3063\u3066\u3044\u307e\u3059\u3002".encode("utf-8")
+ h = Header(g_head, g, header_name='Subject')
+ h.append(cz_head, cz)
+ h.append(utf8_head, utf8)
+ msg = Message()
+ msg['Subject'] = h
+ sfp = StringIO()
+ g = Generator(sfp)
+ g.flatten(msg)
+ eq(sfp.getvalue(), """\
+Subject: =?iso-8859-1?q?Die_Mieter_treten_hier_ein_werden_mit_einem_Foerd?=
+ =?iso-8859-1?q?erband_komfortabel_den_Korridor_entlang=2C_an_s=FCdl=FCndi?=
+ =?iso-8859-1?q?schen_Wandgem=E4lden_vorbei=2C_gegen_die_rotierenden_Kling?=
+ =?iso-8859-1?q?en_bef=F6rdert=2E_?= =?iso-8859-2?q?Finan=E8ni_met?=
+ =?iso-8859-2?q?ropole_se_hroutily_pod_tlakem_jejich_d=F9vtipu=2E=2E_?=
+ =?utf-8?b?5q2j56K644Gr6KiA44GG44Go57+76Kiz44Gv44GV44KM44Gm44GE?=
+ =?utf-8?b?44G+44Gb44KT44CC5LiA6YOo44Gv44OJ44Kk44OE6Kqe44Gn44GZ44GM44CB?=
+ =?utf-8?b?44GC44Go44Gv44Gn44Gf44KJ44KB44Gn44GZ44CC5a6f6Zqb44Gr44Gv44CM?=
+ =?utf-8?q?Wenn_ist_das_Nunstuck_git_und_Slotermeyer=3F_Ja!_Beiherhund_das?=
+ =?utf-8?b?IE9kZXIgZGllIEZsaXBwZXJ3YWxkdCBnZXJzcHV0LuOAjeOBqOiogOOBow==?=
+ =?utf-8?b?44Gm44GE44G+44GZ44CC?=
+
+""")
+ eq(h.encode(), """\
+=?iso-8859-1?q?Die_Mieter_treten_hier_ein_werden_mit_einem_Foerd?=
+ =?iso-8859-1?q?erband_komfortabel_den_Korridor_entlang=2C_an_s=FCdl=FCndi?=
+ =?iso-8859-1?q?schen_Wandgem=E4lden_vorbei=2C_gegen_die_rotierenden_Kling?=
+ =?iso-8859-1?q?en_bef=F6rdert=2E_?= =?iso-8859-2?q?Finan=E8ni_met?=
+ =?iso-8859-2?q?ropole_se_hroutily_pod_tlakem_jejich_d=F9vtipu=2E=2E_?=
+ =?utf-8?b?5q2j56K644Gr6KiA44GG44Go57+76Kiz44Gv44GV44KM44Gm44GE?=
+ =?utf-8?b?44G+44Gb44KT44CC5LiA6YOo44Gv44OJ44Kk44OE6Kqe44Gn44GZ44GM44CB?=
+ =?utf-8?b?44GC44Go44Gv44Gn44Gf44KJ44KB44Gn44GZ44CC5a6f6Zqb44Gr44Gv44CM?=
+ =?utf-8?q?Wenn_ist_das_Nunstuck_git_und_Slotermeyer=3F_Ja!_Beiherhund_das?=
+ =?utf-8?b?IE9kZXIgZGllIEZsaXBwZXJ3YWxkdCBnZXJzcHV0LuOAjeOBqOiogOOBow==?=
+ =?utf-8?b?44Gm44GE44G+44GZ44CC?=""")
+
+ def test_long_header_encode(self):
+ eq = self.ndiffAssertEqual
+ h = Header('wasnipoop; giraffes="very-long-necked-animals"; '
+ 'spooge="yummy"; hippos="gargantuan"; marshmallows="gooey"',
+ header_name='X-Foobar-Spoink-Defrobnit')
+ eq(h.encode(), '''\
+wasnipoop; giraffes="very-long-necked-animals";
+ spooge="yummy"; hippos="gargantuan"; marshmallows="gooey"''')
+
+ def test_long_header_encode_with_tab_continuation(self):
+ eq = self.ndiffAssertEqual
+ h = Header('wasnipoop; giraffes="very-long-necked-animals"; '
+ 'spooge="yummy"; hippos="gargantuan"; marshmallows="gooey"',
+ header_name='X-Foobar-Spoink-Defrobnit',
+ continuation_ws='\t')
+ eq(h.encode(), '''\
+wasnipoop; giraffes="very-long-necked-animals";
+\tspooge="yummy"; hippos="gargantuan"; marshmallows="gooey"''')
+
+ def test_header_splitter(self):
+ eq = self.ndiffAssertEqual
+ msg = MIMEText('')
+ # It'd be great if we could use add_header() here, but that doesn't
+ # guarantee an order of the parameters.
+ msg['X-Foobar-Spoink-Defrobnit'] = (
+ 'wasnipoop; giraffes="very-long-necked-animals"; '
+ 'spooge="yummy"; hippos="gargantuan"; marshmallows="gooey"')
+ sfp = StringIO()
+ g = Generator(sfp)
+ g.flatten(msg)
+ eq(sfp.getvalue(), '''\
+Content-Type: text/plain; charset="us-ascii"
+MIME-Version: 1.0
+Content-Transfer-Encoding: 7bit
+X-Foobar-Spoink-Defrobnit: wasnipoop; giraffes="very-long-necked-animals";
+ spooge="yummy"; hippos="gargantuan"; marshmallows="gooey"
+
+''')
+
+ def test_no_semis_header_splitter(self):
+ eq = self.ndiffAssertEqual
+ msg = Message()
+ msg['From'] = 'test at dom.ain'
+ msg['References'] = SPACE.join(['<%d at dom.ain>' % i for i in range(10)])
+ msg.set_payload('Test')
+ sfp = StringIO()
+ g = Generator(sfp)
+ g.flatten(msg)
+ eq(sfp.getvalue(), """\
+From: test at dom.ain
+References: <0 at dom.ain> <1 at dom.ain> <2 at dom.ain> <3 at dom.ain> <4 at dom.ain>
+ <5 at dom.ain> <6 at dom.ain> <7 at dom.ain> <8 at dom.ain> <9 at dom.ain>
+
+Test""")
+
+ def test_no_split_long_header(self):
+ eq = self.ndiffAssertEqual
+ hstr = 'References: ' + 'x' * 80
+ h = Header(hstr, continuation_ws='\t')
+ eq(h.encode(), """\
+References: xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx""")
+
+ def test_splitting_multiple_long_lines(self):
+ eq = self.ndiffAssertEqual
+ hstr = """\
+from babylon.socal-raves.org (localhost [127.0.0.1]); by babylon.socal-raves.org (Postfix) with ESMTP id B570E51B81; for <mailman-admin at babylon.socal-raves.org>; Sat, 2 Feb 2002 17:00:06 -0800 (PST)
+\tfrom babylon.socal-raves.org (localhost [127.0.0.1]); by babylon.socal-raves.org (Postfix) with ESMTP id B570E51B81; for <mailman-admin at babylon.socal-raves.org>; Sat, 2 Feb 2002 17:00:06 -0800 (PST)
+\tfrom babylon.socal-raves.org (localhost [127.0.0.1]); by babylon.socal-raves.org (Postfix) with ESMTP id B570E51B81; for <mailman-admin at babylon.socal-raves.org>; Sat, 2 Feb 2002 17:00:06 -0800 (PST)
+"""
+ h = Header(hstr, continuation_ws='\t')
+ eq(h.encode(), """\
+from babylon.socal-raves.org (localhost [127.0.0.1]);
+\tby babylon.socal-raves.org (Postfix) with ESMTP id B570E51B81;
+\tfor <mailman-admin at babylon.socal-raves.org>;
+\tSat, 2 Feb 2002 17:00:06 -0800 (PST)
+\tfrom babylon.socal-raves.org (localhost [127.0.0.1]);
+\tby babylon.socal-raves.org (Postfix) with ESMTP id B570E51B81;
+\tfor <mailman-admin at babylon.socal-raves.org>;
+\tSat, 2 Feb 2002 17:00:06 -0800 (PST)
+\tfrom babylon.socal-raves.org (localhost [127.0.0.1]);
+\tby babylon.socal-raves.org (Postfix) with ESMTP id B570E51B81;
+\tfor <mailman-admin at babylon.socal-raves.org>;
+\tSat, 2 Feb 2002 17:00:06 -0800 (PST)""")
+
+ def test_splitting_first_line_only_is_long(self):
+ eq = self.ndiffAssertEqual
+ hstr = """\
+from modemcable093.139-201-24.que.mc.videotron.ca ([24.201.139.93] helo=cthulhu.gerg.ca)
+\tby kronos.mems-exchange.org with esmtp (Exim 4.05)
+\tid 17k4h5-00034i-00
+\tfor test at mems-exchange.org; Wed, 28 Aug 2002 11:25:20 -0400"""
+ h = Header(hstr, maxlinelen=78, header_name='Received',
+ continuation_ws='\t')
+ eq(h.encode(), """\
+from modemcable093.139-201-24.que.mc.videotron.ca ([24.201.139.93]
+\thelo=cthulhu.gerg.ca)
+\tby kronos.mems-exchange.org with esmtp (Exim 4.05)
+\tid 17k4h5-00034i-00
+\tfor test at mems-exchange.org; Wed, 28 Aug 2002 11:25:20 -0400""")
+
+ def test_long_8bit_header(self):
+ eq = self.ndiffAssertEqual
+ msg = Message()
+ h = Header('Britische Regierung gibt', 'iso-8859-1',
+ header_name='Subject')
+ h.append('gr\xfcnes Licht f\xfcr Offshore-Windkraftprojekte')
+ msg['Subject'] = h
+ eq(msg.as_string(), """\
+Subject: =?iso-8859-1?q?Britische_Regierung_gibt?= =?iso-8859-1?q?gr=FCnes?=
+ =?iso-8859-1?q?_Licht_f=FCr_Offshore-Windkraftprojekte?=
+
+""")
+
+ def test_long_8bit_header_no_charset(self):
+ eq = self.ndiffAssertEqual
+ msg = Message()
+ msg['Reply-To'] = 'Britische Regierung gibt gr\xfcnes Licht f\xfcr Offshore-Windkraftprojekte <a-very-long-address at example.com>'
+ eq(msg.as_string(), """\
+Reply-To: Britische Regierung gibt gr\xfcnes Licht f\xfcr Offshore-Windkraftprojekte <a-very-long-address at example.com>
+
+""")
+
+ def test_long_to_header(self):
+ eq = self.ndiffAssertEqual
+ to = '"Someone Test #A" <someone at eecs.umich.edu>,<someone at eecs.umich.edu>,"Someone Test #B" <someone at umich.edu>, "Someone Test #C" <someone at eecs.umich.edu>, "Someone Test #D" <someone at eecs.umich.edu>'
+ msg = Message()
+ msg['To'] = to
+ eq(msg.as_string(0), '''\
+To: "Someone Test #A" <someone at eecs.umich.edu>, <someone at eecs.umich.edu>,
+ "Someone Test #B" <someone at umich.edu>,
+ "Someone Test #C" <someone at eecs.umich.edu>,
+ "Someone Test #D" <someone at eecs.umich.edu>
+
+''')
+
+ def test_long_line_after_append(self):
+ eq = self.ndiffAssertEqual
+ s = 'This is an example of string which has almost the limit of header length.'
+ h = Header(s)
+ h.append('Add another line.')
+ eq(h.encode(), """\
+This is an example of string which has almost the limit of header length.
+ Add another line.""")
+
+ def test_shorter_line_with_append(self):
+ eq = self.ndiffAssertEqual
+ s = 'This is a shorter line.'
+ h = Header(s)
+ h.append('Add another sentence. (Surprise?)')
+ eq(h.encode(),
+ 'This is a shorter line. Add another sentence. (Surprise?)')
+
+ def test_long_field_name(self):
+ eq = self.ndiffAssertEqual
+ fn = 'X-Very-Very-Very-Long-Header-Name'
+ gs = "Die Mieter treten hier ein werden mit einem Foerderband komfortabel den Korridor entlang, an s\xfcdl\xfcndischen Wandgem\xe4lden vorbei, gegen die rotierenden Klingen bef\xf6rdert. "
+ h = Header(gs, 'iso-8859-1', header_name=fn)
+ # BAW: this seems broken because the first line is too long
+ eq(h.encode(), """\
+=?iso-8859-1?q?Die_Mieter_treten_hier_?=
+ =?iso-8859-1?q?ein_werden_mit_einem_Foerderband_komfortabel_den_Korridor_?=
+ =?iso-8859-1?q?entlang=2C_an_s=FCdl=FCndischen_Wandgem=E4lden_vorbei=2C_g?=
+ =?iso-8859-1?q?egen_die_rotierenden_Klingen_bef=F6rdert=2E_?=""")
+
+ def test_long_received_header(self):
+ h = 'from FOO.TLD (vizworld.acl.foo.tld [123.452.678.9]) by hrothgar.la.mastaler.com (tmda-ofmipd) with ESMTP; Wed, 05 Mar 2003 18:10:18 -0700'
+ msg = Message()
+ msg['Received-1'] = Header(h, continuation_ws='\t')
+ msg['Received-2'] = h
+ self.assertEqual(msg.as_string(), """\
+Received-1: from FOO.TLD (vizworld.acl.foo.tld [123.452.678.9]) by
+\throthgar.la.mastaler.com (tmda-ofmipd) with ESMTP;
+\tWed, 05 Mar 2003 18:10:18 -0700
+Received-2: from FOO.TLD (vizworld.acl.foo.tld [123.452.678.9]) by
+ hrothgar.la.mastaler.com (tmda-ofmipd) with ESMTP;
+ Wed, 05 Mar 2003 18:10:18 -0700
+
+""")
+
+ def test_string_headerinst_eq(self):
+ h = '<15975.17901.207240.414604 at sgigritzmann1.mathematik.tu-muenchen.de> (David Bremner\'s message of "Thu, 6 Mar 2003 13:58:21 +0100")'
+ msg = Message()
+ msg['Received'] = Header(h, header_name='Received',
+ continuation_ws='\t')
+ msg['Received'] = h
+ self.ndiffAssertEqual(msg.as_string(), """\
+Received: <15975.17901.207240.414604 at sgigritzmann1.mathematik.tu-muenchen.de>
+\t(David Bremner's message of "Thu, 6 Mar 2003 13:58:21 +0100")
+Received: <15975.17901.207240.414604 at sgigritzmann1.mathematik.tu-muenchen.de>
+ (David Bremner's message of "Thu, 6 Mar 2003 13:58:21 +0100")
+
+""")
+
+ def test_long_unbreakable_lines_with_continuation(self):
+ eq = self.ndiffAssertEqual
+ msg = Message()
+ t = """\
+ iVBORw0KGgoAAAANSUhEUgAAADAAAAAwBAMAAAClLOS0AAAAGFBMVEUAAAAkHiJeRUIcGBi9
+ locQDQ4zJykFBAXJfWDjAAACYUlEQVR4nF2TQY/jIAyFc6lydlG5x8Nyp1Y69wj1PN2I5gzp"""
+ msg['Face-1'] = t
+ msg['Face-2'] = Header(t, header_name='Face-2')
+ eq(msg.as_string(), """\
+Face-1: iVBORw0KGgoAAAANSUhEUgAAADAAAAAwBAMAAAClLOS0AAAAGFBMVEUAAAAkHiJeRUIcGBi9
+ locQDQ4zJykFBAXJfWDjAAACYUlEQVR4nF2TQY/jIAyFc6lydlG5x8Nyp1Y69wj1PN2I5gzp
+Face-2: iVBORw0KGgoAAAANSUhEUgAAADAAAAAwBAMAAAClLOS0AAAAGFBMVEUAAAAkHiJeRUIcGBi9
+ locQDQ4zJykFBAXJfWDjAAACYUlEQVR4nF2TQY/jIAyFc6lydlG5x8Nyp1Y69wj1PN2I5gzp
+
+""")
+
+ def test_another_long_multiline_header(self):
+ eq = self.ndiffAssertEqual
+ m = '''\
+Received: from siimage.com ([172.25.1.3]) by zima.siliconimage.com with Microsoft SMTPSVC(5.0.2195.4905);
+ Wed, 16 Oct 2002 07:41:11 -0700'''
+ msg = email.message_from_string(m)
+ eq(msg.as_string(), '''\
+Received: from siimage.com ([172.25.1.3]) by zima.siliconimage.com with
+ Microsoft SMTPSVC(5.0.2195.4905); Wed, 16 Oct 2002 07:41:11 -0700
+
+''')
+
+ def test_long_lines_with_different_header(self):
+ eq = self.ndiffAssertEqual
+ h = """\
+List-Unsubscribe: <https://lists.sourceforge.net/lists/listinfo/spamassassin-talk>,
+ <mailto:spamassassin-talk-request at lists.sourceforge.net?subject=unsubscribe>"""
+ msg = Message()
+ msg['List'] = h
+ msg['List'] = Header(h, header_name='List')
+ eq(msg.as_string(), """\
+List: List-Unsubscribe: <https://lists.sourceforge.net/lists/listinfo/spamassassin-talk>,
+ <mailto:spamassassin-talk-request at lists.sourceforge.net?subject=unsubscribe>
+List: List-Unsubscribe: <https://lists.sourceforge.net/lists/listinfo/spamassassin-talk>,
+ <mailto:spamassassin-talk-request at lists.sourceforge.net?subject=unsubscribe>
+
+""")
+
+
+
+# Test mangling of "From " lines in the body of a message
+class TestFromMangling(unittest.TestCase):
+ def setUp(self):
+ self.msg = Message()
+ self.msg['From'] = 'aaa at bbb.org'
+ self.msg.set_payload("""\
+From the desk of A.A.A.:
+Blah blah blah
+""")
+
+ def test_mangled_from(self):
+ s = StringIO()
+ g = Generator(s, mangle_from_=True)
+ g.flatten(self.msg)
+ self.assertEqual(s.getvalue(), """\
+From: aaa at bbb.org
+
+>From the desk of A.A.A.:
+Blah blah blah
+""")
+
+ def test_dont_mangle_from(self):
+ s = StringIO()
+ g = Generator(s, mangle_from_=False)
+ g.flatten(self.msg)
+ self.assertEqual(s.getvalue(), """\
+From: aaa at bbb.org
+
+From the desk of A.A.A.:
+Blah blah blah
+""")
+
+ def test_mangle_from_in_preamble_and_epilog(self):
+ s = StringIO()
+ g = Generator(s, mangle_from_=True)
+ msg = email.message_from_string(textwrap.dedent("""\
+ From: foo at bar.com
+ Mime-Version: 1.0
+ Content-Type: multipart/mixed; boundary=XXX
+
+ From somewhere unknown
+
+ --XXX
+ Content-Type: text/plain
+
+ foo
+
+ --XXX--
+
+ From somewhere unknowable
+ """))
+ g.flatten(msg)
+ self.assertEqual(len([1 for x in s.getvalue().split('\n')
+ if x.startswith('>From ')]), 2)
+
+
+# Test the basic MIMEAudio class
+class TestMIMEAudio(unittest.TestCase):
+ def setUp(self):
+ # Make sure we pick up the audiotest.au that lives in email/test/data.
+ # In Python, there's an audiotest.au living in Lib/test but that isn't
+ # included in some binary distros that don't include the test
+ # package. The trailing empty string on the .join() is significant
+ # since findfile() will do a dirname().
+ datadir = os.path.join(os.path.dirname(landmark), 'data', '')
+ fp = open(findfile('audiotest.au', datadir), 'rb')
+ try:
+ self._audiodata = fp.read()
+ finally:
+ fp.close()
+ self._au = MIMEAudio(self._audiodata)
+
+ def test_guess_minor_type(self):
+ self.assertEqual(self._au.get_content_type(), 'audio/basic')
+
+ def test_encoding(self):
+ payload = self._au.get_payload()
+ self.assertEqual(base64.decodestring(payload), self._audiodata)
+
+ def test_checkSetMinor(self):
+ au = MIMEAudio(self._audiodata, 'fish')
+ self.assertEqual(au.get_content_type(), 'audio/fish')
+
+ def test_add_header(self):
+ eq = self.assertEqual
+ unless = self.assertTrue
+ self._au.add_header('Content-Disposition', 'attachment',
+ filename='audiotest.au')
+ eq(self._au['content-disposition'],
+ 'attachment; filename="audiotest.au"')
+ eq(self._au.get_params(header='content-disposition'),
+ [('attachment', ''), ('filename', 'audiotest.au')])
+ eq(self._au.get_param('filename', header='content-disposition'),
+ 'audiotest.au')
+ missing = []
+ eq(self._au.get_param('attachment', header='content-disposition'), '')
+ unless(self._au.get_param('foo', failobj=missing,
+ header='content-disposition') is missing)
+ # Try some missing stuff
+ unless(self._au.get_param('foobar', missing) is missing)
+ unless(self._au.get_param('attachment', missing,
+ header='foobar') is missing)
+
+
+
+# Test the basic MIMEImage class
+class TestMIMEImage(unittest.TestCase):
+ def setUp(self):
+ fp = openfile('PyBanner048.gif')
+ try:
+ self._imgdata = fp.read()
+ finally:
+ fp.close()
+ self._im = MIMEImage(self._imgdata)
+
+ def test_guess_minor_type(self):
+ self.assertEqual(self._im.get_content_type(), 'image/gif')
+
+ def test_encoding(self):
+ payload = self._im.get_payload()
+ self.assertEqual(base64.decodestring(payload), self._imgdata)
+
+ def test_checkSetMinor(self):
+ im = MIMEImage(self._imgdata, 'fish')
+ self.assertEqual(im.get_content_type(), 'image/fish')
+
+ def test_add_header(self):
+ eq = self.assertEqual
+ unless = self.assertTrue
+ self._im.add_header('Content-Disposition', 'attachment',
+ filename='dingusfish.gif')
+ eq(self._im['content-disposition'],
+ 'attachment; filename="dingusfish.gif"')
+ eq(self._im.get_params(header='content-disposition'),
+ [('attachment', ''), ('filename', 'dingusfish.gif')])
+ eq(self._im.get_param('filename', header='content-disposition'),
+ 'dingusfish.gif')
+ missing = []
+ eq(self._im.get_param('attachment', header='content-disposition'), '')
+ unless(self._im.get_param('foo', failobj=missing,
+ header='content-disposition') is missing)
+ # Try some missing stuff
+ unless(self._im.get_param('foobar', missing) is missing)
+ unless(self._im.get_param('attachment', missing,
+ header='foobar') is missing)
+
+
+
+# Test the basic MIMEText class
+class TestMIMEText(unittest.TestCase):
+ def setUp(self):
+ self._msg = MIMEText('hello there')
+
+ def test_types(self):
+ eq = self.assertEqual
+ unless = self.assertTrue
+ eq(self._msg.get_content_type(), 'text/plain')
+ eq(self._msg.get_param('charset'), 'us-ascii')
+ missing = []
+ unless(self._msg.get_param('foobar', missing) is missing)
+ unless(self._msg.get_param('charset', missing, header='foobar')
+ is missing)
+
+ def test_payload(self):
+ self.assertEqual(self._msg.get_payload(), 'hello there')
+ self.assertTrue(not self._msg.is_multipart())
+
+ def test_charset(self):
+ eq = self.assertEqual
+ msg = MIMEText('hello there', _charset='us-ascii')
+ eq(msg.get_charset().input_charset, 'us-ascii')
+ eq(msg['content-type'], 'text/plain; charset="us-ascii"')
+
+ def test_7bit_unicode_input(self):
+ eq = self.assertEqual
+ msg = MIMEText(u'hello there', _charset='us-ascii')
+ eq(msg.get_charset().input_charset, 'us-ascii')
+ eq(msg['content-type'], 'text/plain; charset="us-ascii"')
+
+ def test_7bit_unicode_input_no_charset(self):
+ eq = self.assertEqual
+ msg = MIMEText(u'hello there')
+ eq(msg.get_charset(), 'us-ascii')
+ eq(msg['content-type'], 'text/plain; charset="us-ascii"')
+ self.assertTrue('hello there' in msg.as_string())
+
+ def test_8bit_unicode_input(self):
+ teststr = u'\u043a\u0438\u0440\u0438\u043b\u0438\u0446\u0430'
+ eq = self.assertEqual
+ msg = MIMEText(teststr, _charset='utf-8')
+ eq(msg.get_charset().output_charset, 'utf-8')
+ eq(msg['content-type'], 'text/plain; charset="utf-8"')
+ eq(msg.get_payload(decode=True), teststr.encode('utf-8'))
+
+ def test_8bit_unicode_input_no_charset(self):
+ teststr = u'\u043a\u0438\u0440\u0438\u043b\u0438\u0446\u0430'
+ self.assertRaises(UnicodeEncodeError, MIMEText, teststr)
+
+
+
+# Test complicated multipart/* messages
+class TestMultipart(TestEmailBase):
+ def setUp(self):
+ fp = openfile('PyBanner048.gif')
+ try:
+ data = fp.read()
+ finally:
+ fp.close()
+
+ container = MIMEBase('multipart', 'mixed', boundary='BOUNDARY')
+ image = MIMEImage(data, name='dingusfish.gif')
+ image.add_header('content-disposition', 'attachment',
+ filename='dingusfish.gif')
+ intro = MIMEText('''\
+Hi there,
+
+This is the dingus fish.
+''')
+ container.attach(intro)
+ container.attach(image)
+ container['From'] = 'Barry <barry at digicool.com>'
+ container['To'] = 'Dingus Lovers <cravindogs at cravindogs.com>'
+ container['Subject'] = 'Here is your dingus fish'
+
+ now = 987809702.54848599
+ timetuple = time.localtime(now)
+ if timetuple[-1] == 0:
+ tzsecs = time.timezone
+ else:
+ tzsecs = time.altzone
+ if tzsecs > 0:
+ sign = '-'
+ else:
+ sign = '+'
+ tzoffset = ' %s%04d' % (sign, tzsecs // 36)
+ container['Date'] = time.strftime(
+ '%a, %d %b %Y %H:%M:%S',
+ time.localtime(now)) + tzoffset
+ self._msg = container
+ self._im = image
+ self._txt = intro
+
+ def test_hierarchy(self):
+ # convenience
+ eq = self.assertEqual
+ unless = self.assertTrue
+ raises = self.assertRaises
+ # tests
+ m = self._msg
+ unless(m.is_multipart())
+ eq(m.get_content_type(), 'multipart/mixed')
+ eq(len(m.get_payload()), 2)
+ raises(IndexError, m.get_payload, 2)
+ m0 = m.get_payload(0)
+ m1 = m.get_payload(1)
+ unless(m0 is self._txt)
+ unless(m1 is self._im)
+ eq(m.get_payload(), [m0, m1])
+ unless(not m0.is_multipart())
+ unless(not m1.is_multipart())
+
+ def test_empty_multipart_idempotent(self):
+ text = """\
+Content-Type: multipart/mixed; boundary="BOUNDARY"
+MIME-Version: 1.0
+Subject: A subject
+To: aperson at dom.ain
+From: bperson at dom.ain
+
+
+--BOUNDARY
+
+
+--BOUNDARY--
+"""
+ msg = Parser().parsestr(text)
+ self.ndiffAssertEqual(text, msg.as_string())
+
+ def test_no_parts_in_a_multipart_with_none_epilogue(self):
+ outer = MIMEBase('multipart', 'mixed')
+ outer['Subject'] = 'A subject'
+ outer['To'] = 'aperson at dom.ain'
+ outer['From'] = 'bperson at dom.ain'
+ outer.set_boundary('BOUNDARY')
+ self.ndiffAssertEqual(outer.as_string(), '''\
+Content-Type: multipart/mixed; boundary="BOUNDARY"
+MIME-Version: 1.0
+Subject: A subject
+To: aperson at dom.ain
+From: bperson at dom.ain
+
+--BOUNDARY
+
+--BOUNDARY--''')
+
+ def test_no_parts_in_a_multipart_with_empty_epilogue(self):
+ outer = MIMEBase('multipart', 'mixed')
+ outer['Subject'] = 'A subject'
+ outer['To'] = 'aperson at dom.ain'
+ outer['From'] = 'bperson at dom.ain'
+ outer.preamble = ''
+ outer.epilogue = ''
+ outer.set_boundary('BOUNDARY')
+ self.ndiffAssertEqual(outer.as_string(), '''\
+Content-Type: multipart/mixed; boundary="BOUNDARY"
+MIME-Version: 1.0
+Subject: A subject
+To: aperson at dom.ain
+From: bperson at dom.ain
+
+
+--BOUNDARY
+
+--BOUNDARY--
+''')
+
+ def test_one_part_in_a_multipart(self):
+ eq = self.ndiffAssertEqual
+ outer = MIMEBase('multipart', 'mixed')
+ outer['Subject'] = 'A subject'
+ outer['To'] = 'aperson at dom.ain'
+ outer['From'] = 'bperson at dom.ain'
+ outer.set_boundary('BOUNDARY')
+ msg = MIMEText('hello world')
+ outer.attach(msg)
+ eq(outer.as_string(), '''\
+Content-Type: multipart/mixed; boundary="BOUNDARY"
+MIME-Version: 1.0
+Subject: A subject
+To: aperson at dom.ain
+From: bperson at dom.ain
+
+--BOUNDARY
+Content-Type: text/plain; charset="us-ascii"
+MIME-Version: 1.0
+Content-Transfer-Encoding: 7bit
+
+hello world
+--BOUNDARY--''')
+
+ def test_seq_parts_in_a_multipart_with_empty_preamble(self):
+ eq = self.ndiffAssertEqual
+ outer = MIMEBase('multipart', 'mixed')
+ outer['Subject'] = 'A subject'
+ outer['To'] = 'aperson at dom.ain'
+ outer['From'] = 'bperson at dom.ain'
+ outer.preamble = ''
+ msg = MIMEText('hello world')
+ outer.attach(msg)
+ outer.set_boundary('BOUNDARY')
+ eq(outer.as_string(), '''\
+Content-Type: multipart/mixed; boundary="BOUNDARY"
+MIME-Version: 1.0
+Subject: A subject
+To: aperson at dom.ain
+From: bperson at dom.ain
+
+
+--BOUNDARY
+Content-Type: text/plain; charset="us-ascii"
+MIME-Version: 1.0
+Content-Transfer-Encoding: 7bit
+
+hello world
+--BOUNDARY--''')
+
+
+ def test_seq_parts_in_a_multipart_with_none_preamble(self):
+ eq = self.ndiffAssertEqual
+ outer = MIMEBase('multipart', 'mixed')
+ outer['Subject'] = 'A subject'
+ outer['To'] = 'aperson at dom.ain'
+ outer['From'] = 'bperson at dom.ain'
+ outer.preamble = None
+ msg = MIMEText('hello world')
+ outer.attach(msg)
+ outer.set_boundary('BOUNDARY')
+ eq(outer.as_string(), '''\
+Content-Type: multipart/mixed; boundary="BOUNDARY"
+MIME-Version: 1.0
+Subject: A subject
+To: aperson at dom.ain
+From: bperson at dom.ain
+
+--BOUNDARY
+Content-Type: text/plain; charset="us-ascii"
+MIME-Version: 1.0
+Content-Transfer-Encoding: 7bit
+
+hello world
+--BOUNDARY--''')
+
+
+ def test_seq_parts_in_a_multipart_with_none_epilogue(self):
+ eq = self.ndiffAssertEqual
+ outer = MIMEBase('multipart', 'mixed')
+ outer['Subject'] = 'A subject'
+ outer['To'] = 'aperson at dom.ain'
+ outer['From'] = 'bperson at dom.ain'
+ outer.epilogue = None
+ msg = MIMEText('hello world')
+ outer.attach(msg)
+ outer.set_boundary('BOUNDARY')
+ eq(outer.as_string(), '''\
+Content-Type: multipart/mixed; boundary="BOUNDARY"
+MIME-Version: 1.0
+Subject: A subject
+To: aperson at dom.ain
+From: bperson at dom.ain
+
+--BOUNDARY
+Content-Type: text/plain; charset="us-ascii"
+MIME-Version: 1.0
+Content-Transfer-Encoding: 7bit
+
+hello world
+--BOUNDARY--''')
+
+
+ def test_seq_parts_in_a_multipart_with_empty_epilogue(self):
+ eq = self.ndiffAssertEqual
+ outer = MIMEBase('multipart', 'mixed')
+ outer['Subject'] = 'A subject'
+ outer['To'] = 'aperson at dom.ain'
+ outer['From'] = 'bperson at dom.ain'
+ outer.epilogue = ''
+ msg = MIMEText('hello world')
+ outer.attach(msg)
+ outer.set_boundary('BOUNDARY')
+ eq(outer.as_string(), '''\
+Content-Type: multipart/mixed; boundary="BOUNDARY"
+MIME-Version: 1.0
+Subject: A subject
+To: aperson at dom.ain
+From: bperson at dom.ain
+
+--BOUNDARY
+Content-Type: text/plain; charset="us-ascii"
+MIME-Version: 1.0
+Content-Transfer-Encoding: 7bit
+
+hello world
+--BOUNDARY--
+''')
+
+
+ def test_seq_parts_in_a_multipart_with_nl_epilogue(self):
+ eq = self.ndiffAssertEqual
+ outer = MIMEBase('multipart', 'mixed')
+ outer['Subject'] = 'A subject'
+ outer['To'] = 'aperson at dom.ain'
+ outer['From'] = 'bperson at dom.ain'
+ outer.epilogue = '\n'
+ msg = MIMEText('hello world')
+ outer.attach(msg)
+ outer.set_boundary('BOUNDARY')
+ eq(outer.as_string(), '''\
+Content-Type: multipart/mixed; boundary="BOUNDARY"
+MIME-Version: 1.0
+Subject: A subject
+To: aperson at dom.ain
+From: bperson at dom.ain
+
+--BOUNDARY
+Content-Type: text/plain; charset="us-ascii"
+MIME-Version: 1.0
+Content-Transfer-Encoding: 7bit
+
+hello world
+--BOUNDARY--
+
+''')
+
+ def test_message_external_body(self):
+ eq = self.assertEqual
+ msg = self._msgobj('msg_36.txt')
+ eq(len(msg.get_payload()), 2)
+ msg1 = msg.get_payload(1)
+ eq(msg1.get_content_type(), 'multipart/alternative')
+ eq(len(msg1.get_payload()), 2)
+ for subpart in msg1.get_payload():
+ eq(subpart.get_content_type(), 'message/external-body')
+ eq(len(subpart.get_payload()), 1)
+ subsubpart = subpart.get_payload(0)
+ eq(subsubpart.get_content_type(), 'text/plain')
+
+ def test_double_boundary(self):
+ # msg_37.txt is a multipart that contains two dash-boundary's in a
+ # row. Our interpretation of RFC 2046 calls for ignoring the second
+ # and subsequent boundaries.
+ msg = self._msgobj('msg_37.txt')
+ self.assertEqual(len(msg.get_payload()), 3)
+
+ def test_nested_inner_contains_outer_boundary(self):
+ eq = self.ndiffAssertEqual
+ # msg_38.txt has an inner part that contains outer boundaries. My
+ # interpretation of RFC 2046 (based on sections 5.1 and 5.1.2) say
+ # these are illegal and should be interpreted as unterminated inner
+ # parts.
+ msg = self._msgobj('msg_38.txt')
+ sfp = StringIO()
+ Iterators._structure(msg, sfp)
+ eq(sfp.getvalue(), """\
+multipart/mixed
+ multipart/mixed
+ multipart/alternative
+ text/plain
+ text/plain
+ text/plain
+ text/plain
+""")
+
+ def test_nested_with_same_boundary(self):
+ eq = self.ndiffAssertEqual
+ # msg 39.txt is similarly evil in that it's got inner parts that use
+ # the same boundary as outer parts. Again, I believe the way this is
+ # parsed is closest to the spirit of RFC 2046
+ msg = self._msgobj('msg_39.txt')
+ sfp = StringIO()
+ Iterators._structure(msg, sfp)
+ eq(sfp.getvalue(), """\
+multipart/mixed
+ multipart/mixed
+ multipart/alternative
+ application/octet-stream
+ application/octet-stream
+ text/plain
+""")
+
+ def test_boundary_in_non_multipart(self):
+ msg = self._msgobj('msg_40.txt')
+ self.assertEqual(msg.as_string(), '''\
+MIME-Version: 1.0
+Content-Type: text/html; boundary="--961284236552522269"
+
+----961284236552522269
+Content-Type: text/html;
+Content-Transfer-Encoding: 7Bit
+
+<html></html>
+
+----961284236552522269--
+''')
+
+ def test_boundary_with_leading_space(self):
+ eq = self.assertEqual
+ msg = email.message_from_string('''\
+MIME-Version: 1.0
+Content-Type: multipart/mixed; boundary=" XXXX"
+
+-- XXXX
+Content-Type: text/plain
+
+
+-- XXXX
+Content-Type: text/plain
+
+-- XXXX--
+''')
+ self.assertTrue(msg.is_multipart())
+ eq(msg.get_boundary(), ' XXXX')
+ eq(len(msg.get_payload()), 2)
+
+ def test_boundary_without_trailing_newline(self):
+ m = Parser().parsestr("""\
+Content-Type: multipart/mixed; boundary="===============0012394164=="
+MIME-Version: 1.0
+
+--===============0012394164==
+Content-Type: image/file1.jpg
+MIME-Version: 1.0
+Content-Transfer-Encoding: base64
+
+YXNkZg==
+--===============0012394164==--""")
+ self.assertEqual(m.get_payload(0).get_payload(), 'YXNkZg==')
+
+
+
+# Test some badly formatted messages
+class TestNonConformant(TestEmailBase):
+ def test_parse_missing_minor_type(self):
+ eq = self.assertEqual
+ msg = self._msgobj('msg_14.txt')
+ eq(msg.get_content_type(), 'text/plain')
+ eq(msg.get_content_maintype(), 'text')
+ eq(msg.get_content_subtype(), 'plain')
+
+ def test_same_boundary_inner_outer(self):
+ unless = self.assertTrue
+ msg = self._msgobj('msg_15.txt')
+ # XXX We can probably eventually do better
+ inner = msg.get_payload(0)
+ unless(hasattr(inner, 'defects'))
+ self.assertEqual(len(inner.defects), 1)
+ unless(isinstance(inner.defects[0],
+ Errors.StartBoundaryNotFoundDefect))
+
+ def test_multipart_no_boundary(self):
+ unless = self.assertTrue
+ msg = self._msgobj('msg_25.txt')
+ unless(isinstance(msg.get_payload(), str))
+ self.assertEqual(len(msg.defects), 2)
+ unless(isinstance(msg.defects[0], Errors.NoBoundaryInMultipartDefect))
+ unless(isinstance(msg.defects[1],
+ Errors.MultipartInvariantViolationDefect))
+
+ def test_invalid_content_type(self):
+ eq = self.assertEqual
+ neq = self.ndiffAssertEqual
+ msg = Message()
+ # RFC 2045, $5.2 says invalid yields text/plain
+ msg['Content-Type'] = 'text'
+ eq(msg.get_content_maintype(), 'text')
+ eq(msg.get_content_subtype(), 'plain')
+ eq(msg.get_content_type(), 'text/plain')
+ # Clear the old value and try something /really/ invalid
+ del msg['content-type']
+ msg['Content-Type'] = 'foo'
+ eq(msg.get_content_maintype(), 'text')
+ eq(msg.get_content_subtype(), 'plain')
+ eq(msg.get_content_type(), 'text/plain')
+ # Still, make sure that the message is idempotently generated
+ s = StringIO()
+ g = Generator(s)
+ g.flatten(msg)
+ neq(s.getvalue(), 'Content-Type: foo\n\n')
+
+ def test_no_start_boundary(self):
+ eq = self.ndiffAssertEqual
+ msg = self._msgobj('msg_31.txt')
+ eq(msg.get_payload(), """\
+--BOUNDARY
+Content-Type: text/plain
+
+message 1
+
+--BOUNDARY
+Content-Type: text/plain
+
+message 2
+
+--BOUNDARY--
+""")
+
+ def test_no_separating_blank_line(self):
+ eq = self.ndiffAssertEqual
+ msg = self._msgobj('msg_35.txt')
+ eq(msg.as_string(), """\
+From: aperson at dom.ain
+To: bperson at dom.ain
+Subject: here's something interesting
+
+counter to RFC 2822, there's no separating newline here
+""")
+
+ def test_lying_multipart(self):
+ unless = self.assertTrue
+ msg = self._msgobj('msg_41.txt')
+ unless(hasattr(msg, 'defects'))
+ self.assertEqual(len(msg.defects), 2)
+ unless(isinstance(msg.defects[0], Errors.NoBoundaryInMultipartDefect))
+ unless(isinstance(msg.defects[1],
+ Errors.MultipartInvariantViolationDefect))
+
+ def test_missing_start_boundary(self):
+ outer = self._msgobj('msg_42.txt')
+ # The message structure is:
+ #
+ # multipart/mixed
+ # text/plain
+ # message/rfc822
+ # multipart/mixed [*]
+ #
+ # [*] This message is missing its start boundary
+ bad = outer.get_payload(1).get_payload(0)
+ self.assertEqual(len(bad.defects), 1)
+ self.assertTrue(isinstance(bad.defects[0],
+ Errors.StartBoundaryNotFoundDefect))
+
+ def test_first_line_is_continuation_header(self):
+ eq = self.assertEqual
+ m = ' Line 1\nLine 2\nLine 3'
+ msg = email.message_from_string(m)
+ eq(msg.keys(), [])
+ eq(msg.get_payload(), 'Line 2\nLine 3')
+ eq(len(msg.defects), 1)
+ self.assertTrue(isinstance(msg.defects[0],
+ Errors.FirstHeaderLineIsContinuationDefect))
+ eq(msg.defects[0].line, ' Line 1\n')
+
+
+
+
+# Test RFC 2047 header encoding and decoding
+class TestRFC2047(unittest.TestCase):
+ def test_rfc2047_multiline(self):
+ eq = self.assertEqual
+ s = """Re: =?mac-iceland?q?r=8Aksm=9Arg=8Cs?= baz
+ foo bar =?mac-iceland?q?r=8Aksm=9Arg=8Cs?="""
+ dh = decode_header(s)
+ eq(dh, [
+ ('Re:', None),
+ ('r\x8aksm\x9arg\x8cs', 'mac-iceland'),
+ ('baz foo bar', None),
+ ('r\x8aksm\x9arg\x8cs', 'mac-iceland')])
+ eq(str(make_header(dh)),
+ """Re: =?mac-iceland?q?r=8Aksm=9Arg=8Cs?= baz foo bar
+ =?mac-iceland?q?r=8Aksm=9Arg=8Cs?=""")
+
+ def test_whitespace_eater_unicode(self):
+ eq = self.assertEqual
+ s = '=?ISO-8859-1?Q?Andr=E9?= Pirard <pirard at dom.ain>'
+ dh = decode_header(s)
+ eq(dh, [('Andr\xe9', 'iso-8859-1'), ('Pirard <pirard at dom.ain>', None)])
+ hu = unicode(make_header(dh)).encode('latin-1')
+ eq(hu, 'Andr\xe9 Pirard <pirard at dom.ain>')
+
+ def test_whitespace_eater_unicode_2(self):
+ eq = self.assertEqual
+ s = 'The =?iso-8859-1?b?cXVpY2sgYnJvd24gZm94?= jumped over the =?iso-8859-1?b?bGF6eSBkb2c=?='
+ dh = decode_header(s)
+ eq(dh, [('The', None), ('quick brown fox', 'iso-8859-1'),
+ ('jumped over the', None), ('lazy dog', 'iso-8859-1')])
+ hu = make_header(dh).__unicode__()
+ eq(hu, u'The quick brown fox jumped over the lazy dog')
+
+ def test_rfc2047_without_whitespace(self):
+ s = 'Sm=?ISO-8859-1?B?9g==?=rg=?ISO-8859-1?B?5Q==?=sbord'
+ dh = decode_header(s)
+ self.assertEqual(dh, [(s, None)])
+
+ def test_rfc2047_with_whitespace(self):
+ s = 'Sm =?ISO-8859-1?B?9g==?= rg =?ISO-8859-1?B?5Q==?= sbord'
+ dh = decode_header(s)
+ self.assertEqual(dh, [('Sm', None), ('\xf6', 'iso-8859-1'),
+ ('rg', None), ('\xe5', 'iso-8859-1'),
+ ('sbord', None)])
+
+ def test_rfc2047_B_bad_padding(self):
+ s = '=?iso-8859-1?B?%s?='
+ data = [ # only test complete bytes
+ ('dm==', 'v'), ('dm=', 'v'), ('dm', 'v'),
+ ('dmk=', 'vi'), ('dmk', 'vi')
+ ]
+ for q, a in data:
+ dh = decode_header(s % q)
+ self.assertEqual(dh, [(a, 'iso-8859-1')])
+
+ def test_rfc2047_Q_invalid_digits(self):
+ # issue 10004.
+ s = '=?iso-8659-1?Q?andr=e9=zz?='
+ self.assertEqual(decode_header(s),
+ [(b'andr\xe9=zz', 'iso-8659-1')])
+
+
+# Test the MIMEMessage class
+class TestMIMEMessage(TestEmailBase):
+ def setUp(self):
+ fp = openfile('msg_11.txt')
+ try:
+ self._text = fp.read()
+ finally:
+ fp.close()
+
+ def test_type_error(self):
+ self.assertRaises(TypeError, MIMEMessage, 'a plain string')
+
+ def test_valid_argument(self):
+ eq = self.assertEqual
+ unless = self.assertTrue
+ subject = 'A sub-message'
+ m = Message()
+ m['Subject'] = subject
+ r = MIMEMessage(m)
+ eq(r.get_content_type(), 'message/rfc822')
+ payload = r.get_payload()
+ unless(isinstance(payload, list))
+ eq(len(payload), 1)
+ subpart = payload[0]
+ unless(subpart is m)
+ eq(subpart['subject'], subject)
+
+ def test_bad_multipart(self):
+ eq = self.assertEqual
+ msg1 = Message()
+ msg1['Subject'] = 'subpart 1'
+ msg2 = Message()
+ msg2['Subject'] = 'subpart 2'
+ r = MIMEMessage(msg1)
+ self.assertRaises(Errors.MultipartConversionError, r.attach, msg2)
+
+ def test_generate(self):
+ # First craft the message to be encapsulated
+ m = Message()
+ m['Subject'] = 'An enclosed message'
+ m.set_payload('Here is the body of the message.\n')
+ r = MIMEMessage(m)
+ r['Subject'] = 'The enclosing message'
+ s = StringIO()
+ g = Generator(s)
+ g.flatten(r)
+ self.assertEqual(s.getvalue(), """\
+Content-Type: message/rfc822
+MIME-Version: 1.0
+Subject: The enclosing message
+
+Subject: An enclosed message
+
+Here is the body of the message.
+""")
+
+ def test_parse_message_rfc822(self):
+ eq = self.assertEqual
+ unless = self.assertTrue
+ msg = self._msgobj('msg_11.txt')
+ eq(msg.get_content_type(), 'message/rfc822')
+ payload = msg.get_payload()
+ unless(isinstance(payload, list))
+ eq(len(payload), 1)
+ submsg = payload[0]
+ self.assertTrue(isinstance(submsg, Message))
+ eq(submsg['subject'], 'An enclosed message')
+ eq(submsg.get_payload(), 'Here is the body of the message.\n')
+
+ def test_dsn(self):
+ eq = self.assertEqual
+ unless = self.assertTrue
+ # msg 16 is a Delivery Status Notification, see RFC 1894
+ msg = self._msgobj('msg_16.txt')
+ eq(msg.get_content_type(), 'multipart/report')
+ unless(msg.is_multipart())
+ eq(len(msg.get_payload()), 3)
+ # Subpart 1 is a text/plain, human readable section
+ subpart = msg.get_payload(0)
+ eq(subpart.get_content_type(), 'text/plain')
+ eq(subpart.get_payload(), """\
+This report relates to a message you sent with the following header fields:
+
+ Message-id: <002001c144a6$8752e060$56104586 at oxy.edu>
+ Date: Sun, 23 Sep 2001 20:10:55 -0700
+ From: "Ian T. Henry" <henryi at oxy.edu>
+ To: SoCal Raves <scr at socal-raves.org>
+ Subject: [scr] yeah for Ians!!
+
+Your message cannot be delivered to the following recipients:
+
+ Recipient address: jangel1 at cougar.noc.ucla.edu
+ Reason: recipient reached disk quota
+
+""")
+ # Subpart 2 contains the machine parsable DSN information. It
+ # consists of two blocks of headers, represented by two nested Message
+ # objects.
+ subpart = msg.get_payload(1)
+ eq(subpart.get_content_type(), 'message/delivery-status')
+ eq(len(subpart.get_payload()), 2)
+ # message/delivery-status should treat each block as a bunch of
+ # headers, i.e. a bunch of Message objects.
+ dsn1 = subpart.get_payload(0)
+ unless(isinstance(dsn1, Message))
+ eq(dsn1['original-envelope-id'], '0GK500B4HD0888 at cougar.noc.ucla.edu')
+ eq(dsn1.get_param('dns', header='reporting-mta'), '')
+ # Try a missing one <wink>
+ eq(dsn1.get_param('nsd', header='reporting-mta'), None)
+ dsn2 = subpart.get_payload(1)
+ unless(isinstance(dsn2, Message))
+ eq(dsn2['action'], 'failed')
+ eq(dsn2.get_params(header='original-recipient'),
+ [('rfc822', ''), ('jangel1 at cougar.noc.ucla.edu', '')])
+ eq(dsn2.get_param('rfc822', header='final-recipient'), '')
+ # Subpart 3 is the original message
+ subpart = msg.get_payload(2)
+ eq(subpart.get_content_type(), 'message/rfc822')
+ payload = subpart.get_payload()
+ unless(isinstance(payload, list))
+ eq(len(payload), 1)
+ subsubpart = payload[0]
+ unless(isinstance(subsubpart, Message))
+ eq(subsubpart.get_content_type(), 'text/plain')
+ eq(subsubpart['message-id'],
+ '<002001c144a6$8752e060$56104586 at oxy.edu>')
+
+ def test_epilogue(self):
+ eq = self.ndiffAssertEqual
+ fp = openfile('msg_21.txt')
+ try:
+ text = fp.read()
+ finally:
+ fp.close()
+ msg = Message()
+ msg['From'] = 'aperson at dom.ain'
+ msg['To'] = 'bperson at dom.ain'
+ msg['Subject'] = 'Test'
+ msg.preamble = 'MIME message'
+ msg.epilogue = 'End of MIME message\n'
+ msg1 = MIMEText('One')
+ msg2 = MIMEText('Two')
+ msg.add_header('Content-Type', 'multipart/mixed', boundary='BOUNDARY')
+ msg.attach(msg1)
+ msg.attach(msg2)
+ sfp = StringIO()
+ g = Generator(sfp)
+ g.flatten(msg)
+ eq(sfp.getvalue(), text)
+
+ def test_no_nl_preamble(self):
+ eq = self.ndiffAssertEqual
+ msg = Message()
+ msg['From'] = 'aperson at dom.ain'
+ msg['To'] = 'bperson at dom.ain'
+ msg['Subject'] = 'Test'
+ msg.preamble = 'MIME message'
+ msg.epilogue = ''
+ msg1 = MIMEText('One')
+ msg2 = MIMEText('Two')
+ msg.add_header('Content-Type', 'multipart/mixed', boundary='BOUNDARY')
+ msg.attach(msg1)
+ msg.attach(msg2)
+ eq(msg.as_string(), """\
+From: aperson at dom.ain
+To: bperson at dom.ain
+Subject: Test
+Content-Type: multipart/mixed; boundary="BOUNDARY"
+
+MIME message
+--BOUNDARY
+Content-Type: text/plain; charset="us-ascii"
+MIME-Version: 1.0
+Content-Transfer-Encoding: 7bit
+
+One
+--BOUNDARY
+Content-Type: text/plain; charset="us-ascii"
+MIME-Version: 1.0
+Content-Transfer-Encoding: 7bit
+
+Two
+--BOUNDARY--
+""")
+
+ def test_default_type(self):
+ eq = self.assertEqual
+ fp = openfile('msg_30.txt')
+ try:
+ msg = email.message_from_file(fp)
+ finally:
+ fp.close()
+ container1 = msg.get_payload(0)
+ eq(container1.get_default_type(), 'message/rfc822')
+ eq(container1.get_content_type(), 'message/rfc822')
+ container2 = msg.get_payload(1)
+ eq(container2.get_default_type(), 'message/rfc822')
+ eq(container2.get_content_type(), 'message/rfc822')
+ container1a = container1.get_payload(0)
+ eq(container1a.get_default_type(), 'text/plain')
+ eq(container1a.get_content_type(), 'text/plain')
+ container2a = container2.get_payload(0)
+ eq(container2a.get_default_type(), 'text/plain')
+ eq(container2a.get_content_type(), 'text/plain')
+
+ def test_default_type_with_explicit_container_type(self):
+ eq = self.assertEqual
+ fp = openfile('msg_28.txt')
+ try:
+ msg = email.message_from_file(fp)
+ finally:
+ fp.close()
+ container1 = msg.get_payload(0)
+ eq(container1.get_default_type(), 'message/rfc822')
+ eq(container1.get_content_type(), 'message/rfc822')
+ container2 = msg.get_payload(1)
+ eq(container2.get_default_type(), 'message/rfc822')
+ eq(container2.get_content_type(), 'message/rfc822')
+ container1a = container1.get_payload(0)
+ eq(container1a.get_default_type(), 'text/plain')
+ eq(container1a.get_content_type(), 'text/plain')
+ container2a = container2.get_payload(0)
+ eq(container2a.get_default_type(), 'text/plain')
+ eq(container2a.get_content_type(), 'text/plain')
+
+ def test_default_type_non_parsed(self):
+ eq = self.assertEqual
+ neq = self.ndiffAssertEqual
+ # Set up container
+ container = MIMEMultipart('digest', 'BOUNDARY')
+ container.epilogue = ''
+ # Set up subparts
+ subpart1a = MIMEText('message 1\n')
+ subpart2a = MIMEText('message 2\n')
+ subpart1 = MIMEMessage(subpart1a)
+ subpart2 = MIMEMessage(subpart2a)
+ container.attach(subpart1)
+ container.attach(subpart2)
+ eq(subpart1.get_content_type(), 'message/rfc822')
+ eq(subpart1.get_default_type(), 'message/rfc822')
+ eq(subpart2.get_content_type(), 'message/rfc822')
+ eq(subpart2.get_default_type(), 'message/rfc822')
+ neq(container.as_string(0), '''\
+Content-Type: multipart/digest; boundary="BOUNDARY"
+MIME-Version: 1.0
+
+--BOUNDARY
+Content-Type: message/rfc822
+MIME-Version: 1.0
+
+Content-Type: text/plain; charset="us-ascii"
+MIME-Version: 1.0
+Content-Transfer-Encoding: 7bit
+
+message 1
+
+--BOUNDARY
+Content-Type: message/rfc822
+MIME-Version: 1.0
+
+Content-Type: text/plain; charset="us-ascii"
+MIME-Version: 1.0
+Content-Transfer-Encoding: 7bit
+
+message 2
+
+--BOUNDARY--
+''')
+ del subpart1['content-type']
+ del subpart1['mime-version']
+ del subpart2['content-type']
+ del subpart2['mime-version']
+ eq(subpart1.get_content_type(), 'message/rfc822')
+ eq(subpart1.get_default_type(), 'message/rfc822')
+ eq(subpart2.get_content_type(), 'message/rfc822')
+ eq(subpart2.get_default_type(), 'message/rfc822')
+ neq(container.as_string(0), '''\
+Content-Type: multipart/digest; boundary="BOUNDARY"
+MIME-Version: 1.0
+
+--BOUNDARY
+
+Content-Type: text/plain; charset="us-ascii"
+MIME-Version: 1.0
+Content-Transfer-Encoding: 7bit
+
+message 1
+
+--BOUNDARY
+
+Content-Type: text/plain; charset="us-ascii"
+MIME-Version: 1.0
+Content-Transfer-Encoding: 7bit
+
+message 2
+
+--BOUNDARY--
+''')
+
+ def test_mime_attachments_in_constructor(self):
+ eq = self.assertEqual
+ text1 = MIMEText('')
+ text2 = MIMEText('')
+ msg = MIMEMultipart(_subparts=(text1, text2))
+ eq(len(msg.get_payload()), 2)
+ eq(msg.get_payload(0), text1)
+ eq(msg.get_payload(1), text2)
+
+ def test_default_multipart_constructor(self):
+ msg = MIMEMultipart()
+ self.assertTrue(msg.is_multipart())
+
+
+# A general test of parser->model->generator idempotency. IOW, read a message
+# in, parse it into a message object tree, then without touching the tree,
+# regenerate the plain text. The original text and the transformed text
+# should be identical. Note: that we ignore the Unix-From since that may
+# contain a changed date.
+class TestIdempotent(TestEmailBase):
+ def _msgobj(self, filename):
+ fp = openfile(filename)
+ try:
+ data = fp.read()
+ finally:
+ fp.close()
+ msg = email.message_from_string(data)
+ return msg, data
+
+ def _idempotent(self, msg, text):
+ eq = self.ndiffAssertEqual
+ s = StringIO()
+ g = Generator(s, maxheaderlen=0)
+ g.flatten(msg)
+ eq(text, s.getvalue())
+
+ def test_parse_text_message(self):
+ eq = self.assertEqual
+ msg, text = self._msgobj('msg_01.txt')
+ eq(msg.get_content_type(), 'text/plain')
+ eq(msg.get_content_maintype(), 'text')
+ eq(msg.get_content_subtype(), 'plain')
+ eq(msg.get_params()[1], ('charset', 'us-ascii'))
+ eq(msg.get_param('charset'), 'us-ascii')
+ eq(msg.preamble, None)
+ eq(msg.epilogue, None)
+ self._idempotent(msg, text)
+
+ def test_parse_untyped_message(self):
+ eq = self.assertEqual
+ msg, text = self._msgobj('msg_03.txt')
+ eq(msg.get_content_type(), 'text/plain')
+ eq(msg.get_params(), None)
+ eq(msg.get_param('charset'), None)
+ self._idempotent(msg, text)
+
+ def test_simple_multipart(self):
+ msg, text = self._msgobj('msg_04.txt')
+ self._idempotent(msg, text)
+
+ def test_MIME_digest(self):
+ msg, text = self._msgobj('msg_02.txt')
+ self._idempotent(msg, text)
+
+ def test_long_header(self):
+ msg, text = self._msgobj('msg_27.txt')
+ self._idempotent(msg, text)
+
+ def test_MIME_digest_with_part_headers(self):
+ msg, text = self._msgobj('msg_28.txt')
+ self._idempotent(msg, text)
+
+ def test_mixed_with_image(self):
+ msg, text = self._msgobj('msg_06.txt')
+ self._idempotent(msg, text)
+
+ def test_multipart_report(self):
+ msg, text = self._msgobj('msg_05.txt')
+ self._idempotent(msg, text)
+
+ def test_dsn(self):
+ msg, text = self._msgobj('msg_16.txt')
+ self._idempotent(msg, text)
+
+ def test_preamble_epilogue(self):
+ msg, text = self._msgobj('msg_21.txt')
+ self._idempotent(msg, text)
+
+ def test_multipart_one_part(self):
+ msg, text = self._msgobj('msg_23.txt')
+ self._idempotent(msg, text)
+
+ def test_multipart_no_parts(self):
+ msg, text = self._msgobj('msg_24.txt')
+ self._idempotent(msg, text)
+
+ def test_no_start_boundary(self):
+ msg, text = self._msgobj('msg_31.txt')
+ self._idempotent(msg, text)
+
+ def test_rfc2231_charset(self):
+ msg, text = self._msgobj('msg_32.txt')
+ self._idempotent(msg, text)
+
+ def test_more_rfc2231_parameters(self):
+ msg, text = self._msgobj('msg_33.txt')
+ self._idempotent(msg, text)
+
+ def test_text_plain_in_a_multipart_digest(self):
+ msg, text = self._msgobj('msg_34.txt')
+ self._idempotent(msg, text)
+
+ def test_nested_multipart_mixeds(self):
+ msg, text = self._msgobj('msg_12a.txt')
+ self._idempotent(msg, text)
+
+ def test_message_external_body_idempotent(self):
+ msg, text = self._msgobj('msg_36.txt')
+ self._idempotent(msg, text)
+
+ def test_content_type(self):
+ eq = self.assertEqual
+ unless = self.assertTrue
+ # Get a message object and reset the seek pointer for other tests
+ msg, text = self._msgobj('msg_05.txt')
+ eq(msg.get_content_type(), 'multipart/report')
+ # Test the Content-Type: parameters
+ params = {}
+ for pk, pv in msg.get_params():
+ params[pk] = pv
+ eq(params['report-type'], 'delivery-status')
+ eq(params['boundary'], 'D1690A7AC1.996856090/mail.example.com')
+ eq(msg.preamble, 'This is a MIME-encapsulated message.\n')
+ eq(msg.epilogue, '\n')
+ eq(len(msg.get_payload()), 3)
+ # Make sure the subparts are what we expect
+ msg1 = msg.get_payload(0)
+ eq(msg1.get_content_type(), 'text/plain')
+ eq(msg1.get_payload(), 'Yadda yadda yadda\n')
+ msg2 = msg.get_payload(1)
+ eq(msg2.get_content_type(), 'text/plain')
+ eq(msg2.get_payload(), 'Yadda yadda yadda\n')
+ msg3 = msg.get_payload(2)
+ eq(msg3.get_content_type(), 'message/rfc822')
+ self.assertTrue(isinstance(msg3, Message))
+ payload = msg3.get_payload()
+ unless(isinstance(payload, list))
+ eq(len(payload), 1)
+ msg4 = payload[0]
+ unless(isinstance(msg4, Message))
+ eq(msg4.get_payload(), 'Yadda yadda yadda\n')
+
+ def test_parser(self):
+ eq = self.assertEqual
+ unless = self.assertTrue
+ msg, text = self._msgobj('msg_06.txt')
+ # Check some of the outer headers
+ eq(msg.get_content_type(), 'message/rfc822')
+ # Make sure the payload is a list of exactly one sub-Message, and that
+ # that submessage has a type of text/plain
+ payload = msg.get_payload()
+ unless(isinstance(payload, list))
+ eq(len(payload), 1)
+ msg1 = payload[0]
+ self.assertTrue(isinstance(msg1, Message))
+ eq(msg1.get_content_type(), 'text/plain')
+ self.assertTrue(isinstance(msg1.get_payload(), str))
+ eq(msg1.get_payload(), '\n')
+
+
+
+# Test various other bits of the package's functionality
+class TestMiscellaneous(TestEmailBase):
+ def test_message_from_string(self):
+ fp = openfile('msg_01.txt')
+ try:
+ text = fp.read()
+ finally:
+ fp.close()
+ msg = email.message_from_string(text)
+ s = StringIO()
+ # Don't wrap/continue long headers since we're trying to test
+ # idempotency.
+ g = Generator(s, maxheaderlen=0)
+ g.flatten(msg)
+ self.assertEqual(text, s.getvalue())
+
+ def test_message_from_file(self):
+ fp = openfile('msg_01.txt')
+ try:
+ text = fp.read()
+ fp.seek(0)
+ msg = email.message_from_file(fp)
+ s = StringIO()
+ # Don't wrap/continue long headers since we're trying to test
+ # idempotency.
+ g = Generator(s, maxheaderlen=0)
+ g.flatten(msg)
+ self.assertEqual(text, s.getvalue())
+ finally:
+ fp.close()
+
+ def test_message_from_string_with_class(self):
+ unless = self.assertTrue
+ fp = openfile('msg_01.txt')
+ try:
+ text = fp.read()
+ finally:
+ fp.close()
+ # Create a subclass
+ class MyMessage(Message):
+ pass
+
+ msg = email.message_from_string(text, MyMessage)
+ unless(isinstance(msg, MyMessage))
+ # Try something more complicated
+ fp = openfile('msg_02.txt')
+ try:
+ text = fp.read()
+ finally:
+ fp.close()
+ msg = email.message_from_string(text, MyMessage)
+ for subpart in msg.walk():
+ unless(isinstance(subpart, MyMessage))
+
+ def test_message_from_file_with_class(self):
+ unless = self.assertTrue
+ # Create a subclass
+ class MyMessage(Message):
+ pass
+
+ fp = openfile('msg_01.txt')
+ try:
+ msg = email.message_from_file(fp, MyMessage)
+ finally:
+ fp.close()
+ unless(isinstance(msg, MyMessage))
+ # Try something more complicated
+ fp = openfile('msg_02.txt')
+ try:
+ msg = email.message_from_file(fp, MyMessage)
+ finally:
+ fp.close()
+ for subpart in msg.walk():
+ unless(isinstance(subpart, MyMessage))
+
+ def test__all__(self):
+ module = __import__('email')
+ all = module.__all__
+ all.sort()
+ self.assertEqual(all, [
+ # Old names
+ 'Charset', 'Encoders', 'Errors', 'Generator',
+ 'Header', 'Iterators', 'MIMEAudio', 'MIMEBase',
+ 'MIMEImage', 'MIMEMessage', 'MIMEMultipart',
+ 'MIMENonMultipart', 'MIMEText', 'Message',
+ 'Parser', 'Utils', 'base64MIME',
+ # new names
+ 'base64mime', 'charset', 'encoders', 'errors', 'generator',
+ 'header', 'iterators', 'message', 'message_from_file',
+ 'message_from_string', 'mime', 'parser',
+ 'quopriMIME', 'quoprimime', 'utils',
+ ])
+
+ def test_formatdate(self):
+ now = time.time()
+ self.assertEqual(Utils.parsedate(Utils.formatdate(now))[:6],
+ time.gmtime(now)[:6])
+
+ def test_formatdate_localtime(self):
+ now = time.time()
+ self.assertEqual(
+ Utils.parsedate(Utils.formatdate(now, localtime=True))[:6],
+ time.localtime(now)[:6])
+
+ def test_formatdate_usegmt(self):
+ now = time.time()
+ self.assertEqual(
+ Utils.formatdate(now, localtime=False),
+ time.strftime('%a, %d %b %Y %H:%M:%S -0000', time.gmtime(now)))
+ self.assertEqual(
+ Utils.formatdate(now, localtime=False, usegmt=True),
+ time.strftime('%a, %d %b %Y %H:%M:%S GMT', time.gmtime(now)))
+
+ def test_parsedate_none(self):
+ self.assertEqual(Utils.parsedate(''), None)
+
+ def test_parsedate_compact(self):
+ # The FWS after the comma is optional
+ self.assertEqual(Utils.parsedate('Wed,3 Apr 2002 14:58:26 +0800'),
+ Utils.parsedate('Wed, 3 Apr 2002 14:58:26 +0800'))
+
+ def test_parsedate_no_dayofweek(self):
+ eq = self.assertEqual
+ eq(Utils.parsedate_tz('25 Feb 2003 13:47:26 -0800'),
+ (2003, 2, 25, 13, 47, 26, 0, 1, -1, -28800))
+
+ def test_parsedate_compact_no_dayofweek(self):
+ eq = self.assertEqual
+ eq(Utils.parsedate_tz('5 Feb 2003 13:47:26 -0800'),
+ (2003, 2, 5, 13, 47, 26, 0, 1, -1, -28800))
+
+ def test_parsedate_acceptable_to_time_functions(self):
+ eq = self.assertEqual
+ timetup = Utils.parsedate('5 Feb 2003 13:47:26 -0800')
+ t = int(time.mktime(timetup))
+ eq(time.localtime(t)[:6], timetup[:6])
+ eq(int(time.strftime('%Y', timetup)), 2003)
+ timetup = Utils.parsedate_tz('5 Feb 2003 13:47:26 -0800')
+ t = int(time.mktime(timetup[:9]))
+ eq(time.localtime(t)[:6], timetup[:6])
+ eq(int(time.strftime('%Y', timetup[:9])), 2003)
+
+ def test_mktime_tz(self):
+ self.assertEqual(Utils.mktime_tz((1970, 1, 1, 0, 0, 0,
+ -1, -1, -1, 0)), 0)
+ self.assertEqual(Utils.mktime_tz((1970, 1, 1, 0, 0, 0,
+ -1, -1, -1, 1234)), -1234)
+
+ def test_parsedate_y2k(self):
+ """Test for parsing a date with a two-digit year.
+
+ Parsing a date with a two-digit year should return the correct
+ four-digit year. RFC822 allows two-digit years, but RFC2822 (which
+ obsoletes RFC822) requires four-digit years.
+
+ """
+ self.assertEqual(Utils.parsedate_tz('25 Feb 03 13:47:26 -0800'),
+ Utils.parsedate_tz('25 Feb 2003 13:47:26 -0800'))
+ self.assertEqual(Utils.parsedate_tz('25 Feb 71 13:47:26 -0800'),
+ Utils.parsedate_tz('25 Feb 1971 13:47:26 -0800'))
+
+ def test_parseaddr_empty(self):
+ self.assertEqual(Utils.parseaddr('<>'), ('', ''))
+ self.assertEqual(Utils.formataddr(Utils.parseaddr('<>')), '')
+
+ def test_noquote_dump(self):
+ self.assertEqual(
+ Utils.formataddr(('A Silly Person', 'person at dom.ain')),
+ 'A Silly Person <person at dom.ain>')
+
+ def test_escape_dump(self):
+ self.assertEqual(
+ Utils.formataddr(('A (Very) Silly Person', 'person at dom.ain')),
+ r'"A \(Very\) Silly Person" <person at dom.ain>')
+ a = r'A \(Special\) Person'
+ b = 'person at dom.ain'
+ self.assertEqual(Utils.parseaddr(Utils.formataddr((a, b))), (a, b))
+
+ def test_escape_backslashes(self):
+ self.assertEqual(
+ Utils.formataddr(('Arthur \Backslash\ Foobar', 'person at dom.ain')),
+ r'"Arthur \\Backslash\\ Foobar" <person at dom.ain>')
+ a = r'Arthur \Backslash\ Foobar'
+ b = 'person at dom.ain'
+ self.assertEqual(Utils.parseaddr(Utils.formataddr((a, b))), (a, b))
+
+ def test_name_with_dot(self):
+ x = 'John X. Doe <jxd at example.com>'
+ y = '"John X. Doe" <jxd at example.com>'
+ a, b = ('John X. Doe', 'jxd at example.com')
+ self.assertEqual(Utils.parseaddr(x), (a, b))
+ self.assertEqual(Utils.parseaddr(y), (a, b))
+ # formataddr() quotes the name if there's a dot in it
+ self.assertEqual(Utils.formataddr((a, b)), y)
+
+ def test_parseaddr_preserves_quoted_pairs_in_addresses(self):
+ # issue 10005. Note that in the third test the second pair of
+ # backslashes is not actually a quoted pair because it is not inside a
+ # comment or quoted string: the address being parsed has a quoted
+ # string containing a quoted backslash, followed by 'example' and two
+ # backslashes, followed by another quoted string containing a space and
+ # the word 'example'. parseaddr copies those two backslashes
+ # literally. Per rfc5322 this is not technically correct since a \ may
+ # not appear in an address outside of a quoted string. It is probably
+ # a sensible Postel interpretation, though.
+ eq = self.assertEqual
+ eq(Utils.parseaddr('""example" example"@example.com'),
+ ('', '""example" example"@example.com'))
+ eq(Utils.parseaddr('"\\"example\\" example"@example.com'),
+ ('', '"\\"example\\" example"@example.com'))
+ eq(Utils.parseaddr('"\\\\"example\\\\" example"@example.com'),
+ ('', '"\\\\"example\\\\" example"@example.com'))
+
+ def test_multiline_from_comment(self):
+ x = """\
+Foo
+\tBar <foo at example.com>"""
+ self.assertEqual(Utils.parseaddr(x), ('Foo Bar', 'foo at example.com'))
+
+ def test_quote_dump(self):
+ self.assertEqual(
+ Utils.formataddr(('A Silly; Person', 'person at dom.ain')),
+ r'"A Silly; Person" <person at dom.ain>')
+
+ def test_fix_eols(self):
+ eq = self.assertEqual
+ eq(Utils.fix_eols('hello'), 'hello')
+ eq(Utils.fix_eols('hello\n'), 'hello\r\n')
+ eq(Utils.fix_eols('hello\r'), 'hello\r\n')
+ eq(Utils.fix_eols('hello\r\n'), 'hello\r\n')
+ eq(Utils.fix_eols('hello\n\r'), 'hello\r\n\r\n')
+
+ def test_charset_richcomparisons(self):
+ eq = self.assertEqual
+ ne = self.assertNotEqual
+ cset1 = Charset()
+ cset2 = Charset()
+ eq(cset1, 'us-ascii')
+ eq(cset1, 'US-ASCII')
+ eq(cset1, 'Us-AsCiI')
+ eq('us-ascii', cset1)
+ eq('US-ASCII', cset1)
+ eq('Us-AsCiI', cset1)
+ ne(cset1, 'usascii')
+ ne(cset1, 'USASCII')
+ ne(cset1, 'UsAsCiI')
+ ne('usascii', cset1)
+ ne('USASCII', cset1)
+ ne('UsAsCiI', cset1)
+ eq(cset1, cset2)
+ eq(cset2, cset1)
+
+ def test_getaddresses(self):
+ eq = self.assertEqual
+ eq(Utils.getaddresses(['aperson at dom.ain (Al Person)',
+ 'Bud Person <bperson at dom.ain>']),
+ [('Al Person', 'aperson at dom.ain'),
+ ('Bud Person', 'bperson at dom.ain')])
+
+ def test_getaddresses_nasty(self):
+ eq = self.assertEqual
+ eq(Utils.getaddresses(['foo: ;']), [('', '')])
+ eq(Utils.getaddresses(
+ ['[]*-- =~$']),
+ [('', ''), ('', ''), ('', '*--')])
+ eq(Utils.getaddresses(
+ ['foo: ;', '"Jason R. Mastaler" <jason at dom.ain>']),
+ [('', ''), ('Jason R. Mastaler', 'jason at dom.ain')])
+
+ def test_getaddresses_embedded_comment(self):
+ """Test proper handling of a nested comment"""
+ eq = self.assertEqual
+ addrs = Utils.getaddresses(['User ((nested comment)) <foo at bar.com>'])
+ eq(addrs[0][1], 'foo at bar.com')
+
+ def test_utils_quote_unquote(self):
+ eq = self.assertEqual
+ msg = Message()
+ msg.add_header('content-disposition', 'attachment',
+ filename='foo\\wacky"name')
+ eq(msg.get_filename(), 'foo\\wacky"name')
+
+ def test_get_body_encoding_with_bogus_charset(self):
+ charset = Charset('not a charset')
+ self.assertEqual(charset.get_body_encoding(), 'base64')
+
+ def test_get_body_encoding_with_uppercase_charset(self):
+ eq = self.assertEqual
+ msg = Message()
+ msg['Content-Type'] = 'text/plain; charset=UTF-8'
+ eq(msg['content-type'], 'text/plain; charset=UTF-8')
+ charsets = msg.get_charsets()
+ eq(len(charsets), 1)
+ eq(charsets[0], 'utf-8')
+ charset = Charset(charsets[0])
+ eq(charset.get_body_encoding(), 'base64')
+ msg.set_payload('hello world', charset=charset)
+ eq(msg.get_payload(), 'aGVsbG8gd29ybGQ=\n')
+ eq(msg.get_payload(decode=True), 'hello world')
+ eq(msg['content-transfer-encoding'], 'base64')
+ # Try another one
+ msg = Message()
+ msg['Content-Type'] = 'text/plain; charset="US-ASCII"'
+ charsets = msg.get_charsets()
+ eq(len(charsets), 1)
+ eq(charsets[0], 'us-ascii')
+ charset = Charset(charsets[0])
+ eq(charset.get_body_encoding(), Encoders.encode_7or8bit)
+ msg.set_payload('hello world', charset=charset)
+ eq(msg.get_payload(), 'hello world')
+ eq(msg['content-transfer-encoding'], '7bit')
+
+ def test_charsets_case_insensitive(self):
+ lc = Charset('us-ascii')
+ uc = Charset('US-ASCII')
+ self.assertEqual(lc.get_body_encoding(), uc.get_body_encoding())
+
+ def test_partial_falls_inside_message_delivery_status(self):
+ eq = self.ndiffAssertEqual
+ # The Parser interface provides chunks of data to FeedParser in 8192
+ # byte gulps. SF bug #1076485 found one of those chunks inside
+ # message/delivery-status header block, which triggered an
+ # unreadline() of NeedMoreData.
+ msg = self._msgobj('msg_43.txt')
+ sfp = StringIO()
+ Iterators._structure(msg, sfp)
+ eq(sfp.getvalue(), """\
+multipart/report
+ text/plain
+ message/delivery-status
+ text/plain
+ text/plain
+ text/plain
+ text/plain
+ text/plain
+ text/plain
+ text/plain
+ text/plain
+ text/plain
+ text/plain
+ text/plain
+ text/plain
+ text/plain
+ text/plain
+ text/plain
+ text/plain
+ text/plain
+ text/plain
+ text/plain
+ text/plain
+ text/plain
+ text/plain
+ text/plain
+ text/plain
+ text/plain
+ text/plain
+ text/rfc822-headers
+""")
+
+
+
+# Test the iterator/generators
+class TestIterators(TestEmailBase):
+ def test_body_line_iterator(self):
+ eq = self.assertEqual
+ neq = self.ndiffAssertEqual
+ # First a simple non-multipart message
+ msg = self._msgobj('msg_01.txt')
+ it = Iterators.body_line_iterator(msg)
+ lines = list(it)
+ eq(len(lines), 6)
+ neq(EMPTYSTRING.join(lines), msg.get_payload())
+ # Now a more complicated multipart
+ msg = self._msgobj('msg_02.txt')
+ it = Iterators.body_line_iterator(msg)
+ lines = list(it)
+ eq(len(lines), 43)
+ fp = openfile('msg_19.txt')
+ try:
+ neq(EMPTYSTRING.join(lines), fp.read())
+ finally:
+ fp.close()
+
+ def test_typed_subpart_iterator(self):
+ eq = self.assertEqual
+ msg = self._msgobj('msg_04.txt')
+ it = Iterators.typed_subpart_iterator(msg, 'text')
+ lines = []
+ subparts = 0
+ for subpart in it:
+ subparts += 1
+ lines.append(subpart.get_payload())
+ eq(subparts, 2)
+ eq(EMPTYSTRING.join(lines), """\
+a simple kind of mirror
+to reflect upon our own
+a simple kind of mirror
+to reflect upon our own
+""")
+
+ def test_typed_subpart_iterator_default_type(self):
+ eq = self.assertEqual
+ msg = self._msgobj('msg_03.txt')
+ it = Iterators.typed_subpart_iterator(msg, 'text', 'plain')
+ lines = []
+ subparts = 0
+ for subpart in it:
+ subparts += 1
+ lines.append(subpart.get_payload())
+ eq(subparts, 1)
+ eq(EMPTYSTRING.join(lines), """\
+
+Hi,
+
+Do you like this message?
+
+-Me
+""")
+
+ def test_pushCR_LF(self):
+ '''FeedParser BufferedSubFile.push() assumed it received complete
+ line endings. A CR ending one push() followed by a LF starting
+ the next push() added an empty line.
+ '''
+ imt = [
+ ("a\r \n", 2),
+ ("b", 0),
+ ("c\n", 1),
+ ("", 0),
+ ("d\r\n", 1),
+ ("e\r", 0),
+ ("\nf", 1),
+ ("\r\n", 1),
+ ]
+ from email.feedparser import BufferedSubFile, NeedMoreData
+ bsf = BufferedSubFile()
+ om = []
+ nt = 0
+ for il, n in imt:
+ bsf.push(il)
+ nt += n
+ n1 = 0
+ while True:
+ ol = bsf.readline()
+ if ol == NeedMoreData:
+ break
+ om.append(ol)
+ n1 += 1
+ self.assertTrue(n == n1)
+ self.assertTrue(len(om) == nt)
+ self.assertTrue(''.join([il for il, n in imt]) == ''.join(om))
+
+
+
+class TestParsers(TestEmailBase):
+ def test_header_parser(self):
+ eq = self.assertEqual
+ # Parse only the headers of a complex multipart MIME document
+ fp = openfile('msg_02.txt')
+ try:
+ msg = HeaderParser().parse(fp)
+ finally:
+ fp.close()
+ eq(msg['from'], 'ppp-request at zzz.org')
+ eq(msg['to'], 'ppp at zzz.org')
+ eq(msg.get_content_type(), 'multipart/mixed')
+ self.assertFalse(msg.is_multipart())
+ self.assertTrue(isinstance(msg.get_payload(), str))
+
+ def test_whitespace_continuation(self):
+ eq = self.assertEqual
+ # This message contains a line after the Subject: header that has only
+ # whitespace, but it is not empty!
+ msg = email.message_from_string("""\
+From: aperson at dom.ain
+To: bperson at dom.ain
+Subject: the next line has a space on it
+\x20
+Date: Mon, 8 Apr 2002 15:09:19 -0400
+Message-ID: spam
+
+Here's the message body
+""")
+ eq(msg['subject'], 'the next line has a space on it\n ')
+ eq(msg['message-id'], 'spam')
+ eq(msg.get_payload(), "Here's the message body\n")
+
+ def test_whitespace_continuation_last_header(self):
+ eq = self.assertEqual
+ # Like the previous test, but the subject line is the last
+ # header.
+ msg = email.message_from_string("""\
+From: aperson at dom.ain
+To: bperson at dom.ain
+Date: Mon, 8 Apr 2002 15:09:19 -0400
+Message-ID: spam
+Subject: the next line has a space on it
+\x20
+
+Here's the message body
+""")
+ eq(msg['subject'], 'the next line has a space on it\n ')
+ eq(msg['message-id'], 'spam')
+ eq(msg.get_payload(), "Here's the message body\n")
+
+ def test_crlf_separation(self):
+ eq = self.assertEqual
+ fp = openfile('msg_26.txt', mode='rb')
+ try:
+ msg = Parser().parse(fp)
+ finally:
+ fp.close()
+ eq(len(msg.get_payload()), 2)
+ part1 = msg.get_payload(0)
+ eq(part1.get_content_type(), 'text/plain')
+ eq(part1.get_payload(), 'Simple email with attachment.\r\n\r\n')
+ part2 = msg.get_payload(1)
+ eq(part2.get_content_type(), 'application/riscos')
+
+ def test_multipart_digest_with_extra_mime_headers(self):
+ eq = self.assertEqual
+ neq = self.ndiffAssertEqual
+ fp = openfile('msg_28.txt')
+ try:
+ msg = email.message_from_file(fp)
+ finally:
+ fp.close()
+ # Structure is:
+ # multipart/digest
+ # message/rfc822
+ # text/plain
+ # message/rfc822
+ # text/plain
+ eq(msg.is_multipart(), 1)
+ eq(len(msg.get_payload()), 2)
+ part1 = msg.get_payload(0)
+ eq(part1.get_content_type(), 'message/rfc822')
+ eq(part1.is_multipart(), 1)
+ eq(len(part1.get_payload()), 1)
+ part1a = part1.get_payload(0)
+ eq(part1a.is_multipart(), 0)
+ eq(part1a.get_content_type(), 'text/plain')
+ neq(part1a.get_payload(), 'message 1\n')
+ # next message/rfc822
+ part2 = msg.get_payload(1)
+ eq(part2.get_content_type(), 'message/rfc822')
+ eq(part2.is_multipart(), 1)
+ eq(len(part2.get_payload()), 1)
+ part2a = part2.get_payload(0)
+ eq(part2a.is_multipart(), 0)
+ eq(part2a.get_content_type(), 'text/plain')
+ neq(part2a.get_payload(), 'message 2\n')
+
+ def test_three_lines(self):
+ # A bug report by Andrew McNamara
+ lines = ['From: Andrew Person <aperson at dom.ain',
+ 'Subject: Test',
+ 'Date: Tue, 20 Aug 2002 16:43:45 +1000']
+ msg = email.message_from_string(NL.join(lines))
+ self.assertEqual(msg['date'], 'Tue, 20 Aug 2002 16:43:45 +1000')
+
+ def test_strip_line_feed_and_carriage_return_in_headers(self):
+ eq = self.assertEqual
+ # For [ 1002475 ] email message parser doesn't handle \r\n correctly
+ value1 = 'text'
+ value2 = 'more text'
+ m = 'Header: %s\r\nNext-Header: %s\r\n\r\nBody\r\n\r\n' % (
+ value1, value2)
+ msg = email.message_from_string(m)
+ eq(msg.get('Header'), value1)
+ eq(msg.get('Next-Header'), value2)
+
+ def test_rfc2822_header_syntax(self):
+ eq = self.assertEqual
+ m = '>From: foo\nFrom: bar\n!"#QUX;~: zoo\n\nbody'
+ msg = email.message_from_string(m)
+ eq(len(msg.keys()), 3)
+ keys = msg.keys()
+ keys.sort()
+ eq(keys, ['!"#QUX;~', '>From', 'From'])
+ eq(msg.get_payload(), 'body')
+
+ def test_rfc2822_space_not_allowed_in_header(self):
+ eq = self.assertEqual
+ m = '>From foo at example.com 11:25:53\nFrom: bar\n!"#QUX;~: zoo\n\nbody'
+ msg = email.message_from_string(m)
+ eq(len(msg.keys()), 0)
+
+ def test_rfc2822_one_character_header(self):
+ eq = self.assertEqual
+ m = 'A: first header\nB: second header\nCC: third header\n\nbody'
+ msg = email.message_from_string(m)
+ headers = msg.keys()
+ headers.sort()
+ eq(headers, ['A', 'B', 'CC'])
+ eq(msg.get_payload(), 'body')
+
+ def test_CRLFLF_at_end_of_part(self):
+ # issue 5610: feedparser should not eat two chars from body part ending
+ # with "\r\n\n".
+ m = (
+ "From: foo at bar.com\n"
+ "To: baz\n"
+ "Mime-Version: 1.0\n"
+ "Content-Type: multipart/mixed; boundary=BOUNDARY\n"
+ "\n"
+ "--BOUNDARY\n"
+ "Content-Type: text/plain\n"
+ "\n"
+ "body ending with CRLF newline\r\n"
+ "\n"
+ "--BOUNDARY--\n"
+ )
+ msg = email.message_from_string(m)
+ self.assertTrue(msg.get_payload(0).get_payload().endswith('\r\n'))
+
+
+class TestBase64(unittest.TestCase):
+ def test_len(self):
+ eq = self.assertEqual
+ eq(base64MIME.base64_len('hello'),
+ len(base64MIME.encode('hello', eol='')))
+ for size in range(15):
+ if size == 0 : bsize = 0
+ elif size <= 3 : bsize = 4
+ elif size <= 6 : bsize = 8
+ elif size <= 9 : bsize = 12
+ elif size <= 12: bsize = 16
+ else : bsize = 20
+ eq(base64MIME.base64_len('x'*size), bsize)
+
+ def test_decode(self):
+ eq = self.assertEqual
+ eq(base64MIME.decode(''), '')
+ eq(base64MIME.decode('aGVsbG8='), 'hello')
+ eq(base64MIME.decode('aGVsbG8=', 'X'), 'hello')
+ eq(base64MIME.decode('aGVsbG8NCndvcmxk\n', 'X'), 'helloXworld')
+
+ def test_encode(self):
+ eq = self.assertEqual
+ eq(base64MIME.encode(''), '')
+ eq(base64MIME.encode('hello'), 'aGVsbG8=\n')
+ # Test the binary flag
+ eq(base64MIME.encode('hello\n'), 'aGVsbG8K\n')
+ eq(base64MIME.encode('hello\n', 0), 'aGVsbG8NCg==\n')
+ # Test the maxlinelen arg
+ eq(base64MIME.encode('xxxx ' * 20, maxlinelen=40), """\
+eHh4eCB4eHh4IHh4eHggeHh4eCB4eHh4IHh4eHgg
+eHh4eCB4eHh4IHh4eHggeHh4eCB4eHh4IHh4eHgg
+eHh4eCB4eHh4IHh4eHggeHh4eCB4eHh4IHh4eHgg
+eHh4eCB4eHh4IA==
+""")
+ # Test the eol argument
+ eq(base64MIME.encode('xxxx ' * 20, maxlinelen=40, eol='\r\n'), """\
+eHh4eCB4eHh4IHh4eHggeHh4eCB4eHh4IHh4eHgg\r
+eHh4eCB4eHh4IHh4eHggeHh4eCB4eHh4IHh4eHgg\r
+eHh4eCB4eHh4IHh4eHggeHh4eCB4eHh4IHh4eHgg\r
+eHh4eCB4eHh4IA==\r
+""")
+
+ def test_header_encode(self):
+ eq = self.assertEqual
+ he = base64MIME.header_encode
+ eq(he('hello'), '=?iso-8859-1?b?aGVsbG8=?=')
+ eq(he('hello\nworld'), '=?iso-8859-1?b?aGVsbG8NCndvcmxk?=')
+ # Test the charset option
+ eq(he('hello', charset='iso-8859-2'), '=?iso-8859-2?b?aGVsbG8=?=')
+ # Test the keep_eols flag
+ eq(he('hello\nworld', keep_eols=True),
+ '=?iso-8859-1?b?aGVsbG8Kd29ybGQ=?=')
+ # Test the maxlinelen argument
+ eq(he('xxxx ' * 20, maxlinelen=40), """\
+=?iso-8859-1?b?eHh4eCB4eHh4IHh4eHggeHg=?=
+ =?iso-8859-1?b?eHggeHh4eCB4eHh4IHh4eHg=?=
+ =?iso-8859-1?b?IHh4eHggeHh4eCB4eHh4IHg=?=
+ =?iso-8859-1?b?eHh4IHh4eHggeHh4eCB4eHg=?=
+ =?iso-8859-1?b?eCB4eHh4IHh4eHggeHh4eCA=?=
+ =?iso-8859-1?b?eHh4eCB4eHh4IHh4eHgg?=""")
+ # Test the eol argument
+ eq(he('xxxx ' * 20, maxlinelen=40, eol='\r\n'), """\
+=?iso-8859-1?b?eHh4eCB4eHh4IHh4eHggeHg=?=\r
+ =?iso-8859-1?b?eHggeHh4eCB4eHh4IHh4eHg=?=\r
+ =?iso-8859-1?b?IHh4eHggeHh4eCB4eHh4IHg=?=\r
+ =?iso-8859-1?b?eHh4IHh4eHggeHh4eCB4eHg=?=\r
+ =?iso-8859-1?b?eCB4eHh4IHh4eHggeHh4eCA=?=\r
+ =?iso-8859-1?b?eHh4eCB4eHh4IHh4eHgg?=""")
+
+
+
+class TestQuopri(unittest.TestCase):
+ def setUp(self):
+ self.hlit = [chr(x) for x in range(ord('a'), ord('z')+1)] + \
+ [chr(x) for x in range(ord('A'), ord('Z')+1)] + \
+ [chr(x) for x in range(ord('0'), ord('9')+1)] + \
+ ['!', '*', '+', '-', '/', ' ']
+ self.hnon = [chr(x) for x in range(256) if chr(x) not in self.hlit]
+ assert len(self.hlit) + len(self.hnon) == 256
+ self.blit = [chr(x) for x in range(ord(' '), ord('~')+1)] + ['\t']
+ self.blit.remove('=')
+ self.bnon = [chr(x) for x in range(256) if chr(x) not in self.blit]
+ assert len(self.blit) + len(self.bnon) == 256
+
+ def test_header_quopri_check(self):
+ for c in self.hlit:
+ self.assertFalse(quopriMIME.header_quopri_check(c))
+ for c in self.hnon:
+ self.assertTrue(quopriMIME.header_quopri_check(c))
+
+ def test_body_quopri_check(self):
+ for c in self.blit:
+ self.assertFalse(quopriMIME.body_quopri_check(c))
+ for c in self.bnon:
+ self.assertTrue(quopriMIME.body_quopri_check(c))
+
+ def test_header_quopri_len(self):
+ eq = self.assertEqual
+ hql = quopriMIME.header_quopri_len
+ enc = quopriMIME.header_encode
+ for s in ('hello', 'h at e@l at l@o@'):
+ # Empty charset and no line-endings. 7 == RFC chrome
+ eq(hql(s), len(enc(s, charset='', eol=''))-7)
+ for c in self.hlit:
+ eq(hql(c), 1)
+ for c in self.hnon:
+ eq(hql(c), 3)
+
+ def test_body_quopri_len(self):
+ eq = self.assertEqual
+ bql = quopriMIME.body_quopri_len
+ for c in self.blit:
+ eq(bql(c), 1)
+ for c in self.bnon:
+ eq(bql(c), 3)
+
+ def test_quote_unquote_idempotent(self):
+ for x in range(256):
+ c = chr(x)
+ self.assertEqual(quopriMIME.unquote(quopriMIME.quote(c)), c)
+
+ def test_header_encode(self):
+ eq = self.assertEqual
+ he = quopriMIME.header_encode
+ eq(he('hello'), '=?iso-8859-1?q?hello?=')
+ eq(he('hello\nworld'), '=?iso-8859-1?q?hello=0D=0Aworld?=')
+ # Test the charset option
+ eq(he('hello', charset='iso-8859-2'), '=?iso-8859-2?q?hello?=')
+ # Test the keep_eols flag
+ eq(he('hello\nworld', keep_eols=True), '=?iso-8859-1?q?hello=0Aworld?=')
+ # Test a non-ASCII character
+ eq(he('hello\xc7there'), '=?iso-8859-1?q?hello=C7there?=')
+ # Test the maxlinelen argument
+ eq(he('xxxx ' * 20, maxlinelen=40), """\
+=?iso-8859-1?q?xxxx_xxxx_xxxx_xxxx_xx?=
+ =?iso-8859-1?q?xx_xxxx_xxxx_xxxx_xxxx?=
+ =?iso-8859-1?q?_xxxx_xxxx_xxxx_xxxx_x?=
+ =?iso-8859-1?q?xxx_xxxx_xxxx_xxxx_xxx?=
+ =?iso-8859-1?q?x_xxxx_xxxx_?=""")
+ # Test the eol argument
+ eq(he('xxxx ' * 20, maxlinelen=40, eol='\r\n'), """\
+=?iso-8859-1?q?xxxx_xxxx_xxxx_xxxx_xx?=\r
+ =?iso-8859-1?q?xx_xxxx_xxxx_xxxx_xxxx?=\r
+ =?iso-8859-1?q?_xxxx_xxxx_xxxx_xxxx_x?=\r
+ =?iso-8859-1?q?xxx_xxxx_xxxx_xxxx_xxx?=\r
+ =?iso-8859-1?q?x_xxxx_xxxx_?=""")
+
+ def test_decode(self):
+ eq = self.assertEqual
+ eq(quopriMIME.decode(''), '')
+ eq(quopriMIME.decode('hello'), 'hello')
+ eq(quopriMIME.decode('hello', 'X'), 'hello')
+ eq(quopriMIME.decode('hello\nworld', 'X'), 'helloXworld')
+
+ def test_encode(self):
+ eq = self.assertEqual
+ eq(quopriMIME.encode(''), '')
+ eq(quopriMIME.encode('hello'), 'hello')
+ # Test the binary flag
+ eq(quopriMIME.encode('hello\r\nworld'), 'hello\nworld')
+ eq(quopriMIME.encode('hello\r\nworld', 0), 'hello\nworld')
+ # Test the maxlinelen arg
+ eq(quopriMIME.encode('xxxx ' * 20, maxlinelen=40), """\
+xxxx xxxx xxxx xxxx xxxx xxxx xxxx xxxx=
+ xxxx xxxx xxxx xxxx xxxx xxxx xxxx xxx=
+x xxxx xxxx xxxx xxxx=20""")
+ # Test the eol argument
+ eq(quopriMIME.encode('xxxx ' * 20, maxlinelen=40, eol='\r\n'), """\
+xxxx xxxx xxxx xxxx xxxx xxxx xxxx xxxx=\r
+ xxxx xxxx xxxx xxxx xxxx xxxx xxxx xxx=\r
+x xxxx xxxx xxxx xxxx=20""")
+ eq(quopriMIME.encode("""\
+one line
+
+two line"""), """\
+one line
+
+two line""")
+
+
+
+# Test the Charset class
+class TestCharset(unittest.TestCase):
+ def tearDown(self):
+ from email import Charset as CharsetModule
+ try:
+ del CharsetModule.CHARSETS['fake']
+ except KeyError:
+ pass
+
+ def test_idempotent(self):
+ eq = self.assertEqual
+ # Make sure us-ascii = no Unicode conversion
+ c = Charset('us-ascii')
+ s = 'Hello World!'
+ sp = c.to_splittable(s)
+ eq(s, c.from_splittable(sp))
+ # test 8-bit idempotency with us-ascii
+ s = '\xa4\xa2\xa4\xa4\xa4\xa6\xa4\xa8\xa4\xaa'
+ sp = c.to_splittable(s)
+ eq(s, c.from_splittable(sp))
+
+ def test_body_encode(self):
+ eq = self.assertEqual
+ # Try a charset with QP body encoding
+ c = Charset('iso-8859-1')
+ eq('hello w=F6rld', c.body_encode('hello w\xf6rld'))
+ # Try a charset with Base64 body encoding
+ c = Charset('utf-8')
+ eq('aGVsbG8gd29ybGQ=\n', c.body_encode('hello world'))
+ # Try a charset with None body encoding
+ c = Charset('us-ascii')
+ eq('hello world', c.body_encode('hello world'))
+ # Try the convert argument, where input codec != output codec
+ c = Charset('euc-jp')
+ # With apologies to Tokio Kikuchi ;)
+ if not is_jython:
+ # TODO Jython with its Java-based codecs does not
+ # currently support trailing bytes in CJK texts
+ try:
+ eq('\x1b$B5FCO;~IW\x1b(B',
+ c.body_encode('\xb5\xc6\xc3\xcf\xbb\xfe\xc9\xd7'))
+ eq('\xb5\xc6\xc3\xcf\xbb\xfe\xc9\xd7',
+ c.body_encode('\xb5\xc6\xc3\xcf\xbb\xfe\xc9\xd7', False))
+ except LookupError:
+ # We probably don't have the Japanese codecs installed
+ pass
+ # Testing SF bug #625509, which we have to fake, since there are no
+ # built-in encodings where the header encoding is QP but the body
+ # encoding is not.
+ from email import Charset as CharsetModule
+ CharsetModule.add_charset('fake', CharsetModule.QP, None)
+ c = Charset('fake')
+ eq('hello w\xf6rld', c.body_encode('hello w\xf6rld'))
+
+ def test_unicode_charset_name(self):
+ charset = Charset(u'us-ascii')
+ self.assertEqual(str(charset), 'us-ascii')
+ self.assertRaises(Errors.CharsetError, Charset, 'asc\xffii')
+
+ def test_codecs_aliases_accepted(self):
+ charset = Charset('utf8')
+ self.assertEqual(str(charset), 'utf-8')
+
+
+# Test multilingual MIME headers.
+class TestHeader(TestEmailBase):
+ def test_simple(self):
+ eq = self.ndiffAssertEqual
+ h = Header('Hello World!')
+ eq(h.encode(), 'Hello World!')
+ h.append(' Goodbye World!')
+ eq(h.encode(), 'Hello World! Goodbye World!')
+
+ def test_simple_surprise(self):
+ eq = self.ndiffAssertEqual
+ h = Header('Hello World!')
+ eq(h.encode(), 'Hello World!')
+ h.append('Goodbye World!')
+ eq(h.encode(), 'Hello World! Goodbye World!')
+
+ def test_header_needs_no_decoding(self):
+ h = 'no decoding needed'
+ self.assertEqual(decode_header(h), [(h, None)])
+
+ def test_long(self):
+ h = Header("I am the very model of a modern Major-General; I've information vegetable, animal, and mineral; I know the kings of England, and I quote the fights historical from Marathon to Waterloo, in order categorical; I'm very well acquainted, too, with matters mathematical; I understand equations, both the simple and quadratical; about binomial theorem I'm teeming with a lot o' news, with many cheerful facts about the square of the hypotenuse.",
+ maxlinelen=76)
+ for l in h.encode(splitchars=' ').split('\n '):
+ self.assertTrue(len(l) <= 76)
+
+ def test_multilingual(self):
+ eq = self.ndiffAssertEqual
+ g = Charset("iso-8859-1")
+ cz = Charset("iso-8859-2")
+ utf8 = Charset("utf-8")
+ g_head = "Die Mieter treten hier ein werden mit einem Foerderband komfortabel den Korridor entlang, an s\xfcdl\xfcndischen Wandgem\xe4lden vorbei, gegen die rotierenden Klingen bef\xf6rdert. "
+ cz_head = "Finan\xe8ni metropole se hroutily pod tlakem jejich d\xf9vtipu.. "
+ utf8_head = u"\u6b63\u78ba\u306b\u8a00\u3046\u3068\u7ffb\u8a33\u306f\u3055\u308c\u3066\u3044\u307e\u305b\u3093\u3002\u4e00\u90e8\u306f\u30c9\u30a4\u30c4\u8a9e\u3067\u3059\u304c\u3001\u3042\u3068\u306f\u3067\u305f\u3089\u3081\u3067\u3059\u3002\u5b9f\u969b\u306b\u306f\u300cWenn ist das Nunstuck git und Slotermeyer? Ja! Beiherhund das Oder die Flipperwaldt gersput.\u300d\u3068\u8a00\u3063\u3066\u3044\u307e\u3059\u3002".encode("utf-8")
+ h = Header(g_head, g)
+ h.append(cz_head, cz)
+ h.append(utf8_head, utf8)
+ enc = h.encode()
+ eq(enc, """\
+=?iso-8859-1?q?Die_Mieter_treten_hier_ein_werden_mit_einem_Foerderband_ko?=
+ =?iso-8859-1?q?mfortabel_den_Korridor_entlang=2C_an_s=FCdl=FCndischen_Wan?=
+ =?iso-8859-1?q?dgem=E4lden_vorbei=2C_gegen_die_rotierenden_Klingen_bef=F6?=
+ =?iso-8859-1?q?rdert=2E_?= =?iso-8859-2?q?Finan=E8ni_metropole_se_hroutily?=
+ =?iso-8859-2?q?_pod_tlakem_jejich_d=F9vtipu=2E=2E_?= =?utf-8?b?5q2j56K6?=
+ =?utf-8?b?44Gr6KiA44GG44Go57+76Kiz44Gv44GV44KM44Gm44GE44G+44Gb44KT44CC?=
+ =?utf-8?b?5LiA6YOo44Gv44OJ44Kk44OE6Kqe44Gn44GZ44GM44CB44GC44Go44Gv44Gn?=
+ =?utf-8?b?44Gf44KJ44KB44Gn44GZ44CC5a6f6Zqb44Gr44Gv44CMV2VubiBpc3QgZGFz?=
+ =?utf-8?q?_Nunstuck_git_und_Slotermeyer=3F_Ja!_Beiherhund_das_Oder_die_Fl?=
+ =?utf-8?b?aXBwZXJ3YWxkdCBnZXJzcHV0LuOAjeOBqOiogOOBo+OBpuOBhOOBvuOBmQ==?=
+ =?utf-8?b?44CC?=""")
+ eq(decode_header(enc),
+ [(g_head, "iso-8859-1"), (cz_head, "iso-8859-2"),
+ (utf8_head, "utf-8")])
+ ustr = unicode(h)
+ eq(ustr.encode('utf-8'),
+ 'Die Mieter treten hier ein werden mit einem Foerderband '
+ 'komfortabel den Korridor entlang, an s\xc3\xbcdl\xc3\xbcndischen '
+ 'Wandgem\xc3\xa4lden vorbei, gegen die rotierenden Klingen '
+ 'bef\xc3\xb6rdert. Finan\xc4\x8dni metropole se hroutily pod '
+ 'tlakem jejich d\xc5\xafvtipu.. \xe6\xad\xa3\xe7\xa2\xba\xe3\x81'
+ '\xab\xe8\xa8\x80\xe3\x81\x86\xe3\x81\xa8\xe7\xbf\xbb\xe8\xa8\xb3'
+ '\xe3\x81\xaf\xe3\x81\x95\xe3\x82\x8c\xe3\x81\xa6\xe3\x81\x84\xe3'
+ '\x81\xbe\xe3\x81\x9b\xe3\x82\x93\xe3\x80\x82\xe4\xb8\x80\xe9\x83'
+ '\xa8\xe3\x81\xaf\xe3\x83\x89\xe3\x82\xa4\xe3\x83\x84\xe8\xaa\x9e'
+ '\xe3\x81\xa7\xe3\x81\x99\xe3\x81\x8c\xe3\x80\x81\xe3\x81\x82\xe3'
+ '\x81\xa8\xe3\x81\xaf\xe3\x81\xa7\xe3\x81\x9f\xe3\x82\x89\xe3\x82'
+ '\x81\xe3\x81\xa7\xe3\x81\x99\xe3\x80\x82\xe5\xae\x9f\xe9\x9a\x9b'
+ '\xe3\x81\xab\xe3\x81\xaf\xe3\x80\x8cWenn ist das Nunstuck git '
+ 'und Slotermeyer? Ja! Beiherhund das Oder die Flipperwaldt '
+ 'gersput.\xe3\x80\x8d\xe3\x81\xa8\xe8\xa8\x80\xe3\x81\xa3\xe3\x81'
+ '\xa6\xe3\x81\x84\xe3\x81\xbe\xe3\x81\x99\xe3\x80\x82')
+ # Test make_header()
+ newh = make_header(decode_header(enc))
+ eq(newh, enc)
+
+ def test_header_ctor_default_args(self):
+ eq = self.ndiffAssertEqual
+ h = Header()
+ eq(h, '')
+ h.append('foo', Charset('iso-8859-1'))
+ eq(h, '=?iso-8859-1?q?foo?=')
+
+ def test_explicit_maxlinelen(self):
+ eq = self.ndiffAssertEqual
+ hstr = 'A very long line that must get split to something other than at the 76th character boundary to test the non-default behavior'
+ h = Header(hstr)
+ eq(h.encode(), '''\
+A very long line that must get split to something other than at the 76th
+ character boundary to test the non-default behavior''')
+ h = Header(hstr, header_name='Subject')
+ eq(h.encode(), '''\
+A very long line that must get split to something other than at the
+ 76th character boundary to test the non-default behavior''')
+ h = Header(hstr, maxlinelen=1024, header_name='Subject')
+ eq(h.encode(), hstr)
+
+ def test_us_ascii_header(self):
+ eq = self.assertEqual
+ s = 'hello'
+ x = decode_header(s)
+ eq(x, [('hello', None)])
+ h = make_header(x)
+ eq(s, h.encode())
+
+ def test_string_charset(self):
+ eq = self.assertEqual
+ h = Header()
+ h.append('hello', 'iso-8859-1')
+ eq(h, '=?iso-8859-1?q?hello?=')
+
+## def test_unicode_error(self):
+## raises = self.assertRaises
+## raises(UnicodeError, Header, u'[P\xf6stal]', 'us-ascii')
+## raises(UnicodeError, Header, '[P\xf6stal]', 'us-ascii')
+## h = Header()
+## raises(UnicodeError, h.append, u'[P\xf6stal]', 'us-ascii')
+## raises(UnicodeError, h.append, '[P\xf6stal]', 'us-ascii')
+## raises(UnicodeError, Header, u'\u83ca\u5730\u6642\u592b', 'iso-8859-1')
+
+ def test_utf8_shortest(self):
+ eq = self.assertEqual
+ h = Header(u'p\xf6stal', 'utf-8')
+ eq(h.encode(), '=?utf-8?q?p=C3=B6stal?=')
+ h = Header(u'\u83ca\u5730\u6642\u592b', 'utf-8')
+ eq(h.encode(), '=?utf-8?b?6I+K5Zyw5pmC5aSr?=')
+
+ def test_bad_8bit_header(self):
+ raises = self.assertRaises
+ eq = self.assertEqual
+ x = 'Ynwp4dUEbay Auction Semiar- No Charge \x96 Earn Big'
+ raises(UnicodeError, Header, x)
+ h = Header()
+ raises(UnicodeError, h.append, x)
+ eq(str(Header(x, errors='replace')), x)
+ h.append(x, errors='replace')
+ eq(str(h), x)
+
+ def test_encoded_adjacent_nonencoded(self):
+ eq = self.assertEqual
+ h = Header()
+ h.append('hello', 'iso-8859-1')
+ h.append('world')
+ s = h.encode()
+ eq(s, '=?iso-8859-1?q?hello?= world')
+ h = make_header(decode_header(s))
+ eq(h.encode(), s)
+
+ def test_whitespace_eater(self):
+ eq = self.assertEqual
+ s = 'Subject: =?koi8-r?b?8NLP18XSy8EgzsEgxsnOwczYztk=?= =?koi8-r?q?=CA?= zz.'
+ parts = decode_header(s)
+ eq(parts, [('Subject:', None), ('\xf0\xd2\xcf\xd7\xc5\xd2\xcb\xc1 \xce\xc1 \xc6\xc9\xce\xc1\xcc\xd8\xce\xd9\xca', 'koi8-r'), ('zz.', None)])
+ hdr = make_header(parts)
+ eq(hdr.encode(),
+ 'Subject: =?koi8-r?b?8NLP18XSy8EgzsEgxsnOwczYztnK?= zz.')
+
+ def test_broken_base64_header(self):
+ raises = self.assertRaises
+ s = 'Subject: =?EUC-KR?B?CSixpLDtKSC/7Liuvsax4iC6uLmwMcijIKHaILzSwd/H0SC8+LCjwLsgv7W/+Mj3I ?='
+ raises(Errors.HeaderParseError, decode_header, s)
+
+ # Issue 1078919
+ def test_ascii_add_header(self):
+ msg = Message()
+ msg.add_header('Content-Disposition', 'attachment',
+ filename='bud.gif')
+ self.assertEqual('attachment; filename="bud.gif"',
+ msg['Content-Disposition'])
+
+ def test_nonascii_add_header_via_triple(self):
+ msg = Message()
+ msg.add_header('Content-Disposition', 'attachment',
+ filename=('iso-8859-1', '', 'Fu\xdfballer.ppt'))
+ self.assertEqual(
+ 'attachment; filename*="iso-8859-1\'\'Fu%DFballer.ppt"',
+ msg['Content-Disposition'])
+
+ def test_encode_unaliased_charset(self):
+ # Issue 1379416: when the charset has no output conversion,
+ # output was accidentally getting coerced to unicode.
+ res = Header('abc','iso-8859-2').encode()
+ self.assertEqual(res, '=?iso-8859-2?q?abc?=')
+ self.assertIsInstance(res, str)
+
+
+# Test RFC 2231 header parameters (en/de)coding
+class TestRFC2231(TestEmailBase):
+ def test_get_param(self):
+ eq = self.assertEqual
+ msg = self._msgobj('msg_29.txt')
+ eq(msg.get_param('title'),
+ ('us-ascii', 'en', 'This is even more ***fun*** isn\'t it!'))
+ eq(msg.get_param('title', unquote=False),
+ ('us-ascii', 'en', '"This is even more ***fun*** isn\'t it!"'))
+
+ def test_set_param(self):
+ eq = self.assertEqual
+ msg = Message()
+ msg.set_param('title', 'This is even more ***fun*** isn\'t it!',
+ charset='us-ascii')
+ eq(msg.get_param('title'),
+ ('us-ascii', '', 'This is even more ***fun*** isn\'t it!'))
+ msg.set_param('title', 'This is even more ***fun*** isn\'t it!',
+ charset='us-ascii', language='en')
+ eq(msg.get_param('title'),
+ ('us-ascii', 'en', 'This is even more ***fun*** isn\'t it!'))
+ msg = self._msgobj('msg_01.txt')
+ msg.set_param('title', 'This is even more ***fun*** isn\'t it!',
+ charset='us-ascii', language='en')
+ self.ndiffAssertEqual(msg.as_string(), """\
+Return-Path: <bbb at zzz.org>
+Delivered-To: bbb at zzz.org
+Received: by mail.zzz.org (Postfix, from userid 889)
+ id 27CEAD38CC; Fri, 4 May 2001 14:05:44 -0400 (EDT)
+MIME-Version: 1.0
+Content-Transfer-Encoding: 7bit
+Message-ID: <15090.61304.110929.45684 at aaa.zzz.org>
+From: bbb at ddd.com (John X. Doe)
+To: bbb at zzz.org
+Subject: This is a test message
+Date: Fri, 4 May 2001 14:05:44 -0400
+Content-Type: text/plain; charset=us-ascii;
+ title*="us-ascii'en'This%20is%20even%20more%20%2A%2A%2Afun%2A%2A%2A%20isn%27t%20it%21"
+
+
+Hi,
+
+Do you like this message?
+
+-Me
+""")
+
+ def test_del_param(self):
+ eq = self.ndiffAssertEqual
+ msg = self._msgobj('msg_01.txt')
+ msg.set_param('foo', 'bar', charset='us-ascii', language='en')
+ msg.set_param('title', 'This is even more ***fun*** isn\'t it!',
+ charset='us-ascii', language='en')
+ msg.del_param('foo', header='Content-Type')
+ eq(msg.as_string(), """\
+Return-Path: <bbb at zzz.org>
+Delivered-To: bbb at zzz.org
+Received: by mail.zzz.org (Postfix, from userid 889)
+ id 27CEAD38CC; Fri, 4 May 2001 14:05:44 -0400 (EDT)
+MIME-Version: 1.0
+Content-Transfer-Encoding: 7bit
+Message-ID: <15090.61304.110929.45684 at aaa.zzz.org>
+From: bbb at ddd.com (John X. Doe)
+To: bbb at zzz.org
+Subject: This is a test message
+Date: Fri, 4 May 2001 14:05:44 -0400
+Content-Type: text/plain; charset="us-ascii";
+ title*="us-ascii'en'This%20is%20even%20more%20%2A%2A%2Afun%2A%2A%2A%20isn%27t%20it%21"
+
+
+Hi,
+
+Do you like this message?
+
+-Me
+""")
+
+ def test_rfc2231_get_content_charset(self):
+ eq = self.assertEqual
+ msg = self._msgobj('msg_32.txt')
+ eq(msg.get_content_charset(), 'us-ascii')
+
+ def test_rfc2231_no_language_or_charset(self):
+ m = '''\
+Content-Transfer-Encoding: 8bit
+Content-Disposition: inline; filename="file____C__DOCUMENTS_20AND_20SETTINGS_FABIEN_LOCAL_20SETTINGS_TEMP_nsmail.htm"
+Content-Type: text/html; NAME*0=file____C__DOCUMENTS_20AND_20SETTINGS_FABIEN_LOCAL_20SETTINGS_TEM; NAME*1=P_nsmail.htm
+
+'''
+ msg = email.message_from_string(m)
+ param = msg.get_param('NAME')
+ self.assertFalse(isinstance(param, tuple))
+ self.assertEqual(
+ param,
+ 'file____C__DOCUMENTS_20AND_20SETTINGS_FABIEN_LOCAL_20SETTINGS_TEMP_nsmail.htm')
+
+ def test_rfc2231_no_language_or_charset_in_filename(self):
+ m = '''\
+Content-Disposition: inline;
+\tfilename*0*="''This%20is%20even%20more%20";
+\tfilename*1*="%2A%2A%2Afun%2A%2A%2A%20";
+\tfilename*2="is it not.pdf"
+
+'''
+ msg = email.message_from_string(m)
+ self.assertEqual(msg.get_filename(),
+ 'This is even more ***fun*** is it not.pdf')
+
+ def test_rfc2231_no_language_or_charset_in_filename_encoded(self):
+ m = '''\
+Content-Disposition: inline;
+\tfilename*0*="''This%20is%20even%20more%20";
+\tfilename*1*="%2A%2A%2Afun%2A%2A%2A%20";
+\tfilename*2="is it not.pdf"
+
+'''
+ msg = email.message_from_string(m)
+ self.assertEqual(msg.get_filename(),
+ 'This is even more ***fun*** is it not.pdf')
+
+ def test_rfc2231_partly_encoded(self):
+ m = '''\
+Content-Disposition: inline;
+\tfilename*0="''This%20is%20even%20more%20";
+\tfilename*1*="%2A%2A%2Afun%2A%2A%2A%20";
+\tfilename*2="is it not.pdf"
+
+'''
+ msg = email.message_from_string(m)
+ self.assertEqual(
+ msg.get_filename(),
+ 'This%20is%20even%20more%20***fun*** is it not.pdf')
+
+ def test_rfc2231_partly_nonencoded(self):
+ m = '''\
+Content-Disposition: inline;
+\tfilename*0="This%20is%20even%20more%20";
+\tfilename*1="%2A%2A%2Afun%2A%2A%2A%20";
+\tfilename*2="is it not.pdf"
+
+'''
+ msg = email.message_from_string(m)
+ self.assertEqual(
+ msg.get_filename(),
+ 'This%20is%20even%20more%20%2A%2A%2Afun%2A%2A%2A%20is it not.pdf')
+
+ def test_rfc2231_no_language_or_charset_in_boundary(self):
+ m = '''\
+Content-Type: multipart/alternative;
+\tboundary*0*="''This%20is%20even%20more%20";
+\tboundary*1*="%2A%2A%2Afun%2A%2A%2A%20";
+\tboundary*2="is it not.pdf"
+
+'''
+ msg = email.message_from_string(m)
+ self.assertEqual(msg.get_boundary(),
+ 'This is even more ***fun*** is it not.pdf')
+
+ def test_rfc2231_no_language_or_charset_in_charset(self):
+ # This is a nonsensical charset value, but tests the code anyway
+ m = '''\
+Content-Type: text/plain;
+\tcharset*0*="This%20is%20even%20more%20";
+\tcharset*1*="%2A%2A%2Afun%2A%2A%2A%20";
+\tcharset*2="is it not.pdf"
+
+'''
+ msg = email.message_from_string(m)
+ self.assertEqual(msg.get_content_charset(),
+ 'this is even more ***fun*** is it not.pdf')
+
+ def test_rfc2231_bad_encoding_in_filename(self):
+ m = '''\
+Content-Disposition: inline;
+\tfilename*0*="bogus'xx'This%20is%20even%20more%20";
+\tfilename*1*="%2A%2A%2Afun%2A%2A%2A%20";
+\tfilename*2="is it not.pdf"
+
+'''
+ msg = email.message_from_string(m)
+ self.assertEqual(msg.get_filename(),
+ 'This is even more ***fun*** is it not.pdf')
+
+ def test_rfc2231_bad_encoding_in_charset(self):
+ m = """\
+Content-Type: text/plain; charset*=bogus''utf-8%E2%80%9D
+
+"""
+ msg = email.message_from_string(m)
+ # This should return None because non-ascii characters in the charset
+ # are not allowed.
+ self.assertEqual(msg.get_content_charset(), None)
+
+ def test_rfc2231_bad_character_in_charset(self):
+ m = """\
+Content-Type: text/plain; charset*=ascii''utf-8%E2%80%9D
+
+"""
+ msg = email.message_from_string(m)
+ # This should return None because non-ascii characters in the charset
+ # are not allowed.
+ self.assertEqual(msg.get_content_charset(), None)
+
+ def test_rfc2231_bad_character_in_filename(self):
+ m = '''\
+Content-Disposition: inline;
+\tfilename*0*="ascii'xx'This%20is%20even%20more%20";
+\tfilename*1*="%2A%2A%2Afun%2A%2A%2A%20";
+\tfilename*2*="is it not.pdf%E2"
+
+'''
+ msg = email.message_from_string(m)
+ self.assertEqual(msg.get_filename(),
+ u'This is even more ***fun*** is it not.pdf\ufffd')
+
+ def test_rfc2231_unknown_encoding(self):
+ m = """\
+Content-Transfer-Encoding: 8bit
+Content-Disposition: inline; filename*=X-UNKNOWN''myfile.txt
+
+"""
+ msg = email.message_from_string(m)
+ self.assertEqual(msg.get_filename(), 'myfile.txt')
+
+ def test_rfc2231_single_tick_in_filename_extended(self):
+ eq = self.assertEqual
+ m = """\
+Content-Type: application/x-foo;
+\tname*0*=\"Frank's\"; name*1*=\" Document\"
+
+"""
+ msg = email.message_from_string(m)
+ charset, language, s = msg.get_param('name')
+ eq(charset, None)
+ eq(language, None)
+ eq(s, "Frank's Document")
+
+ def test_rfc2231_single_tick_in_filename(self):
+ m = """\
+Content-Type: application/x-foo; name*0=\"Frank's\"; name*1=\" Document\"
+
+"""
+ msg = email.message_from_string(m)
+ param = msg.get_param('name')
+ self.assertFalse(isinstance(param, tuple))
+ self.assertEqual(param, "Frank's Document")
+
+ def test_rfc2231_tick_attack_extended(self):
+ eq = self.assertEqual
+ m = """\
+Content-Type: application/x-foo;
+\tname*0*=\"us-ascii'en-us'Frank's\"; name*1*=\" Document\"
+
+"""
+ msg = email.message_from_string(m)
+ charset, language, s = msg.get_param('name')
+ eq(charset, 'us-ascii')
+ eq(language, 'en-us')
+ eq(s, "Frank's Document")
+
+ def test_rfc2231_tick_attack(self):
+ m = """\
+Content-Type: application/x-foo;
+\tname*0=\"us-ascii'en-us'Frank's\"; name*1=\" Document\"
+
+"""
+ msg = email.message_from_string(m)
+ param = msg.get_param('name')
+ self.assertFalse(isinstance(param, tuple))
+ self.assertEqual(param, "us-ascii'en-us'Frank's Document")
+
+ def test_rfc2231_no_extended_values(self):
+ eq = self.assertEqual
+ m = """\
+Content-Type: application/x-foo; name=\"Frank's Document\"
+
+"""
+ msg = email.message_from_string(m)
+ eq(msg.get_param('name'), "Frank's Document")
+
+ def test_rfc2231_encoded_then_unencoded_segments(self):
+ eq = self.assertEqual
+ m = """\
+Content-Type: application/x-foo;
+\tname*0*=\"us-ascii'en-us'My\";
+\tname*1=\" Document\";
+\tname*2*=\" For You\"
+
+"""
+ msg = email.message_from_string(m)
+ charset, language, s = msg.get_param('name')
+ eq(charset, 'us-ascii')
+ eq(language, 'en-us')
+ eq(s, 'My Document For You')
+
+ def test_rfc2231_unencoded_then_encoded_segments(self):
+ eq = self.assertEqual
+ m = """\
+Content-Type: application/x-foo;
+\tname*0=\"us-ascii'en-us'My\";
+\tname*1*=\" Document\";
+\tname*2*=\" For You\"
+
+"""
+ msg = email.message_from_string(m)
+ charset, language, s = msg.get_param('name')
+ eq(charset, 'us-ascii')
+ eq(language, 'en-us')
+ eq(s, 'My Document For You')
+
+
+
+# Tests to ensure that signed parts of an email are completely preserved, as
+# required by RFC1847 section 2.1. Note that these are incomplete, because the
+# email package does not currently always preserve the body. See issue 1670765.
+class TestSigned(TestEmailBase):
+
+ def _msg_and_obj(self, filename):
+ fp = openfile(findfile(filename))
+ try:
+ original = fp.read()
+ msg = email.message_from_string(original)
+ finally:
+ fp.close()
+ return original, msg
+
+ def _signed_parts_eq(self, original, result):
+ # Extract the first mime part of each message
+ import re
+ repart = re.compile(r'^--([^\n]+)\n(.*?)\n--\1$', re.S | re.M)
+ inpart = repart.search(original).group(2)
+ outpart = repart.search(result).group(2)
+ self.assertEqual(outpart, inpart)
+
+ def test_long_headers_as_string(self):
+ original, msg = self._msg_and_obj('msg_45.txt')
+ result = msg.as_string()
+ self._signed_parts_eq(original, result)
+
+ def test_long_headers_flatten(self):
+ original, msg = self._msg_and_obj('msg_45.txt')
+ fp = StringIO()
+ Generator(fp).flatten(msg)
+ result = fp.getvalue()
+ self._signed_parts_eq(original, result)
+
+
+
+def _testclasses():
+ mod = sys.modules[__name__]
+ return [getattr(mod, name) for name in dir(mod) if name.startswith('Test')]
+
+
+def suite():
+ suite = unittest.TestSuite()
+ for testclass in _testclasses():
+ suite.addTest(unittest.makeSuite(testclass))
+ return suite
+
+
+def test_main():
+ for testclass in _testclasses():
+ run_unittest(testclass)
+
+
+
+if __name__ == '__main__':
+ unittest.main(defaultTest='suite')
diff --git a/Lib/email/test/test_email_renamed.py b/Lib/email/test/test_email_renamed.py
new file mode 100644
--- /dev/null
+++ b/Lib/email/test/test_email_renamed.py
@@ -0,0 +1,3332 @@
+# Copyright (C) 2001-2007 Python Software Foundation
+# Contact: email-sig at python.org
+# email package unit tests
+
+import os
+import sys
+import time
+import base64
+import difflib
+import unittest
+import warnings
+from cStringIO import StringIO
+
+import email
+
+from email.charset import Charset
+from email.header import Header, decode_header, make_header
+from email.parser import Parser, HeaderParser
+from email.generator import Generator, DecodedGenerator
+from email.message import Message
+from email.mime.application import MIMEApplication
+from email.mime.audio import MIMEAudio
+from email.mime.text import MIMEText
+from email.mime.image import MIMEImage
+from email.mime.base import MIMEBase
+from email.mime.message import MIMEMessage
+from email.mime.multipart import MIMEMultipart
+from email import utils
+from email import errors
+from email import encoders
+from email import iterators
+from email import base64mime
+from email import quoprimime
+
+from test.test_support import findfile, run_unittest, is_jython
+from email.test import __file__ as landmark
+
+
+NL = '\n'
+EMPTYSTRING = ''
+SPACE = ' '
+
+
+
+def openfile(filename, mode='r'):
+ path = os.path.join(os.path.dirname(landmark), 'data', filename)
+ return open(path, mode)
+
+
+
+# Base test class
+class TestEmailBase(unittest.TestCase):
+ def ndiffAssertEqual(self, first, second):
+ """Like assertEqual except use ndiff for readable output."""
+ if first != second:
+ sfirst = str(first)
+ ssecond = str(second)
+ diff = difflib.ndiff(sfirst.splitlines(), ssecond.splitlines())
+ fp = StringIO()
+ print >> fp, NL, NL.join(diff)
+ raise self.failureException, fp.getvalue()
+
+ def _msgobj(self, filename):
+ fp = openfile(findfile(filename))
+ try:
+ msg = email.message_from_file(fp)
+ finally:
+ fp.close()
+ return msg
+
+
+
+# Test various aspects of the Message class's API
+class TestMessageAPI(TestEmailBase):
+ def test_get_all(self):
+ eq = self.assertEqual
+ msg = self._msgobj('msg_20.txt')
+ eq(msg.get_all('cc'), ['ccc at zzz.org', 'ddd at zzz.org', 'eee at zzz.org'])
+ eq(msg.get_all('xx', 'n/a'), 'n/a')
+
+ def test_getset_charset(self):
+ eq = self.assertEqual
+ msg = Message()
+ eq(msg.get_charset(), None)
+ charset = Charset('iso-8859-1')
+ msg.set_charset(charset)
+ eq(msg['mime-version'], '1.0')
+ eq(msg.get_content_type(), 'text/plain')
+ eq(msg['content-type'], 'text/plain; charset="iso-8859-1"')
+ eq(msg.get_param('charset'), 'iso-8859-1')
+ eq(msg['content-transfer-encoding'], 'quoted-printable')
+ eq(msg.get_charset().input_charset, 'iso-8859-1')
+ # Remove the charset
+ msg.set_charset(None)
+ eq(msg.get_charset(), None)
+ eq(msg['content-type'], 'text/plain')
+ # Try adding a charset when there's already MIME headers present
+ msg = Message()
+ msg['MIME-Version'] = '2.0'
+ msg['Content-Type'] = 'text/x-weird'
+ msg['Content-Transfer-Encoding'] = 'quinted-puntable'
+ msg.set_charset(charset)
+ eq(msg['mime-version'], '2.0')
+ eq(msg['content-type'], 'text/x-weird; charset="iso-8859-1"')
+ eq(msg['content-transfer-encoding'], 'quinted-puntable')
+
+ def test_set_charset_from_string(self):
+ eq = self.assertEqual
+ msg = Message()
+ msg.set_charset('us-ascii')
+ eq(msg.get_charset().input_charset, 'us-ascii')
+ eq(msg['content-type'], 'text/plain; charset="us-ascii"')
+
+ def test_set_payload_with_charset(self):
+ msg = Message()
+ charset = Charset('iso-8859-1')
+ msg.set_payload('This is a string payload', charset)
+ self.assertEqual(msg.get_charset().input_charset, 'iso-8859-1')
+
+ def test_get_charsets(self):
+ eq = self.assertEqual
+
+ msg = self._msgobj('msg_08.txt')
+ charsets = msg.get_charsets()
+ eq(charsets, [None, 'us-ascii', 'iso-8859-1', 'iso-8859-2', 'koi8-r'])
+
+ msg = self._msgobj('msg_09.txt')
+ charsets = msg.get_charsets('dingbat')
+ eq(charsets, ['dingbat', 'us-ascii', 'iso-8859-1', 'dingbat',
+ 'koi8-r'])
+
+ msg = self._msgobj('msg_12.txt')
+ charsets = msg.get_charsets()
+ eq(charsets, [None, 'us-ascii', 'iso-8859-1', None, 'iso-8859-2',
+ 'iso-8859-3', 'us-ascii', 'koi8-r'])
+
+ def test_get_filename(self):
+ eq = self.assertEqual
+
+ msg = self._msgobj('msg_04.txt')
+ filenames = [p.get_filename() for p in msg.get_payload()]
+ eq(filenames, ['msg.txt', 'msg.txt'])
+
+ msg = self._msgobj('msg_07.txt')
+ subpart = msg.get_payload(1)
+ eq(subpart.get_filename(), 'dingusfish.gif')
+
+ def test_get_filename_with_name_parameter(self):
+ eq = self.assertEqual
+
+ msg = self._msgobj('msg_44.txt')
+ filenames = [p.get_filename() for p in msg.get_payload()]
+ eq(filenames, ['msg.txt', 'msg.txt'])
+
+ def test_get_boundary(self):
+ eq = self.assertEqual
+ msg = self._msgobj('msg_07.txt')
+ # No quotes!
+ eq(msg.get_boundary(), 'BOUNDARY')
+
+ def test_set_boundary(self):
+ eq = self.assertEqual
+ # This one has no existing boundary parameter, but the Content-Type:
+ # header appears fifth.
+ msg = self._msgobj('msg_01.txt')
+ msg.set_boundary('BOUNDARY')
+ header, value = msg.items()[4]
+ eq(header.lower(), 'content-type')
+ eq(value, 'text/plain; charset="us-ascii"; boundary="BOUNDARY"')
+ # This one has a Content-Type: header, with a boundary, stuck in the
+ # middle of its headers. Make sure the order is preserved; it should
+ # be fifth.
+ msg = self._msgobj('msg_04.txt')
+ msg.set_boundary('BOUNDARY')
+ header, value = msg.items()[4]
+ eq(header.lower(), 'content-type')
+ eq(value, 'multipart/mixed; boundary="BOUNDARY"')
+ # And this one has no Content-Type: header at all.
+ msg = self._msgobj('msg_03.txt')
+ self.assertRaises(errors.HeaderParseError,
+ msg.set_boundary, 'BOUNDARY')
+
+ def test_get_decoded_payload(self):
+ eq = self.assertEqual
+ msg = self._msgobj('msg_10.txt')
+ # The outer message is a multipart
+ eq(msg.get_payload(decode=True), None)
+ # Subpart 1 is 7bit encoded
+ eq(msg.get_payload(0).get_payload(decode=True),
+ 'This is a 7bit encoded message.\n')
+ # Subpart 2 is quopri
+ eq(msg.get_payload(1).get_payload(decode=True),
+ '\xa1This is a Quoted Printable encoded message!\n')
+ # Subpart 3 is base64
+ eq(msg.get_payload(2).get_payload(decode=True),
+ 'This is a Base64 encoded message.')
+ # Subpart 4 is base64 with a trailing newline, which
+ # used to be stripped (issue 7143).
+ eq(msg.get_payload(3).get_payload(decode=True),
+ 'This is a Base64 encoded message.\n')
+ # Subpart 5 has no Content-Transfer-Encoding: header.
+ eq(msg.get_payload(4).get_payload(decode=True),
+ 'This has no Content-Transfer-Encoding: header.\n')
+
+ def test_get_decoded_uu_payload(self):
+ eq = self.assertEqual
+ msg = Message()
+ msg.set_payload('begin 666 -\n+:&5L;&\\@=V]R;&0 \n \nend\n')
+ for cte in ('x-uuencode', 'uuencode', 'uue', 'x-uue'):
+ msg['content-transfer-encoding'] = cte
+ eq(msg.get_payload(decode=True), 'hello world')
+ # Now try some bogus data
+ msg.set_payload('foo')
+ eq(msg.get_payload(decode=True), 'foo')
+
+ def test_decoded_generator(self):
+ eq = self.assertEqual
+ msg = self._msgobj('msg_07.txt')
+ fp = openfile('msg_17.txt')
+ try:
+ text = fp.read()
+ finally:
+ fp.close()
+ s = StringIO()
+ g = DecodedGenerator(s)
+ g.flatten(msg)
+ eq(s.getvalue(), text)
+
+ def test__contains__(self):
+ msg = Message()
+ msg['From'] = 'Me'
+ msg['to'] = 'You'
+ # Check for case insensitivity
+ self.assertTrue('from' in msg)
+ self.assertTrue('From' in msg)
+ self.assertTrue('FROM' in msg)
+ self.assertTrue('to' in msg)
+ self.assertTrue('To' in msg)
+ self.assertTrue('TO' in msg)
+
+ def test_as_string(self):
+ eq = self.assertEqual
+ msg = self._msgobj('msg_01.txt')
+ fp = openfile('msg_01.txt')
+ try:
+ # BAW 30-Mar-2009 Evil be here. So, the generator is broken with
+ # respect to long line breaking. It's also not idempotent when a
+ # header from a parsed message is continued with tabs rather than
+ # spaces. Before we fixed bug 1974 it was reversedly broken,
+ # i.e. headers that were continued with spaces got continued with
+ # tabs. For Python 2.x there's really no good fix and in Python
+ # 3.x all this stuff is re-written to be right(er). Chris Withers
+ # convinced me that using space as the default continuation
+ # character is less bad for more applications.
+ text = fp.read().replace('\t', ' ')
+ finally:
+ fp.close()
+ self.ndiffAssertEqual(text, msg.as_string())
+ fullrepr = str(msg)
+ lines = fullrepr.split('\n')
+ self.assertTrue(lines[0].startswith('From '))
+ eq(text, NL.join(lines[1:]))
+
+ def test_bad_param(self):
+ msg = email.message_from_string("Content-Type: blarg; baz; boo\n")
+ self.assertEqual(msg.get_param('baz'), '')
+
+ def test_missing_filename(self):
+ msg = email.message_from_string("From: foo\n")
+ self.assertEqual(msg.get_filename(), None)
+
+ def test_bogus_filename(self):
+ msg = email.message_from_string(
+ "Content-Disposition: blarg; filename\n")
+ self.assertEqual(msg.get_filename(), '')
+
+ def test_missing_boundary(self):
+ msg = email.message_from_string("From: foo\n")
+ self.assertEqual(msg.get_boundary(), None)
+
+ def test_get_params(self):
+ eq = self.assertEqual
+ msg = email.message_from_string(
+ 'X-Header: foo=one; bar=two; baz=three\n')
+ eq(msg.get_params(header='x-header'),
+ [('foo', 'one'), ('bar', 'two'), ('baz', 'three')])
+ msg = email.message_from_string(
+ 'X-Header: foo; bar=one; baz=two\n')
+ eq(msg.get_params(header='x-header'),
+ [('foo', ''), ('bar', 'one'), ('baz', 'two')])
+ eq(msg.get_params(), None)
+ msg = email.message_from_string(
+ 'X-Header: foo; bar="one"; baz=two\n')
+ eq(msg.get_params(header='x-header'),
+ [('foo', ''), ('bar', 'one'), ('baz', 'two')])
+
+ def test_get_param_liberal(self):
+ msg = Message()
+ msg['Content-Type'] = 'Content-Type: Multipart/mixed; boundary = "CPIMSSMTPC06p5f3tG"'
+ self.assertEqual(msg.get_param('boundary'), 'CPIMSSMTPC06p5f3tG')
+
+ def test_get_param(self):
+ eq = self.assertEqual
+ msg = email.message_from_string(
+ "X-Header: foo=one; bar=two; baz=three\n")
+ eq(msg.get_param('bar', header='x-header'), 'two')
+ eq(msg.get_param('quuz', header='x-header'), None)
+ eq(msg.get_param('quuz'), None)
+ msg = email.message_from_string(
+ 'X-Header: foo; bar="one"; baz=two\n')
+ eq(msg.get_param('foo', header='x-header'), '')
+ eq(msg.get_param('bar', header='x-header'), 'one')
+ eq(msg.get_param('baz', header='x-header'), 'two')
+ # XXX: We are not RFC-2045 compliant! We cannot parse:
+ # msg["Content-Type"] = 'text/plain; weird="hey; dolly? [you] @ <\\"home\\">?"'
+ # msg.get_param("weird")
+ # yet.
+
+ def test_get_param_funky_continuation_lines(self):
+ msg = self._msgobj('msg_22.txt')
+ self.assertEqual(msg.get_payload(1).get_param('name'), 'wibble.JPG')
+
+ def test_get_param_with_semis_in_quotes(self):
+ msg = email.message_from_string(
+ 'Content-Type: image/pjpeg; name="Jim&&Jill"\n')
+ self.assertEqual(msg.get_param('name'), 'Jim&&Jill')
+ self.assertEqual(msg.get_param('name', unquote=False),
+ '"Jim&&Jill"')
+
+ def test_has_key(self):
+ msg = email.message_from_string('Header: exists')
+ self.assertTrue(msg.has_key('header'))
+ self.assertTrue(msg.has_key('Header'))
+ self.assertTrue(msg.has_key('HEADER'))
+ self.assertFalse(msg.has_key('headeri'))
+
+ def test_set_param(self):
+ eq = self.assertEqual
+ msg = Message()
+ msg.set_param('charset', 'iso-2022-jp')
+ eq(msg.get_param('charset'), 'iso-2022-jp')
+ msg.set_param('importance', 'high value')
+ eq(msg.get_param('importance'), 'high value')
+ eq(msg.get_param('importance', unquote=False), '"high value"')
+ eq(msg.get_params(), [('text/plain', ''),
+ ('charset', 'iso-2022-jp'),
+ ('importance', 'high value')])
+ eq(msg.get_params(unquote=False), [('text/plain', ''),
+ ('charset', '"iso-2022-jp"'),
+ ('importance', '"high value"')])
+ msg.set_param('charset', 'iso-9999-xx', header='X-Jimmy')
+ eq(msg.get_param('charset', header='X-Jimmy'), 'iso-9999-xx')
+
+ def test_del_param(self):
+ eq = self.assertEqual
+ msg = self._msgobj('msg_05.txt')
+ eq(msg.get_params(),
+ [('multipart/report', ''), ('report-type', 'delivery-status'),
+ ('boundary', 'D1690A7AC1.996856090/mail.example.com')])
+ old_val = msg.get_param("report-type")
+ msg.del_param("report-type")
+ eq(msg.get_params(),
+ [('multipart/report', ''),
+ ('boundary', 'D1690A7AC1.996856090/mail.example.com')])
+ msg.set_param("report-type", old_val)
+ eq(msg.get_params(),
+ [('multipart/report', ''),
+ ('boundary', 'D1690A7AC1.996856090/mail.example.com'),
+ ('report-type', old_val)])
+
+ def test_del_param_on_other_header(self):
+ msg = Message()
+ msg.add_header('Content-Disposition', 'attachment', filename='bud.gif')
+ msg.del_param('filename', 'content-disposition')
+ self.assertEqual(msg['content-disposition'], 'attachment')
+
+ def test_set_type(self):
+ eq = self.assertEqual
+ msg = Message()
+ self.assertRaises(ValueError, msg.set_type, 'text')
+ msg.set_type('text/plain')
+ eq(msg['content-type'], 'text/plain')
+ msg.set_param('charset', 'us-ascii')
+ eq(msg['content-type'], 'text/plain; charset="us-ascii"')
+ msg.set_type('text/html')
+ eq(msg['content-type'], 'text/html; charset="us-ascii"')
+
+ def test_set_type_on_other_header(self):
+ msg = Message()
+ msg['X-Content-Type'] = 'text/plain'
+ msg.set_type('application/octet-stream', 'X-Content-Type')
+ self.assertEqual(msg['x-content-type'], 'application/octet-stream')
+
+ def test_get_content_type_missing(self):
+ msg = Message()
+ self.assertEqual(msg.get_content_type(), 'text/plain')
+
+ def test_get_content_type_missing_with_default_type(self):
+ msg = Message()
+ msg.set_default_type('message/rfc822')
+ self.assertEqual(msg.get_content_type(), 'message/rfc822')
+
+ def test_get_content_type_from_message_implicit(self):
+ msg = self._msgobj('msg_30.txt')
+ self.assertEqual(msg.get_payload(0).get_content_type(),
+ 'message/rfc822')
+
+ def test_get_content_type_from_message_explicit(self):
+ msg = self._msgobj('msg_28.txt')
+ self.assertEqual(msg.get_payload(0).get_content_type(),
+ 'message/rfc822')
+
+ def test_get_content_type_from_message_text_plain_implicit(self):
+ msg = self._msgobj('msg_03.txt')
+ self.assertEqual(msg.get_content_type(), 'text/plain')
+
+ def test_get_content_type_from_message_text_plain_explicit(self):
+ msg = self._msgobj('msg_01.txt')
+ self.assertEqual(msg.get_content_type(), 'text/plain')
+
+ def test_get_content_maintype_missing(self):
+ msg = Message()
+ self.assertEqual(msg.get_content_maintype(), 'text')
+
+ def test_get_content_maintype_missing_with_default_type(self):
+ msg = Message()
+ msg.set_default_type('message/rfc822')
+ self.assertEqual(msg.get_content_maintype(), 'message')
+
+ def test_get_content_maintype_from_message_implicit(self):
+ msg = self._msgobj('msg_30.txt')
+ self.assertEqual(msg.get_payload(0).get_content_maintype(), 'message')
+
+ def test_get_content_maintype_from_message_explicit(self):
+ msg = self._msgobj('msg_28.txt')
+ self.assertEqual(msg.get_payload(0).get_content_maintype(), 'message')
+
+ def test_get_content_maintype_from_message_text_plain_implicit(self):
+ msg = self._msgobj('msg_03.txt')
+ self.assertEqual(msg.get_content_maintype(), 'text')
+
+ def test_get_content_maintype_from_message_text_plain_explicit(self):
+ msg = self._msgobj('msg_01.txt')
+ self.assertEqual(msg.get_content_maintype(), 'text')
+
+ def test_get_content_subtype_missing(self):
+ msg = Message()
+ self.assertEqual(msg.get_content_subtype(), 'plain')
+
+ def test_get_content_subtype_missing_with_default_type(self):
+ msg = Message()
+ msg.set_default_type('message/rfc822')
+ self.assertEqual(msg.get_content_subtype(), 'rfc822')
+
+ def test_get_content_subtype_from_message_implicit(self):
+ msg = self._msgobj('msg_30.txt')
+ self.assertEqual(msg.get_payload(0).get_content_subtype(), 'rfc822')
+
+ def test_get_content_subtype_from_message_explicit(self):
+ msg = self._msgobj('msg_28.txt')
+ self.assertEqual(msg.get_payload(0).get_content_subtype(), 'rfc822')
+
+ def test_get_content_subtype_from_message_text_plain_implicit(self):
+ msg = self._msgobj('msg_03.txt')
+ self.assertEqual(msg.get_content_subtype(), 'plain')
+
+ def test_get_content_subtype_from_message_text_plain_explicit(self):
+ msg = self._msgobj('msg_01.txt')
+ self.assertEqual(msg.get_content_subtype(), 'plain')
+
+ def test_get_content_maintype_error(self):
+ msg = Message()
+ msg['Content-Type'] = 'no-slash-in-this-string'
+ self.assertEqual(msg.get_content_maintype(), 'text')
+
+ def test_get_content_subtype_error(self):
+ msg = Message()
+ msg['Content-Type'] = 'no-slash-in-this-string'
+ self.assertEqual(msg.get_content_subtype(), 'plain')
+
+ def test_replace_header(self):
+ eq = self.assertEqual
+ msg = Message()
+ msg.add_header('First', 'One')
+ msg.add_header('Second', 'Two')
+ msg.add_header('Third', 'Three')
+ eq(msg.keys(), ['First', 'Second', 'Third'])
+ eq(msg.values(), ['One', 'Two', 'Three'])
+ msg.replace_header('Second', 'Twenty')
+ eq(msg.keys(), ['First', 'Second', 'Third'])
+ eq(msg.values(), ['One', 'Twenty', 'Three'])
+ msg.add_header('First', 'Eleven')
+ msg.replace_header('First', 'One Hundred')
+ eq(msg.keys(), ['First', 'Second', 'Third', 'First'])
+ eq(msg.values(), ['One Hundred', 'Twenty', 'Three', 'Eleven'])
+ self.assertRaises(KeyError, msg.replace_header, 'Fourth', 'Missing')
+
+ def test_broken_base64_payload(self):
+ x = 'AwDp0P7//y6LwKEAcPa/6Q=9'
+ msg = Message()
+ msg['content-type'] = 'audio/x-midi'
+ msg['content-transfer-encoding'] = 'base64'
+ msg.set_payload(x)
+ self.assertEqual(msg.get_payload(decode=True), x)
+
+
+
+# Test the email.encoders module
+class TestEncoders(unittest.TestCase):
+ def test_encode_empty_payload(self):
+ eq = self.assertEqual
+ msg = Message()
+ msg.set_charset('us-ascii')
+ eq(msg['content-transfer-encoding'], '7bit')
+
+ def test_default_cte(self):
+ eq = self.assertEqual
+ msg = MIMEText('hello world')
+ eq(msg['content-transfer-encoding'], '7bit')
+
+ def test_default_cte(self):
+ eq = self.assertEqual
+ # With no explicit _charset its us-ascii, and all are 7-bit
+ msg = MIMEText('hello world')
+ eq(msg['content-transfer-encoding'], '7bit')
+ # Similar, but with 8-bit data
+ msg = MIMEText('hello \xf8 world')
+ eq(msg['content-transfer-encoding'], '8bit')
+ # And now with a different charset
+ msg = MIMEText('hello \xf8 world', _charset='iso-8859-1')
+ eq(msg['content-transfer-encoding'], 'quoted-printable')
+
+
+
+# Test long header wrapping
+class TestLongHeaders(TestEmailBase):
+ def test_split_long_continuation(self):
+ eq = self.ndiffAssertEqual
+ msg = email.message_from_string("""\
+Subject: bug demonstration
+\t12345678911234567892123456789312345678941234567895123456789612345678971234567898112345678911234567892123456789112345678911234567892123456789
+\tmore text
+
+test
+""")
+ sfp = StringIO()
+ g = Generator(sfp)
+ g.flatten(msg)
+ eq(sfp.getvalue(), """\
+Subject: bug demonstration
+ 12345678911234567892123456789312345678941234567895123456789612345678971234567898112345678911234567892123456789112345678911234567892123456789
+ more text
+
+test
+""")
+
+ def test_another_long_almost_unsplittable_header(self):
+ eq = self.ndiffAssertEqual
+ hstr = """\
+bug demonstration
+\t12345678911234567892123456789312345678941234567895123456789612345678971234567898112345678911234567892123456789112345678911234567892123456789
+\tmore text"""
+ h = Header(hstr, continuation_ws='\t')
+ eq(h.encode(), """\
+bug demonstration
+\t12345678911234567892123456789312345678941234567895123456789612345678971234567898112345678911234567892123456789112345678911234567892123456789
+\tmore text""")
+ h = Header(hstr)
+ eq(h.encode(), """\
+bug demonstration
+ 12345678911234567892123456789312345678941234567895123456789612345678971234567898112345678911234567892123456789112345678911234567892123456789
+ more text""")
+
+ def test_long_nonstring(self):
+ eq = self.ndiffAssertEqual
+ g = Charset("iso-8859-1")
+ cz = Charset("iso-8859-2")
+ utf8 = Charset("utf-8")
+ g_head = "Die Mieter treten hier ein werden mit einem Foerderband komfortabel den Korridor entlang, an s\xfcdl\xfcndischen Wandgem\xe4lden vorbei, gegen die rotierenden Klingen bef\xf6rdert. "
+ cz_head = "Finan\xe8ni metropole se hroutily pod tlakem jejich d\xf9vtipu.. "
+ utf8_head = u"\u6b63\u78ba\u306b\u8a00\u3046\u3068\u7ffb\u8a33\u306f\u3055\u308c\u3066\u3044\u307e\u305b\u3093\u3002\u4e00\u90e8\u306f\u30c9\u30a4\u30c4\u8a9e\u3067\u3059\u304c\u3001\u3042\u3068\u306f\u3067\u305f\u3089\u3081\u3067\u3059\u3002\u5b9f\u969b\u306b\u306f\u300cWenn ist das Nunstuck git und Slotermeyer? Ja! Beiherhund das Oder die Flipperwaldt gersput.\u300d\u3068\u8a00\u3063\u3066\u3044\u307e\u3059\u3002".encode("utf-8")
+ h = Header(g_head, g, header_name='Subject')
+ h.append(cz_head, cz)
+ h.append(utf8_head, utf8)
+ msg = Message()
+ msg['Subject'] = h
+ sfp = StringIO()
+ g = Generator(sfp)
+ g.flatten(msg)
+ eq(sfp.getvalue(), """\
+Subject: =?iso-8859-1?q?Die_Mieter_treten_hier_ein_werden_mit_einem_Foerd?=
+ =?iso-8859-1?q?erband_komfortabel_den_Korridor_entlang=2C_an_s=FCdl=FCndi?=
+ =?iso-8859-1?q?schen_Wandgem=E4lden_vorbei=2C_gegen_die_rotierenden_Kling?=
+ =?iso-8859-1?q?en_bef=F6rdert=2E_?= =?iso-8859-2?q?Finan=E8ni_met?=
+ =?iso-8859-2?q?ropole_se_hroutily_pod_tlakem_jejich_d=F9vtipu=2E=2E_?=
+ =?utf-8?b?5q2j56K644Gr6KiA44GG44Go57+76Kiz44Gv44GV44KM44Gm44GE?=
+ =?utf-8?b?44G+44Gb44KT44CC5LiA6YOo44Gv44OJ44Kk44OE6Kqe44Gn44GZ44GM44CB?=
+ =?utf-8?b?44GC44Go44Gv44Gn44Gf44KJ44KB44Gn44GZ44CC5a6f6Zqb44Gr44Gv44CM?=
+ =?utf-8?q?Wenn_ist_das_Nunstuck_git_und_Slotermeyer=3F_Ja!_Beiherhund_das?=
+ =?utf-8?b?IE9kZXIgZGllIEZsaXBwZXJ3YWxkdCBnZXJzcHV0LuOAjeOBqOiogOOBow==?=
+ =?utf-8?b?44Gm44GE44G+44GZ44CC?=
+
+""")
+ eq(h.encode(), """\
+=?iso-8859-1?q?Die_Mieter_treten_hier_ein_werden_mit_einem_Foerd?=
+ =?iso-8859-1?q?erband_komfortabel_den_Korridor_entlang=2C_an_s=FCdl=FCndi?=
+ =?iso-8859-1?q?schen_Wandgem=E4lden_vorbei=2C_gegen_die_rotierenden_Kling?=
+ =?iso-8859-1?q?en_bef=F6rdert=2E_?= =?iso-8859-2?q?Finan=E8ni_met?=
+ =?iso-8859-2?q?ropole_se_hroutily_pod_tlakem_jejich_d=F9vtipu=2E=2E_?=
+ =?utf-8?b?5q2j56K644Gr6KiA44GG44Go57+76Kiz44Gv44GV44KM44Gm44GE?=
+ =?utf-8?b?44G+44Gb44KT44CC5LiA6YOo44Gv44OJ44Kk44OE6Kqe44Gn44GZ44GM44CB?=
+ =?utf-8?b?44GC44Go44Gv44Gn44Gf44KJ44KB44Gn44GZ44CC5a6f6Zqb44Gr44Gv44CM?=
+ =?utf-8?q?Wenn_ist_das_Nunstuck_git_und_Slotermeyer=3F_Ja!_Beiherhund_das?=
+ =?utf-8?b?IE9kZXIgZGllIEZsaXBwZXJ3YWxkdCBnZXJzcHV0LuOAjeOBqOiogOOBow==?=
+ =?utf-8?b?44Gm44GE44G+44GZ44CC?=""")
+
+ def test_long_header_encode(self):
+ eq = self.ndiffAssertEqual
+ h = Header('wasnipoop; giraffes="very-long-necked-animals"; '
+ 'spooge="yummy"; hippos="gargantuan"; marshmallows="gooey"',
+ header_name='X-Foobar-Spoink-Defrobnit')
+ eq(h.encode(), '''\
+wasnipoop; giraffes="very-long-necked-animals";
+ spooge="yummy"; hippos="gargantuan"; marshmallows="gooey"''')
+
+ def test_long_header_encode_with_tab_continuation(self):
+ eq = self.ndiffAssertEqual
+ h = Header('wasnipoop; giraffes="very-long-necked-animals"; '
+ 'spooge="yummy"; hippos="gargantuan"; marshmallows="gooey"',
+ header_name='X-Foobar-Spoink-Defrobnit',
+ continuation_ws='\t')
+ eq(h.encode(), '''\
+wasnipoop; giraffes="very-long-necked-animals";
+\tspooge="yummy"; hippos="gargantuan"; marshmallows="gooey"''')
+
+ def test_header_splitter(self):
+ eq = self.ndiffAssertEqual
+ msg = MIMEText('')
+ # It'd be great if we could use add_header() here, but that doesn't
+ # guarantee an order of the parameters.
+ msg['X-Foobar-Spoink-Defrobnit'] = (
+ 'wasnipoop; giraffes="very-long-necked-animals"; '
+ 'spooge="yummy"; hippos="gargantuan"; marshmallows="gooey"')
+ sfp = StringIO()
+ g = Generator(sfp)
+ g.flatten(msg)
+ eq(sfp.getvalue(), '''\
+Content-Type: text/plain; charset="us-ascii"
+MIME-Version: 1.0
+Content-Transfer-Encoding: 7bit
+X-Foobar-Spoink-Defrobnit: wasnipoop; giraffes="very-long-necked-animals";
+ spooge="yummy"; hippos="gargantuan"; marshmallows="gooey"
+
+''')
+
+ def test_no_semis_header_splitter(self):
+ eq = self.ndiffAssertEqual
+ msg = Message()
+ msg['From'] = 'test at dom.ain'
+ msg['References'] = SPACE.join(['<%d at dom.ain>' % i for i in range(10)])
+ msg.set_payload('Test')
+ sfp = StringIO()
+ g = Generator(sfp)
+ g.flatten(msg)
+ eq(sfp.getvalue(), """\
+From: test at dom.ain
+References: <0 at dom.ain> <1 at dom.ain> <2 at dom.ain> <3 at dom.ain> <4 at dom.ain>
+ <5 at dom.ain> <6 at dom.ain> <7 at dom.ain> <8 at dom.ain> <9 at dom.ain>
+
+Test""")
+
+ def test_no_split_long_header(self):
+ eq = self.ndiffAssertEqual
+ hstr = 'References: ' + 'x' * 80
+ h = Header(hstr, continuation_ws='\t')
+ eq(h.encode(), """\
+References: xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx""")
+
+ def test_splitting_multiple_long_lines(self):
+ eq = self.ndiffAssertEqual
+ hstr = """\
+from babylon.socal-raves.org (localhost [127.0.0.1]); by babylon.socal-raves.org (Postfix) with ESMTP id B570E51B81; for <mailman-admin at babylon.socal-raves.org>; Sat, 2 Feb 2002 17:00:06 -0800 (PST)
+\tfrom babylon.socal-raves.org (localhost [127.0.0.1]); by babylon.socal-raves.org (Postfix) with ESMTP id B570E51B81; for <mailman-admin at babylon.socal-raves.org>; Sat, 2 Feb 2002 17:00:06 -0800 (PST)
+\tfrom babylon.socal-raves.org (localhost [127.0.0.1]); by babylon.socal-raves.org (Postfix) with ESMTP id B570E51B81; for <mailman-admin at babylon.socal-raves.org>; Sat, 2 Feb 2002 17:00:06 -0800 (PST)
+"""
+ h = Header(hstr, continuation_ws='\t')
+ eq(h.encode(), """\
+from babylon.socal-raves.org (localhost [127.0.0.1]);
+\tby babylon.socal-raves.org (Postfix) with ESMTP id B570E51B81;
+\tfor <mailman-admin at babylon.socal-raves.org>;
+\tSat, 2 Feb 2002 17:00:06 -0800 (PST)
+\tfrom babylon.socal-raves.org (localhost [127.0.0.1]);
+\tby babylon.socal-raves.org (Postfix) with ESMTP id B570E51B81;
+\tfor <mailman-admin at babylon.socal-raves.org>;
+\tSat, 2 Feb 2002 17:00:06 -0800 (PST)
+\tfrom babylon.socal-raves.org (localhost [127.0.0.1]);
+\tby babylon.socal-raves.org (Postfix) with ESMTP id B570E51B81;
+\tfor <mailman-admin at babylon.socal-raves.org>;
+\tSat, 2 Feb 2002 17:00:06 -0800 (PST)""")
+
+ def test_splitting_first_line_only_is_long(self):
+ eq = self.ndiffAssertEqual
+ hstr = """\
+from modemcable093.139-201-24.que.mc.videotron.ca ([24.201.139.93] helo=cthulhu.gerg.ca)
+\tby kronos.mems-exchange.org with esmtp (Exim 4.05)
+\tid 17k4h5-00034i-00
+\tfor test at mems-exchange.org; Wed, 28 Aug 2002 11:25:20 -0400"""
+ h = Header(hstr, maxlinelen=78, header_name='Received',
+ continuation_ws='\t')
+ eq(h.encode(), """\
+from modemcable093.139-201-24.que.mc.videotron.ca ([24.201.139.93]
+\thelo=cthulhu.gerg.ca)
+\tby kronos.mems-exchange.org with esmtp (Exim 4.05)
+\tid 17k4h5-00034i-00
+\tfor test at mems-exchange.org; Wed, 28 Aug 2002 11:25:20 -0400""")
+
+ def test_long_8bit_header(self):
+ eq = self.ndiffAssertEqual
+ msg = Message()
+ h = Header('Britische Regierung gibt', 'iso-8859-1',
+ header_name='Subject')
+ h.append('gr\xfcnes Licht f\xfcr Offshore-Windkraftprojekte')
+ msg['Subject'] = h
+ eq(msg.as_string(), """\
+Subject: =?iso-8859-1?q?Britische_Regierung_gibt?= =?iso-8859-1?q?gr=FCnes?=
+ =?iso-8859-1?q?_Licht_f=FCr_Offshore-Windkraftprojekte?=
+
+""")
+
+ def test_long_8bit_header_no_charset(self):
+ eq = self.ndiffAssertEqual
+ msg = Message()
+ msg['Reply-To'] = 'Britische Regierung gibt gr\xfcnes Licht f\xfcr Offshore-Windkraftprojekte <a-very-long-address at example.com>'
+ eq(msg.as_string(), """\
+Reply-To: Britische Regierung gibt gr\xfcnes Licht f\xfcr Offshore-Windkraftprojekte <a-very-long-address at example.com>
+
+""")
+
+ def test_long_to_header(self):
+ eq = self.ndiffAssertEqual
+ to = '"Someone Test #A" <someone at eecs.umich.edu>,<someone at eecs.umich.edu>,"Someone Test #B" <someone at umich.edu>, "Someone Test #C" <someone at eecs.umich.edu>, "Someone Test #D" <someone at eecs.umich.edu>'
+ msg = Message()
+ msg['To'] = to
+ eq(msg.as_string(0), '''\
+To: "Someone Test #A" <someone at eecs.umich.edu>, <someone at eecs.umich.edu>,
+ "Someone Test #B" <someone at umich.edu>,
+ "Someone Test #C" <someone at eecs.umich.edu>,
+ "Someone Test #D" <someone at eecs.umich.edu>
+
+''')
+
+ def test_long_line_after_append(self):
+ eq = self.ndiffAssertEqual
+ s = 'This is an example of string which has almost the limit of header length.'
+ h = Header(s)
+ h.append('Add another line.')
+ eq(h.encode(), """\
+This is an example of string which has almost the limit of header length.
+ Add another line.""")
+
+ def test_shorter_line_with_append(self):
+ eq = self.ndiffAssertEqual
+ s = 'This is a shorter line.'
+ h = Header(s)
+ h.append('Add another sentence. (Surprise?)')
+ eq(h.encode(),
+ 'This is a shorter line. Add another sentence. (Surprise?)')
+
+ def test_long_field_name(self):
+ eq = self.ndiffAssertEqual
+ fn = 'X-Very-Very-Very-Long-Header-Name'
+ gs = "Die Mieter treten hier ein werden mit einem Foerderband komfortabel den Korridor entlang, an s\xfcdl\xfcndischen Wandgem\xe4lden vorbei, gegen die rotierenden Klingen bef\xf6rdert. "
+ h = Header(gs, 'iso-8859-1', header_name=fn)
+ # BAW: this seems broken because the first line is too long
+ eq(h.encode(), """\
+=?iso-8859-1?q?Die_Mieter_treten_hier_?=
+ =?iso-8859-1?q?ein_werden_mit_einem_Foerderband_komfortabel_den_Korridor_?=
+ =?iso-8859-1?q?entlang=2C_an_s=FCdl=FCndischen_Wandgem=E4lden_vorbei=2C_g?=
+ =?iso-8859-1?q?egen_die_rotierenden_Klingen_bef=F6rdert=2E_?=""")
+
+ def test_long_received_header(self):
+ h = 'from FOO.TLD (vizworld.acl.foo.tld [123.452.678.9]) by hrothgar.la.mastaler.com (tmda-ofmipd) with ESMTP; Wed, 05 Mar 2003 18:10:18 -0700'
+ msg = Message()
+ msg['Received-1'] = Header(h, continuation_ws='\t')
+ msg['Received-2'] = h
+ self.ndiffAssertEqual(msg.as_string(), """\
+Received-1: from FOO.TLD (vizworld.acl.foo.tld [123.452.678.9]) by
+\throthgar.la.mastaler.com (tmda-ofmipd) with ESMTP;
+\tWed, 05 Mar 2003 18:10:18 -0700
+Received-2: from FOO.TLD (vizworld.acl.foo.tld [123.452.678.9]) by
+ hrothgar.la.mastaler.com (tmda-ofmipd) with ESMTP;
+ Wed, 05 Mar 2003 18:10:18 -0700
+
+""")
+
+ def test_string_headerinst_eq(self):
+ h = '<15975.17901.207240.414604 at sgigritzmann1.mathematik.tu-muenchen.de> (David Bremner\'s message of "Thu, 6 Mar 2003 13:58:21 +0100")'
+ msg = Message()
+ msg['Received'] = Header(h, header_name='Received-1',
+ continuation_ws='\t')
+ msg['Received'] = h
+ self.ndiffAssertEqual(msg.as_string(), """\
+Received: <15975.17901.207240.414604 at sgigritzmann1.mathematik.tu-muenchen.de>
+\t(David Bremner's message of "Thu, 6 Mar 2003 13:58:21 +0100")
+Received: <15975.17901.207240.414604 at sgigritzmann1.mathematik.tu-muenchen.de>
+ (David Bremner's message of "Thu, 6 Mar 2003 13:58:21 +0100")
+
+""")
+
+ def test_long_unbreakable_lines_with_continuation(self):
+ eq = self.ndiffAssertEqual
+ msg = Message()
+ t = """\
+ iVBORw0KGgoAAAANSUhEUgAAADAAAAAwBAMAAAClLOS0AAAAGFBMVEUAAAAkHiJeRUIcGBi9
+ locQDQ4zJykFBAXJfWDjAAACYUlEQVR4nF2TQY/jIAyFc6lydlG5x8Nyp1Y69wj1PN2I5gzp"""
+ msg['Face-1'] = t
+ msg['Face-2'] = Header(t, header_name='Face-2')
+ eq(msg.as_string(), """\
+Face-1: iVBORw0KGgoAAAANSUhEUgAAADAAAAAwBAMAAAClLOS0AAAAGFBMVEUAAAAkHiJeRUIcGBi9
+ locQDQ4zJykFBAXJfWDjAAACYUlEQVR4nF2TQY/jIAyFc6lydlG5x8Nyp1Y69wj1PN2I5gzp
+Face-2: iVBORw0KGgoAAAANSUhEUgAAADAAAAAwBAMAAAClLOS0AAAAGFBMVEUAAAAkHiJeRUIcGBi9
+ locQDQ4zJykFBAXJfWDjAAACYUlEQVR4nF2TQY/jIAyFc6lydlG5x8Nyp1Y69wj1PN2I5gzp
+
+""")
+
+ def test_another_long_multiline_header(self):
+ eq = self.ndiffAssertEqual
+ m = '''\
+Received: from siimage.com ([172.25.1.3]) by zima.siliconimage.com with Microsoft SMTPSVC(5.0.2195.4905);
+ Wed, 16 Oct 2002 07:41:11 -0700'''
+ msg = email.message_from_string(m)
+ eq(msg.as_string(), '''\
+Received: from siimage.com ([172.25.1.3]) by zima.siliconimage.com with
+ Microsoft SMTPSVC(5.0.2195.4905); Wed, 16 Oct 2002 07:41:11 -0700
+
+''')
+
+ def test_long_lines_with_different_header(self):
+ eq = self.ndiffAssertEqual
+ h = """\
+List-Unsubscribe: <https://lists.sourceforge.net/lists/listinfo/spamassassin-talk>,
+ <mailto:spamassassin-talk-request at lists.sourceforge.net?subject=unsubscribe>"""
+ msg = Message()
+ msg['List'] = h
+ msg['List'] = Header(h, header_name='List')
+ self.ndiffAssertEqual(msg.as_string(), """\
+List: List-Unsubscribe: <https://lists.sourceforge.net/lists/listinfo/spamassassin-talk>,
+ <mailto:spamassassin-talk-request at lists.sourceforge.net?subject=unsubscribe>
+List: List-Unsubscribe: <https://lists.sourceforge.net/lists/listinfo/spamassassin-talk>,
+ <mailto:spamassassin-talk-request at lists.sourceforge.net?subject=unsubscribe>
+
+""")
+
+
+
+# Test mangling of "From " lines in the body of a message
+class TestFromMangling(unittest.TestCase):
+ def setUp(self):
+ self.msg = Message()
+ self.msg['From'] = 'aaa at bbb.org'
+ self.msg.set_payload("""\
+From the desk of A.A.A.:
+Blah blah blah
+""")
+
+ def test_mangled_from(self):
+ s = StringIO()
+ g = Generator(s, mangle_from_=True)
+ g.flatten(self.msg)
+ self.assertEqual(s.getvalue(), """\
+From: aaa at bbb.org
+
+>From the desk of A.A.A.:
+Blah blah blah
+""")
+
+ def test_dont_mangle_from(self):
+ s = StringIO()
+ g = Generator(s, mangle_from_=False)
+ g.flatten(self.msg)
+ self.assertEqual(s.getvalue(), """\
+From: aaa at bbb.org
+
+From the desk of A.A.A.:
+Blah blah blah
+""")
+
+
+
+# Test the basic MIMEAudio class
+class TestMIMEAudio(unittest.TestCase):
+ def setUp(self):
+ # Make sure we pick up the audiotest.au that lives in email/test/data.
+ # In Python, there's an audiotest.au living in Lib/test but that isn't
+ # included in some binary distros that don't include the test
+ # package. The trailing empty string on the .join() is significant
+ # since findfile() will do a dirname().
+ datadir = os.path.join(os.path.dirname(landmark), 'data', '')
+ fp = open(findfile('audiotest.au', datadir), 'rb')
+ try:
+ self._audiodata = fp.read()
+ finally:
+ fp.close()
+ self._au = MIMEAudio(self._audiodata)
+
+ def test_guess_minor_type(self):
+ self.assertEqual(self._au.get_content_type(), 'audio/basic')
+
+ def test_encoding(self):
+ payload = self._au.get_payload()
+ self.assertEqual(base64.decodestring(payload), self._audiodata)
+
+ def test_checkSetMinor(self):
+ au = MIMEAudio(self._audiodata, 'fish')
+ self.assertEqual(au.get_content_type(), 'audio/fish')
+
+ def test_add_header(self):
+ eq = self.assertEqual
+ unless = self.assertTrue
+ self._au.add_header('Content-Disposition', 'attachment',
+ filename='audiotest.au')
+ eq(self._au['content-disposition'],
+ 'attachment; filename="audiotest.au"')
+ eq(self._au.get_params(header='content-disposition'),
+ [('attachment', ''), ('filename', 'audiotest.au')])
+ eq(self._au.get_param('filename', header='content-disposition'),
+ 'audiotest.au')
+ missing = []
+ eq(self._au.get_param('attachment', header='content-disposition'), '')
+ unless(self._au.get_param('foo', failobj=missing,
+ header='content-disposition') is missing)
+ # Try some missing stuff
+ unless(self._au.get_param('foobar', missing) is missing)
+ unless(self._au.get_param('attachment', missing,
+ header='foobar') is missing)
+
+
+
+# Test the basic MIMEImage class
+class TestMIMEImage(unittest.TestCase):
+ def setUp(self):
+ fp = openfile('PyBanner048.gif')
+ try:
+ self._imgdata = fp.read()
+ finally:
+ fp.close()
+ self._im = MIMEImage(self._imgdata)
+
+ def test_guess_minor_type(self):
+ self.assertEqual(self._im.get_content_type(), 'image/gif')
+
+ def test_encoding(self):
+ payload = self._im.get_payload()
+ self.assertEqual(base64.decodestring(payload), self._imgdata)
+
+ def test_checkSetMinor(self):
+ im = MIMEImage(self._imgdata, 'fish')
+ self.assertEqual(im.get_content_type(), 'image/fish')
+
+ def test_add_header(self):
+ eq = self.assertEqual
+ unless = self.assertTrue
+ self._im.add_header('Content-Disposition', 'attachment',
+ filename='dingusfish.gif')
+ eq(self._im['content-disposition'],
+ 'attachment; filename="dingusfish.gif"')
+ eq(self._im.get_params(header='content-disposition'),
+ [('attachment', ''), ('filename', 'dingusfish.gif')])
+ eq(self._im.get_param('filename', header='content-disposition'),
+ 'dingusfish.gif')
+ missing = []
+ eq(self._im.get_param('attachment', header='content-disposition'), '')
+ unless(self._im.get_param('foo', failobj=missing,
+ header='content-disposition') is missing)
+ # Try some missing stuff
+ unless(self._im.get_param('foobar', missing) is missing)
+ unless(self._im.get_param('attachment', missing,
+ header='foobar') is missing)
+
+
+
+# Test the basic MIMEApplication class
+class TestMIMEApplication(unittest.TestCase):
+ def test_headers(self):
+ eq = self.assertEqual
+ msg = MIMEApplication('\xfa\xfb\xfc\xfd\xfe\xff')
+ eq(msg.get_content_type(), 'application/octet-stream')
+ eq(msg['content-transfer-encoding'], 'base64')
+
+ def test_body(self):
+ eq = self.assertEqual
+ bytes = '\xfa\xfb\xfc\xfd\xfe\xff'
+ msg = MIMEApplication(bytes)
+ eq(msg.get_payload(), '+vv8/f7/')
+ eq(msg.get_payload(decode=True), bytes)
+
+ def test_binary_body_with_encode_7or8bit(self):
+ # Issue 17171.
+ bytesdata = b'\xfa\xfb\xfc\xfd\xfe\xff'
+ msg = MIMEApplication(bytesdata, _encoder=encoders.encode_7or8bit)
+ # Treated as a string, this will be invalid code points.
+ self.assertEqual(msg.get_payload(), bytesdata)
+ self.assertEqual(msg.get_payload(decode=True), bytesdata)
+ self.assertEqual(msg['Content-Transfer-Encoding'], '8bit')
+ s = StringIO()
+ g = Generator(s)
+ g.flatten(msg)
+ wireform = s.getvalue()
+ msg2 = email.message_from_string(wireform)
+ self.assertEqual(msg.get_payload(), bytesdata)
+ self.assertEqual(msg2.get_payload(decode=True), bytesdata)
+ self.assertEqual(msg2['Content-Transfer-Encoding'], '8bit')
+
+ def test_binary_body_with_encode_noop(self):
+ # Issue 16564: This does not produce an RFC valid message, since to be
+ # valid it should have a CTE of binary. But the below works, and is
+ # documented as working this way.
+ bytesdata = b'\xfa\xfb\xfc\xfd\xfe\xff'
+ msg = MIMEApplication(bytesdata, _encoder=encoders.encode_noop)
+ self.assertEqual(msg.get_payload(), bytesdata)
+ self.assertEqual(msg.get_payload(decode=True), bytesdata)
+ s = StringIO()
+ g = Generator(s)
+ g.flatten(msg)
+ wireform = s.getvalue()
+ msg2 = email.message_from_string(wireform)
+ self.assertEqual(msg.get_payload(), bytesdata)
+ self.assertEqual(msg2.get_payload(decode=True), bytesdata)
+
+
+# Test the basic MIMEText class
+class TestMIMEText(unittest.TestCase):
+ def setUp(self):
+ self._msg = MIMEText('hello there')
+
+ def test_types(self):
+ eq = self.assertEqual
+ unless = self.assertTrue
+ eq(self._msg.get_content_type(), 'text/plain')
+ eq(self._msg.get_param('charset'), 'us-ascii')
+ missing = []
+ unless(self._msg.get_param('foobar', missing) is missing)
+ unless(self._msg.get_param('charset', missing, header='foobar')
+ is missing)
+
+ def test_payload(self):
+ self.assertEqual(self._msg.get_payload(), 'hello there')
+ self.assertTrue(not self._msg.is_multipart())
+
+ def test_charset(self):
+ eq = self.assertEqual
+ msg = MIMEText('hello there', _charset='us-ascii')
+ eq(msg.get_charset().input_charset, 'us-ascii')
+ eq(msg['content-type'], 'text/plain; charset="us-ascii"')
+
+
+
+# Test complicated multipart/* messages
+class TestMultipart(TestEmailBase):
+ def setUp(self):
+ fp = openfile('PyBanner048.gif')
+ try:
+ data = fp.read()
+ finally:
+ fp.close()
+
+ container = MIMEBase('multipart', 'mixed', boundary='BOUNDARY')
+ image = MIMEImage(data, name='dingusfish.gif')
+ image.add_header('content-disposition', 'attachment',
+ filename='dingusfish.gif')
+ intro = MIMEText('''\
+Hi there,
+
+This is the dingus fish.
+''')
+ container.attach(intro)
+ container.attach(image)
+ container['From'] = 'Barry <barry at digicool.com>'
+ container['To'] = 'Dingus Lovers <cravindogs at cravindogs.com>'
+ container['Subject'] = 'Here is your dingus fish'
+
+ now = 987809702.54848599
+ timetuple = time.localtime(now)
+ if timetuple[-1] == 0:
+ tzsecs = time.timezone
+ else:
+ tzsecs = time.altzone
+ if tzsecs > 0:
+ sign = '-'
+ else:
+ sign = '+'
+ tzoffset = ' %s%04d' % (sign, tzsecs // 36)
+ container['Date'] = time.strftime(
+ '%a, %d %b %Y %H:%M:%S',
+ time.localtime(now)) + tzoffset
+ self._msg = container
+ self._im = image
+ self._txt = intro
+
+ def test_hierarchy(self):
+ # convenience
+ eq = self.assertEqual
+ unless = self.assertTrue
+ raises = self.assertRaises
+ # tests
+ m = self._msg
+ unless(m.is_multipart())
+ eq(m.get_content_type(), 'multipart/mixed')
+ eq(len(m.get_payload()), 2)
+ raises(IndexError, m.get_payload, 2)
+ m0 = m.get_payload(0)
+ m1 = m.get_payload(1)
+ unless(m0 is self._txt)
+ unless(m1 is self._im)
+ eq(m.get_payload(), [m0, m1])
+ unless(not m0.is_multipart())
+ unless(not m1.is_multipart())
+
+ def test_empty_multipart_idempotent(self):
+ text = """\
+Content-Type: multipart/mixed; boundary="BOUNDARY"
+MIME-Version: 1.0
+Subject: A subject
+To: aperson at dom.ain
+From: bperson at dom.ain
+
+
+--BOUNDARY
+
+
+--BOUNDARY--
+"""
+ msg = Parser().parsestr(text)
+ self.ndiffAssertEqual(text, msg.as_string())
+
+ def test_no_parts_in_a_multipart_with_none_epilogue(self):
+ outer = MIMEBase('multipart', 'mixed')
+ outer['Subject'] = 'A subject'
+ outer['To'] = 'aperson at dom.ain'
+ outer['From'] = 'bperson at dom.ain'
+ outer.set_boundary('BOUNDARY')
+ self.ndiffAssertEqual(outer.as_string(), '''\
+Content-Type: multipart/mixed; boundary="BOUNDARY"
+MIME-Version: 1.0
+Subject: A subject
+To: aperson at dom.ain
+From: bperson at dom.ain
+
+--BOUNDARY
+
+--BOUNDARY--''')
+
+ def test_no_parts_in_a_multipart_with_empty_epilogue(self):
+ outer = MIMEBase('multipart', 'mixed')
+ outer['Subject'] = 'A subject'
+ outer['To'] = 'aperson at dom.ain'
+ outer['From'] = 'bperson at dom.ain'
+ outer.preamble = ''
+ outer.epilogue = ''
+ outer.set_boundary('BOUNDARY')
+ self.ndiffAssertEqual(outer.as_string(), '''\
+Content-Type: multipart/mixed; boundary="BOUNDARY"
+MIME-Version: 1.0
+Subject: A subject
+To: aperson at dom.ain
+From: bperson at dom.ain
+
+
+--BOUNDARY
+
+--BOUNDARY--
+''')
+
+ def test_one_part_in_a_multipart(self):
+ eq = self.ndiffAssertEqual
+ outer = MIMEBase('multipart', 'mixed')
+ outer['Subject'] = 'A subject'
+ outer['To'] = 'aperson at dom.ain'
+ outer['From'] = 'bperson at dom.ain'
+ outer.set_boundary('BOUNDARY')
+ msg = MIMEText('hello world')
+ outer.attach(msg)
+ eq(outer.as_string(), '''\
+Content-Type: multipart/mixed; boundary="BOUNDARY"
+MIME-Version: 1.0
+Subject: A subject
+To: aperson at dom.ain
+From: bperson at dom.ain
+
+--BOUNDARY
+Content-Type: text/plain; charset="us-ascii"
+MIME-Version: 1.0
+Content-Transfer-Encoding: 7bit
+
+hello world
+--BOUNDARY--''')
+
+ def test_seq_parts_in_a_multipart_with_empty_preamble(self):
+ eq = self.ndiffAssertEqual
+ outer = MIMEBase('multipart', 'mixed')
+ outer['Subject'] = 'A subject'
+ outer['To'] = 'aperson at dom.ain'
+ outer['From'] = 'bperson at dom.ain'
+ outer.preamble = ''
+ msg = MIMEText('hello world')
+ outer.attach(msg)
+ outer.set_boundary('BOUNDARY')
+ eq(outer.as_string(), '''\
+Content-Type: multipart/mixed; boundary="BOUNDARY"
+MIME-Version: 1.0
+Subject: A subject
+To: aperson at dom.ain
+From: bperson at dom.ain
+
+
+--BOUNDARY
+Content-Type: text/plain; charset="us-ascii"
+MIME-Version: 1.0
+Content-Transfer-Encoding: 7bit
+
+hello world
+--BOUNDARY--''')
+
+
+ def test_seq_parts_in_a_multipart_with_none_preamble(self):
+ eq = self.ndiffAssertEqual
+ outer = MIMEBase('multipart', 'mixed')
+ outer['Subject'] = 'A subject'
+ outer['To'] = 'aperson at dom.ain'
+ outer['From'] = 'bperson at dom.ain'
+ outer.preamble = None
+ msg = MIMEText('hello world')
+ outer.attach(msg)
+ outer.set_boundary('BOUNDARY')
+ eq(outer.as_string(), '''\
+Content-Type: multipart/mixed; boundary="BOUNDARY"
+MIME-Version: 1.0
+Subject: A subject
+To: aperson at dom.ain
+From: bperson at dom.ain
+
+--BOUNDARY
+Content-Type: text/plain; charset="us-ascii"
+MIME-Version: 1.0
+Content-Transfer-Encoding: 7bit
+
+hello world
+--BOUNDARY--''')
+
+
+ def test_seq_parts_in_a_multipart_with_none_epilogue(self):
+ eq = self.ndiffAssertEqual
+ outer = MIMEBase('multipart', 'mixed')
+ outer['Subject'] = 'A subject'
+ outer['To'] = 'aperson at dom.ain'
+ outer['From'] = 'bperson at dom.ain'
+ outer.epilogue = None
+ msg = MIMEText('hello world')
+ outer.attach(msg)
+ outer.set_boundary('BOUNDARY')
+ eq(outer.as_string(), '''\
+Content-Type: multipart/mixed; boundary="BOUNDARY"
+MIME-Version: 1.0
+Subject: A subject
+To: aperson at dom.ain
+From: bperson at dom.ain
+
+--BOUNDARY
+Content-Type: text/plain; charset="us-ascii"
+MIME-Version: 1.0
+Content-Transfer-Encoding: 7bit
+
+hello world
+--BOUNDARY--''')
+
+
+ def test_seq_parts_in_a_multipart_with_empty_epilogue(self):
+ eq = self.ndiffAssertEqual
+ outer = MIMEBase('multipart', 'mixed')
+ outer['Subject'] = 'A subject'
+ outer['To'] = 'aperson at dom.ain'
+ outer['From'] = 'bperson at dom.ain'
+ outer.epilogue = ''
+ msg = MIMEText('hello world')
+ outer.attach(msg)
+ outer.set_boundary('BOUNDARY')
+ eq(outer.as_string(), '''\
+Content-Type: multipart/mixed; boundary="BOUNDARY"
+MIME-Version: 1.0
+Subject: A subject
+To: aperson at dom.ain
+From: bperson at dom.ain
+
+--BOUNDARY
+Content-Type: text/plain; charset="us-ascii"
+MIME-Version: 1.0
+Content-Transfer-Encoding: 7bit
+
+hello world
+--BOUNDARY--
+''')
+
+
+ def test_seq_parts_in_a_multipart_with_nl_epilogue(self):
+ eq = self.ndiffAssertEqual
+ outer = MIMEBase('multipart', 'mixed')
+ outer['Subject'] = 'A subject'
+ outer['To'] = 'aperson at dom.ain'
+ outer['From'] = 'bperson at dom.ain'
+ outer.epilogue = '\n'
+ msg = MIMEText('hello world')
+ outer.attach(msg)
+ outer.set_boundary('BOUNDARY')
+ eq(outer.as_string(), '''\
+Content-Type: multipart/mixed; boundary="BOUNDARY"
+MIME-Version: 1.0
+Subject: A subject
+To: aperson at dom.ain
+From: bperson at dom.ain
+
+--BOUNDARY
+Content-Type: text/plain; charset="us-ascii"
+MIME-Version: 1.0
+Content-Transfer-Encoding: 7bit
+
+hello world
+--BOUNDARY--
+
+''')
+
+ def test_message_external_body(self):
+ eq = self.assertEqual
+ msg = self._msgobj('msg_36.txt')
+ eq(len(msg.get_payload()), 2)
+ msg1 = msg.get_payload(1)
+ eq(msg1.get_content_type(), 'multipart/alternative')
+ eq(len(msg1.get_payload()), 2)
+ for subpart in msg1.get_payload():
+ eq(subpart.get_content_type(), 'message/external-body')
+ eq(len(subpart.get_payload()), 1)
+ subsubpart = subpart.get_payload(0)
+ eq(subsubpart.get_content_type(), 'text/plain')
+
+ def test_double_boundary(self):
+ # msg_37.txt is a multipart that contains two dash-boundary's in a
+ # row. Our interpretation of RFC 2046 calls for ignoring the second
+ # and subsequent boundaries.
+ msg = self._msgobj('msg_37.txt')
+ self.assertEqual(len(msg.get_payload()), 3)
+
+ def test_nested_inner_contains_outer_boundary(self):
+ eq = self.ndiffAssertEqual
+ # msg_38.txt has an inner part that contains outer boundaries. My
+ # interpretation of RFC 2046 (based on sections 5.1 and 5.1.2) say
+ # these are illegal and should be interpreted as unterminated inner
+ # parts.
+ msg = self._msgobj('msg_38.txt')
+ sfp = StringIO()
+ iterators._structure(msg, sfp)
+ eq(sfp.getvalue(), """\
+multipart/mixed
+ multipart/mixed
+ multipart/alternative
+ text/plain
+ text/plain
+ text/plain
+ text/plain
+""")
+
+ def test_nested_with_same_boundary(self):
+ eq = self.ndiffAssertEqual
+ # msg 39.txt is similarly evil in that it's got inner parts that use
+ # the same boundary as outer parts. Again, I believe the way this is
+ # parsed is closest to the spirit of RFC 2046
+ msg = self._msgobj('msg_39.txt')
+ sfp = StringIO()
+ iterators._structure(msg, sfp)
+ eq(sfp.getvalue(), """\
+multipart/mixed
+ multipart/mixed
+ multipart/alternative
+ application/octet-stream
+ application/octet-stream
+ text/plain
+""")
+
+ def test_boundary_in_non_multipart(self):
+ msg = self._msgobj('msg_40.txt')
+ self.assertEqual(msg.as_string(), '''\
+MIME-Version: 1.0
+Content-Type: text/html; boundary="--961284236552522269"
+
+----961284236552522269
+Content-Type: text/html;
+Content-Transfer-Encoding: 7Bit
+
+<html></html>
+
+----961284236552522269--
+''')
+
+ def test_boundary_with_leading_space(self):
+ eq = self.assertEqual
+ msg = email.message_from_string('''\
+MIME-Version: 1.0
+Content-Type: multipart/mixed; boundary=" XXXX"
+
+-- XXXX
+Content-Type: text/plain
+
+
+-- XXXX
+Content-Type: text/plain
+
+-- XXXX--
+''')
+ self.assertTrue(msg.is_multipart())
+ eq(msg.get_boundary(), ' XXXX')
+ eq(len(msg.get_payload()), 2)
+
+ def test_boundary_without_trailing_newline(self):
+ m = Parser().parsestr("""\
+Content-Type: multipart/mixed; boundary="===============0012394164=="
+MIME-Version: 1.0
+
+--===============0012394164==
+Content-Type: image/file1.jpg
+MIME-Version: 1.0
+Content-Transfer-Encoding: base64
+
+YXNkZg==
+--===============0012394164==--""")
+ self.assertEqual(m.get_payload(0).get_payload(), 'YXNkZg==')
+
+
+
+# Test some badly formatted messages
+class TestNonConformant(TestEmailBase):
+ def test_parse_missing_minor_type(self):
+ eq = self.assertEqual
+ msg = self._msgobj('msg_14.txt')
+ eq(msg.get_content_type(), 'text/plain')
+ eq(msg.get_content_maintype(), 'text')
+ eq(msg.get_content_subtype(), 'plain')
+
+ def test_same_boundary_inner_outer(self):
+ unless = self.assertTrue
+ msg = self._msgobj('msg_15.txt')
+ # XXX We can probably eventually do better
+ inner = msg.get_payload(0)
+ unless(hasattr(inner, 'defects'))
+ self.assertEqual(len(inner.defects), 1)
+ unless(isinstance(inner.defects[0],
+ errors.StartBoundaryNotFoundDefect))
+
+ def test_multipart_no_boundary(self):
+ unless = self.assertTrue
+ msg = self._msgobj('msg_25.txt')
+ unless(isinstance(msg.get_payload(), str))
+ self.assertEqual(len(msg.defects), 2)
+ unless(isinstance(msg.defects[0], errors.NoBoundaryInMultipartDefect))
+ unless(isinstance(msg.defects[1],
+ errors.MultipartInvariantViolationDefect))
+
+ def test_invalid_content_type(self):
+ eq = self.assertEqual
+ neq = self.ndiffAssertEqual
+ msg = Message()
+ # RFC 2045, $5.2 says invalid yields text/plain
+ msg['Content-Type'] = 'text'
+ eq(msg.get_content_maintype(), 'text')
+ eq(msg.get_content_subtype(), 'plain')
+ eq(msg.get_content_type(), 'text/plain')
+ # Clear the old value and try something /really/ invalid
+ del msg['content-type']
+ msg['Content-Type'] = 'foo'
+ eq(msg.get_content_maintype(), 'text')
+ eq(msg.get_content_subtype(), 'plain')
+ eq(msg.get_content_type(), 'text/plain')
+ # Still, make sure that the message is idempotently generated
+ s = StringIO()
+ g = Generator(s)
+ g.flatten(msg)
+ neq(s.getvalue(), 'Content-Type: foo\n\n')
+
+ def test_no_start_boundary(self):
+ eq = self.ndiffAssertEqual
+ msg = self._msgobj('msg_31.txt')
+ eq(msg.get_payload(), """\
+--BOUNDARY
+Content-Type: text/plain
+
+message 1
+
+--BOUNDARY
+Content-Type: text/plain
+
+message 2
+
+--BOUNDARY--
+""")
+
+ def test_no_separating_blank_line(self):
+ eq = self.ndiffAssertEqual
+ msg = self._msgobj('msg_35.txt')
+ eq(msg.as_string(), """\
+From: aperson at dom.ain
+To: bperson at dom.ain
+Subject: here's something interesting
+
+counter to RFC 2822, there's no separating newline here
+""")
+
+ def test_lying_multipart(self):
+ unless = self.assertTrue
+ msg = self._msgobj('msg_41.txt')
+ unless(hasattr(msg, 'defects'))
+ self.assertEqual(len(msg.defects), 2)
+ unless(isinstance(msg.defects[0], errors.NoBoundaryInMultipartDefect))
+ unless(isinstance(msg.defects[1],
+ errors.MultipartInvariantViolationDefect))
+
+ def test_missing_start_boundary(self):
+ outer = self._msgobj('msg_42.txt')
+ # The message structure is:
+ #
+ # multipart/mixed
+ # text/plain
+ # message/rfc822
+ # multipart/mixed [*]
+ #
+ # [*] This message is missing its start boundary
+ bad = outer.get_payload(1).get_payload(0)
+ self.assertEqual(len(bad.defects), 1)
+ self.assertTrue(isinstance(bad.defects[0],
+ errors.StartBoundaryNotFoundDefect))
+
+ def test_first_line_is_continuation_header(self):
+ eq = self.assertEqual
+ m = ' Line 1\nLine 2\nLine 3'
+ msg = email.message_from_string(m)
+ eq(msg.keys(), [])
+ eq(msg.get_payload(), 'Line 2\nLine 3')
+ eq(len(msg.defects), 1)
+ self.assertTrue(isinstance(msg.defects[0],
+ errors.FirstHeaderLineIsContinuationDefect))
+ eq(msg.defects[0].line, ' Line 1\n')
+
+
+
+# Test RFC 2047 header encoding and decoding
+class TestRFC2047(unittest.TestCase):
+ def test_rfc2047_multiline(self):
+ eq = self.assertEqual
+ s = """Re: =?mac-iceland?q?r=8Aksm=9Arg=8Cs?= baz
+ foo bar =?mac-iceland?q?r=8Aksm=9Arg=8Cs?="""
+ dh = decode_header(s)
+ eq(dh, [
+ ('Re:', None),
+ ('r\x8aksm\x9arg\x8cs', 'mac-iceland'),
+ ('baz foo bar', None),
+ ('r\x8aksm\x9arg\x8cs', 'mac-iceland')])
+ eq(str(make_header(dh)),
+ """Re: =?mac-iceland?q?r=8Aksm=9Arg=8Cs?= baz foo bar
+ =?mac-iceland?q?r=8Aksm=9Arg=8Cs?=""")
+
+ def test_whitespace_eater_unicode(self):
+ eq = self.assertEqual
+ s = '=?ISO-8859-1?Q?Andr=E9?= Pirard <pirard at dom.ain>'
+ dh = decode_header(s)
+ eq(dh, [('Andr\xe9', 'iso-8859-1'), ('Pirard <pirard at dom.ain>', None)])
+ hu = unicode(make_header(dh)).encode('latin-1')
+ eq(hu, 'Andr\xe9 Pirard <pirard at dom.ain>')
+
+ def test_whitespace_eater_unicode_2(self):
+ eq = self.assertEqual
+ s = 'The =?iso-8859-1?b?cXVpY2sgYnJvd24gZm94?= jumped over the =?iso-8859-1?b?bGF6eSBkb2c=?='
+ dh = decode_header(s)
+ eq(dh, [('The', None), ('quick brown fox', 'iso-8859-1'),
+ ('jumped over the', None), ('lazy dog', 'iso-8859-1')])
+ hu = make_header(dh).__unicode__()
+ eq(hu, u'The quick brown fox jumped over the lazy dog')
+
+ def test_rfc2047_missing_whitespace(self):
+ s = 'Sm=?ISO-8859-1?B?9g==?=rg=?ISO-8859-1?B?5Q==?=sbord'
+ dh = decode_header(s)
+ self.assertEqual(dh, [(s, None)])
+
+ def test_rfc2047_with_whitespace(self):
+ s = 'Sm =?ISO-8859-1?B?9g==?= rg =?ISO-8859-1?B?5Q==?= sbord'
+ dh = decode_header(s)
+ self.assertEqual(dh, [('Sm', None), ('\xf6', 'iso-8859-1'),
+ ('rg', None), ('\xe5', 'iso-8859-1'),
+ ('sbord', None)])
+
+
+
+# Test the MIMEMessage class
+class TestMIMEMessage(TestEmailBase):
+ def setUp(self):
+ fp = openfile('msg_11.txt')
+ try:
+ self._text = fp.read()
+ finally:
+ fp.close()
+
+ def test_type_error(self):
+ self.assertRaises(TypeError, MIMEMessage, 'a plain string')
+
+ def test_valid_argument(self):
+ eq = self.assertEqual
+ unless = self.assertTrue
+ subject = 'A sub-message'
+ m = Message()
+ m['Subject'] = subject
+ r = MIMEMessage(m)
+ eq(r.get_content_type(), 'message/rfc822')
+ payload = r.get_payload()
+ unless(isinstance(payload, list))
+ eq(len(payload), 1)
+ subpart = payload[0]
+ unless(subpart is m)
+ eq(subpart['subject'], subject)
+
+ def test_bad_multipart(self):
+ eq = self.assertEqual
+ msg1 = Message()
+ msg1['Subject'] = 'subpart 1'
+ msg2 = Message()
+ msg2['Subject'] = 'subpart 2'
+ r = MIMEMessage(msg1)
+ self.assertRaises(errors.MultipartConversionError, r.attach, msg2)
+
+ def test_generate(self):
+ # First craft the message to be encapsulated
+ m = Message()
+ m['Subject'] = 'An enclosed message'
+ m.set_payload('Here is the body of the message.\n')
+ r = MIMEMessage(m)
+ r['Subject'] = 'The enclosing message'
+ s = StringIO()
+ g = Generator(s)
+ g.flatten(r)
+ self.assertEqual(s.getvalue(), """\
+Content-Type: message/rfc822
+MIME-Version: 1.0
+Subject: The enclosing message
+
+Subject: An enclosed message
+
+Here is the body of the message.
+""")
+
+ def test_parse_message_rfc822(self):
+ eq = self.assertEqual
+ unless = self.assertTrue
+ msg = self._msgobj('msg_11.txt')
+ eq(msg.get_content_type(), 'message/rfc822')
+ payload = msg.get_payload()
+ unless(isinstance(payload, list))
+ eq(len(payload), 1)
+ submsg = payload[0]
+ self.assertTrue(isinstance(submsg, Message))
+ eq(submsg['subject'], 'An enclosed message')
+ eq(submsg.get_payload(), 'Here is the body of the message.\n')
+
+ def test_dsn(self):
+ eq = self.assertEqual
+ unless = self.assertTrue
+ # msg 16 is a Delivery Status Notification, see RFC 1894
+ msg = self._msgobj('msg_16.txt')
+ eq(msg.get_content_type(), 'multipart/report')
+ unless(msg.is_multipart())
+ eq(len(msg.get_payload()), 3)
+ # Subpart 1 is a text/plain, human readable section
+ subpart = msg.get_payload(0)
+ eq(subpart.get_content_type(), 'text/plain')
+ eq(subpart.get_payload(), """\
+This report relates to a message you sent with the following header fields:
+
+ Message-id: <002001c144a6$8752e060$56104586 at oxy.edu>
+ Date: Sun, 23 Sep 2001 20:10:55 -0700
+ From: "Ian T. Henry" <henryi at oxy.edu>
+ To: SoCal Raves <scr at socal-raves.org>
+ Subject: [scr] yeah for Ians!!
+
+Your message cannot be delivered to the following recipients:
+
+ Recipient address: jangel1 at cougar.noc.ucla.edu
+ Reason: recipient reached disk quota
+
+""")
+ # Subpart 2 contains the machine parsable DSN information. It
+ # consists of two blocks of headers, represented by two nested Message
+ # objects.
+ subpart = msg.get_payload(1)
+ eq(subpart.get_content_type(), 'message/delivery-status')
+ eq(len(subpart.get_payload()), 2)
+ # message/delivery-status should treat each block as a bunch of
+ # headers, i.e. a bunch of Message objects.
+ dsn1 = subpart.get_payload(0)
+ unless(isinstance(dsn1, Message))
+ eq(dsn1['original-envelope-id'], '0GK500B4HD0888 at cougar.noc.ucla.edu')
+ eq(dsn1.get_param('dns', header='reporting-mta'), '')
+ # Try a missing one <wink>
+ eq(dsn1.get_param('nsd', header='reporting-mta'), None)
+ dsn2 = subpart.get_payload(1)
+ unless(isinstance(dsn2, Message))
+ eq(dsn2['action'], 'failed')
+ eq(dsn2.get_params(header='original-recipient'),
+ [('rfc822', ''), ('jangel1 at cougar.noc.ucla.edu', '')])
+ eq(dsn2.get_param('rfc822', header='final-recipient'), '')
+ # Subpart 3 is the original message
+ subpart = msg.get_payload(2)
+ eq(subpart.get_content_type(), 'message/rfc822')
+ payload = subpart.get_payload()
+ unless(isinstance(payload, list))
+ eq(len(payload), 1)
+ subsubpart = payload[0]
+ unless(isinstance(subsubpart, Message))
+ eq(subsubpart.get_content_type(), 'text/plain')
+ eq(subsubpart['message-id'],
+ '<002001c144a6$8752e060$56104586 at oxy.edu>')
+
+ def test_epilogue(self):
+ eq = self.ndiffAssertEqual
+ fp = openfile('msg_21.txt')
+ try:
+ text = fp.read()
+ finally:
+ fp.close()
+ msg = Message()
+ msg['From'] = 'aperson at dom.ain'
+ msg['To'] = 'bperson at dom.ain'
+ msg['Subject'] = 'Test'
+ msg.preamble = 'MIME message'
+ msg.epilogue = 'End of MIME message\n'
+ msg1 = MIMEText('One')
+ msg2 = MIMEText('Two')
+ msg.add_header('Content-Type', 'multipart/mixed', boundary='BOUNDARY')
+ msg.attach(msg1)
+ msg.attach(msg2)
+ sfp = StringIO()
+ g = Generator(sfp)
+ g.flatten(msg)
+ eq(sfp.getvalue(), text)
+
+ def test_no_nl_preamble(self):
+ eq = self.ndiffAssertEqual
+ msg = Message()
+ msg['From'] = 'aperson at dom.ain'
+ msg['To'] = 'bperson at dom.ain'
+ msg['Subject'] = 'Test'
+ msg.preamble = 'MIME message'
+ msg.epilogue = ''
+ msg1 = MIMEText('One')
+ msg2 = MIMEText('Two')
+ msg.add_header('Content-Type', 'multipart/mixed', boundary='BOUNDARY')
+ msg.attach(msg1)
+ msg.attach(msg2)
+ eq(msg.as_string(), """\
+From: aperson at dom.ain
+To: bperson at dom.ain
+Subject: Test
+Content-Type: multipart/mixed; boundary="BOUNDARY"
+
+MIME message
+--BOUNDARY
+Content-Type: text/plain; charset="us-ascii"
+MIME-Version: 1.0
+Content-Transfer-Encoding: 7bit
+
+One
+--BOUNDARY
+Content-Type: text/plain; charset="us-ascii"
+MIME-Version: 1.0
+Content-Transfer-Encoding: 7bit
+
+Two
+--BOUNDARY--
+""")
+
+ def test_default_type(self):
+ eq = self.assertEqual
+ fp = openfile('msg_30.txt')
+ try:
+ msg = email.message_from_file(fp)
+ finally:
+ fp.close()
+ container1 = msg.get_payload(0)
+ eq(container1.get_default_type(), 'message/rfc822')
+ eq(container1.get_content_type(), 'message/rfc822')
+ container2 = msg.get_payload(1)
+ eq(container2.get_default_type(), 'message/rfc822')
+ eq(container2.get_content_type(), 'message/rfc822')
+ container1a = container1.get_payload(0)
+ eq(container1a.get_default_type(), 'text/plain')
+ eq(container1a.get_content_type(), 'text/plain')
+ container2a = container2.get_payload(0)
+ eq(container2a.get_default_type(), 'text/plain')
+ eq(container2a.get_content_type(), 'text/plain')
+
+ def test_default_type_with_explicit_container_type(self):
+ eq = self.assertEqual
+ fp = openfile('msg_28.txt')
+ try:
+ msg = email.message_from_file(fp)
+ finally:
+ fp.close()
+ container1 = msg.get_payload(0)
+ eq(container1.get_default_type(), 'message/rfc822')
+ eq(container1.get_content_type(), 'message/rfc822')
+ container2 = msg.get_payload(1)
+ eq(container2.get_default_type(), 'message/rfc822')
+ eq(container2.get_content_type(), 'message/rfc822')
+ container1a = container1.get_payload(0)
+ eq(container1a.get_default_type(), 'text/plain')
+ eq(container1a.get_content_type(), 'text/plain')
+ container2a = container2.get_payload(0)
+ eq(container2a.get_default_type(), 'text/plain')
+ eq(container2a.get_content_type(), 'text/plain')
+
+ def test_default_type_non_parsed(self):
+ eq = self.assertEqual
+ neq = self.ndiffAssertEqual
+ # Set up container
+ container = MIMEMultipart('digest', 'BOUNDARY')
+ container.epilogue = ''
+ # Set up subparts
+ subpart1a = MIMEText('message 1\n')
+ subpart2a = MIMEText('message 2\n')
+ subpart1 = MIMEMessage(subpart1a)
+ subpart2 = MIMEMessage(subpart2a)
+ container.attach(subpart1)
+ container.attach(subpart2)
+ eq(subpart1.get_content_type(), 'message/rfc822')
+ eq(subpart1.get_default_type(), 'message/rfc822')
+ eq(subpart2.get_content_type(), 'message/rfc822')
+ eq(subpart2.get_default_type(), 'message/rfc822')
+ neq(container.as_string(0), '''\
+Content-Type: multipart/digest; boundary="BOUNDARY"
+MIME-Version: 1.0
+
+--BOUNDARY
+Content-Type: message/rfc822
+MIME-Version: 1.0
+
+Content-Type: text/plain; charset="us-ascii"
+MIME-Version: 1.0
+Content-Transfer-Encoding: 7bit
+
+message 1
+
+--BOUNDARY
+Content-Type: message/rfc822
+MIME-Version: 1.0
+
+Content-Type: text/plain; charset="us-ascii"
+MIME-Version: 1.0
+Content-Transfer-Encoding: 7bit
+
+message 2
+
+--BOUNDARY--
+''')
+ del subpart1['content-type']
+ del subpart1['mime-version']
+ del subpart2['content-type']
+ del subpart2['mime-version']
+ eq(subpart1.get_content_type(), 'message/rfc822')
+ eq(subpart1.get_default_type(), 'message/rfc822')
+ eq(subpart2.get_content_type(), 'message/rfc822')
+ eq(subpart2.get_default_type(), 'message/rfc822')
+ neq(container.as_string(0), '''\
+Content-Type: multipart/digest; boundary="BOUNDARY"
+MIME-Version: 1.0
+
+--BOUNDARY
+
+Content-Type: text/plain; charset="us-ascii"
+MIME-Version: 1.0
+Content-Transfer-Encoding: 7bit
+
+message 1
+
+--BOUNDARY
+
+Content-Type: text/plain; charset="us-ascii"
+MIME-Version: 1.0
+Content-Transfer-Encoding: 7bit
+
+message 2
+
+--BOUNDARY--
+''')
+
+ def test_mime_attachments_in_constructor(self):
+ eq = self.assertEqual
+ text1 = MIMEText('')
+ text2 = MIMEText('')
+ msg = MIMEMultipart(_subparts=(text1, text2))
+ eq(len(msg.get_payload()), 2)
+ eq(msg.get_payload(0), text1)
+ eq(msg.get_payload(1), text2)
+
+
+
+# A general test of parser->model->generator idempotency. IOW, read a message
+# in, parse it into a message object tree, then without touching the tree,
+# regenerate the plain text. The original text and the transformed text
+# should be identical. Note: that we ignore the Unix-From since that may
+# contain a changed date.
+class TestIdempotent(TestEmailBase):
+ def _msgobj(self, filename):
+ fp = openfile(filename)
+ try:
+ data = fp.read()
+ finally:
+ fp.close()
+ msg = email.message_from_string(data)
+ return msg, data
+
+ def _idempotent(self, msg, text):
+ eq = self.ndiffAssertEqual
+ s = StringIO()
+ g = Generator(s, maxheaderlen=0)
+ g.flatten(msg)
+ eq(text, s.getvalue())
+
+ def test_parse_text_message(self):
+ eq = self.assertEqual
+ msg, text = self._msgobj('msg_01.txt')
+ eq(msg.get_content_type(), 'text/plain')
+ eq(msg.get_content_maintype(), 'text')
+ eq(msg.get_content_subtype(), 'plain')
+ eq(msg.get_params()[1], ('charset', 'us-ascii'))
+ eq(msg.get_param('charset'), 'us-ascii')
+ eq(msg.preamble, None)
+ eq(msg.epilogue, None)
+ self._idempotent(msg, text)
+
+ def test_parse_untyped_message(self):
+ eq = self.assertEqual
+ msg, text = self._msgobj('msg_03.txt')
+ eq(msg.get_content_type(), 'text/plain')
+ eq(msg.get_params(), None)
+ eq(msg.get_param('charset'), None)
+ self._idempotent(msg, text)
+
+ def test_simple_multipart(self):
+ msg, text = self._msgobj('msg_04.txt')
+ self._idempotent(msg, text)
+
+ def test_MIME_digest(self):
+ msg, text = self._msgobj('msg_02.txt')
+ self._idempotent(msg, text)
+
+ def test_long_header(self):
+ msg, text = self._msgobj('msg_27.txt')
+ self._idempotent(msg, text)
+
+ def test_MIME_digest_with_part_headers(self):
+ msg, text = self._msgobj('msg_28.txt')
+ self._idempotent(msg, text)
+
+ def test_mixed_with_image(self):
+ msg, text = self._msgobj('msg_06.txt')
+ self._idempotent(msg, text)
+
+ def test_multipart_report(self):
+ msg, text = self._msgobj('msg_05.txt')
+ self._idempotent(msg, text)
+
+ def test_dsn(self):
+ msg, text = self._msgobj('msg_16.txt')
+ self._idempotent(msg, text)
+
+ def test_preamble_epilogue(self):
+ msg, text = self._msgobj('msg_21.txt')
+ self._idempotent(msg, text)
+
+ def test_multipart_one_part(self):
+ msg, text = self._msgobj('msg_23.txt')
+ self._idempotent(msg, text)
+
+ def test_multipart_no_parts(self):
+ msg, text = self._msgobj('msg_24.txt')
+ self._idempotent(msg, text)
+
+ def test_no_start_boundary(self):
+ msg, text = self._msgobj('msg_31.txt')
+ self._idempotent(msg, text)
+
+ def test_rfc2231_charset(self):
+ msg, text = self._msgobj('msg_32.txt')
+ self._idempotent(msg, text)
+
+ def test_more_rfc2231_parameters(self):
+ msg, text = self._msgobj('msg_33.txt')
+ self._idempotent(msg, text)
+
+ def test_text_plain_in_a_multipart_digest(self):
+ msg, text = self._msgobj('msg_34.txt')
+ self._idempotent(msg, text)
+
+ def test_nested_multipart_mixeds(self):
+ msg, text = self._msgobj('msg_12a.txt')
+ self._idempotent(msg, text)
+
+ def test_message_external_body_idempotent(self):
+ msg, text = self._msgobj('msg_36.txt')
+ self._idempotent(msg, text)
+
+ def test_content_type(self):
+ eq = self.assertEqual
+ unless = self.assertTrue
+ # Get a message object and reset the seek pointer for other tests
+ msg, text = self._msgobj('msg_05.txt')
+ eq(msg.get_content_type(), 'multipart/report')
+ # Test the Content-Type: parameters
+ params = {}
+ for pk, pv in msg.get_params():
+ params[pk] = pv
+ eq(params['report-type'], 'delivery-status')
+ eq(params['boundary'], 'D1690A7AC1.996856090/mail.example.com')
+ eq(msg.preamble, 'This is a MIME-encapsulated message.\n')
+ eq(msg.epilogue, '\n')
+ eq(len(msg.get_payload()), 3)
+ # Make sure the subparts are what we expect
+ msg1 = msg.get_payload(0)
+ eq(msg1.get_content_type(), 'text/plain')
+ eq(msg1.get_payload(), 'Yadda yadda yadda\n')
+ msg2 = msg.get_payload(1)
+ eq(msg2.get_content_type(), 'text/plain')
+ eq(msg2.get_payload(), 'Yadda yadda yadda\n')
+ msg3 = msg.get_payload(2)
+ eq(msg3.get_content_type(), 'message/rfc822')
+ self.assertTrue(isinstance(msg3, Message))
+ payload = msg3.get_payload()
+ unless(isinstance(payload, list))
+ eq(len(payload), 1)
+ msg4 = payload[0]
+ unless(isinstance(msg4, Message))
+ eq(msg4.get_payload(), 'Yadda yadda yadda\n')
+
+ def test_parser(self):
+ eq = self.assertEqual
+ unless = self.assertTrue
+ msg, text = self._msgobj('msg_06.txt')
+ # Check some of the outer headers
+ eq(msg.get_content_type(), 'message/rfc822')
+ # Make sure the payload is a list of exactly one sub-Message, and that
+ # that submessage has a type of text/plain
+ payload = msg.get_payload()
+ unless(isinstance(payload, list))
+ eq(len(payload), 1)
+ msg1 = payload[0]
+ self.assertTrue(isinstance(msg1, Message))
+ eq(msg1.get_content_type(), 'text/plain')
+ self.assertTrue(isinstance(msg1.get_payload(), str))
+ eq(msg1.get_payload(), '\n')
+
+
+
+# Test various other bits of the package's functionality
+class TestMiscellaneous(TestEmailBase):
+ def test_message_from_string(self):
+ fp = openfile('msg_01.txt')
+ try:
+ text = fp.read()
+ finally:
+ fp.close()
+ msg = email.message_from_string(text)
+ s = StringIO()
+ # Don't wrap/continue long headers since we're trying to test
+ # idempotency.
+ g = Generator(s, maxheaderlen=0)
+ g.flatten(msg)
+ self.assertEqual(text, s.getvalue())
+
+ def test_message_from_file(self):
+ fp = openfile('msg_01.txt')
+ try:
+ text = fp.read()
+ fp.seek(0)
+ msg = email.message_from_file(fp)
+ s = StringIO()
+ # Don't wrap/continue long headers since we're trying to test
+ # idempotency.
+ g = Generator(s, maxheaderlen=0)
+ g.flatten(msg)
+ self.assertEqual(text, s.getvalue())
+ finally:
+ fp.close()
+
+ def test_message_from_string_with_class(self):
+ unless = self.assertTrue
+ fp = openfile('msg_01.txt')
+ try:
+ text = fp.read()
+ finally:
+ fp.close()
+ # Create a subclass
+ class MyMessage(Message):
+ pass
+
+ msg = email.message_from_string(text, MyMessage)
+ unless(isinstance(msg, MyMessage))
+ # Try something more complicated
+ fp = openfile('msg_02.txt')
+ try:
+ text = fp.read()
+ finally:
+ fp.close()
+ msg = email.message_from_string(text, MyMessage)
+ for subpart in msg.walk():
+ unless(isinstance(subpart, MyMessage))
+
+ def test_message_from_file_with_class(self):
+ unless = self.assertTrue
+ # Create a subclass
+ class MyMessage(Message):
+ pass
+
+ fp = openfile('msg_01.txt')
+ try:
+ msg = email.message_from_file(fp, MyMessage)
+ finally:
+ fp.close()
+ unless(isinstance(msg, MyMessage))
+ # Try something more complicated
+ fp = openfile('msg_02.txt')
+ try:
+ msg = email.message_from_file(fp, MyMessage)
+ finally:
+ fp.close()
+ for subpart in msg.walk():
+ unless(isinstance(subpart, MyMessage))
+
+ def test__all__(self):
+ module = __import__('email')
+ # Can't use sorted() here due to Python 2.3 compatibility
+ all = module.__all__[:]
+ all.sort()
+ self.assertEqual(all, [
+ # Old names
+ 'Charset', 'Encoders', 'Errors', 'Generator',
+ 'Header', 'Iterators', 'MIMEAudio', 'MIMEBase',
+ 'MIMEImage', 'MIMEMessage', 'MIMEMultipart',
+ 'MIMENonMultipart', 'MIMEText', 'Message',
+ 'Parser', 'Utils', 'base64MIME',
+ # new names
+ 'base64mime', 'charset', 'encoders', 'errors', 'generator',
+ 'header', 'iterators', 'message', 'message_from_file',
+ 'message_from_string', 'mime', 'parser',
+ 'quopriMIME', 'quoprimime', 'utils',
+ ])
+
+ def test_formatdate(self):
+ now = time.time()
+ self.assertEqual(utils.parsedate(utils.formatdate(now))[:6],
+ time.gmtime(now)[:6])
+
+ def test_formatdate_localtime(self):
+ now = time.time()
+ self.assertEqual(
+ utils.parsedate(utils.formatdate(now, localtime=True))[:6],
+ time.localtime(now)[:6])
+
+ def test_formatdate_usegmt(self):
+ now = time.time()
+ self.assertEqual(
+ utils.formatdate(now, localtime=False),
+ time.strftime('%a, %d %b %Y %H:%M:%S -0000', time.gmtime(now)))
+ self.assertEqual(
+ utils.formatdate(now, localtime=False, usegmt=True),
+ time.strftime('%a, %d %b %Y %H:%M:%S GMT', time.gmtime(now)))
+
+ def test_parsedate_none(self):
+ self.assertEqual(utils.parsedate(''), None)
+
+ def test_parsedate_compact(self):
+ # The FWS after the comma is optional
+ self.assertEqual(utils.parsedate('Wed,3 Apr 2002 14:58:26 +0800'),
+ utils.parsedate('Wed, 3 Apr 2002 14:58:26 +0800'))
+
+ def test_parsedate_no_dayofweek(self):
+ eq = self.assertEqual
+ eq(utils.parsedate_tz('25 Feb 2003 13:47:26 -0800'),
+ (2003, 2, 25, 13, 47, 26, 0, 1, -1, -28800))
+
+ def test_parsedate_compact_no_dayofweek(self):
+ eq = self.assertEqual
+ eq(utils.parsedate_tz('5 Feb 2003 13:47:26 -0800'),
+ (2003, 2, 5, 13, 47, 26, 0, 1, -1, -28800))
+
+ def test_parsedate_acceptable_to_time_functions(self):
+ eq = self.assertEqual
+ timetup = utils.parsedate('5 Feb 2003 13:47:26 -0800')
+ t = int(time.mktime(timetup))
+ eq(time.localtime(t)[:6], timetup[:6])
+ eq(int(time.strftime('%Y', timetup)), 2003)
+ timetup = utils.parsedate_tz('5 Feb 2003 13:47:26 -0800')
+ t = int(time.mktime(timetup[:9]))
+ eq(time.localtime(t)[:6], timetup[:6])
+ eq(int(time.strftime('%Y', timetup[:9])), 2003)
+
+ def test_parseaddr_empty(self):
+ self.assertEqual(utils.parseaddr('<>'), ('', ''))
+ self.assertEqual(utils.formataddr(utils.parseaddr('<>')), '')
+
+ def test_noquote_dump(self):
+ self.assertEqual(
+ utils.formataddr(('A Silly Person', 'person at dom.ain')),
+ 'A Silly Person <person at dom.ain>')
+
+ def test_escape_dump(self):
+ self.assertEqual(
+ utils.formataddr(('A (Very) Silly Person', 'person at dom.ain')),
+ r'"A \(Very\) Silly Person" <person at dom.ain>')
+ a = r'A \(Special\) Person'
+ b = 'person at dom.ain'
+ self.assertEqual(utils.parseaddr(utils.formataddr((a, b))), (a, b))
+
+ def test_escape_backslashes(self):
+ self.assertEqual(
+ utils.formataddr(('Arthur \Backslash\ Foobar', 'person at dom.ain')),
+ r'"Arthur \\Backslash\\ Foobar" <person at dom.ain>')
+ a = r'Arthur \Backslash\ Foobar'
+ b = 'person at dom.ain'
+ self.assertEqual(utils.parseaddr(utils.formataddr((a, b))), (a, b))
+
+ def test_name_with_dot(self):
+ x = 'John X. Doe <jxd at example.com>'
+ y = '"John X. Doe" <jxd at example.com>'
+ a, b = ('John X. Doe', 'jxd at example.com')
+ self.assertEqual(utils.parseaddr(x), (a, b))
+ self.assertEqual(utils.parseaddr(y), (a, b))
+ # formataddr() quotes the name if there's a dot in it
+ self.assertEqual(utils.formataddr((a, b)), y)
+
+ def test_multiline_from_comment(self):
+ x = """\
+Foo
+\tBar <foo at example.com>"""
+ self.assertEqual(utils.parseaddr(x), ('Foo Bar', 'foo at example.com'))
+
+ def test_quote_dump(self):
+ self.assertEqual(
+ utils.formataddr(('A Silly; Person', 'person at dom.ain')),
+ r'"A Silly; Person" <person at dom.ain>')
+
+ def test_fix_eols(self):
+ eq = self.assertEqual
+ eq(utils.fix_eols('hello'), 'hello')
+ eq(utils.fix_eols('hello\n'), 'hello\r\n')
+ eq(utils.fix_eols('hello\r'), 'hello\r\n')
+ eq(utils.fix_eols('hello\r\n'), 'hello\r\n')
+ eq(utils.fix_eols('hello\n\r'), 'hello\r\n\r\n')
+
+ def test_charset_richcomparisons(self):
+ eq = self.assertEqual
+ ne = self.assertNotEqual
+ cset1 = Charset()
+ cset2 = Charset()
+ eq(cset1, 'us-ascii')
+ eq(cset1, 'US-ASCII')
+ eq(cset1, 'Us-AsCiI')
+ eq('us-ascii', cset1)
+ eq('US-ASCII', cset1)
+ eq('Us-AsCiI', cset1)
+ ne(cset1, 'usascii')
+ ne(cset1, 'USASCII')
+ ne(cset1, 'UsAsCiI')
+ ne('usascii', cset1)
+ ne('USASCII', cset1)
+ ne('UsAsCiI', cset1)
+ eq(cset1, cset2)
+ eq(cset2, cset1)
+
+ def test_getaddresses(self):
+ eq = self.assertEqual
+ eq(utils.getaddresses(['aperson at dom.ain (Al Person)',
+ 'Bud Person <bperson at dom.ain>']),
+ [('Al Person', 'aperson at dom.ain'),
+ ('Bud Person', 'bperson at dom.ain')])
+
+ def test_getaddresses_nasty(self):
+ eq = self.assertEqual
+ eq(utils.getaddresses(['foo: ;']), [('', '')])
+ eq(utils.getaddresses(
+ ['[]*-- =~$']),
+ [('', ''), ('', ''), ('', '*--')])
+ eq(utils.getaddresses(
+ ['foo: ;', '"Jason R. Mastaler" <jason at dom.ain>']),
+ [('', ''), ('Jason R. Mastaler', 'jason at dom.ain')])
+
+ def test_getaddresses_embedded_comment(self):
+ """Test proper handling of a nested comment"""
+ eq = self.assertEqual
+ addrs = utils.getaddresses(['User ((nested comment)) <foo at bar.com>'])
+ eq(addrs[0][1], 'foo at bar.com')
+
+ def test_utils_quote_unquote(self):
+ eq = self.assertEqual
+ msg = Message()
+ msg.add_header('content-disposition', 'attachment',
+ filename='foo\\wacky"name')
+ eq(msg.get_filename(), 'foo\\wacky"name')
+
+ def test_get_body_encoding_with_bogus_charset(self):
+ charset = Charset('not a charset')
+ self.assertEqual(charset.get_body_encoding(), 'base64')
+
+ def test_get_body_encoding_with_uppercase_charset(self):
+ eq = self.assertEqual
+ msg = Message()
+ msg['Content-Type'] = 'text/plain; charset=UTF-8'
+ eq(msg['content-type'], 'text/plain; charset=UTF-8')
+ charsets = msg.get_charsets()
+ eq(len(charsets), 1)
+ eq(charsets[0], 'utf-8')
+ charset = Charset(charsets[0])
+ eq(charset.get_body_encoding(), 'base64')
+ msg.set_payload('hello world', charset=charset)
+ eq(msg.get_payload(), 'aGVsbG8gd29ybGQ=\n')
+ eq(msg.get_payload(decode=True), 'hello world')
+ eq(msg['content-transfer-encoding'], 'base64')
+ # Try another one
+ msg = Message()
+ msg['Content-Type'] = 'text/plain; charset="US-ASCII"'
+ charsets = msg.get_charsets()
+ eq(len(charsets), 1)
+ eq(charsets[0], 'us-ascii')
+ charset = Charset(charsets[0])
+ eq(charset.get_body_encoding(), encoders.encode_7or8bit)
+ msg.set_payload('hello world', charset=charset)
+ eq(msg.get_payload(), 'hello world')
+ eq(msg['content-transfer-encoding'], '7bit')
+
+ def test_charsets_case_insensitive(self):
+ lc = Charset('us-ascii')
+ uc = Charset('US-ASCII')
+ self.assertEqual(lc.get_body_encoding(), uc.get_body_encoding())
+
+ def test_partial_falls_inside_message_delivery_status(self):
+ eq = self.ndiffAssertEqual
+ # The Parser interface provides chunks of data to FeedParser in 8192
+ # byte gulps. SF bug #1076485 found one of those chunks inside
+ # message/delivery-status header block, which triggered an
+ # unreadline() of NeedMoreData.
+ msg = self._msgobj('msg_43.txt')
+ sfp = StringIO()
+ iterators._structure(msg, sfp)
+ eq(sfp.getvalue(), """\
+multipart/report
+ text/plain
+ message/delivery-status
+ text/plain
+ text/plain
+ text/plain
+ text/plain
+ text/plain
+ text/plain
+ text/plain
+ text/plain
+ text/plain
+ text/plain
+ text/plain
+ text/plain
+ text/plain
+ text/plain
+ text/plain
+ text/plain
+ text/plain
+ text/plain
+ text/plain
+ text/plain
+ text/plain
+ text/plain
+ text/plain
+ text/plain
+ text/plain
+ text/plain
+ text/rfc822-headers
+""")
+
+
+
+# Test the iterator/generators
+class TestIterators(TestEmailBase):
+ def test_body_line_iterator(self):
+ eq = self.assertEqual
+ neq = self.ndiffAssertEqual
+ # First a simple non-multipart message
+ msg = self._msgobj('msg_01.txt')
+ it = iterators.body_line_iterator(msg)
+ lines = list(it)
+ eq(len(lines), 6)
+ neq(EMPTYSTRING.join(lines), msg.get_payload())
+ # Now a more complicated multipart
+ msg = self._msgobj('msg_02.txt')
+ it = iterators.body_line_iterator(msg)
+ lines = list(it)
+ eq(len(lines), 43)
+ fp = openfile('msg_19.txt')
+ try:
+ neq(EMPTYSTRING.join(lines), fp.read())
+ finally:
+ fp.close()
+
+ def test_typed_subpart_iterator(self):
+ eq = self.assertEqual
+ msg = self._msgobj('msg_04.txt')
+ it = iterators.typed_subpart_iterator(msg, 'text')
+ lines = []
+ subparts = 0
+ for subpart in it:
+ subparts += 1
+ lines.append(subpart.get_payload())
+ eq(subparts, 2)
+ eq(EMPTYSTRING.join(lines), """\
+a simple kind of mirror
+to reflect upon our own
+a simple kind of mirror
+to reflect upon our own
+""")
+
+ def test_typed_subpart_iterator_default_type(self):
+ eq = self.assertEqual
+ msg = self._msgobj('msg_03.txt')
+ it = iterators.typed_subpart_iterator(msg, 'text', 'plain')
+ lines = []
+ subparts = 0
+ for subpart in it:
+ subparts += 1
+ lines.append(subpart.get_payload())
+ eq(subparts, 1)
+ eq(EMPTYSTRING.join(lines), """\
+
+Hi,
+
+Do you like this message?
+
+-Me
+""")
+
+
+
+class TestParsers(TestEmailBase):
+ def test_header_parser(self):
+ eq = self.assertEqual
+ # Parse only the headers of a complex multipart MIME document
+ fp = openfile('msg_02.txt')
+ try:
+ msg = HeaderParser().parse(fp)
+ finally:
+ fp.close()
+ eq(msg['from'], 'ppp-request at zzz.org')
+ eq(msg['to'], 'ppp at zzz.org')
+ eq(msg.get_content_type(), 'multipart/mixed')
+ self.assertFalse(msg.is_multipart())
+ self.assertTrue(isinstance(msg.get_payload(), str))
+
+ def test_whitespace_continuation(self):
+ eq = self.assertEqual
+ # This message contains a line after the Subject: header that has only
+ # whitespace, but it is not empty!
+ msg = email.message_from_string("""\
+From: aperson at dom.ain
+To: bperson at dom.ain
+Subject: the next line has a space on it
+\x20
+Date: Mon, 8 Apr 2002 15:09:19 -0400
+Message-ID: spam
+
+Here's the message body
+""")
+ eq(msg['subject'], 'the next line has a space on it\n ')
+ eq(msg['message-id'], 'spam')
+ eq(msg.get_payload(), "Here's the message body\n")
+
+ def test_whitespace_continuation_last_header(self):
+ eq = self.assertEqual
+ # Like the previous test, but the subject line is the last
+ # header.
+ msg = email.message_from_string("""\
+From: aperson at dom.ain
+To: bperson at dom.ain
+Date: Mon, 8 Apr 2002 15:09:19 -0400
+Message-ID: spam
+Subject: the next line has a space on it
+\x20
+
+Here's the message body
+""")
+ eq(msg['subject'], 'the next line has a space on it\n ')
+ eq(msg['message-id'], 'spam')
+ eq(msg.get_payload(), "Here's the message body\n")
+
+ def test_crlf_separation(self):
+ eq = self.assertEqual
+ fp = openfile('msg_26.txt', mode='rb')
+ try:
+ msg = Parser().parse(fp)
+ finally:
+ fp.close()
+ eq(len(msg.get_payload()), 2)
+ part1 = msg.get_payload(0)
+ eq(part1.get_content_type(), 'text/plain')
+ eq(part1.get_payload(), 'Simple email with attachment.\r\n\r\n')
+ part2 = msg.get_payload(1)
+ eq(part2.get_content_type(), 'application/riscos')
+
+ def test_multipart_digest_with_extra_mime_headers(self):
+ eq = self.assertEqual
+ neq = self.ndiffAssertEqual
+ fp = openfile('msg_28.txt')
+ try:
+ msg = email.message_from_file(fp)
+ finally:
+ fp.close()
+ # Structure is:
+ # multipart/digest
+ # message/rfc822
+ # text/plain
+ # message/rfc822
+ # text/plain
+ eq(msg.is_multipart(), 1)
+ eq(len(msg.get_payload()), 2)
+ part1 = msg.get_payload(0)
+ eq(part1.get_content_type(), 'message/rfc822')
+ eq(part1.is_multipart(), 1)
+ eq(len(part1.get_payload()), 1)
+ part1a = part1.get_payload(0)
+ eq(part1a.is_multipart(), 0)
+ eq(part1a.get_content_type(), 'text/plain')
+ neq(part1a.get_payload(), 'message 1\n')
+ # next message/rfc822
+ part2 = msg.get_payload(1)
+ eq(part2.get_content_type(), 'message/rfc822')
+ eq(part2.is_multipart(), 1)
+ eq(len(part2.get_payload()), 1)
+ part2a = part2.get_payload(0)
+ eq(part2a.is_multipart(), 0)
+ eq(part2a.get_content_type(), 'text/plain')
+ neq(part2a.get_payload(), 'message 2\n')
+
+ def test_three_lines(self):
+ # A bug report by Andrew McNamara
+ lines = ['From: Andrew Person <aperson at dom.ain',
+ 'Subject: Test',
+ 'Date: Tue, 20 Aug 2002 16:43:45 +1000']
+ msg = email.message_from_string(NL.join(lines))
+ self.assertEqual(msg['date'], 'Tue, 20 Aug 2002 16:43:45 +1000')
+
+ def test_strip_line_feed_and_carriage_return_in_headers(self):
+ eq = self.assertEqual
+ # For [ 1002475 ] email message parser doesn't handle \r\n correctly
+ value1 = 'text'
+ value2 = 'more text'
+ m = 'Header: %s\r\nNext-Header: %s\r\n\r\nBody\r\n\r\n' % (
+ value1, value2)
+ msg = email.message_from_string(m)
+ eq(msg.get('Header'), value1)
+ eq(msg.get('Next-Header'), value2)
+
+ def test_rfc2822_header_syntax(self):
+ eq = self.assertEqual
+ m = '>From: foo\nFrom: bar\n!"#QUX;~: zoo\n\nbody'
+ msg = email.message_from_string(m)
+ eq(len(msg.keys()), 3)
+ keys = msg.keys()
+ keys.sort()
+ eq(keys, ['!"#QUX;~', '>From', 'From'])
+ eq(msg.get_payload(), 'body')
+
+ def test_rfc2822_space_not_allowed_in_header(self):
+ eq = self.assertEqual
+ m = '>From foo at example.com 11:25:53\nFrom: bar\n!"#QUX;~: zoo\n\nbody'
+ msg = email.message_from_string(m)
+ eq(len(msg.keys()), 0)
+
+ def test_rfc2822_one_character_header(self):
+ eq = self.assertEqual
+ m = 'A: first header\nB: second header\nCC: third header\n\nbody'
+ msg = email.message_from_string(m)
+ headers = msg.keys()
+ headers.sort()
+ eq(headers, ['A', 'B', 'CC'])
+ eq(msg.get_payload(), 'body')
+
+
+
+class TestBase64(unittest.TestCase):
+ def test_len(self):
+ eq = self.assertEqual
+ eq(base64mime.base64_len('hello'),
+ len(base64mime.encode('hello', eol='')))
+ for size in range(15):
+ if size == 0 : bsize = 0
+ elif size <= 3 : bsize = 4
+ elif size <= 6 : bsize = 8
+ elif size <= 9 : bsize = 12
+ elif size <= 12: bsize = 16
+ else : bsize = 20
+ eq(base64mime.base64_len('x'*size), bsize)
+
+ def test_decode(self):
+ eq = self.assertEqual
+ eq(base64mime.decode(''), '')
+ eq(base64mime.decode('aGVsbG8='), 'hello')
+ eq(base64mime.decode('aGVsbG8=', 'X'), 'hello')
+ eq(base64mime.decode('aGVsbG8NCndvcmxk\n', 'X'), 'helloXworld')
+
+ def test_encode(self):
+ eq = self.assertEqual
+ eq(base64mime.encode(''), '')
+ eq(base64mime.encode('hello'), 'aGVsbG8=\n')
+ # Test the binary flag
+ eq(base64mime.encode('hello\n'), 'aGVsbG8K\n')
+ eq(base64mime.encode('hello\n', 0), 'aGVsbG8NCg==\n')
+ # Test the maxlinelen arg
+ eq(base64mime.encode('xxxx ' * 20, maxlinelen=40), """\
+eHh4eCB4eHh4IHh4eHggeHh4eCB4eHh4IHh4eHgg
+eHh4eCB4eHh4IHh4eHggeHh4eCB4eHh4IHh4eHgg
+eHh4eCB4eHh4IHh4eHggeHh4eCB4eHh4IHh4eHgg
+eHh4eCB4eHh4IA==
+""")
+ # Test the eol argument
+ eq(base64mime.encode('xxxx ' * 20, maxlinelen=40, eol='\r\n'), """\
+eHh4eCB4eHh4IHh4eHggeHh4eCB4eHh4IHh4eHgg\r
+eHh4eCB4eHh4IHh4eHggeHh4eCB4eHh4IHh4eHgg\r
+eHh4eCB4eHh4IHh4eHggeHh4eCB4eHh4IHh4eHgg\r
+eHh4eCB4eHh4IA==\r
+""")
+
+ def test_header_encode(self):
+ eq = self.assertEqual
+ he = base64mime.header_encode
+ eq(he('hello'), '=?iso-8859-1?b?aGVsbG8=?=')
+ eq(he('hello\nworld'), '=?iso-8859-1?b?aGVsbG8NCndvcmxk?=')
+ # Test the charset option
+ eq(he('hello', charset='iso-8859-2'), '=?iso-8859-2?b?aGVsbG8=?=')
+ # Test the keep_eols flag
+ eq(he('hello\nworld', keep_eols=True),
+ '=?iso-8859-1?b?aGVsbG8Kd29ybGQ=?=')
+ # Test the maxlinelen argument
+ eq(he('xxxx ' * 20, maxlinelen=40), """\
+=?iso-8859-1?b?eHh4eCB4eHh4IHh4eHggeHg=?=
+ =?iso-8859-1?b?eHggeHh4eCB4eHh4IHh4eHg=?=
+ =?iso-8859-1?b?IHh4eHggeHh4eCB4eHh4IHg=?=
+ =?iso-8859-1?b?eHh4IHh4eHggeHh4eCB4eHg=?=
+ =?iso-8859-1?b?eCB4eHh4IHh4eHggeHh4eCA=?=
+ =?iso-8859-1?b?eHh4eCB4eHh4IHh4eHgg?=""")
+ # Test the eol argument
+ eq(he('xxxx ' * 20, maxlinelen=40, eol='\r\n'), """\
+=?iso-8859-1?b?eHh4eCB4eHh4IHh4eHggeHg=?=\r
+ =?iso-8859-1?b?eHggeHh4eCB4eHh4IHh4eHg=?=\r
+ =?iso-8859-1?b?IHh4eHggeHh4eCB4eHh4IHg=?=\r
+ =?iso-8859-1?b?eHh4IHh4eHggeHh4eCB4eHg=?=\r
+ =?iso-8859-1?b?eCB4eHh4IHh4eHggeHh4eCA=?=\r
+ =?iso-8859-1?b?eHh4eCB4eHh4IHh4eHgg?=""")
+
+
+
+class TestQuopri(unittest.TestCase):
+ def setUp(self):
+ self.hlit = [chr(x) for x in range(ord('a'), ord('z')+1)] + \
+ [chr(x) for x in range(ord('A'), ord('Z')+1)] + \
+ [chr(x) for x in range(ord('0'), ord('9')+1)] + \
+ ['!', '*', '+', '-', '/', ' ']
+ self.hnon = [chr(x) for x in range(256) if chr(x) not in self.hlit]
+ assert len(self.hlit) + len(self.hnon) == 256
+ self.blit = [chr(x) for x in range(ord(' '), ord('~')+1)] + ['\t']
+ self.blit.remove('=')
+ self.bnon = [chr(x) for x in range(256) if chr(x) not in self.blit]
+ assert len(self.blit) + len(self.bnon) == 256
+
+ def test_header_quopri_check(self):
+ for c in self.hlit:
+ self.assertFalse(quoprimime.header_quopri_check(c))
+ for c in self.hnon:
+ self.assertTrue(quoprimime.header_quopri_check(c))
+
+ def test_body_quopri_check(self):
+ for c in self.blit:
+ self.assertFalse(quoprimime.body_quopri_check(c))
+ for c in self.bnon:
+ self.assertTrue(quoprimime.body_quopri_check(c))
+
+ def test_header_quopri_len(self):
+ eq = self.assertEqual
+ hql = quoprimime.header_quopri_len
+ enc = quoprimime.header_encode
+ for s in ('hello', 'h at e@l at l@o@'):
+ # Empty charset and no line-endings. 7 == RFC chrome
+ eq(hql(s), len(enc(s, charset='', eol=''))-7)
+ for c in self.hlit:
+ eq(hql(c), 1)
+ for c in self.hnon:
+ eq(hql(c), 3)
+
+ def test_body_quopri_len(self):
+ eq = self.assertEqual
+ bql = quoprimime.body_quopri_len
+ for c in self.blit:
+ eq(bql(c), 1)
+ for c in self.bnon:
+ eq(bql(c), 3)
+
+ def test_quote_unquote_idempotent(self):
+ for x in range(256):
+ c = chr(x)
+ self.assertEqual(quoprimime.unquote(quoprimime.quote(c)), c)
+
+ def test_header_encode(self):
+ eq = self.assertEqual
+ he = quoprimime.header_encode
+ eq(he('hello'), '=?iso-8859-1?q?hello?=')
+ eq(he('hello\nworld'), '=?iso-8859-1?q?hello=0D=0Aworld?=')
+ # Test the charset option
+ eq(he('hello', charset='iso-8859-2'), '=?iso-8859-2?q?hello?=')
+ # Test the keep_eols flag
+ eq(he('hello\nworld', keep_eols=True), '=?iso-8859-1?q?hello=0Aworld?=')
+ # Test a non-ASCII character
+ eq(he('hello\xc7there'), '=?iso-8859-1?q?hello=C7there?=')
+ # Test the maxlinelen argument
+ eq(he('xxxx ' * 20, maxlinelen=40), """\
+=?iso-8859-1?q?xxxx_xxxx_xxxx_xxxx_xx?=
+ =?iso-8859-1?q?xx_xxxx_xxxx_xxxx_xxxx?=
+ =?iso-8859-1?q?_xxxx_xxxx_xxxx_xxxx_x?=
+ =?iso-8859-1?q?xxx_xxxx_xxxx_xxxx_xxx?=
+ =?iso-8859-1?q?x_xxxx_xxxx_?=""")
+ # Test the eol argument
+ eq(he('xxxx ' * 20, maxlinelen=40, eol='\r\n'), """\
+=?iso-8859-1?q?xxxx_xxxx_xxxx_xxxx_xx?=\r
+ =?iso-8859-1?q?xx_xxxx_xxxx_xxxx_xxxx?=\r
+ =?iso-8859-1?q?_xxxx_xxxx_xxxx_xxxx_x?=\r
+ =?iso-8859-1?q?xxx_xxxx_xxxx_xxxx_xxx?=\r
+ =?iso-8859-1?q?x_xxxx_xxxx_?=""")
+
+ def test_decode(self):
+ eq = self.assertEqual
+ eq(quoprimime.decode(''), '')
+ eq(quoprimime.decode('hello'), 'hello')
+ eq(quoprimime.decode('hello', 'X'), 'hello')
+ eq(quoprimime.decode('hello\nworld', 'X'), 'helloXworld')
+
+ def test_encode(self):
+ eq = self.assertEqual
+ eq(quoprimime.encode(''), '')
+ eq(quoprimime.encode('hello'), 'hello')
+ # Test the binary flag
+ eq(quoprimime.encode('hello\r\nworld'), 'hello\nworld')
+ eq(quoprimime.encode('hello\r\nworld', 0), 'hello\nworld')
+ # Test the maxlinelen arg
+ eq(quoprimime.encode('xxxx ' * 20, maxlinelen=40), """\
+xxxx xxxx xxxx xxxx xxxx xxxx xxxx xxxx=
+ xxxx xxxx xxxx xxxx xxxx xxxx xxxx xxx=
+x xxxx xxxx xxxx xxxx=20""")
+ # Test the eol argument
+ eq(quoprimime.encode('xxxx ' * 20, maxlinelen=40, eol='\r\n'), """\
+xxxx xxxx xxxx xxxx xxxx xxxx xxxx xxxx=\r
+ xxxx xxxx xxxx xxxx xxxx xxxx xxxx xxx=\r
+x xxxx xxxx xxxx xxxx=20""")
+ eq(quoprimime.encode("""\
+one line
+
+two line"""), """\
+one line
+
+two line""")
+
+
+
+# Test the Charset class
+class TestCharset(unittest.TestCase):
+ def tearDown(self):
+ from email import charset as CharsetModule
+ try:
+ del CharsetModule.CHARSETS['fake']
+ except KeyError:
+ pass
+
+ def test_idempotent(self):
+ eq = self.assertEqual
+ # Make sure us-ascii = no Unicode conversion
+ c = Charset('us-ascii')
+ s = 'Hello World!'
+ sp = c.to_splittable(s)
+ eq(s, c.from_splittable(sp))
+ # test 8-bit idempotency with us-ascii
+ s = '\xa4\xa2\xa4\xa4\xa4\xa6\xa4\xa8\xa4\xaa'
+ sp = c.to_splittable(s)
+ eq(s, c.from_splittable(sp))
+
+ def test_body_encode(self):
+ eq = self.assertEqual
+ # Try a charset with QP body encoding
+ c = Charset('iso-8859-1')
+ eq('hello w=F6rld', c.body_encode('hello w\xf6rld'))
+ # Try a charset with Base64 body encoding
+ c = Charset('utf-8')
+ eq('aGVsbG8gd29ybGQ=\n', c.body_encode('hello world'))
+ # Try a charset with None body encoding
+ c = Charset('us-ascii')
+ eq('hello world', c.body_encode('hello world'))
+ # Try the convert argument, where input codec != output codec
+ c = Charset('euc-jp')
+ # With apologies to Tokio Kikuchi ;)
+ if not is_jython:
+ # TODO Jython with its Java-based codecs does not
+ # currently support trailing bytes in CJK texts
+ try:
+ eq('\x1b$B5FCO;~IW\x1b(B',
+ c.body_encode('\xb5\xc6\xc3\xcf\xbb\xfe\xc9\xd7'))
+ eq('\xb5\xc6\xc3\xcf\xbb\xfe\xc9\xd7',
+ c.body_encode('\xb5\xc6\xc3\xcf\xbb\xfe\xc9\xd7', False))
+ except LookupError:
+ # We probably don't have the Japanese codecs installed
+ pass
+ # Testing SF bug #625509, which we have to fake, since there are no
+ # built-in encodings where the header encoding is QP but the body
+ # encoding is not.
+ from email import charset as CharsetModule
+ CharsetModule.add_charset('fake', CharsetModule.QP, None)
+ c = Charset('fake')
+ eq('hello w\xf6rld', c.body_encode('hello w\xf6rld'))
+
+ def test_unicode_charset_name(self):
+ charset = Charset(u'us-ascii')
+ self.assertEqual(str(charset), 'us-ascii')
+ self.assertRaises(errors.CharsetError, Charset, 'asc\xffii')
+
+
+
+# Test multilingual MIME headers.
+class TestHeader(TestEmailBase):
+ def test_simple(self):
+ eq = self.ndiffAssertEqual
+ h = Header('Hello World!')
+ eq(h.encode(), 'Hello World!')
+ h.append(' Goodbye World!')
+ eq(h.encode(), 'Hello World! Goodbye World!')
+
+ def test_simple_surprise(self):
+ eq = self.ndiffAssertEqual
+ h = Header('Hello World!')
+ eq(h.encode(), 'Hello World!')
+ h.append('Goodbye World!')
+ eq(h.encode(), 'Hello World! Goodbye World!')
+
+ def test_header_needs_no_decoding(self):
+ h = 'no decoding needed'
+ self.assertEqual(decode_header(h), [(h, None)])
+
+ def test_long(self):
+ h = Header("I am the very model of a modern Major-General; I've information vegetable, animal, and mineral; I know the kings of England, and I quote the fights historical from Marathon to Waterloo, in order categorical; I'm very well acquainted, too, with matters mathematical; I understand equations, both the simple and quadratical; about binomial theorem I'm teeming with a lot o' news, with many cheerful facts about the square of the hypotenuse.",
+ maxlinelen=76)
+ for l in h.encode(splitchars=' ').split('\n '):
+ self.assertTrue(len(l) <= 76)
+
+ def test_multilingual(self):
+ eq = self.ndiffAssertEqual
+ g = Charset("iso-8859-1")
+ cz = Charset("iso-8859-2")
+ utf8 = Charset("utf-8")
+ g_head = "Die Mieter treten hier ein werden mit einem Foerderband komfortabel den Korridor entlang, an s\xfcdl\xfcndischen Wandgem\xe4lden vorbei, gegen die rotierenden Klingen bef\xf6rdert. "
+ cz_head = "Finan\xe8ni metropole se hroutily pod tlakem jejich d\xf9vtipu.. "
+ utf8_head = u"\u6b63\u78ba\u306b\u8a00\u3046\u3068\u7ffb\u8a33\u306f\u3055\u308c\u3066\u3044\u307e\u305b\u3093\u3002\u4e00\u90e8\u306f\u30c9\u30a4\u30c4\u8a9e\u3067\u3059\u304c\u3001\u3042\u3068\u306f\u3067\u305f\u3089\u3081\u3067\u3059\u3002\u5b9f\u969b\u306b\u306f\u300cWenn ist das Nunstuck git und Slotermeyer? Ja! Beiherhund das Oder die Flipperwaldt gersput.\u300d\u3068\u8a00\u3063\u3066\u3044\u307e\u3059\u3002".encode("utf-8")
+ h = Header(g_head, g)
+ h.append(cz_head, cz)
+ h.append(utf8_head, utf8)
+ enc = h.encode()
+ eq(enc, """\
+=?iso-8859-1?q?Die_Mieter_treten_hier_ein_werden_mit_einem_Foerderband_ko?=
+ =?iso-8859-1?q?mfortabel_den_Korridor_entlang=2C_an_s=FCdl=FCndischen_Wan?=
+ =?iso-8859-1?q?dgem=E4lden_vorbei=2C_gegen_die_rotierenden_Klingen_bef=F6?=
+ =?iso-8859-1?q?rdert=2E_?= =?iso-8859-2?q?Finan=E8ni_metropole_se_hroutily?=
+ =?iso-8859-2?q?_pod_tlakem_jejich_d=F9vtipu=2E=2E_?= =?utf-8?b?5q2j56K6?=
+ =?utf-8?b?44Gr6KiA44GG44Go57+76Kiz44Gv44GV44KM44Gm44GE44G+44Gb44KT44CC?=
+ =?utf-8?b?5LiA6YOo44Gv44OJ44Kk44OE6Kqe44Gn44GZ44GM44CB44GC44Go44Gv44Gn?=
+ =?utf-8?b?44Gf44KJ44KB44Gn44GZ44CC5a6f6Zqb44Gr44Gv44CMV2VubiBpc3QgZGFz?=
+ =?utf-8?q?_Nunstuck_git_und_Slotermeyer=3F_Ja!_Beiherhund_das_Oder_die_Fl?=
+ =?utf-8?b?aXBwZXJ3YWxkdCBnZXJzcHV0LuOAjeOBqOiogOOBo+OBpuOBhOOBvuOBmQ==?=
+ =?utf-8?b?44CC?=""")
+ eq(decode_header(enc),
+ [(g_head, "iso-8859-1"), (cz_head, "iso-8859-2"),
+ (utf8_head, "utf-8")])
+ ustr = unicode(h)
+ eq(ustr.encode('utf-8'),
+ 'Die Mieter treten hier ein werden mit einem Foerderband '
+ 'komfortabel den Korridor entlang, an s\xc3\xbcdl\xc3\xbcndischen '
+ 'Wandgem\xc3\xa4lden vorbei, gegen die rotierenden Klingen '
+ 'bef\xc3\xb6rdert. Finan\xc4\x8dni metropole se hroutily pod '
+ 'tlakem jejich d\xc5\xafvtipu.. \xe6\xad\xa3\xe7\xa2\xba\xe3\x81'
+ '\xab\xe8\xa8\x80\xe3\x81\x86\xe3\x81\xa8\xe7\xbf\xbb\xe8\xa8\xb3'
+ '\xe3\x81\xaf\xe3\x81\x95\xe3\x82\x8c\xe3\x81\xa6\xe3\x81\x84\xe3'
+ '\x81\xbe\xe3\x81\x9b\xe3\x82\x93\xe3\x80\x82\xe4\xb8\x80\xe9\x83'
+ '\xa8\xe3\x81\xaf\xe3\x83\x89\xe3\x82\xa4\xe3\x83\x84\xe8\xaa\x9e'
+ '\xe3\x81\xa7\xe3\x81\x99\xe3\x81\x8c\xe3\x80\x81\xe3\x81\x82\xe3'
+ '\x81\xa8\xe3\x81\xaf\xe3\x81\xa7\xe3\x81\x9f\xe3\x82\x89\xe3\x82'
+ '\x81\xe3\x81\xa7\xe3\x81\x99\xe3\x80\x82\xe5\xae\x9f\xe9\x9a\x9b'
+ '\xe3\x81\xab\xe3\x81\xaf\xe3\x80\x8cWenn ist das Nunstuck git '
+ 'und Slotermeyer? Ja! Beiherhund das Oder die Flipperwaldt '
+ 'gersput.\xe3\x80\x8d\xe3\x81\xa8\xe8\xa8\x80\xe3\x81\xa3\xe3\x81'
+ '\xa6\xe3\x81\x84\xe3\x81\xbe\xe3\x81\x99\xe3\x80\x82')
+ # Test make_header()
+ newh = make_header(decode_header(enc))
+ eq(newh, enc)
+
+ def test_header_ctor_default_args(self):
+ eq = self.ndiffAssertEqual
+ h = Header()
+ eq(h, '')
+ h.append('foo', Charset('iso-8859-1'))
+ eq(h, '=?iso-8859-1?q?foo?=')
+
+ def test_explicit_maxlinelen(self):
+ eq = self.ndiffAssertEqual
+ hstr = 'A very long line that must get split to something other than at the 76th character boundary to test the non-default behavior'
+ h = Header(hstr)
+ eq(h.encode(), '''\
+A very long line that must get split to something other than at the 76th
+ character boundary to test the non-default behavior''')
+ h = Header(hstr, header_name='Subject')
+ eq(h.encode(), '''\
+A very long line that must get split to something other than at the
+ 76th character boundary to test the non-default behavior''')
+ h = Header(hstr, maxlinelen=1024, header_name='Subject')
+ eq(h.encode(), hstr)
+
+ def test_us_ascii_header(self):
+ eq = self.assertEqual
+ s = 'hello'
+ x = decode_header(s)
+ eq(x, [('hello', None)])
+ h = make_header(x)
+ eq(s, h.encode())
+
+ def test_string_charset(self):
+ eq = self.assertEqual
+ h = Header()
+ h.append('hello', 'iso-8859-1')
+ eq(h, '=?iso-8859-1?q?hello?=')
+
+## def test_unicode_error(self):
+## raises = self.assertRaises
+## raises(UnicodeError, Header, u'[P\xf6stal]', 'us-ascii')
+## raises(UnicodeError, Header, '[P\xf6stal]', 'us-ascii')
+## h = Header()
+## raises(UnicodeError, h.append, u'[P\xf6stal]', 'us-ascii')
+## raises(UnicodeError, h.append, '[P\xf6stal]', 'us-ascii')
+## raises(UnicodeError, Header, u'\u83ca\u5730\u6642\u592b', 'iso-8859-1')
+
+ def test_utf8_shortest(self):
+ eq = self.assertEqual
+ h = Header(u'p\xf6stal', 'utf-8')
+ eq(h.encode(), '=?utf-8?q?p=C3=B6stal?=')
+ h = Header(u'\u83ca\u5730\u6642\u592b', 'utf-8')
+ eq(h.encode(), '=?utf-8?b?6I+K5Zyw5pmC5aSr?=')
+
+ def test_bad_8bit_header(self):
+ raises = self.assertRaises
+ eq = self.assertEqual
+ x = 'Ynwp4dUEbay Auction Semiar- No Charge \x96 Earn Big'
+ raises(UnicodeError, Header, x)
+ h = Header()
+ raises(UnicodeError, h.append, x)
+ eq(str(Header(x, errors='replace')), x)
+ h.append(x, errors='replace')
+ eq(str(h), x)
+
+ def test_encoded_adjacent_nonencoded(self):
+ eq = self.assertEqual
+ h = Header()
+ h.append('hello', 'iso-8859-1')
+ h.append('world')
+ s = h.encode()
+ eq(s, '=?iso-8859-1?q?hello?= world')
+ h = make_header(decode_header(s))
+ eq(h.encode(), s)
+
+ def test_whitespace_eater(self):
+ eq = self.assertEqual
+ s = 'Subject: =?koi8-r?b?8NLP18XSy8EgzsEgxsnOwczYztk=?= =?koi8-r?q?=CA?= zz.'
+ parts = decode_header(s)
+ eq(parts, [('Subject:', None), ('\xf0\xd2\xcf\xd7\xc5\xd2\xcb\xc1 \xce\xc1 \xc6\xc9\xce\xc1\xcc\xd8\xce\xd9\xca', 'koi8-r'), ('zz.', None)])
+ hdr = make_header(parts)
+ eq(hdr.encode(),
+ 'Subject: =?koi8-r?b?8NLP18XSy8EgzsEgxsnOwczYztnK?= zz.')
+
+ def test_broken_base64_header(self):
+ raises = self.assertRaises
+ s = 'Subject: =?EUC-KR?B?CSixpLDtKSC/7Liuvsax4iC6uLmwMcijIKHaILzSwd/H0SC8+LCjwLsgv7W/+Mj3I ?='
+ raises(errors.HeaderParseError, decode_header, s)
+
+
+
+# Test RFC 2231 header parameters (en/de)coding
+class TestRFC2231(TestEmailBase):
+ def test_get_param(self):
+ eq = self.assertEqual
+ msg = self._msgobj('msg_29.txt')
+ eq(msg.get_param('title'),
+ ('us-ascii', 'en', 'This is even more ***fun*** isn\'t it!'))
+ eq(msg.get_param('title', unquote=False),
+ ('us-ascii', 'en', '"This is even more ***fun*** isn\'t it!"'))
+
+ def test_set_param(self):
+ eq = self.assertEqual
+ msg = Message()
+ msg.set_param('title', 'This is even more ***fun*** isn\'t it!',
+ charset='us-ascii')
+ eq(msg.get_param('title'),
+ ('us-ascii', '', 'This is even more ***fun*** isn\'t it!'))
+ msg.set_param('title', 'This is even more ***fun*** isn\'t it!',
+ charset='us-ascii', language='en')
+ eq(msg.get_param('title'),
+ ('us-ascii', 'en', 'This is even more ***fun*** isn\'t it!'))
+ msg = self._msgobj('msg_01.txt')
+ msg.set_param('title', 'This is even more ***fun*** isn\'t it!',
+ charset='us-ascii', language='en')
+ self.ndiffAssertEqual(msg.as_string(), """\
+Return-Path: <bbb at zzz.org>
+Delivered-To: bbb at zzz.org
+Received: by mail.zzz.org (Postfix, from userid 889)
+ id 27CEAD38CC; Fri, 4 May 2001 14:05:44 -0400 (EDT)
+MIME-Version: 1.0
+Content-Transfer-Encoding: 7bit
+Message-ID: <15090.61304.110929.45684 at aaa.zzz.org>
+From: bbb at ddd.com (John X. Doe)
+To: bbb at zzz.org
+Subject: This is a test message
+Date: Fri, 4 May 2001 14:05:44 -0400
+Content-Type: text/plain; charset=us-ascii;
+ title*="us-ascii'en'This%20is%20even%20more%20%2A%2A%2Afun%2A%2A%2A%20isn%27t%20it%21"
+
+
+Hi,
+
+Do you like this message?
+
+-Me
+""")
+
+ def test_del_param(self):
+ eq = self.ndiffAssertEqual
+ msg = self._msgobj('msg_01.txt')
+ msg.set_param('foo', 'bar', charset='us-ascii', language='en')
+ msg.set_param('title', 'This is even more ***fun*** isn\'t it!',
+ charset='us-ascii', language='en')
+ msg.del_param('foo', header='Content-Type')
+ eq(msg.as_string(), """\
+Return-Path: <bbb at zzz.org>
+Delivered-To: bbb at zzz.org
+Received: by mail.zzz.org (Postfix, from userid 889)
+ id 27CEAD38CC; Fri, 4 May 2001 14:05:44 -0400 (EDT)
+MIME-Version: 1.0
+Content-Transfer-Encoding: 7bit
+Message-ID: <15090.61304.110929.45684 at aaa.zzz.org>
+From: bbb at ddd.com (John X. Doe)
+To: bbb at zzz.org
+Subject: This is a test message
+Date: Fri, 4 May 2001 14:05:44 -0400
+Content-Type: text/plain; charset="us-ascii";
+ title*="us-ascii'en'This%20is%20even%20more%20%2A%2A%2Afun%2A%2A%2A%20isn%27t%20it%21"
+
+
+Hi,
+
+Do you like this message?
+
+-Me
+""")
+
+ def test_rfc2231_get_content_charset(self):
+ eq = self.assertEqual
+ msg = self._msgobj('msg_32.txt')
+ eq(msg.get_content_charset(), 'us-ascii')
+
+ def test_rfc2231_no_language_or_charset(self):
+ m = '''\
+Content-Transfer-Encoding: 8bit
+Content-Disposition: inline; filename="file____C__DOCUMENTS_20AND_20SETTINGS_FABIEN_LOCAL_20SETTINGS_TEMP_nsmail.htm"
+Content-Type: text/html; NAME*0=file____C__DOCUMENTS_20AND_20SETTINGS_FABIEN_LOCAL_20SETTINGS_TEM; NAME*1=P_nsmail.htm
+
+'''
+ msg = email.message_from_string(m)
+ param = msg.get_param('NAME')
+ self.assertFalse(isinstance(param, tuple))
+ self.assertEqual(
+ param,
+ 'file____C__DOCUMENTS_20AND_20SETTINGS_FABIEN_LOCAL_20SETTINGS_TEMP_nsmail.htm')
+
+ def test_rfc2231_no_language_or_charset_in_filename(self):
+ m = '''\
+Content-Disposition: inline;
+\tfilename*0*="''This%20is%20even%20more%20";
+\tfilename*1*="%2A%2A%2Afun%2A%2A%2A%20";
+\tfilename*2="is it not.pdf"
+
+'''
+ msg = email.message_from_string(m)
+ self.assertEqual(msg.get_filename(),
+ 'This is even more ***fun*** is it not.pdf')
+
+ def test_rfc2231_no_language_or_charset_in_filename_encoded(self):
+ m = '''\
+Content-Disposition: inline;
+\tfilename*0*="''This%20is%20even%20more%20";
+\tfilename*1*="%2A%2A%2Afun%2A%2A%2A%20";
+\tfilename*2="is it not.pdf"
+
+'''
+ msg = email.message_from_string(m)
+ self.assertEqual(msg.get_filename(),
+ 'This is even more ***fun*** is it not.pdf')
+
+ def test_rfc2231_partly_encoded(self):
+ m = '''\
+Content-Disposition: inline;
+\tfilename*0="''This%20is%20even%20more%20";
+\tfilename*1*="%2A%2A%2Afun%2A%2A%2A%20";
+\tfilename*2="is it not.pdf"
+
+'''
+ msg = email.message_from_string(m)
+ self.assertEqual(
+ msg.get_filename(),
+ 'This%20is%20even%20more%20***fun*** is it not.pdf')
+
+ def test_rfc2231_partly_nonencoded(self):
+ m = '''\
+Content-Disposition: inline;
+\tfilename*0="This%20is%20even%20more%20";
+\tfilename*1="%2A%2A%2Afun%2A%2A%2A%20";
+\tfilename*2="is it not.pdf"
+
+'''
+ msg = email.message_from_string(m)
+ self.assertEqual(
+ msg.get_filename(),
+ 'This%20is%20even%20more%20%2A%2A%2Afun%2A%2A%2A%20is it not.pdf')
+
+ def test_rfc2231_no_language_or_charset_in_boundary(self):
+ m = '''\
+Content-Type: multipart/alternative;
+\tboundary*0*="''This%20is%20even%20more%20";
+\tboundary*1*="%2A%2A%2Afun%2A%2A%2A%20";
+\tboundary*2="is it not.pdf"
+
+'''
+ msg = email.message_from_string(m)
+ self.assertEqual(msg.get_boundary(),
+ 'This is even more ***fun*** is it not.pdf')
+
+ def test_rfc2231_no_language_or_charset_in_charset(self):
+ # This is a nonsensical charset value, but tests the code anyway
+ m = '''\
+Content-Type: text/plain;
+\tcharset*0*="This%20is%20even%20more%20";
+\tcharset*1*="%2A%2A%2Afun%2A%2A%2A%20";
+\tcharset*2="is it not.pdf"
+
+'''
+ msg = email.message_from_string(m)
+ self.assertEqual(msg.get_content_charset(),
+ 'this is even more ***fun*** is it not.pdf')
+
+ def test_rfc2231_bad_encoding_in_filename(self):
+ m = '''\
+Content-Disposition: inline;
+\tfilename*0*="bogus'xx'This%20is%20even%20more%20";
+\tfilename*1*="%2A%2A%2Afun%2A%2A%2A%20";
+\tfilename*2="is it not.pdf"
+
+'''
+ msg = email.message_from_string(m)
+ self.assertEqual(msg.get_filename(),
+ 'This is even more ***fun*** is it not.pdf')
+
+ def test_rfc2231_bad_encoding_in_charset(self):
+ m = """\
+Content-Type: text/plain; charset*=bogus''utf-8%E2%80%9D
+
+"""
+ msg = email.message_from_string(m)
+ # This should return None because non-ascii characters in the charset
+ # are not allowed.
+ self.assertEqual(msg.get_content_charset(), None)
+
+ def test_rfc2231_bad_character_in_charset(self):
+ m = """\
+Content-Type: text/plain; charset*=ascii''utf-8%E2%80%9D
+
+"""
+ msg = email.message_from_string(m)
+ # This should return None because non-ascii characters in the charset
+ # are not allowed.
+ self.assertEqual(msg.get_content_charset(), None)
+
+ def test_rfc2231_bad_character_in_filename(self):
+ m = '''\
+Content-Disposition: inline;
+\tfilename*0*="ascii'xx'This%20is%20even%20more%20";
+\tfilename*1*="%2A%2A%2Afun%2A%2A%2A%20";
+\tfilename*2*="is it not.pdf%E2"
+
+'''
+ msg = email.message_from_string(m)
+ self.assertEqual(msg.get_filename(),
+ u'This is even more ***fun*** is it not.pdf\ufffd')
+
+ def test_rfc2231_unknown_encoding(self):
+ m = """\
+Content-Transfer-Encoding: 8bit
+Content-Disposition: inline; filename*=X-UNKNOWN''myfile.txt
+
+"""
+ msg = email.message_from_string(m)
+ self.assertEqual(msg.get_filename(), 'myfile.txt')
+
+ def test_rfc2231_single_tick_in_filename_extended(self):
+ eq = self.assertEqual
+ m = """\
+Content-Type: application/x-foo;
+\tname*0*=\"Frank's\"; name*1*=\" Document\"
+
+"""
+ msg = email.message_from_string(m)
+ charset, language, s = msg.get_param('name')
+ eq(charset, None)
+ eq(language, None)
+ eq(s, "Frank's Document")
+
+ def test_rfc2231_single_tick_in_filename(self):
+ m = """\
+Content-Type: application/x-foo; name*0=\"Frank's\"; name*1=\" Document\"
+
+"""
+ msg = email.message_from_string(m)
+ param = msg.get_param('name')
+ self.assertFalse(isinstance(param, tuple))
+ self.assertEqual(param, "Frank's Document")
+
+ def test_rfc2231_tick_attack_extended(self):
+ eq = self.assertEqual
+ m = """\
+Content-Type: application/x-foo;
+\tname*0*=\"us-ascii'en-us'Frank's\"; name*1*=\" Document\"
+
+"""
+ msg = email.message_from_string(m)
+ charset, language, s = msg.get_param('name')
+ eq(charset, 'us-ascii')
+ eq(language, 'en-us')
+ eq(s, "Frank's Document")
+
+ def test_rfc2231_tick_attack(self):
+ m = """\
+Content-Type: application/x-foo;
+\tname*0=\"us-ascii'en-us'Frank's\"; name*1=\" Document\"
+
+"""
+ msg = email.message_from_string(m)
+ param = msg.get_param('name')
+ self.assertFalse(isinstance(param, tuple))
+ self.assertEqual(param, "us-ascii'en-us'Frank's Document")
+
+ def test_rfc2231_no_extended_values(self):
+ eq = self.assertEqual
+ m = """\
+Content-Type: application/x-foo; name=\"Frank's Document\"
+
+"""
+ msg = email.message_from_string(m)
+ eq(msg.get_param('name'), "Frank's Document")
+
+ def test_rfc2231_encoded_then_unencoded_segments(self):
+ eq = self.assertEqual
+ m = """\
+Content-Type: application/x-foo;
+\tname*0*=\"us-ascii'en-us'My\";
+\tname*1=\" Document\";
+\tname*2*=\" For You\"
+
+"""
+ msg = email.message_from_string(m)
+ charset, language, s = msg.get_param('name')
+ eq(charset, 'us-ascii')
+ eq(language, 'en-us')
+ eq(s, 'My Document For You')
+
+ def test_rfc2231_unencoded_then_encoded_segments(self):
+ eq = self.assertEqual
+ m = """\
+Content-Type: application/x-foo;
+\tname*0=\"us-ascii'en-us'My\";
+\tname*1*=\" Document\";
+\tname*2*=\" For You\"
+
+"""
+ msg = email.message_from_string(m)
+ charset, language, s = msg.get_param('name')
+ eq(charset, 'us-ascii')
+ eq(language, 'en-us')
+ eq(s, 'My Document For You')
+
+
+
+def _testclasses():
+ mod = sys.modules[__name__]
+ return [getattr(mod, name) for name in dir(mod) if name.startswith('Test')]
+
+
+def suite():
+ suite = unittest.TestSuite()
+ for testclass in _testclasses():
+ suite.addTest(unittest.makeSuite(testclass))
+ return suite
+
+
+def test_main():
+ for testclass in _testclasses():
+ run_unittest(testclass)
+
+
+
+if __name__ == '__main__':
+ unittest.main(defaultTest='suite')
diff --git a/Lib/gzip.py b/Lib/gzip.py
--- a/Lib/gzip.py
+++ b/Lib/gzip.py
@@ -21,9 +21,6 @@
# or unsigned.
output.write(struct.pack("<L", value))
-def read32(input):
- return struct.unpack("<I", input.read(4))[0]
-
def open(filename, mode="rb", compresslevel=9):
"""Shorthand for GzipFile(filename, mode, compresslevel).
@@ -40,11 +37,7 @@
"""
myfileobj = None
- # XXX: repeated 10mb chunk reads hurt test_gzip.test_many_append's
- # performance on Jython (maybe CPython's allocator recycles the same
- # 10mb buffer whereas Java's doesn't)
- #max_read_chunk = 10 * 1024 * 1024 # 10Mb
- max_read_chunk = 256 * 1024 # 256kb
+ max_read_chunk = 10 * 1024 * 1024 # 10Mb
def __init__(self, filename=None, mode=None,
compresslevel=9, fileobj=None, mtime=None):
@@ -70,9 +63,10 @@
Be aware that only the 'rb', 'ab', and 'wb' values should be used
for cross-platform portability.
- The compresslevel argument is an integer from 1 to 9 controlling the
+ The compresslevel argument is an integer from 0 to 9 controlling the
level of compression; 1 is fastest and produces the least compression,
- and 9 is slowest and produces the most compression. The default is 9.
+ and 9 is slowest and produces the most compression. 0 is no compression
+ at all. The default is 9.
The mtime argument is an optional numeric timestamp to be written
to the stream when compressing. All gzip compressed streams
@@ -85,6 +79,10 @@
"""
+ # Make sure we don't inadvertently enable universal newlines on the
+ # underlying file object - in read mode, this causes data corruption.
+ if mode:
+ mode = mode.replace('U', '')
# guarantee the file is opened in binary mode on platforms
# that care about that sort of thing
if mode and 'b' not in mode:
@@ -92,8 +90,12 @@
if fileobj is None:
fileobj = self.myfileobj = __builtin__.open(filename, mode or 'rb')
if filename is None:
- if hasattr(fileobj, 'name'): filename = fileobj.name
- else: filename = ''
+ # Issue #13781: os.fdopen() creates a fileobj with a bogus name
+ # attribute. Avoid saving this in the gzip header's filename field.
+ if hasattr(fileobj, 'name') and fileobj.name != '<fdopen>':
+ filename = fileobj.name
+ else:
+ filename = ''
if mode is None:
if hasattr(fileobj, 'mode'): mode = fileobj.mode
else: mode = 'rb'
@@ -179,24 +181,28 @@
self.crc = zlib.crc32("") & 0xffffffffL
self.size = 0
+ def _read_exact(self, n):
+ data = self.fileobj.read(n)
+ while len(data) < n:
+ b = self.fileobj.read(n - len(data))
+ if not b:
+ raise EOFError("Compressed file ended before the "
+ "end-of-stream marker was reached")
+ data += b
+ return data
+
def _read_gzip_header(self):
magic = self.fileobj.read(2)
if magic != '\037\213':
raise IOError, 'Not a gzipped file'
- method = ord( self.fileobj.read(1) )
+
+ method, flag, self.mtime = struct.unpack("<BBIxx", self._read_exact(8))
if method != 8:
raise IOError, 'Unknown compression method'
- flag = ord( self.fileobj.read(1) )
- self.mtime = read32(self.fileobj)
- # extraflag = self.fileobj.read(1)
- # os = self.fileobj.read(1)
- self.fileobj.read(2)
if flag & FEXTRA:
# Read & discard the extra field, if present
- xlen = ord(self.fileobj.read(1))
- xlen = xlen + 256*ord(self.fileobj.read(1))
- self.fileobj.read(xlen)
+ self._read_exact(struct.unpack("<H", self._read_exact(2)))
if flag & FNAME:
# Read and discard a null-terminated string containing the filename
while True:
@@ -210,7 +216,7 @@
if not s or s=='\000':
break
if flag & FHCRC:
- self.fileobj.read(2) # Read & discard the 16-bit header CRC
+ self._read_exact(2) # Read & discard the 16-bit header CRC
def write(self,data):
self._check_closed()
@@ -244,20 +250,16 @@
readsize = 1024
if size < 0: # get the whole thing
- try:
- while True:
- self._read(readsize)
- readsize = min(self.max_read_chunk, readsize * 2)
- except EOFError:
- size = self.extrasize
+ while self._read(readsize):
+ readsize = min(self.max_read_chunk, readsize * 2)
+ size = self.extrasize
else: # just get some more of it
- try:
- while size > self.extrasize:
- self._read(readsize)
- readsize = min(self.max_read_chunk, readsize * 2)
- except EOFError:
- if size > self.extrasize:
- size = self.extrasize
+ while size > self.extrasize:
+ if not self._read(readsize):
+ if size > self.extrasize:
+ size = self.extrasize
+ break
+ readsize = min(self.max_read_chunk, readsize * 2)
offset = self.offset - self.extrastart
chunk = self.extrabuf[offset: offset + size]
@@ -272,7 +274,7 @@
def _read(self, size=1024):
if self.fileobj is None:
- raise EOFError, "Reached EOF"
+ return False
if self._new_member:
# If the _new_member flag is set, we have to
@@ -283,7 +285,7 @@
pos = self.fileobj.tell() # Save current position
self.fileobj.seek(0, 2) # Seek to end of file
if pos == self.fileobj.tell():
- raise EOFError, "Reached EOF"
+ return False
else:
self.fileobj.seek( pos ) # Return to original position
@@ -300,9 +302,10 @@
if buf == "":
uncompress = self.decompress.flush()
+ self.fileobj.seek(-len(self.decompress.unused_data), 1)
self._read_eof()
self._add_read_data( uncompress )
- raise EOFError, 'Reached EOF'
+ return False
uncompress = self.decompress.decompress(buf)
self._add_read_data( uncompress )
@@ -312,13 +315,14 @@
# so seek back to the start of the unused data, finish up
# this member, and read a new gzip header.
# (The number of bytes to seek back is the length of the unused
- # data, minus 8 because _read_eof() will rewind a further 8 bytes)
- self.fileobj.seek( -len(self.decompress.unused_data)+8, 1)
+ # data)
+ self.fileobj.seek(-len(self.decompress.unused_data), 1)
# Check the CRC and file size, and set the flag so we read
# a new member on the next call
self._read_eof()
self._new_member = True
+ return True
def _add_read_data(self, data):
self.crc = zlib.crc32(data, self.crc) & 0xffffffffL
@@ -329,14 +333,11 @@
self.size = self.size + len(data)
def _read_eof(self):
- # We've read to the end of the file, so we have to rewind in order
- # to reread the 8 bytes containing the CRC and the file size.
+ # We've read to the end of the file.
# We check the that the computed CRC and size of the
# uncompressed data matches the stored values. Note that the size
# stored is the true file size mod 2**32.
- self.fileobj.seek(-8, 1)
- crc32 = read32(self.fileobj)
- isize = read32(self.fileobj) # may exceed 2GB
+ crc32, isize = struct.unpack("<II", self._read_exact(8))
if crc32 != self.crc:
raise IOError("CRC check failed %s != %s" % (hex(crc32),
hex(self.crc)))
@@ -371,6 +372,17 @@
self.myfileobj.close()
self.myfileobj = None
+ def __enter__(self):
+ # __enter__ is defined in _jyio._IOBase (aka
+ # org.python.modules._io.PyIOBase), but because we override
+ # the closed property in this class, we need to reproduce this
+ # method here for the calls to properly go through, much like
+ # the difference seen in _check_closed vs _checkClosed
+ self._check_closed()
+ return self
+
+ __iter__ = __enter__
+
if not sys.platform.startswith('java'):
def flush(self,zlib_mode=zlib.Z_SYNC_FLUSH):
self._check_closed()
@@ -424,7 +436,7 @@
if offset < self.offset:
raise IOError('Negative seek in write mode')
count = offset - self.offset
- for i in range(count // 1024):
+ for i in xrange(count // 1024):
self.write(1024 * '\0')
self.write((count % 1024) * '\0')
elif self.mode == READ:
@@ -432,7 +444,7 @@
# for negative seek, rewind and do positive seek
self.rewind()
count = offset - self.offset
- for i in range(count // 1024):
+ for i in xrange(count // 1024):
self.read(1024)
self.read(count % 1024)
diff --git a/Lib/json/tests/test_recursion.py b/Lib/json/tests/test_recursion.py
new file mode 100644
--- /dev/null
+++ b/Lib/json/tests/test_recursion.py
@@ -0,0 +1,112 @@
+from json.tests import PyTest, CTest
+
+
+class JSONTestObject:
+ pass
+
+
+class TestRecursion(object):
+ def test_listrecursion(self):
+ x = []
+ x.append(x)
+ try:
+ self.dumps(x)
+ except ValueError:
+ pass
+ else:
+ self.fail("didn't raise ValueError on list recursion")
+ x = []
+ y = [x]
+ x.append(y)
+ try:
+ self.dumps(x)
+ except ValueError:
+ pass
+ else:
+ self.fail("didn't raise ValueError on alternating list recursion")
+ y = []
+ x = [y, y]
+ # ensure that the marker is cleared
+ self.dumps(x)
+
+ def test_dictrecursion(self):
+ x = {}
+ x["test"] = x
+ try:
+ self.dumps(x)
+ except ValueError:
+ pass
+ else:
+ self.fail("didn't raise ValueError on dict recursion")
+ x = {}
+ y = {"a": x, "b": x}
+ # ensure that the marker is cleared
+ self.dumps(x)
+
+ def test_defaultrecursion(self):
+ class RecursiveJSONEncoder(self.json.JSONEncoder):
+ recurse = False
+ def default(self, o):
+ if o is JSONTestObject:
+ if self.recurse:
+ return [JSONTestObject]
+ else:
+ return 'JSONTestObject'
+ return pyjson.JSONEncoder.default(o)
+
+ enc = RecursiveJSONEncoder()
+ self.assertEqual(enc.encode(JSONTestObject), '"JSONTestObject"')
+ enc.recurse = True
+ try:
+ enc.encode(JSONTestObject)
+ except ValueError:
+ pass
+ else:
+ self.fail("didn't raise ValueError on default recursion")
+
+
+ def test_highly_nested_objects_decoding(self):
+ # test that loading highly-nested objects doesn't segfault when C
+ # accelerations are used. See #12017
+ # str
+ with self.assertRaises(RuntimeError):
+ self.loads('{"a":' * 100000 + '1' + '}' * 100000)
+ with self.assertRaises(RuntimeError):
+ self.loads('{"a":' * 100000 + '[1]' + '}' * 100000)
+ with self.assertRaises(RuntimeError):
+ self.loads('[' * 100000 + '1' + ']' * 100000)
+ # unicode
+ with self.assertRaises(RuntimeError):
+ self.loads(u'{"a":' * 100000 + u'1' + u'}' * 100000)
+ with self.assertRaises(RuntimeError):
+ self.loads(u'{"a":' * 100000 + u'[1]' + u'}' * 100000)
+ with self.assertRaises(RuntimeError):
+ self.loads(u'[' * 100000 + u'1' + u']' * 100000)
+
+ def test_highly_nested_objects_encoding(self):
+ # See #12051
+ l, d = [], {}
+ for x in xrange(100000):
+ l, d = [l], {'k':d}
+ with self.assertRaises(RuntimeError):
+ self.dumps(l)
+ with self.assertRaises(RuntimeError):
+ self.dumps(d)
+
+ def test_endless_recursion(self):
+ # See #12051
+ class EndlessJSONEncoder(self.json.JSONEncoder):
+ def default(self, o):
+ """If check_circular is False, this will keep adding another list."""
+ return [o]
+
+ # NB: Jython interacts with overflows differently than CPython;
+ # given that the default function normally raises a ValueError upon
+ # an overflow, this seems reasonable.
+ with self.assertRaises(Exception) as cm:
+ EndlessJSONEncoder(check_circular=False).encode(5j)
+ self.assertIn(type(cm.exception), [RuntimeError, ValueError])
+
+
+class TestPyRecursion(TestRecursion, PyTest): pass
+class TestCRecursion(TestRecursion, CTest): pass
diff --git a/Lib/json/tests/test_tool.py b/Lib/json/tests/test_tool.py
new file mode 100644
--- /dev/null
+++ b/Lib/json/tests/test_tool.py
@@ -0,0 +1,75 @@
+import os
+import sys
+import textwrap
+import unittest
+import subprocess
+from test import test_support
+from test.script_helper import assert_python_ok
+
+class TestTool(unittest.TestCase):
+ data = """
+
+ [["blorpie"],[ "whoops" ] , [
+ ],\t"d-shtaeou",\r"d-nthiouh",
+ "i-vhbjkhnth", {"nifty":87}, {"morefield" :\tfalse,"field"
+ :"yes"} ]
+ """
+
+ expect = textwrap.dedent("""\
+ [
+ [
+ "blorpie"
+ ],
+ [
+ "whoops"
+ ],
+ [],
+ "d-shtaeou",
+ "d-nthiouh",
+ "i-vhbjkhnth",
+ {
+ "nifty": 87
+ },
+ {
+ "field": "yes",
+ "morefield": false
+ }
+ ]
+ """)
+
+ @unittest.skipIf(test_support.is_jython, "Revisit when http://bugs.jython.org/issue695383 is fixed")
+ def test_stdin_stdout(self):
+ proc = subprocess.Popen(
+ (sys.executable, '-m', 'json.tool'),
+ stdin=subprocess.PIPE, stdout=subprocess.PIPE)
+ out, err = proc.communicate(self.data.encode())
+ self.assertEqual(out.splitlines(), self.expect.encode().splitlines())
+ self.assertEqual(err, None)
+
+ def _create_infile(self):
+ infile = test_support.TESTFN
+ with open(infile, "w") as fp:
+ self.addCleanup(os.remove, infile)
+ fp.write(self.data)
+ return infile
+
+ # This is a problem orthogonal to json support, even for usage of
+ # this tool. Instead it seems to be a problem in simply testing
+ # it. TODO fix this underlying issue that's been outstanding for a
+ # while in Jython.
+ @unittest.skipIf(test_support.is_jython, "Revisit when http://bugs.jython.org/issue695383 is fixed")
+ def test_infile_stdout(self):
+ infile = self._create_infile()
+ rc, out, err = assert_python_ok('-m', 'json.tool', infile)
+ self.assertEqual(out.splitlines(), self.expect.encode().splitlines())
+ self.assertEqual(err, b'')
+
+ def test_infile_outfile(self):
+ infile = self._create_infile()
+ outfile = test_support.TESTFN + '.out'
+ rc, out, err = assert_python_ok('-m', 'json.tool', infile, outfile)
+ self.addCleanup(os.remove, outfile)
+ with open(outfile, "r") as fp:
+ self.assertEqual(fp.read(), self.expect)
+ self.assertEqual(out, b'')
+ self.assertEqual(err, b'')
diff --git a/Lib/pkgutil.py b/Lib/pkgutil.py
--- a/Lib/pkgutil.py
+++ b/Lib/pkgutil.py
@@ -214,7 +214,12 @@
if not modname and os.path.isdir(path) and '.' not in fn:
modname = fn
- for fn in os.listdir(path):
+ try:
+ dircontents = os.listdir(path)
+ except OSError:
+ # ignore unreadable directories like import does
+ dircontents = []
+ for fn in dircontents:
subname = inspect.getmodulename(fn)
if subname=='__init__':
ispkg = True
diff --git a/Lib/tarfile.py b/Lib/tarfile.py
--- a/Lib/tarfile.py
+++ b/Lib/tarfile.py
@@ -35,8 +35,8 @@
version = "0.9.0"
__author__ = "Lars Gustäbel (lars at gustaebel.de)"
-__date__ = "$Date: 2010-10-04 08:37:53 -0700 (ma, 04 loka  2010) $"
-__cvsid__ = "$Id: tarfile.py 85213 2010-10-04 15:37:53Z lars.gustaebel $"
+__date__ = "$Date$"
+__cvsid__ = "$Id$"
__credits__ = "Gustavo Niemeyer, Niels Gustäbel, Richard Townsend."
#---------
@@ -454,6 +454,8 @@
0)
timestamp = struct.pack("<L", long(time.time()))
self.__write("\037\213\010\010%s\002\377" % timestamp)
+ if type(self.name) is unicode:
+ self.name = self.name.encode("iso-8859-1", "replace")
if self.name.endswith(".gz"):
self.name = self.name[:-3]
self.__write(self.name + NUL)
@@ -627,7 +629,7 @@
def getcomptype(self):
if self.buf.startswith("\037\213\010"):
return "gz"
- if self.buf.startswith("BZh91"):
+ if self.buf[0:3] == "BZh" and self.buf[4:10] == "1AY&SY":
return "bz2"
return "tar"
@@ -1721,7 +1723,6 @@
try:
t = cls.taropen(name, mode, fileobj, **kwargs)
except IOError:
- fileobj.close()
raise ReadError("not a gzip file")
t._extfileobj = False
return t
@@ -1741,16 +1742,12 @@
if fileobj is not None:
fileobj = _BZ2Proxy(fileobj, mode)
- extfileobj = True
else:
fileobj = bz2.BZ2File(name, mode, compresslevel=compresslevel)
- extfileobj = False
try:
t = cls.taropen(name, mode, fileobj, **kwargs)
except (IOError, EOFError):
- if not extfileobj:
- fileobj.close()
raise ReadError("not a bzip2 file")
t._extfileobj = False
return t
@@ -1987,9 +1984,8 @@
# Append the tar header and data to the archive.
if tarinfo.isreg():
- f = bltn_open(name, "rb")
- self.addfile(tarinfo, f)
- f.close()
+ with bltn_open(name, "rb") as f:
+ self.addfile(tarinfo, f)
elif tarinfo.isdir():
self.addfile(tarinfo)
@@ -2197,10 +2193,11 @@
"""Make a file called targetpath.
"""
source = self.extractfile(tarinfo)
- target = bltn_open(targetpath, "wb")
- copyfileobj(source, target)
- source.close()
- target.close()
+ try:
+ with bltn_open(targetpath, "wb") as target:
+ copyfileobj(source, target)
+ finally:
+ source.close()
def makeunknown(self, tarinfo, targetpath):
"""Make a file from a TarInfo object with an unknown type
@@ -2241,10 +2238,14 @@
if hasattr(os, "symlink") and hasattr(os, "link"):
# For systems that support symbolic and hard links.
if tarinfo.issym():
+ if os.path.lexists(targetpath):
+ os.unlink(targetpath)
os.symlink(tarinfo.linkname, targetpath)
else:
# See extract().
if os.path.exists(tarinfo._link_target):
+ if os.path.lexists(targetpath):
+ os.unlink(targetpath)
os.link(tarinfo._link_target, targetpath)
else:
self._extract_member(self._find_link_target(tarinfo), targetpath)
@@ -2262,17 +2263,11 @@
try:
g = grp.getgrnam(tarinfo.gname)[2]
except KeyError:
- try:
- g = grp.getgrgid(tarinfo.gid)[2]
- except KeyError:
- g = os.getgid()
+ g = tarinfo.gid
try:
u = pwd.getpwnam(tarinfo.uname)[2]
except KeyError:
- try:
- u = pwd.getpwuid(tarinfo.uid)[2]
- except KeyError:
- u = os.getuid()
+ u = tarinfo.uid
try:
if tarinfo.issym() and hasattr(os, "lchown"):
os.lchown(targetpath, u, g)
@@ -2399,7 +2394,7 @@
"""
if tarinfo.issym():
# Always search the entire archive.
- linkname = os.path.dirname(tarinfo.name) + "/" + tarinfo.linkname
+ linkname = "/".join(filter(None, (os.path.dirname(tarinfo.name), tarinfo.linkname)))
limit = None
else:
# Search the archive before the link, because a hard link is
diff --git a/Lib/test/list_tests.py b/Lib/test/list_tests.py
--- a/Lib/test/list_tests.py
+++ b/Lib/test/list_tests.py
@@ -4,9 +4,14 @@
import sys
import os
+import unittest
from test import test_support, seq_tests
+if test_support.is_jython:
+ from java.util import List as JList
+
+
class CommonTest(seq_tests.CommonTest):
def test_init(self):
@@ -40,12 +45,14 @@
self.assertEqual(str(a2), "[0, 1, 2]")
self.assertEqual(repr(a2), "[0, 1, 2]")
- a2.append(a2)
- a2.append(3)
- self.assertEqual(str(a2), "[0, 1, 2, [...], 3]")
- self.assertEqual(repr(a2), "[0, 1, 2, [...], 3]")
+ if not (test_support.is_jython and issubclass(self.type2test, JList)):
+ # Jython does not support shallow copies of object graphs
+ # when moving back and forth from Java object space
+ a2.append(a2)
+ a2.append(3)
+ self.assertEqual(str(a2), "[0, 1, 2, [...], 3]")
+ self.assertEqual(repr(a2), "[0, 1, 2, [...], 3]")
- #FIXME: not working on Jython
if not test_support.is_jython:
l0 = []
for i in xrange(sys.getrecursionlimit() + 100):
@@ -53,6 +60,8 @@
self.assertRaises(RuntimeError, repr, l0)
def test_print(self):
+ if test_support.is_jython and issubclass(self.type2test, JList):
+ raise unittest.SkipTest("Jython does not support shallow copies of object graphs")
d = self.type2test(xrange(200))
d.append(d)
d.extend(xrange(200,400))
@@ -184,10 +193,14 @@
a[:] = tuple(range(10))
self.assertEqual(a, self.type2test(range(10)))
- self.assertRaises(TypeError, a.__setslice__, 0, 1, 5)
+ if not (test_support.is_jython and issubclass(self.type2test, JList)):
+ # no support for __setslice__ on Jython for
+ # java.util.List, given that method deprecated since 2.0!
+ self.assertRaises(TypeError, a.__setslice__, 0, 1, 5)
self.assertRaises(TypeError, a.__setitem__, slice(0, 1, 5))
- self.assertRaises(TypeError, a.__setslice__)
+ if not (test_support.is_jython and issubclass(self.type2test, JList)):
+ self.assertRaises(TypeError, a.__setslice__)
self.assertRaises(TypeError, a.__setitem__)
def test_delslice(self):
@@ -330,9 +343,12 @@
d = self.type2test(['a', 'b', BadCmp2(), 'c'])
e = self.type2test(d)
self.assertRaises(BadExc, d.remove, 'c')
- for x, y in zip(d, e):
- # verify that original order and values are retained.
- self.assertIs(x, y)
+ if not (test_support.is_jython and issubclass(self.type2test, JList)):
+ # When converting back and forth to Java space, Jython does not
+ # maintain object identity
+ for x, y in zip(d, e):
+ # verify that original order and values are retained.
+ self.assertIs(x, y)
def test_count(self):
a = self.type2test([0, 1, 2])*3
@@ -452,8 +468,13 @@
def selfmodifyingComparison(x,y):
z.append(1)
return cmp(x, y)
+
+ # Need to ensure the comparisons are actually executed by
+ # setting up a list
+ z = self.type2test(range(12))
self.assertRaises(ValueError, z.sort, selfmodifyingComparison)
+ z = self.type2test(range(12))
self.assertRaises(TypeError, z.sort, lambda x, y: 's')
self.assertRaises(TypeError, z.sort, 42, 42, 42, 42)
diff --git a/Lib/test/regrtest.py b/Lib/test/regrtest.py
--- a/Lib/test/regrtest.py
+++ b/Lib/test/regrtest.py
@@ -1266,6 +1266,9 @@
test_asynchat
test_asyncore
+ # Command line testing is hard for Jython to do, but revisit
+ test_cmd_line_script
+
# Tests that should work with socket-reboot, but currently hang
test_ftplib
test_httplib
@@ -1311,7 +1314,6 @@
test_peepholer
test_pyclbr
test_pyexpat
- test_select_new
test_stringprep
test_threadsignals
test_transformer
diff --git a/Lib/test/seq_tests.py b/Lib/test/seq_tests.py
new file mode 100644
--- /dev/null
+++ b/Lib/test/seq_tests.py
@@ -0,0 +1,409 @@
+"""
+Tests common to tuple, list and UserList.UserList
+"""
+
+import unittest
+import sys
+
+from test import test_support
+
+if test_support.is_jython:
+ from java.util import List as JList
+
+# Various iterables
+# This is used for checking the constructor (here and in test_deque.py)
+def iterfunc(seqn):
+ 'Regular generator'
+ for i in seqn:
+ yield i
+
+class Sequence:
+ 'Sequence using __getitem__'
+ def __init__(self, seqn):
+ self.seqn = seqn
+ def __getitem__(self, i):
+ return self.seqn[i]
+
+class IterFunc:
+ 'Sequence using iterator protocol'
+ def __init__(self, seqn):
+ self.seqn = seqn
+ self.i = 0
+ def __iter__(self):
+ return self
+ def next(self):
+ if self.i >= len(self.seqn): raise StopIteration
+ v = self.seqn[self.i]
+ self.i += 1
+ return v
+
+class IterGen:
+ 'Sequence using iterator protocol defined with a generator'
+ def __init__(self, seqn):
+ self.seqn = seqn
+ self.i = 0
+ def __iter__(self):
+ for val in self.seqn:
+ yield val
+
+class IterNextOnly:
+ 'Missing __getitem__ and __iter__'
+ def __init__(self, seqn):
+ self.seqn = seqn
+ self.i = 0
+ def next(self):
+ if self.i >= len(self.seqn): raise StopIteration
+ v = self.seqn[self.i]
+ self.i += 1
+ return v
+
+class IterNoNext:
+ 'Iterator missing next()'
+ def __init__(self, seqn):
+ self.seqn = seqn
+ self.i = 0
+ def __iter__(self):
+ return self
+
+class IterGenExc:
+ 'Test propagation of exceptions'
+ def __init__(self, seqn):
+ self.seqn = seqn
+ self.i = 0
+ def __iter__(self):
+ return self
+ def next(self):
+ 3 // 0
+
+class IterFuncStop:
+ 'Test immediate stop'
+ def __init__(self, seqn):
+ pass
+ def __iter__(self):
+ return self
+ def next(self):
+ raise StopIteration
+
+from itertools import chain, imap
+def itermulti(seqn):
+ 'Test multiple tiers of iterators'
+ return chain(imap(lambda x:x, iterfunc(IterGen(Sequence(seqn)))))
+
+class CommonTest(unittest.TestCase):
+ # The type to be tested
+ type2test = None
+
+ def test_constructors(self):
+ l0 = []
+ l1 = [0]
+ l2 = [0, 1]
+
+ u = self.type2test()
+ u0 = self.type2test(l0)
+ u1 = self.type2test(l1)
+ u2 = self.type2test(l2)
+
+ uu = self.type2test(u)
+ uu0 = self.type2test(u0)
+ uu1 = self.type2test(u1)
+ uu2 = self.type2test(u2)
+
+ v = self.type2test(tuple(u))
+ class OtherSeq:
+ def __init__(self, initseq):
+ self.__data = initseq
+ def __len__(self):
+ return len(self.__data)
+ def __getitem__(self, i):
+ return self.__data[i]
+ if not (test_support.is_jython and issubclass(self.type2test, JList)):
+ # Jython does not currently support in reflected args
+ # converting List-like objects to Lists. This lack of
+ # support should be fixed, but it's tricky.
+ s = OtherSeq(u0)
+ v0 = self.type2test(s)
+ self.assertEqual(len(v0), len(s))
+
+ s = "this is also a sequence"
+ vv = self.type2test(s)
+ self.assertEqual(len(vv), len(s))
+
+
+ if test_support.is_jython and issubclass(self.type2test, JList):
+ # Ditto from above, we need to skip the rest of the test
+ return
+
+ # Create from various iteratables
+ for s in ("123", "", range(1000), ('do', 1.2), xrange(2000,2200,5)):
+ for g in (Sequence, IterFunc, IterGen,
+ itermulti, iterfunc):
+ self.assertEqual(self.type2test(g(s)), self.type2test(s))
+ self.assertEqual(self.type2test(IterFuncStop(s)), self.type2test())
+ self.assertEqual(self.type2test(c for c in "123"), self.type2test("123"))
+ self.assertRaises(TypeError, self.type2test, IterNextOnly(s))
+ self.assertRaises(TypeError, self.type2test, IterNoNext(s))
+ self.assertRaises(ZeroDivisionError, self.type2test, IterGenExc(s))
+
+ def test_truth(self):
+ self.assertFalse(self.type2test())
+ self.assertTrue(self.type2test([42]))
+
+ def test_getitem(self):
+ u = self.type2test([0, 1, 2, 3, 4])
+ for i in xrange(len(u)):
+ self.assertEqual(u[i], i)
+ self.assertEqual(u[long(i)], i)
+ for i in xrange(-len(u), -1):
+ self.assertEqual(u[i], len(u)+i)
+ self.assertEqual(u[long(i)], len(u)+i)
+ self.assertRaises(IndexError, u.__getitem__, -len(u)-1)
+ self.assertRaises(IndexError, u.__getitem__, len(u))
+ self.assertRaises(ValueError, u.__getitem__, slice(0,10,0))
+
+ u = self.type2test()
+ self.assertRaises(IndexError, u.__getitem__, 0)
+ self.assertRaises(IndexError, u.__getitem__, -1)
+
+ self.assertRaises(TypeError, u.__getitem__)
+
+ a = self.type2test([10, 11])
+ self.assertEqual(a[0], 10)
+ self.assertEqual(a[1], 11)
+ self.assertEqual(a[-2], 10)
+ self.assertEqual(a[-1], 11)
+ self.assertRaises(IndexError, a.__getitem__, -3)
+ self.assertRaises(IndexError, a.__getitem__, 3)
+
+ def test_getslice(self):
+ l = [0, 1, 2, 3, 4]
+ u = self.type2test(l)
+
+ self.assertEqual(u[0:0], self.type2test())
+ self.assertEqual(u[1:2], self.type2test([1]))
+ self.assertEqual(u[-2:-1], self.type2test([3]))
+ self.assertEqual(u[-1000:1000], u)
+ self.assertEqual(u[1000:-1000], self.type2test([]))
+ self.assertEqual(u[:], u)
+ self.assertEqual(u[1:None], self.type2test([1, 2, 3, 4]))
+ self.assertEqual(u[None:3], self.type2test([0, 1, 2]))
+
+ # Extended slices
+ self.assertEqual(u[::], u)
+ self.assertEqual(u[::2], self.type2test([0, 2, 4]))
+ self.assertEqual(u[1::2], self.type2test([1, 3]))
+ self.assertEqual(u[::-1], self.type2test([4, 3, 2, 1, 0]))
+ self.assertEqual(u[::-2], self.type2test([4, 2, 0]))
+ self.assertEqual(u[3::-2], self.type2test([3, 1]))
+ self.assertEqual(u[3:3:-2], self.type2test([]))
+ self.assertEqual(u[3:2:-2], self.type2test([3]))
+ self.assertEqual(u[3:1:-2], self.type2test([3]))
+ self.assertEqual(u[3:0:-2], self.type2test([3, 1]))
+ self.assertEqual(u[::-100], self.type2test([4]))
+ self.assertEqual(u[100:-100:], self.type2test([]))
+ self.assertEqual(u[-100:100:], u)
+ self.assertEqual(u[100:-100:-1], u[::-1])
+ self.assertEqual(u[-100:100:-1], self.type2test([]))
+ self.assertEqual(u[-100L:100L:2L], self.type2test([0, 2, 4]))
+
+ # Test extreme cases with long ints
+ a = self.type2test([0,1,2,3,4])
+ self.assertEqual(a[ -pow(2,128L): 3 ], self.type2test([0,1,2]))
+ self.assertEqual(a[ 3: pow(2,145L) ], self.type2test([3,4]))
+
+ if not (test_support.is_jython and issubclass(self.type2test, JList)):
+ # no support for __getslice__ on Jython for
+ # java.util.List, given that method deprecated since 2.0!
+ self.assertRaises(TypeError, u.__getslice__)
+
+ def test_contains(self):
+ u = self.type2test([0, 1, 2])
+ for i in u:
+ self.assertIn(i, u)
+ for i in min(u)-1, max(u)+1:
+ self.assertNotIn(i, u)
+
+ self.assertRaises(TypeError, u.__contains__)
+
+ def test_contains_fake(self):
+ class AllEq:
+ # Sequences must use rich comparison against each item
+ # (unless "is" is true, or an earlier item answered)
+ # So instances of AllEq must be found in all non-empty sequences.
+ def __eq__(self, other):
+ return True
+ __hash__ = None # Can't meet hash invariant requirements
+ self.assertNotIn(AllEq(), self.type2test([]))
+ self.assertIn(AllEq(), self.type2test([1]))
+
+ def test_contains_order(self):
+ # Sequences must test in-order. If a rich comparison has side
+ # effects, these will be visible to tests against later members.
+ # In this test, the "side effect" is a short-circuiting raise.
+ class DoNotTestEq(Exception):
+ pass
+ class StopCompares:
+ def __eq__(self, other):
+ raise DoNotTestEq
+
+ checkfirst = self.type2test([1, StopCompares()])
+ self.assertIn(1, checkfirst)
+ checklast = self.type2test([StopCompares(), 1])
+ self.assertRaises(DoNotTestEq, checklast.__contains__, 1)
+
+ def test_len(self):
+ self.assertEqual(len(self.type2test()), 0)
+ self.assertEqual(len(self.type2test([])), 0)
+ self.assertEqual(len(self.type2test([0])), 1)
+ self.assertEqual(len(self.type2test([0, 1, 2])), 3)
+
+ def test_minmax(self):
+ u = self.type2test([0, 1, 2])
+ self.assertEqual(min(u), 0)
+ self.assertEqual(max(u), 2)
+
+ def test_addmul(self):
+ u1 = self.type2test([0])
+ u2 = self.type2test([0, 1])
+ self.assertEqual(u1, u1 + self.type2test())
+ self.assertEqual(u1, self.type2test() + u1)
+ self.assertEqual(u1 + self.type2test([1]), u2)
+ self.assertEqual(self.type2test([-1]) + u1, self.type2test([-1, 0]))
+ self.assertEqual(self.type2test(), u2*0)
+ self.assertEqual(self.type2test(), 0*u2)
+ self.assertEqual(self.type2test(), u2*0L)
+ self.assertEqual(self.type2test(), 0L*u2)
+ self.assertEqual(u2, u2*1)
+ self.assertEqual(u2, 1*u2)
+ self.assertEqual(u2, u2*1L)
+ self.assertEqual(u2, 1L*u2)
+ self.assertEqual(u2+u2, u2*2)
+ self.assertEqual(u2+u2, 2*u2)
+ self.assertEqual(u2+u2, u2*2L)
+ self.assertEqual(u2+u2, 2L*u2)
+ self.assertEqual(u2+u2+u2, u2*3)
+ self.assertEqual(u2+u2+u2, 3*u2)
+
+ class subclass(self.type2test):
+ pass
+ u3 = subclass([0, 1])
+ self.assertEqual(u3, u3*1)
+ self.assertIsNot(u3, u3*1)
+
+ def test_iadd(self):
+ u = self.type2test([0, 1])
+ u += self.type2test()
+ self.assertEqual(u, self.type2test([0, 1]))
+ u += self.type2test([2, 3])
+ self.assertEqual(u, self.type2test([0, 1, 2, 3]))
+ u += self.type2test([4, 5])
+ self.assertEqual(u, self.type2test([0, 1, 2, 3, 4, 5]))
+
+ u = self.type2test("spam")
+ u += self.type2test("eggs")
+ self.assertEqual(u, self.type2test("spameggs"))
+
+ def test_imul(self):
+ u = self.type2test([0, 1])
+ u *= 3
+ self.assertEqual(u, self.type2test([0, 1, 0, 1, 0, 1]))
+
+ def test_getitemoverwriteiter(self):
+ # Verify that __getitem__ overrides are not recognized by __iter__
+ class T(self.type2test):
+ def __getitem__(self, key):
+ return str(key) + '!!!'
+ self.assertEqual(iter(T((1,2))).next(), 1)
+
+ def test_repeat(self):
+ for m in xrange(4):
+ s = tuple(range(m))
+ for n in xrange(-3, 5):
+ self.assertEqual(self.type2test(s*n), self.type2test(s)*n)
+ self.assertEqual(self.type2test(s)*(-4), self.type2test([]))
+ self.assertEqual(id(s), id(s*1))
+
+ def test_bigrepeat(self):
+ import sys
+ if sys.maxint <= 2147483647:
+ x = self.type2test([0])
+ x *= 2**16
+ self.assertRaises(MemoryError, x.__mul__, 2**16)
+ if hasattr(x, '__imul__'):
+ self.assertRaises(MemoryError, x.__imul__, 2**16)
+
+ def test_subscript(self):
+ a = self.type2test([10, 11])
+ self.assertEqual(a.__getitem__(0L), 10)
+ self.assertEqual(a.__getitem__(1L), 11)
+ self.assertEqual(a.__getitem__(-2L), 10)
+ self.assertEqual(a.__getitem__(-1L), 11)
+ self.assertRaises(IndexError, a.__getitem__, -3)
+ self.assertRaises(IndexError, a.__getitem__, 3)
+ self.assertEqual(a.__getitem__(slice(0,1)), self.type2test([10]))
+ self.assertEqual(a.__getitem__(slice(1,2)), self.type2test([11]))
+ self.assertEqual(a.__getitem__(slice(0,2)), self.type2test([10, 11]))
+ self.assertEqual(a.__getitem__(slice(0,3)), self.type2test([10, 11]))
+ self.assertEqual(a.__getitem__(slice(3,5)), self.type2test([]))
+ self.assertRaises(ValueError, a.__getitem__, slice(0, 10, 0))
+ self.assertRaises(TypeError, a.__getitem__, 'x')
+
+ def test_count(self):
+ a = self.type2test([0, 1, 2])*3
+ self.assertEqual(a.count(0), 3)
+ self.assertEqual(a.count(1), 3)
+ self.assertEqual(a.count(3), 0)
+
+ self.assertRaises(TypeError, a.count)
+
+ class BadExc(Exception):
+ pass
+
+ class BadCmp:
+ def __eq__(self, other):
+ if other == 2:
+ raise BadExc()
+ return False
+
+ self.assertRaises(BadExc, a.count, BadCmp())
+
+ def test_index(self):
+ u = self.type2test([0, 1])
+ self.assertEqual(u.index(0), 0)
+ self.assertEqual(u.index(1), 1)
+ self.assertRaises(ValueError, u.index, 2)
+
+ u = self.type2test([-2, -1, 0, 0, 1, 2])
+ self.assertEqual(u.count(0), 2)
+ self.assertEqual(u.index(0), 2)
+ self.assertEqual(u.index(0, 2), 2)
+ self.assertEqual(u.index(-2, -10), 0)
+ self.assertEqual(u.index(0, 3), 3)
+ self.assertEqual(u.index(0, 3, 4), 3)
+ self.assertRaises(ValueError, u.index, 2, 0, -10)
+
+ self.assertRaises(TypeError, u.index)
+
+ class BadExc(Exception):
+ pass
+
+ class BadCmp:
+ def __eq__(self, other):
+ if other == 2:
+ raise BadExc()
+ return False
+
+ a = self.type2test([0, 1, 2, 3])
+ self.assertRaises(BadExc, a.index, BadCmp())
+
+ a = self.type2test([-2, -1, 0, 0, 1, 2])
+ self.assertEqual(a.index(0), 2)
+ self.assertEqual(a.index(0, 2), 2)
+ self.assertEqual(a.index(0, -4), 2)
+ self.assertEqual(a.index(-2, -10), 0)
+ self.assertEqual(a.index(0, 3), 3)
+ self.assertEqual(a.index(0, -3), 3)
+ self.assertEqual(a.index(0, 3, 4), 3)
+ self.assertEqual(a.index(0, -3, -2), 3)
+ self.assertEqual(a.index(0, -4*sys.maxint, 4*sys.maxint), 2)
+ self.assertRaises(ValueError, a.index, 0, 4*sys.maxint,-4*sys.maxint)
+ self.assertRaises(ValueError, a.index, 2, 0, -10)
diff --git a/Lib/test/test_builtin.py b/Lib/test/test_builtin.py
--- a/Lib/test/test_builtin.py
+++ b/Lib/test/test_builtin.py
@@ -2,13 +2,13 @@
import platform
import unittest
+from test.test_support import fcmp, have_unicode, TESTFN, unlink, \
+ run_unittest, check_py3k_warnings, is_jython
import warnings
-from test.test_support import (fcmp, have_unicode, TESTFN, unlink,
- run_unittest, check_py3k_warnings, check_warnings,
- is_jython)
from operator import neg
import sys, cStringIO, random, UserDict
+
# count the number of test runs.
# used to skip running test_execfile() multiple times
# and to create unique strings to intern in test_intern()
@@ -90,6 +90,16 @@
self.assertEqual(abs(-1234L), 1234L)
# str
self.assertRaises(TypeError, abs, 'a')
+ # bool
+ self.assertEqual(abs(True), 1)
+ self.assertEqual(abs(False), 0)
+ # other
+ self.assertRaises(TypeError, abs)
+ self.assertRaises(TypeError, abs, None)
+ class AbsClass(object):
+ def __abs__(self):
+ return -5
+ self.assertEqual(abs(AbsClass()), -5)
def test_all(self):
self.assertEqual(all([2, 4, 6]), True)
@@ -100,6 +110,7 @@
self.assertRaises(TypeError, all) # No args
self.assertRaises(TypeError, all, [2, 4, 6], []) # Too many args
self.assertEqual(all([]), True) # Empty iterator
+ self.assertEqual(all([0, TestFailingBool()]), False)# Short-circuit
S = [50, 60]
self.assertEqual(all(x > 42 for x in S), True)
S = [50, 40, 60]
@@ -109,11 +120,12 @@
self.assertEqual(any([None, None, None]), False)
self.assertEqual(any([None, 4, None]), True)
self.assertRaises(RuntimeError, any, [None, TestFailingBool(), 6])
- self.assertRaises(RuntimeError, all, TestFailingIter())
+ self.assertRaises(RuntimeError, any, TestFailingIter())
self.assertRaises(TypeError, any, 10) # Non-iterable
self.assertRaises(TypeError, any) # No args
self.assertRaises(TypeError, any, [2, 4, 6], []) # Too many args
self.assertEqual(any([]), False) # Empty iterator
+ self.assertEqual(any([1, TestFailingBool()]), True) # Short-circuit
S = [40, 60, 30]
self.assertEqual(any(x > 42 for x in S), True)
S = [10, 20, 30]
@@ -121,7 +133,7 @@
def test_neg(self):
x = -sys.maxint-1
- self.assert_(isinstance(x, int))
+ self.assertTrue(isinstance(x, int))
self.assertEqual(-x, sys.maxint+1)
def test_apply(self):
@@ -151,20 +163,45 @@
self.assertRaises(TypeError, apply, id, (42,), 42)
def test_callable(self):
- self.assert_(callable(len))
+ self.assertTrue(callable(len))
+ self.assertFalse(callable("a"))
+ self.assertTrue(callable(callable))
+ self.assertTrue(callable(lambda x, y: x + y))
+ self.assertFalse(callable(__builtins__))
def f(): pass
- self.assert_(callable(f))
- class C:
+ self.assertTrue(callable(f))
+
+ class Classic:
def meth(self): pass
- self.assert_(callable(C))
- x = C()
- self.assert_(callable(x.meth))
- self.assert_(not callable(x))
- class D(C):
+ self.assertTrue(callable(Classic))
+ c = Classic()
+ self.assertTrue(callable(c.meth))
+ self.assertFalse(callable(c))
+
+ class NewStyle(object):
+ def meth(self): pass
+ self.assertTrue(callable(NewStyle))
+ n = NewStyle()
+ self.assertTrue(callable(n.meth))
+ self.assertFalse(callable(n))
+
+ # Classic and new-style classes evaluate __call__() differently
+ c.__call__ = None
+ self.assertTrue(callable(c))
+ del c.__call__
+ self.assertFalse(callable(c))
+ n.__call__ = None
+ self.assertFalse(callable(n))
+ del n.__call__
+ self.assertFalse(callable(n))
+
+ class N2(object):
def __call__(self): pass
- y = D()
- self.assert_(callable(y))
- y()
+ n2 = N2()
+ self.assertTrue(callable(n2))
+ class N3(N2): pass
+ n3 = N3()
+ self.assertTrue(callable(n3))
def test_chr(self):
self.assertEqual(chr(32), ' ')
@@ -178,23 +215,29 @@
self.assertEqual(cmp(-1, 1), -1)
self.assertEqual(cmp(1, -1), 1)
self.assertEqual(cmp(1, 1), 0)
- # verify that circular objects are handled for Jython
+ # verify that circular objects are not handled
a = []; a.append(a)
b = []; b.append(b)
from UserList import UserList
c = UserList(); c.append(c)
- self.assertEqual(cmp(a, b), 0)
- self.assertEqual(cmp(b, c), 0)
- self.assertEqual(cmp(c, a), 0)
- self.assertEqual(cmp(a, c), 0)
- # okay, now break the cycles
+ if is_jython:
+ self.assertEqual(cmp(a, b), 0)
+ self.assertEqual(cmp(b, c), 0)
+ self.assertEqual(cmp(c, a), 0)
+ self.assertEqual(cmp(a, c), 0)
+ else:
+ self.assertRaises(RuntimeError, cmp, a, b)
+ self.assertRaises(RuntimeError, cmp, b, c)
+ self.assertRaises(RuntimeError, cmp, c, a)
+ self.assertRaises(RuntimeError, cmp, a, c)
+ # okay, now break the cycles
a.pop(); b.pop(); c.pop()
self.assertRaises(TypeError, cmp)
def test_coerce(self):
- self.assert_(not fcmp(coerce(1, 1.1), (1.0, 1.1)))
+ self.assertTrue(not fcmp(coerce(1, 1.1), (1.0, 1.1)))
self.assertEqual(coerce(1, 1L), (1L, 1L))
- self.assert_(not fcmp(coerce(1L, 1.1), (1.0, 1.1)))
+ self.assertTrue(not fcmp(coerce(1L, 1.1), (1.0, 1.1)))
self.assertRaises(TypeError, coerce)
class BadNumber:
def __coerce__(self, other):
@@ -233,23 +276,22 @@
# dir() - local scope
local_var = 1
- self.assert_('local_var' in dir())
+ self.assertIn('local_var', dir())
# dir(module)
import sys
- self.assert_('exit' in dir(sys))
+ self.assertIn('exit', dir(sys))
# dir(module_with_invalid__dict__)
import types
class Foo(types.ModuleType):
__dict__ = 8
f = Foo("foo")
- if not is_jython: #FIXME #1861
- self.assertRaises(TypeError, dir, f)
+ self.assertRaises(TypeError, dir, f)
# dir(type)
- self.assert_("strip" in dir(str))
- self.assert_("__mro__" not in dir(str))
+ self.assertIn("strip", dir(str))
+ self.assertNotIn("__mro__", dir(str))
# dir(obj)
class Foo(object):
@@ -258,13 +300,13 @@
self.y = 8
self.z = 9
f = Foo()
- self.assert_("y" in dir(f))
+ self.assertIn("y", dir(f))
# dir(obj_no__dict__)
class Foo(object):
__slots__ = []
f = Foo()
- self.assert_("__repr__" in dir(f))
+ self.assertIn("__repr__", dir(f))
# dir(obj_no__class__with__dict__)
# (an ugly trick to cause getattr(f, "__class__") to fail)
@@ -273,24 +315,22 @@
def __init__(self):
self.bar = "wow"
f = Foo()
- self.assert_("__repr__" not in dir(f))
- self.assert_("bar" in dir(f))
+ self.assertNotIn("__repr__", dir(f))
+ self.assertIn("bar", dir(f))
# dir(obj_using __dir__)
class Foo(object):
def __dir__(self):
return ["kan", "ga", "roo"]
f = Foo()
- if not is_jython: #FIXME #1861
- self.assert_(dir(f) == ["ga", "kan", "roo"])
+ self.assertTrue(dir(f) == ["ga", "kan", "roo"])
# dir(obj__dir__not_list)
class Foo(object):
def __dir__(self):
return 7
f = Foo()
- if not is_jython: #FIXME #1861
- self.assertRaises(TypeError, dir, f)
+ self.assertRaises(TypeError, dir, f)
def test_divmod(self):
self.assertEqual(divmod(12, 7), (1, 5))
@@ -311,10 +351,10 @@
self.assertEqual(divmod(-sys.maxint-1, -1),
(sys.maxint+1, 0))
- self.assert_(not fcmp(divmod(3.25, 1.0), (3.0, 0.25)))
- self.assert_(not fcmp(divmod(-3.25, 1.0), (-4.0, 0.75)))
- self.assert_(not fcmp(divmod(3.25, -1.0), (-4.0, -0.75)))
- self.assert_(not fcmp(divmod(-3.25, -1.0), (3.0, -0.25)))
+ self.assertTrue(not fcmp(divmod(3.25, 1.0), (3.0, 0.25)))
+ self.assertTrue(not fcmp(divmod(-3.25, 1.0), (-4.0, 0.75)))
+ self.assertTrue(not fcmp(divmod(3.25, -1.0), (-4.0, -0.75)))
+ self.assertTrue(not fcmp(divmod(-3.25, -1.0), (3.0, -0.25)))
self.assertRaises(TypeError, divmod)
@@ -363,9 +403,12 @@
self.assertEqual(eval('dir()', g, m), list('xyz'))
self.assertEqual(eval('globals()', g, m), g)
self.assertEqual(eval('locals()', g, m), m)
-
- # Jython allows arbitrary mappings for globals
- self.assertEqual(eval('a', m), 12)
+ if is_jython:
+ # Jython allows any mapping to work, including ones that
+ # are read only as in the case of M
+ self.assertEqual(eval('a', m), 12)
+ else:
+ self.assertRaises(TypeError, eval, 'a', m)
class A:
"Non-mapping"
pass
@@ -577,11 +620,11 @@
for func in funcs:
outp = filter(func, cls(inp))
self.assertEqual(outp, exp)
- self.assert_(not isinstance(outp, cls))
+ self.assertTrue(not isinstance(outp, cls))
def test_getattr(self):
import sys
- self.assert_(getattr(sys, 'stdout') is sys.stdout)
+ self.assertTrue(getattr(sys, 'stdout') is sys.stdout)
self.assertRaises(TypeError, getattr, sys, 1)
self.assertRaises(TypeError, getattr, sys, 1, "foo")
self.assertRaises(TypeError, getattr)
@@ -590,7 +633,7 @@
def test_hasattr(self):
import sys
- self.assert_(hasattr(sys, 'stdout'))
+ self.assertTrue(hasattr(sys, 'stdout'))
self.assertRaises(TypeError, hasattr, sys, 1)
self.assertRaises(TypeError, hasattr)
if have_unicode:
@@ -621,15 +664,15 @@
class X:
def __hash__(self):
return 2**100
- self.assertEquals(type(hash(X())), int)
+ self.assertEqual(type(hash(X())), int)
class Y(object):
def __hash__(self):
return 2**100
- self.assertEquals(type(hash(Y())), int)
+ self.assertEqual(type(hash(Y())), int)
class Z(long):
def __hash__(self):
return self
- self.assertEquals(hash(Z(42)), hash(42L))
+ self.assertEqual(hash(Z(42)), hash(42L))
def test_hex(self):
self.assertEqual(hex(16), '0x10')
@@ -650,20 +693,22 @@
# Test input() later, together with raw_input
+ # test_int(): see test_int.py for int() tests.
+
def test_intern(self):
self.assertRaises(TypeError, intern)
# This fails if the test is run twice with a constant string,
# therefore append the run counter
s = "never interned before " + str(numruns)
- self.assert_(intern(s) is s)
+ self.assertTrue(intern(s) is s)
s2 = s.swapcase().swapcase()
- self.assert_(intern(s2) is s)
+ self.assertTrue(intern(s2) is s)
# Subclasses of string can't be interned, because they
# provide too much opportunity for insane things to happen.
# We don't want them in the interned dict and if they aren't
# actually interned, we don't want to create the appearance
- # that they are by allowing intern() to succeeed.
+ # that they are by allowing intern() to succeed.
class S(str):
def __hash__(self):
return 123
@@ -698,11 +743,11 @@
c = C()
d = D()
e = E()
- self.assert_(isinstance(c, C))
- self.assert_(isinstance(d, C))
- self.assert_(not isinstance(e, C))
- self.assert_(not isinstance(c, D))
- self.assert_(not isinstance('foo', E))
+ self.assertTrue(isinstance(c, C))
+ self.assertTrue(isinstance(d, C))
+ self.assertTrue(not isinstance(e, C))
+ self.assertTrue(not isinstance(c, D))
+ self.assertTrue(not isinstance('foo', E))
self.assertRaises(TypeError, isinstance, E, 'foo')
self.assertRaises(TypeError, isinstance)
@@ -716,9 +761,9 @@
c = C()
d = D()
e = E()
- self.assert_(issubclass(D, C))
- self.assert_(issubclass(C, C))
- self.assert_(not issubclass(C, D))
+ self.assertTrue(issubclass(D, C))
+ self.assertTrue(issubclass(C, C))
+ self.assertTrue(not issubclass(C, D))
self.assertRaises(TypeError, issubclass, 'foo', E)
self.assertRaises(TypeError, issubclass, E, 'foo')
self.assertRaises(TypeError, issubclass)
@@ -734,6 +779,11 @@
def __len__(self):
raise ValueError
self.assertRaises(ValueError, len, BadSeq())
+ self.assertRaises(TypeError, len, 2)
+ class ClassicStyle: pass
+ class NewStyle(object): pass
+ self.assertRaises(AttributeError, len, ClassicStyle())
+ self.assertRaises(TypeError, len, NewStyle())
def test_map(self):
self.assertEqual(
@@ -895,7 +945,7 @@
self.assertEqual(next(it), 1)
self.assertRaises(StopIteration, next, it)
self.assertRaises(StopIteration, next, it)
- self.assertEquals(next(it, 42), 42)
+ self.assertEqual(next(it, 42), 42)
class Iter(object):
def __iter__(self):
@@ -904,7 +954,7 @@
raise StopIteration
it = iter(Iter())
- self.assertEquals(next(it, 42), 42)
+ self.assertEqual(next(it, 42), 42)
self.assertRaises(StopIteration, next, it)
def gen():
@@ -912,9 +962,9 @@
return
it = gen()
- self.assertEquals(next(it), 1)
+ self.assertEqual(next(it), 1)
self.assertRaises(StopIteration, next, it)
- self.assertEquals(next(it, 42), 42)
+ self.assertEqual(next(it, 42), 42)
def test_oct(self):
self.assertEqual(oct(100), '0144')
@@ -1050,18 +1100,18 @@
self.assertEqual(range(a+4, a, -2), [a+4, a+2])
seq = range(a, b, c)
- self.assert_(a in seq)
- self.assert_(b not in seq)
+ self.assertIn(a, seq)
+ self.assertNotIn(b, seq)
self.assertEqual(len(seq), 2)
seq = range(b, a, -c)
- self.assert_(b in seq)
- self.assert_(a not in seq)
+ self.assertIn(b, seq)
+ self.assertNotIn(a, seq)
self.assertEqual(len(seq), 2)
seq = range(-a, -b, -c)
- self.assert_(-a in seq)
- self.assert_(-b not in seq)
+ self.assertIn(-a, seq)
+ self.assertNotIn(-b, seq)
self.assertEqual(len(seq), 2)
self.assertRaises(TypeError, range)
@@ -1075,14 +1125,9 @@
__hash__ = None # Invalid cmp makes this unhashable
self.assertRaises(RuntimeError, range, a, a + 1, badzero(1))
- # Reject floats when it would require PyLongs to represent.
- # (smaller floats still accepted, but deprecated)
- with check_warnings() as w:
- warnings.simplefilter("always")
- self.assertRaises(TypeError, range, 1e100, 1e101, 1e101)
- with check_warnings() as w:
- warnings.simplefilter("always")
- self.assertEqual(range(1.0), [0])
+ # Reject floats.
+ self.assertRaises(TypeError, range, 1., 1., 1.)
+ self.assertRaises(TypeError, range, 1e100, 1e101, 1e101)
self.assertRaises(TypeError, range, 0, "spam")
self.assertRaises(TypeError, range, 0, 42, "spam")
@@ -1124,20 +1169,21 @@
# Exercise various combinations of bad arguments, to check
# refcounting logic
- with check_warnings():
- self.assertRaises(TypeError, range, 1e100)
+ self.assertRaises(TypeError, range, 0.0)
- self.assertRaises(TypeError, range, 0, 1e100)
- self.assertRaises(TypeError, range, 1e100, 0)
- self.assertRaises(TypeError, range, 1e100, 1e100)
+ self.assertRaises(TypeError, range, 0, 0.0)
+ self.assertRaises(TypeError, range, 0.0, 0)
+ self.assertRaises(TypeError, range, 0.0, 0.0)
- self.assertRaises(TypeError, range, 0, 0, 1e100)
- self.assertRaises(TypeError, range, 0, 1e100, 1)
- self.assertRaises(TypeError, range, 0, 1e100, 1e100)
- self.assertRaises(TypeError, range, 1e100, 0, 1)
- self.assertRaises(TypeError, range, 1e100, 0, 1e100)
- self.assertRaises(TypeError, range, 1e100, 1e100, 1)
- self.assertRaises(TypeError, range, 1e100, 1e100, 1e100)
+ self.assertRaises(TypeError, range, 0, 0, 1.0)
+ self.assertRaises(TypeError, range, 0, 0.0, 1)
+ self.assertRaises(TypeError, range, 0, 0.0, 1.0)
+ self.assertRaises(TypeError, range, 0.0, 0, 1)
+ self.assertRaises(TypeError, range, 0.0, 0, 1.0)
+ self.assertRaises(TypeError, range, 0.0, 0.0, 1)
+ self.assertRaises(TypeError, range, 0.0, 0.0, 1.0)
+
+
def test_input_and_raw_input(self):
self.write_testfile()
@@ -1197,9 +1243,10 @@
unlink(TESTFN)
def test_reduce(self):
- self.assertEqual(reduce(lambda x, y: x+y, ['a', 'b', 'c'], ''), 'abc')
+ add = lambda x, y: x+y
+ self.assertEqual(reduce(add, ['a', 'b', 'c'], ''), 'abc')
self.assertEqual(
- reduce(lambda x, y: x+y, [['a', 'c'], [], ['d', 'w']], []),
+ reduce(add, [['a', 'c'], [], ['d', 'w']], []),
['a','c','d','w']
)
self.assertEqual(reduce(lambda x, y: x*y, range(2,8), 1), 5040)
@@ -1207,15 +1254,23 @@
reduce(lambda x, y: x*y, range(2,21), 1L),
2432902008176640000L
)
- self.assertEqual(reduce(lambda x, y: x+y, Squares(10)), 285)
- self.assertEqual(reduce(lambda x, y: x+y, Squares(10), 0), 285)
- self.assertEqual(reduce(lambda x, y: x+y, Squares(0), 0), 0)
+ self.assertEqual(reduce(add, Squares(10)), 285)
+ self.assertEqual(reduce(add, Squares(10), 0), 285)
+ self.assertEqual(reduce(add, Squares(0), 0), 0)
self.assertRaises(TypeError, reduce)
+ self.assertRaises(TypeError, reduce, 42)
self.assertRaises(TypeError, reduce, 42, 42)
self.assertRaises(TypeError, reduce, 42, 42, 42)
+ self.assertRaises(TypeError, reduce, None, range(5))
+ self.assertRaises(TypeError, reduce, add, 42)
self.assertEqual(reduce(42, "1"), "1") # func is never called with one item
self.assertEqual(reduce(42, "", "1"), "1") # func is never called with one item
self.assertRaises(TypeError, reduce, 42, (42, 42))
+ self.assertRaises(TypeError, reduce, add, []) # arg 2 must not be empty sequence with no initial value
+ self.assertRaises(TypeError, reduce, add, "")
+ self.assertRaises(TypeError, reduce, add, ())
+ self.assertEqual(reduce(add, [], None), None)
+ self.assertEqual(reduce(add, [], 42), 42)
class BadSeq:
def __getitem__(self, index):
@@ -1318,6 +1373,19 @@
self.assertRaises(TypeError, round, t)
self.assertRaises(TypeError, round, t, 0)
+ # Some versions of glibc for alpha have a bug that affects
+ # float -> integer rounding (floor, ceil, rint, round) for
+ # values in the range [2**52, 2**53). See:
+ #
+ # http://sources.redhat.com/bugzilla/show_bug.cgi?id=5350
+ #
+ # We skip this test on Linux/alpha if it would fail.
+ linux_alpha = (platform.system().startswith('Linux') and
+ platform.machine().startswith('alpha'))
+ system_round_bug = round(5e15+1) != 5e15+1
+ @unittest.skipIf(linux_alpha and system_round_bug,
+ "test will fail; failure is probably due to a "
+ "buggy system round function")
def test_round_large(self):
# Issue #1869: integral floats should remain unchanged
self.assertEqual(round(5e15-1), 5e15-1)
@@ -1353,6 +1421,10 @@
raise ValueError
self.assertRaises(ValueError, sum, BadSeq())
+ empty = []
+ sum(([x] for x in range(10)), empty)
+ self.assertEqual(empty, [])
+
def test_type(self):
self.assertEqual(type(''), type('123'))
self.assertNotEqual(type(''), type(()))
@@ -1368,8 +1440,7 @@
)
self.assertRaises(ValueError, unichr, sys.maxunicode+1)
self.assertRaises(TypeError, unichr)
- if not is_jython: #FIXME #1861
- self.assertRaises((OverflowError, ValueError), unichr, 2**32)
+ self.assertRaises((OverflowError, ValueError), unichr, 2**32)
# We don't want self in vars(), so these are static methods
@@ -1384,6 +1455,11 @@
b = 2
return vars()
+ class C_get_vars(object):
+ def getDict(self):
+ return {'a':2}
+ __dict__ = property(fget=getDict)
+
def test_vars(self):
self.assertEqual(set(vars()), set(dir()))
import sys
@@ -1392,6 +1468,7 @@
self.assertEqual(self.get_vars_f2(), {'a': 1, 'b': 2})
self.assertRaises(TypeError, vars, 42, 42)
self.assertRaises(TypeError, vars, 42)
+ self.assertEqual(vars(self.C_get_vars()), {'a':2})
def test_zip(self):
a = (1, 2, 3)
@@ -1511,8 +1588,7 @@
class BadFormatResult:
def __format__(self, format_spec):
return 1.0
- if not is_jython: #FIXME #1861 check again when __format__ works better.
- self.assertRaises(TypeError, format, BadFormatResult(), "")
+ self.assertRaises(TypeError, format, BadFormatResult(), "")
# TypeError because format_spec is not unicode or str
self.assertRaises(TypeError, format, object(), 4)
@@ -1521,13 +1597,48 @@
# tests for object.__format__ really belong elsewhere, but
# there's no good place to put them
x = object().__format__('')
- self.assert_(x.startswith('<object object at'))
+ self.assertTrue(x.startswith('<object object at'))
# first argument to object.__format__ must be string
self.assertRaises(TypeError, object().__format__, 3)
self.assertRaises(TypeError, object().__format__, object())
self.assertRaises(TypeError, object().__format__, None)
+ # --------------------------------------------------------------------
+ # Issue #7994: object.__format__ with a non-empty format string is
+ # pending deprecated
+ def test_deprecated_format_string(obj, fmt_str, should_raise_warning):
+ with warnings.catch_warnings(record=True) as w:
+ warnings.simplefilter("always", PendingDeprecationWarning)
+ format(obj, fmt_str)
+ if should_raise_warning:
+ self.assertEqual(len(w), 1)
+ self.assertIsInstance(w[0].message, PendingDeprecationWarning)
+ self.assertIn('object.__format__ with a non-empty format '
+ 'string', str(w[0].message))
+ else:
+ self.assertEqual(len(w), 0)
+
+ fmt_strs = ['', 's', u'', u's']
+
+ class A:
+ def __format__(self, fmt_str):
+ return format('', fmt_str)
+
+ for fmt_str in fmt_strs:
+ test_deprecated_format_string(A(), fmt_str, False)
+
+ class B:
+ pass
+
+ class C(object):
+ pass
+
+ for cls in [object, B, C]:
+ for fmt_str in fmt_strs:
+ test_deprecated_format_string(cls(), fmt_str, len(fmt_str) != 0)
+ # --------------------------------------------------------------------
+
# make sure we can take a subclass of str as a format spec
class DerivedFromStr(str): pass
self.assertEqual(format(0, DerivedFromStr('10')), ' 0')
@@ -1592,7 +1703,6 @@
("classic int division", DeprecationWarning)):
run_unittest(*args)
-
def test_main(verbose=None):
test_classes = (BuiltinTest, TestSorted)
diff --git a/Lib/test/test_cgi.py b/Lib/test/test_cgi.py
deleted file mode 100644
--- a/Lib/test/test_cgi.py
+++ /dev/null
@@ -1,395 +0,0 @@
-from test.test_support import run_unittest, check_warnings
-import cgi
-import os
-import sys
-import tempfile
-import unittest
-
-class HackedSysModule:
- # The regression test will have real values in sys.argv, which
- # will completely confuse the test of the cgi module
- argv = []
- stdin = sys.stdin
-
-cgi.sys = HackedSysModule()
-
-try:
- from cStringIO import StringIO
-except ImportError:
- from StringIO import StringIO
-
-class ComparableException:
- def __init__(self, err):
- self.err = err
-
- def __str__(self):
- return str(self.err)
-
- def __cmp__(self, anExc):
- if not isinstance(anExc, Exception):
- return -1
- x = cmp(self.err.__class__, anExc.__class__)
- if x != 0:
- return x
- return cmp(self.err.args, anExc.args)
-
- def __getattr__(self, attr):
- return getattr(self.err, attr)
-
-def do_test(buf, method):
- env = {}
- if method == "GET":
- fp = None
- env['REQUEST_METHOD'] = 'GET'
- env['QUERY_STRING'] = buf
- elif method == "POST":
- fp = StringIO(buf)
- env['REQUEST_METHOD'] = 'POST'
- env['CONTENT_TYPE'] = 'application/x-www-form-urlencoded'
- env['CONTENT_LENGTH'] = str(len(buf))
- else:
- raise ValueError, "unknown method: %s" % method
- try:
- return cgi.parse(fp, env, strict_parsing=1)
- except StandardError, err:
- return ComparableException(err)
-
-parse_strict_test_cases = [
- ("", ValueError("bad query field: ''")),
- ("&", ValueError("bad query field: ''")),
- ("&&", ValueError("bad query field: ''")),
- (";", ValueError("bad query field: ''")),
- (";&;", ValueError("bad query field: ''")),
- # Should the next few really be valid?
- ("=", {}),
- ("=&=", {}),
- ("=;=", {}),
- # This rest seem to make sense
- ("=a", {'': ['a']}),
- ("&=a", ValueError("bad query field: ''")),
- ("=a&", ValueError("bad query field: ''")),
- ("=&a", ValueError("bad query field: 'a'")),
- ("b=a", {'b': ['a']}),
- ("b+=a", {'b ': ['a']}),
- ("a=b=a", {'a': ['b=a']}),
- ("a=+b=a", {'a': [' b=a']}),
- ("&b=a", ValueError("bad query field: ''")),
- ("b&=a", ValueError("bad query field: 'b'")),
-#FIXME: None of these are working in Jython
-# ("a=a+b&b=b+c", {'a': ['a b'], 'b': ['b c']}),
-# ("a=a+b&a=b+a", {'a': ['a b', 'b a']}),
-# ("x=1&y=2.0&z=2-3.%2b0", {'x': ['1'], 'y': ['2.0'], 'z': ['2-3.+0']}),
-# ("x=1;y=2.0&z=2-3.%2b0", {'x': ['1'], 'y': ['2.0'], 'z': ['2-3.+0']}),
-# ("x=1;y=2.0;z=2-3.%2b0", {'x': ['1'], 'y': ['2.0'], 'z': ['2-3.+0']}),
-# ("Hbc5161168c542333633315dee1182227:key_store_seqid=400006&cuyer=r&view=bustomer&order_id=0bb2e248638833d48cb7fed300000f1b&expire=964546263&lobale=en-US&kid=130003.300038&ss=env",
-# {'Hbc5161168c542333633315dee1182227:key_store_seqid': ['400006'],
-# 'cuyer': ['r'],
-# 'expire': ['964546263'],
-# 'kid': ['130003.300038'],
-# 'lobale': ['en-US'],
-# 'order_id': ['0bb2e248638833d48cb7fed300000f1b'],
-# 'ss': ['env'],
-# 'view': ['bustomer'],
-# }),
-#
-# ("group_id=5470&set=custom&_assigned_to=31392&_status=1&_category=100&SUBMIT=Browse",
-# {'SUBMIT': ['Browse'],
-# '_assigned_to': ['31392'],
-# '_category': ['100'],
-# '_status': ['1'],
-# 'group_id': ['5470'],
-# 'set': ['custom'],
-# })
- ]
-
-def first_elts(list):
- return map(lambda x:x[0], list)
-
-def first_second_elts(list):
- return map(lambda p:(p[0], p[1][0]), list)
-
-def gen_result(data, environ):
- fake_stdin = StringIO(data)
- fake_stdin.seek(0)
- form = cgi.FieldStorage(fp=fake_stdin, environ=environ)
-
- result = {}
- for k, v in dict(form).items():
- result[k] = isinstance(v, list) and form.getlist(k) or v.value
-
- return result
-
-class CgiTests(unittest.TestCase):
-
- def test_escape(self):
- self.assertEqual("test & string", cgi.escape("test & string"))
- self.assertEqual("<test string>", cgi.escape("<test string>"))
- self.assertEqual(""test string"", cgi.escape('"test string"', True))
-
- def test_strict(self):
- for orig, expect in parse_strict_test_cases:
- # Test basic parsing
- d = do_test(orig, "GET")
- self.assertEqual(d, expect, "Error parsing %s" % repr(orig))
- d = do_test(orig, "POST")
- self.assertEqual(d, expect, "Error parsing %s" % repr(orig))
-
- env = {'QUERY_STRING': orig}
- fcd = cgi.FormContentDict(env)
- sd = cgi.SvFormContentDict(env)
- fs = cgi.FieldStorage(environ=env)
- if isinstance(expect, dict):
- # test dict interface
- self.assertEqual(len(expect), len(fcd))
- self.assertItemsEqual(expect.keys(), fcd.keys())
- self.assertItemsEqual(expect.values(), fcd.values())
- self.assertItemsEqual(expect.items(), fcd.items())
- self.assertEqual(fcd.get("nonexistent field", "default"), "default")
- self.assertEqual(len(sd), len(fs))
- self.assertItemsEqual(sd.keys(), fs.keys())
- self.assertEqual(fs.getvalue("nonexistent field", "default"), "default")
- # test individual fields
- for key in expect.keys():
- expect_val = expect[key]
- self.assertTrue(fcd.has_key(key))
- self.assertItemsEqual(fcd[key], expect[key])
- self.assertEqual(fcd.get(key, "default"), fcd[key])
- self.assertTrue(fs.has_key(key))
- if len(expect_val) > 1:
- single_value = 0
- else:
- single_value = 1
- try:
- val = sd[key]
- except IndexError:
- self.assertFalse(single_value)
- self.assertEqual(fs.getvalue(key), expect_val)
- else:
- self.assertTrue(single_value)
- self.assertEqual(val, expect_val[0])
- self.assertEqual(fs.getvalue(key), expect_val[0])
- self.assertItemsEqual(sd.getlist(key), expect_val)
- if single_value:
- self.assertItemsEqual(sd.values(),
- first_elts(expect.values()))
- self.assertItemsEqual(sd.items(),
- first_second_elts(expect.items()))
-
- def test_weird_formcontentdict(self):
- # Test the weird FormContentDict classes
- env = {'QUERY_STRING': "x=1&y=2.0&z=2-3.%2b0&1=1abc"}
- expect = {'x': 1, 'y': 2.0, 'z': '2-3.+0', '1': '1abc'}
- d = cgi.InterpFormContentDict(env)
- for k, v in expect.items():
- self.assertEqual(d[k], v)
- for k, v in d.items():
- self.assertEqual(expect[k], v)
- self.assertItemsEqual(expect.values(), d.values())
-
- def test_log(self):
- cgi.log("Testing")
-
- cgi.logfp = StringIO()
- cgi.initlog("%s", "Testing initlog 1")
- cgi.log("%s", "Testing log 2")
- self.assertEqual(cgi.logfp.getvalue(), "Testing initlog 1\nTesting log 2\n")
- if os.path.exists("/dev/null"):
- cgi.logfp = None
- cgi.logfile = "/dev/null"
- cgi.initlog("%s", "Testing log 3")
- cgi.log("Testing log 4")
-
- def test_fieldstorage_readline(self):
- # FieldStorage uses readline, which has the capacity to read all
- # contents of the input file into memory; we use readline's size argument
- # to prevent that for files that do not contain any newlines in
- # non-GET/HEAD requests
- class TestReadlineFile:
- def __init__(self, file):
- self.file = file
- self.numcalls = 0
-
- def readline(self, size=None):
- self.numcalls += 1
- if size:
- return self.file.readline(size)
- else:
- return self.file.readline()
-
- def __getattr__(self, name):
- file = self.__dict__['file']
- a = getattr(file, name)
- if not isinstance(a, int):
- setattr(self, name, a)
- return a
-
- f = TestReadlineFile(tempfile.TemporaryFile())
- f.write('x' * 256 * 1024)
- f.seek(0)
- env = {'REQUEST_METHOD':'PUT'}
- fs = cgi.FieldStorage(fp=f, environ=env)
- # if we're not chunking properly, readline is only called twice
- # (by read_binary); if we are chunking properly, it will be called 5 times
- # as long as the chunksize is 1 << 16.
- self.assertTrue(f.numcalls > 2)
-
- def test_fieldstorage_multipart(self):
- #Test basic FieldStorage multipart parsing
- env = {'REQUEST_METHOD':'POST', 'CONTENT_TYPE':'multipart/form-data; boundary=---------------------------721837373350705526688164684', 'CONTENT_LENGTH':'558'}
- postdata = """-----------------------------721837373350705526688164684
-Content-Disposition: form-data; name="id"
-
-1234
------------------------------721837373350705526688164684
-Content-Disposition: form-data; name="title"
-
-
------------------------------721837373350705526688164684
-Content-Disposition: form-data; name="file"; filename="test.txt"
-Content-Type: text/plain
-
-Testing 123.
-
------------------------------721837373350705526688164684
-Content-Disposition: form-data; name="submit"
-
- Add\x20
------------------------------721837373350705526688164684--
-"""
- fs = cgi.FieldStorage(fp=StringIO(postdata), environ=env)
- self.assertEqual(len(fs.list), 4)
- expect = [{'name':'id', 'filename':None, 'value':'1234'},
- {'name':'title', 'filename':None, 'value':''},
- {'name':'file', 'filename':'test.txt','value':'Testing 123.\n'},
- {'name':'submit', 'filename':None, 'value':' Add '}]
- for x in range(len(fs.list)):
- for k, exp in expect[x].items():
- got = getattr(fs.list[x], k)
- self.assertEqual(got, exp)
-
- _qs_result = {
- 'key1': 'value1',
- 'key2': ['value2x', 'value2y'],
- 'key3': 'value3',
- 'key4': 'value4'
- }
- def testQSAndUrlEncode(self):
- data = "key2=value2x&key3=value3&key4=value4"
- environ = {
- 'CONTENT_LENGTH': str(len(data)),
- 'CONTENT_TYPE': 'application/x-www-form-urlencoded',
- 'QUERY_STRING': 'key1=value1&key2=value2y',
- 'REQUEST_METHOD': 'POST',
- }
- v = gen_result(data, environ)
- self.assertEqual(self._qs_result, v)
-
- def testQSAndFormData(self):
- data = """
----123
-Content-Disposition: form-data; name="key2"
-
-value2y
----123
-Content-Disposition: form-data; name="key3"
-
-value3
----123
-Content-Disposition: form-data; name="key4"
-
-value4
----123--
-"""
- environ = {
- 'CONTENT_LENGTH': str(len(data)),
- 'CONTENT_TYPE': 'multipart/form-data; boundary=-123',
- 'QUERY_STRING': 'key1=value1&key2=value2x',
- 'REQUEST_METHOD': 'POST',
- }
- v = gen_result(data, environ)
- self.assertEqual(self._qs_result, v)
-
- def testQSAndFormDataFile(self):
- data = """
----123
-Content-Disposition: form-data; name="key2"
-
-value2y
----123
-Content-Disposition: form-data; name="key3"
-
-value3
----123
-Content-Disposition: form-data; name="key4"
-
-value4
----123
-Content-Disposition: form-data; name="upload"; filename="fake.txt"
-Content-Type: text/plain
-
-this is the content of the fake file
-
----123--
-"""
- environ = {
- 'CONTENT_LENGTH': str(len(data)),
- 'CONTENT_TYPE': 'multipart/form-data; boundary=-123',
- 'QUERY_STRING': 'key1=value1&key2=value2x',
- 'REQUEST_METHOD': 'POST',
- }
- result = self._qs_result.copy()
- result.update({
- 'upload': 'this is the content of the fake file\n'
- })
- v = gen_result(data, environ)
- self.assertEqual(result, v)
-
- def test_deprecated_parse_qs(self):
- # this func is moved to urlparse, this is just a sanity check
- with check_warnings(('cgi.parse_qs is deprecated, use urlparse.'
- 'parse_qs instead', PendingDeprecationWarning)):
- self.assertEqual({'a': ['A1'], 'B': ['B3'], 'b': ['B2']},
- cgi.parse_qs('a=A1&b=B2&B=B3'))
-
- def test_deprecated_parse_qsl(self):
- # this func is moved to urlparse, this is just a sanity check
- with check_warnings(('cgi.parse_qsl is deprecated, use urlparse.'
- 'parse_qsl instead', PendingDeprecationWarning)):
- self.assertEqual([('a', 'A1'), ('b', 'B2'), ('B', 'B3')],
- cgi.parse_qsl('a=A1&b=B2&B=B3'))
-
- def test_parse_header(self):
- self.assertEqual(
- cgi.parse_header("text/plain"),
- ("text/plain", {}))
- self.assertEqual(
- cgi.parse_header("text/vnd.just.made.this.up ; "),
- ("text/vnd.just.made.this.up", {}))
- self.assertEqual(
- cgi.parse_header("text/plain;charset=us-ascii"),
- ("text/plain", {"charset": "us-ascii"}))
- self.assertEqual(
- cgi.parse_header('text/plain ; charset="us-ascii"'),
- ("text/plain", {"charset": "us-ascii"}))
- self.assertEqual(
- cgi.parse_header('text/plain ; charset="us-ascii"; another=opt'),
- ("text/plain", {"charset": "us-ascii", "another": "opt"}))
- self.assertEqual(
- cgi.parse_header('attachment; filename="silly.txt"'),
- ("attachment", {"filename": "silly.txt"}))
- self.assertEqual(
- cgi.parse_header('attachment; filename="strange;name"'),
- ("attachment", {"filename": "strange;name"}))
- self.assertEqual(
- cgi.parse_header('attachment; filename="strange;name";size=123;'),
- ("attachment", {"filename": "strange;name", "size": "123"}))
- self.assertEqual(
- cgi.parse_header('form-data; name="files"; filename="fo\\"o;bar"'),
- ("form-data", {"name": "files", "filename": 'fo"o;bar'}))
-
-
-def test_main():
- run_unittest(CgiTests)
-
-if __name__ == '__main__':
- test_main()
diff --git a/Lib/test/test_cmp_jy.py b/Lib/test/test_cmp_jy.py
--- a/Lib/test/test_cmp_jy.py
+++ b/Lib/test/test_cmp_jy.py
@@ -45,6 +45,46 @@
assert not (-1 == 'a')
+class ObjectCmp(unittest.TestCase):
+ def testObjectListCompares(self):
+ # Also applies to tuple objects given common PySequence implementation
+ assert not object() == list()
+ assert object() != list()
+ assert not list() == object()
+ assert list() != object()
+
+ # Note that <, > rich comparisons in 2.x are broken by the
+ # lexicographic ordering of the type **name**. Example:
+ # 'object' > 'list'
+ assert not object() < list()
+ assert not object() <= list()
+ assert object() > list()
+ assert object() >= list()
+ assert list() < object()
+ assert list() <= object()
+ assert not list() > object()
+ assert not list() >= object()
+
+ def testObjectDictCompares(self):
+ # Also applies to such objects as defaultdict and Counter
+ assert not object() == dict()
+ assert object() != dict()
+ assert not dict() == object()
+ assert dict() != object()
+
+ # Note that <, > rich comparisons in 2.x are broken by the
+ # lexicographic ordering of the type **name**. Example:
+ # 'object' > 'dict'
+ assert not object() < dict()
+ assert not object() <= dict()
+ assert object() > dict()
+ assert object() >= dict()
+ assert dict() < object()
+ assert dict() <= object()
+ assert not dict() > object()
+ assert not dict() >= object()
+
+
class CustomCmp(unittest.TestCase):
def test___cmp___returns(self):
class Foo(object):
@@ -83,7 +123,8 @@
UnicodeDerivedCmp,
LongDerivedCmp,
IntStrCmp,
- CustomCmp
+ ObjectCmp,
+ CustomCmp,
)
diff --git a/Lib/test/test_codeop_jy.py b/Lib/test/test_codeop_jy.py
--- a/Lib/test/test_codeop_jy.py
+++ b/Lib/test/test_codeop_jy.py
@@ -18,6 +18,7 @@
if values:
d = {}
exec code in d
+ del d['__builtins__']
self.assertEquals(d,values)
elif value is not None:
self.assertEquals(eval(code,self.eval_d),value)
diff --git a/Lib/test/test_descr.py b/Lib/test/test_descr.py
--- a/Lib/test/test_descr.py
+++ b/Lib/test/test_descr.py
@@ -351,7 +351,7 @@
minstance.b = 2
minstance.a = 1
names = [x for x in dir(minstance) if x not in ["__name__", "__doc__"]]
- vereq(names, ['__package__', 'a', 'b'])
+ vereq(names, ['a', 'b'])
class M2(M):
def getdict(self):
diff --git a/Lib/test/test_dict_jy.py b/Lib/test/test_dict_jy.py
--- a/Lib/test/test_dict_jy.py
+++ b/Lib/test/test_dict_jy.py
@@ -1,5 +1,5 @@
from test import test_support
-import java
+from java.util import HashMap, Hashtable
import unittest
from collections import defaultdict
import test_dict
@@ -114,7 +114,7 @@
class JavaIntegrationTest(unittest.TestCase):
"Tests for instantiating dicts from Java maps and hashtables"
def test_hashmap(self):
- x = java.util.HashMap()
+ x = HashMap()
x.put('a', 1)
x.put('b', 2)
x.put('c', 3)
@@ -123,7 +123,7 @@
self.assertEqual(set(y.items()), set([('a', 1), ('b', 2), ('c', 3), ((1,2), "xyz")]))
def test_hashmap_builtin_pymethods(self):
- x = java.util.HashMap()
+ x = HashMap()
x['a'] = 1
x[(1, 2)] = 'xyz'
self.assertEqual({tup for tup in x.iteritems()}, {('a', 1), ((1, 2), 'xyz')})
@@ -132,18 +132,18 @@
def test_hashtable_equal(self):
for d in ({}, {1:2}):
- x = java.util.Hashtable(d)
+ x = Hashtable(d)
self.assertEqual(x, d)
self.assertEqual(d, x)
- self.assertEqual(x, java.util.HashMap(d))
+ self.assertEqual(x, HashMap(d))
def test_hashtable_remove(self):
- x = java.util.Hashtable({})
+ x = Hashtable({})
with self.assertRaises(KeyError):
del x[0]
def test_hashtable(self):
- x = java.util.Hashtable()
+ x = Hashtable()
x.put('a', 1)
x.put('b', 2)
x.put('c', 3)
@@ -154,10 +154,10 @@
class JavaDictTest(test_dict.DictTest):
- _class = java.util.HashMap
+ _class = HashMap
def test_copy_java_hashtable(self):
- x = java.util.Hashtable()
+ x = Hashtable()
xc = x.copy()
self.assertEqual(type(x), type(xc))
@@ -179,6 +179,57 @@
self.assertEqual(x.__delitem__(1), None)
self.assertEqual(len(x), 0)
+ def assert_property(self, prop, a, b):
+ prop(self._make_dict(a), self._make_dict(b))
+ prop(a, self._make_dict(b))
+ prop(self._make_dict(a), b)
+
+ def assert_not_property(self, prop, a, b):
+ with self.assertRaises(AssertionError):
+ prop(self._make_dict(a), self._make_dict(b))
+ with self.assertRaises(AssertionError):
+ prop(a, self._make_dict(b))
+ with self.assertRaises(AssertionError):
+ prop(self._make_dict(a), b)
+
+ # NOTE: when comparing dictionaries below exclusively in Java
+ # space, keys like 1 and 1L are different objects. Only when they
+ # are brought into Python space by Py.java2py, as is needed when
+ # comparing a Python dict with a Java Map, do we see them become
+ # equal.
+
+ def test_le(self):
+ self.assert_property(self.assertLessEqual, {}, {})
+ self.assert_property(self.assertLessEqual, {1: 2}, {1: 2})
+ self.assert_not_property(self.assertLessEqual, {1: 2, 3: 4}, {1: 2})
+ self.assert_property(self.assertLessEqual, {}, {1: 2})
+ self.assertLessEqual(self._make_dict({1: 2}), {1L: 2L, 3L: 4L})
+ self.assertLessEqual({1L: 2L}, self._make_dict({1: 2, 3L: 4L}))
+
+ def test_lt(self):
+ self.assert_not_property(self.assertLess, {}, {})
+ self.assert_not_property(self.assertLess, {1: 2}, {1: 2})
+ self.assert_not_property(self.assertLessEqual, {1: 2, 3: 4}, {1: 2})
+ self.assert_property(self.assertLessEqual, {}, {1: 2})
+ self.assertLess(self._make_dict({1: 2}), {1L: 2L, 3L: 4L})
+ self.assertLess({1L: 2L}, self._make_dict({1: 2, 3L: 4L}))
+
+ def test_ge(self):
+ self.assert_property(self.assertGreaterEqual, {}, {})
+ self.assert_property(self.assertGreaterEqual, {1: 2}, {1: 2})
+ self.assert_not_property(self.assertLessEqual, {1: 2, 3: 4}, {1: 2})
+ self.assert_property(self.assertLessEqual, {}, {1: 2})
+ self.assertGreaterEqual(self._make_dict({1: 2, 3: 4}), {1L: 2L})
+ self.assertGreaterEqual({1L: 2L, 3L: 4L}, self._make_dict({1: 2}))
+
+ def test_gt(self):
+ self.assert_not_property(self.assertGreater, {}, {})
+ self.assert_not_property(self.assertGreater, {1: 2}, {1: 2})
+ self.assert_not_property(self.assertLessEqual, {1: 2, 3: 4}, {1: 2})
+ self.assert_property(self.assertLessEqual, {}, {1: 2})
+ self.assertGreater(self._make_dict({1: 2, 3: 4}), {1L: 2L})
+ self.assertGreater({1L: 2L, 3L: 4L}, self._make_dict({1: 2}))
+
def test_main():
test_support.run_unittest(DictInitTest, DictCmpTest, DerivedDictTest, JavaIntegrationTest, JavaDictTest)
diff --git a/Lib/test/test_fileio.py b/Lib/test/test_fileio.py
--- a/Lib/test/test_fileio.py
+++ b/Lib/test/test_fileio.py
@@ -11,7 +11,7 @@
from functools import wraps
from test.test_support import (TESTFN, check_warnings, run_unittest,
- make_bad_fd, is_jython)
+ make_bad_fd, is_jython, gc_collect)
from test.test_support import py3k_bytes as bytes
from test.script_helper import run_python
@@ -34,7 +34,6 @@
self.f.close()
os.remove(TESTFN)
- @unittest.skipIf(is_jython, "FIXME: not working in Jython")
def testWeakRefs(self):
# verify weak references
p = proxy(self.f)
@@ -42,6 +41,7 @@
self.assertEqual(self.f.tell(), p.tell())
self.f.close()
self.f = None
+ gc_collect()
self.assertRaises(ReferenceError, getattr, p, 'tell')
def testSeekTell(self):
@@ -294,7 +294,15 @@
self.assertEqual(f.isatty(), False)
f.close()
- if sys.platform != "win32":
+ # Jython specific issues:
+ # On OSX, FileIO("/dev/tty", "w").isatty() is False
+ # On Ubuntu, FileIO("/dev/tty", "w").isatty() throws IOError: Illegal seek
+ #
+ # Much like we see on other platforms, we cannot reliably
+ # determine it is not seekable (or special).
+ #
+ # Related bug: http://bugs.jython.org/issue1945
+ if sys.platform != "win32" and not is_jython:
try:
f = self.f = _FileIO("/dev/tty", "a")
except EnvironmentError:
diff --git a/Lib/test/test_funcattrs.py b/Lib/test/test_funcattrs.py
--- a/Lib/test/test_funcattrs.py
+++ b/Lib/test/test_funcattrs.py
@@ -62,9 +62,7 @@
def test_func_globals(self):
self.assertIs(self.b.func_globals, globals())
- self.assertIs(self.b.__globals__, globals())
self.cannot_set_attr(self.b, 'func_globals', 2, TypeError)
- self.cannot_set_attr(self.b, '__globals__', 2, TypeError)
def test_func_closure(self):
a = 12
@@ -150,10 +148,8 @@
return a+b
self.assertEqual(first_func.func_defaults, None)
self.assertEqual(second_func.func_defaults, (1, 2))
- self.assertEqual(second_func.func_defaults, second_func.__defaults__)
first_func.func_defaults = (1, 2)
self.assertEqual(first_func.func_defaults, (1, 2))
- self.assertEqual(first_func.func_defaults, first_func.__defaults__)
self.assertEqual(first_func(), 3)
self.assertEqual(first_func(3), 5)
self.assertEqual(first_func(3, 5), 8)
@@ -312,7 +308,6 @@
class FunctionDocstringTest(FuncAttrsTest):
- @unittest.skipIf(test_support.is_jython, "FIXME: not working in Jython")
def test_set_docstring_attr(self):
self.assertEqual(self.b.__doc__, None)
self.assertEqual(self.b.func_doc, None)
@@ -322,8 +317,14 @@
self.assertEqual(self.b.func_doc, docstr)
self.assertEqual(self.f.a.__doc__, docstr)
self.assertEqual(self.fi.a.__doc__, docstr)
- self.cannot_set_attr(self.f.a, "__doc__", docstr, AttributeError)
- self.cannot_set_attr(self.fi.a, "__doc__", docstr, AttributeError)
+ # Jython is more uniform in its attribute model than CPython.
+ # Unfortunately we have more tests depending on such attempted
+ # settings of read-only attributes resulting in a TypeError
+ # than an AttributeError. But fixing this seems pointless for
+ # now, deferring to Jython 3.x. See
+ # http://bugs.python.org/issue1687163
+ self.cannot_set_attr(self.f.a, "__doc__", docstr, TypeError)
+ self.cannot_set_attr(self.fi.a, "__doc__", docstr, TypeError)
def test_delete_docstring(self):
self.b.__doc__ = "The docstring"
diff --git a/Lib/test/test_hmac.py b/Lib/test/test_hmac.py
deleted file mode 100644
--- a/Lib/test/test_hmac.py
+++ /dev/null
@@ -1,318 +0,0 @@
-import hmac
-import hashlib
-import unittest
-import warnings
-from test import test_support
-
-class TestVectorsTestCase(unittest.TestCase):
-
- def test_md5_vectors(self):
- # Test the HMAC module against test vectors from the RFC.
-
- def md5test(key, data, digest):
- h = hmac.HMAC(key, data)
- self.assertEqual(h.hexdigest().upper(), digest.upper())
-
- md5test(chr(0x0b) * 16,
- "Hi There",
- "9294727A3638BB1C13F48EF8158BFC9D")
-
- md5test("Jefe",
- "what do ya want for nothing?",
- "750c783e6ab0b503eaa86e310a5db738")
-
- md5test(chr(0xAA)*16,
- chr(0xDD)*50,
- "56be34521d144c88dbb8c733f0e8b3f6")
-
- md5test("".join([chr(i) for i in range(1, 26)]),
- chr(0xCD) * 50,
- "697eaf0aca3a3aea3a75164746ffaa79")
-
- md5test(chr(0x0C) * 16,
- "Test With Truncation",
- "56461ef2342edc00f9bab995690efd4c")
-
- md5test(chr(0xAA) * 80,
- "Test Using Larger Than Block-Size Key - Hash Key First",
- "6b1ab7fe4bd7bf8f0b62e6ce61b9d0cd")
-
- md5test(chr(0xAA) * 80,
- ("Test Using Larger Than Block-Size Key "
- "and Larger Than One Block-Size Data"),
- "6f630fad67cda0ee1fb1f562db3aa53e")
-
- def test_sha_vectors(self):
- def shatest(key, data, digest):
- h = hmac.HMAC(key, data, digestmod=hashlib.sha1)
- self.assertEqual(h.hexdigest().upper(), digest.upper())
-
- shatest(chr(0x0b) * 20,
- "Hi There",
- "b617318655057264e28bc0b6fb378c8ef146be00")
-
- shatest("Jefe",
- "what do ya want for nothing?",
- "effcdf6ae5eb2fa2d27416d5f184df9c259a7c79")
-
- shatest(chr(0xAA)*20,
- chr(0xDD)*50,
- "125d7342b9ac11cd91a39af48aa17b4f63f175d3")
-
- shatest("".join([chr(i) for i in range(1, 26)]),
- chr(0xCD) * 50,
- "4c9007f4026250c6bc8414f9bf50c86c2d7235da")
-
- shatest(chr(0x0C) * 20,
- "Test With Truncation",
- "4c1a03424b55e07fe7f27be1d58bb9324a9a5a04")
-
- shatest(chr(0xAA) * 80,
- "Test Using Larger Than Block-Size Key - Hash Key First",
- "aa4ae5e15272d00e95705637ce8a3b55ed402112")
-
- shatest(chr(0xAA) * 80,
- ("Test Using Larger Than Block-Size Key "
- "and Larger Than One Block-Size Data"),
- "e8e99d0f45237d786d6bbaa7965c7808bbff1a91")
-
- def _rfc4231_test_cases(self, hashfunc):
- def hmactest(key, data, hexdigests):
- h = hmac.HMAC(key, data, digestmod=hashfunc)
- self.assertEqual(h.hexdigest().lower(), hexdigests[hashfunc])
-
- # 4.2. Test Case 1
- hmactest(key = '\x0b'*20,
- data = 'Hi There',
- hexdigests = {
- hashlib.sha224: '896fb1128abbdf196832107cd49df33f'
- '47b4b1169912ba4f53684b22',
- hashlib.sha256: 'b0344c61d8db38535ca8afceaf0bf12b'
- '881dc200c9833da726e9376c2e32cff7',
- hashlib.sha384: 'afd03944d84895626b0825f4ab46907f'
- '15f9dadbe4101ec682aa034c7cebc59c'
- 'faea9ea9076ede7f4af152e8b2fa9cb6',
- hashlib.sha512: '87aa7cdea5ef619d4ff0b4241a1d6cb0'
- '2379f4e2ce4ec2787ad0b30545e17cde'
- 'daa833b7d6b8a702038b274eaea3f4e4'
- 'be9d914eeb61f1702e696c203a126854',
- })
-
- # 4.3. Test Case 2
- hmactest(key = 'Jefe',
- data = 'what do ya want for nothing?',
- hexdigests = {
- hashlib.sha224: 'a30e01098bc6dbbf45690f3a7e9e6d0f'
- '8bbea2a39e6148008fd05e44',
- hashlib.sha256: '5bdcc146bf60754e6a042426089575c7'
- '5a003f089d2739839dec58b964ec3843',
- hashlib.sha384: 'af45d2e376484031617f78d2b58a6b1b'
- '9c7ef464f5a01b47e42ec3736322445e'
- '8e2240ca5e69e2c78b3239ecfab21649',
- hashlib.sha512: '164b7a7bfcf819e2e395fbe73b56e0a3'
- '87bd64222e831fd610270cd7ea250554'
- '9758bf75c05a994a6d034f65f8f0e6fd'
- 'caeab1a34d4a6b4b636e070a38bce737',
- })
-
- # 4.4. Test Case 3
- hmactest(key = '\xaa'*20,
- data = '\xdd'*50,
- hexdigests = {
- hashlib.sha224: '7fb3cb3588c6c1f6ffa9694d7d6ad264'
- '9365b0c1f65d69d1ec8333ea',
- hashlib.sha256: '773ea91e36800e46854db8ebd09181a7'
- '2959098b3ef8c122d9635514ced565fe',
- hashlib.sha384: '88062608d3e6ad8a0aa2ace014c8a86f'
- '0aa635d947ac9febe83ef4e55966144b'
- '2a5ab39dc13814b94e3ab6e101a34f27',
- hashlib.sha512: 'fa73b0089d56a284efb0f0756c890be9'
- 'b1b5dbdd8ee81a3655f83e33b2279d39'
- 'bf3e848279a722c806b485a47e67c807'
- 'b946a337bee8942674278859e13292fb',
- })
-
- # 4.5. Test Case 4
- hmactest(key = ''.join([chr(x) for x in xrange(0x01, 0x19+1)]),
- data = '\xcd'*50,
- hexdigests = {
- hashlib.sha224: '6c11506874013cac6a2abc1bb382627c'
- 'ec6a90d86efc012de7afec5a',
- hashlib.sha256: '82558a389a443c0ea4cc819899f2083a'
- '85f0faa3e578f8077a2e3ff46729665b',
- hashlib.sha384: '3e8a69b7783c25851933ab6290af6ca7'
- '7a9981480850009cc5577c6e1f573b4e'
- '6801dd23c4a7d679ccf8a386c674cffb',
- hashlib.sha512: 'b0ba465637458c6990e5a8c5f61d4af7'
- 'e576d97ff94b872de76f8050361ee3db'
- 'a91ca5c11aa25eb4d679275cc5788063'
- 'a5f19741120c4f2de2adebeb10a298dd',
- })
-
- # 4.7. Test Case 6
- hmactest(key = '\xaa'*131,
- data = 'Test Using Larger Than Block-Siz'
- 'e Key - Hash Key First',
- hexdigests = {
- hashlib.sha224: '95e9a0db962095adaebe9b2d6f0dbce2'
- 'd499f112f2d2b7273fa6870e',
- hashlib.sha256: '60e431591ee0b67f0d8a26aacbf5b77f'
- '8e0bc6213728c5140546040f0ee37f54',
- hashlib.sha384: '4ece084485813e9088d2c63a041bc5b4'
- '4f9ef1012a2b588f3cd11f05033ac4c6'
- '0c2ef6ab4030fe8296248df163f44952',
- hashlib.sha512: '80b24263c7c1a3ebb71493c1dd7be8b4'
- '9b46d1f41b4aeec1121b013783f8f352'
- '6b56d037e05f2598bd0fd2215d6a1e52'
- '95e64f73f63f0aec8b915a985d786598',
- })
-
- # 4.8. Test Case 7
- hmactest(key = '\xaa'*131,
- data = 'This is a test using a larger th'
- 'an block-size key and a larger t'
- 'han block-size data. The key nee'
- 'ds to be hashed before being use'
- 'd by the HMAC algorithm.',
- hexdigests = {
- hashlib.sha224: '3a854166ac5d9f023f54d517d0b39dbd'
- '946770db9c2b95c9f6f565d1',
- hashlib.sha256: '9b09ffa71b942fcb27635fbcd5b0e944'
- 'bfdc63644f0713938a7f51535c3a35e2',
- hashlib.sha384: '6617178e941f020d351e2f254e8fd32c'
- '602420feb0b8fb9adccebb82461e99c5'
- 'a678cc31e799176d3860e6110c46523e',
- hashlib.sha512: 'e37b6a775dc87dbaa4dfa9f96e5e3ffd'
- 'debd71f8867289865df5a32d20cdc944'
- 'b6022cac3c4982b10d5eeb55c3e4de15'
- '134676fb6de0446065c97440fa8c6a58',
- })
-
- def test_sha224_rfc4231(self):
- self._rfc4231_test_cases(hashlib.sha224)
-
- def test_sha256_rfc4231(self):
- self._rfc4231_test_cases(hashlib.sha256)
-
- def test_sha384_rfc4231(self):
- self._rfc4231_test_cases(hashlib.sha384)
-
- def test_sha512_rfc4231(self):
- self._rfc4231_test_cases(hashlib.sha512)
-
- def test_legacy_block_size_warnings(self):
- class MockCrazyHash(object):
- """Ain't no block_size attribute here."""
- def __init__(self, *args):
- self._x = hashlib.sha1(*args)
- self.digest_size = self._x.digest_size
- def update(self, v):
- self._x.update(v)
- def digest(self):
- return self._x.digest()
-
- with warnings.catch_warnings():
- warnings.simplefilter('error', RuntimeWarning)
- with self.assertRaises(RuntimeWarning):
- hmac.HMAC('a', 'b', digestmod=MockCrazyHash)
- self.fail('Expected warning about missing block_size')
-
- MockCrazyHash.block_size = 1
- with self.assertRaises(RuntimeWarning):
- hmac.HMAC('a', 'b', digestmod=MockCrazyHash)
- self.fail('Expected warning about small block_size')
-
-
-
-class ConstructorTestCase(unittest.TestCase):
-
- def test_normal(self):
- # Standard constructor call.
- failed = 0
- try:
- h = hmac.HMAC("key")
- except:
- self.fail("Standard constructor call raised exception.")
-
- def test_withtext(self):
- # Constructor call with text.
- try:
- h = hmac.HMAC("key", "hash this!")
- except:
- self.fail("Constructor call with text argument raised exception.")
-
- def test_withmodule(self):
- # Constructor call with text and digest module.
- try:
- h = hmac.HMAC("key", "", hashlib.sha1)
- except:
- self.fail("Constructor call with hashlib.sha1 raised exception.")
-
-class SanityTestCase(unittest.TestCase):
-
- def test_default_is_md5(self):
- # Testing if HMAC defaults to MD5 algorithm.
- # NOTE: this whitebox test depends on the hmac class internals
- h = hmac.HMAC("key")
- self.assertTrue(h.digest_cons == hashlib.md5)
-
- def test_exercise_all_methods(self):
- # Exercising all methods once.
- # This must not raise any exceptions
- try:
- h = hmac.HMAC("my secret key")
- h.update("compute the hash of this text!")
- dig = h.digest()
- dig = h.hexdigest()
- h2 = h.copy()
- except:
- self.fail("Exception raised during normal usage of HMAC class.")
-
-class CopyTestCase(unittest.TestCase):
-
- def test_attributes(self):
- # Testing if attributes are of same type.
- h1 = hmac.HMAC("key")
- h2 = h1.copy()
- self.assertTrue(h1.digest_cons == h2.digest_cons,
- "digest constructors don't match.")
- self.assertTrue(type(h1.inner) == type(h2.inner),
- "Types of inner don't match.")
- self.assertTrue(type(h1.outer) == type(h2.outer),
- "Types of outer don't match.")
-
- def test_realcopy(self):
- # Testing if the copy method created a real copy.
- h1 = hmac.HMAC("key")
- h2 = h1.copy()
- # Using id() in case somebody has overridden __cmp__.
- self.assertTrue(id(h1) != id(h2), "No real copy of the HMAC instance.")
- self.assertTrue(id(h1.inner) != id(h2.inner),
- "No real copy of the attribute 'inner'.")
- self.assertTrue(id(h1.outer) != id(h2.outer),
- "No real copy of the attribute 'outer'.")
-
- def test_equality(self):
- # Testing if the copy has the same digests.
- h1 = hmac.HMAC("key")
- h1.update("some random text")
- h2 = h1.copy()
- self.assertTrue(h1.digest() == h2.digest(),
- "Digest of copy doesn't match original digest.")
- self.assertTrue(h1.hexdigest() == h2.hexdigest(),
- "Hexdigest of copy doesn't match original hexdigest.")
-
-def test_main():
- if test_support.is_jython:
- # XXX: Jython doesn't support sha224
- del TestVectorsTestCase.test_sha224_rfc4231
- hashlib.sha224 = None
- test_support.run_unittest(
- TestVectorsTestCase,
- ConstructorTestCase,
- SanityTestCase,
- CopyTestCase
- )
-
-if __name__ == "__main__":
- test_main()
diff --git a/Lib/test/test_import.py b/Lib/test/test_import.py
new file mode 100644
--- /dev/null
+++ b/Lib/test/test_import.py
@@ -0,0 +1,662 @@
+import errno
+import imp
+import marshal
+import os
+import py_compile
+import random
+import stat
+import struct
+import sys
+import unittest
+import textwrap
+import shutil
+
+from test.test_support import (unlink, TESTFN, unload, run_unittest, rmtree,
+ is_jython, check_warnings, EnvironmentVarGuard)
+from test import symlink_support
+from test import script_helper
+
+def _files(name):
+ return (name + os.extsep + "py",
+ name + os.extsep + "pyc",
+ name + os.extsep + "pyo",
+ name + os.extsep + "pyw",
+ name + "$py.class")
+
+def chmod_files(name):
+ for f in _files(name):
+ try:
+ os.chmod(f, 0600)
+ except OSError as exc:
+ if exc.errno != errno.ENOENT:
+ raise
+
+def remove_files(name):
+ for f in _files(name):
+ unlink(f)
+
+
+class ImportTests(unittest.TestCase):
+
+ def tearDown(self):
+ unload(TESTFN)
+ setUp = tearDown
+
+ def test_case_sensitivity(self):
+ # Brief digression to test that import is case-sensitive: if we got
+ # this far, we know for sure that "random" exists.
+ try:
+ import RAnDoM
+ except ImportError:
+ pass
+ else:
+ self.fail("import of RAnDoM should have failed (case mismatch)")
+
+ def test_double_const(self):
+ # Another brief digression to test the accuracy of manifest float
+ # constants.
+ from test import double_const # don't blink -- that *was* the test
+
+ def test_import(self):
+ def test_with_extension(ext):
+ # The extension is normally ".py", perhaps ".pyw".
+ source = TESTFN + ext
+ pyo = TESTFN + os.extsep + "pyo"
+ if is_jython:
+ pyc = TESTFN + "$py.class"
+ else:
+ pyc = TESTFN + os.extsep + "pyc"
+
+ with open(source, "w") as f:
+ print >> f, ("# This tests Python's ability to import a", ext,
+ "file.")
+ a = random.randrange(1000)
+ b = random.randrange(1000)
+ print >> f, "a =", a
+ print >> f, "b =", b
+
+ try:
+ mod = __import__(TESTFN)
+ except ImportError, err:
+ self.fail("import from %s failed: %s" % (ext, err))
+ else:
+ self.assertEqual(mod.a, a,
+ "module loaded (%s) but contents invalid" % mod)
+ self.assertEqual(mod.b, b,
+ "module loaded (%s) but contents invalid" % mod)
+ finally:
+ unlink(source)
+
+ try:
+ imp.reload(mod)
+ except ImportError, err:
+ self.fail("import from .pyc/.pyo failed: %s" % err)
+ finally:
+ unlink(pyc)
+ unlink(pyo)
+ unload(TESTFN)
+
+ sys.path.insert(0, os.curdir)
+ try:
+ test_with_extension(os.extsep + "py")
+ if sys.platform.startswith("win"):
+ for ext in [".PY", ".Py", ".pY", ".pyw", ".PYW", ".pYw"]:
+ test_with_extension(ext)
+ finally:
+ del sys.path[0]
+
+ @unittest.skipUnless(os.name == 'posix', "test meaningful only on posix systems")
+ def test_execute_bit_not_copied(self):
+ # Issue 6070: under posix .pyc files got their execute bit set if
+ # the .py file had the execute bit set, but they aren't executable.
+ oldmask = os.umask(022)
+ sys.path.insert(0, os.curdir)
+ try:
+ fname = TESTFN + os.extsep + "py"
+ f = open(fname, 'w').close()
+ os.chmod(fname, (stat.S_IRUSR | stat.S_IRGRP | stat.S_IROTH |
+ stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH))
+ __import__(TESTFN)
+ fn = fname + 'c'
+ if not os.path.exists(fn):
+ fn = fname + 'o'
+ if not os.path.exists(fn):
+ self.fail("__import__ did not result in creation of "
+ "either a .pyc or .pyo file")
+ s = os.stat(fn)
+ self.assertEqual(stat.S_IMODE(s.st_mode),
+ stat.S_IRUSR | stat.S_IRGRP | stat.S_IROTH)
+ finally:
+ os.umask(oldmask)
+ remove_files(TESTFN)
+ unload(TESTFN)
+ del sys.path[0]
+
+ def test_rewrite_pyc_with_read_only_source(self):
+ # Issue 6074: a long time ago on posix, and more recently on Windows,
+ # a read only source file resulted in a read only pyc file, which
+ # led to problems with updating it later
+ sys.path.insert(0, os.curdir)
+ fname = TESTFN + os.extsep + "py"
+ try:
+ # Write a Python file, make it read-only and import it
+ with open(fname, 'w') as f:
+ f.write("x = 'original'\n")
+ # Tweak the mtime of the source to ensure pyc gets updated later
+ s = os.stat(fname)
+ os.utime(fname, (s.st_atime, s.st_mtime-100000000))
+ os.chmod(fname, 0400)
+ m1 = __import__(TESTFN)
+ self.assertEqual(m1.x, 'original')
+ # Change the file and then reimport it
+ os.chmod(fname, 0600)
+ with open(fname, 'w') as f:
+ f.write("x = 'rewritten'\n")
+ unload(TESTFN)
+ m2 = __import__(TESTFN)
+ self.assertEqual(m2.x, 'rewritten')
+ # Now delete the source file and check the pyc was rewritten
+ unlink(fname)
+ unload(TESTFN)
+ m3 = __import__(TESTFN)
+ self.assertEqual(m3.x, 'rewritten')
+ finally:
+ chmod_files(TESTFN)
+ remove_files(TESTFN)
+ unload(TESTFN)
+ del sys.path[0]
+
+ def test_imp_module(self):
+ # Verify that the imp module can correctly load and find .py files
+
+ # XXX (ncoghlan): It would be nice to use test_support.CleanImport
+ # here, but that breaks because the os module registers some
+ # handlers in copy_reg on import. Since CleanImport doesn't
+ # revert that registration, the module is left in a broken
+ # state after reversion. Reinitialising the module contents
+ # and just reverting os.environ to its previous state is an OK
+ # workaround
+ orig_path = os.path
+ orig_getenv = os.getenv
+ with EnvironmentVarGuard():
+ x = imp.find_module("os")
+ new_os = imp.load_module("os", *x)
+ self.assertIs(os, new_os)
+ self.assertIs(orig_path, new_os.path)
+ self.assertIsNot(orig_getenv, new_os.getenv)
+
+ def test_module_with_large_stack(self, module='longlist'):
+ # Regression test for http://bugs.python.org/issue561858.
+ filename = module + os.extsep + 'py'
+
+ # Create a file with a list of 65000 elements.
+ with open(filename, 'w+') as f:
+ f.write('d = [\n')
+ for i in range(65000):
+ f.write('"",\n')
+ f.write(']')
+
+ # Compile & remove .py file, we only need .pyc (or .pyo).
+ with open(filename, 'r') as f:
+ py_compile.compile(filename)
+ unlink(filename)
+
+ # Need to be able to load from current dir.
+ sys.path.append('')
+
+ # This used to crash.
+ exec 'import ' + module
+
+ # Cleanup.
+ del sys.path[-1]
+ unlink(filename + 'c')
+ unlink(filename + 'o')
+
+ def test_failing_import_sticks(self):
+ source = TESTFN + os.extsep + "py"
+ with open(source, "w") as f:
+ print >> f, "a = 1 // 0"
+
+ # New in 2.4, we shouldn't be able to import that no matter how often
+ # we try.
+ sys.path.insert(0, os.curdir)
+ try:
+ for i in [1, 2, 3]:
+ self.assertRaises(ZeroDivisionError, __import__, TESTFN)
+ self.assertNotIn(TESTFN, sys.modules,
+ "damaged module in sys.modules on %i try" % i)
+ finally:
+ del sys.path[0]
+ remove_files(TESTFN)
+
+ def test_failing_reload(self):
+ # A failing reload should leave the module object in sys.modules.
+ source = TESTFN + os.extsep + "py"
+ with open(source, "w") as f:
+ print >> f, "a = 1"
+ print >> f, "b = 2"
+
+ sys.path.insert(0, os.curdir)
+ try:
+ mod = __import__(TESTFN)
+ self.assertIn(TESTFN, sys.modules)
+ self.assertEqual(mod.a, 1, "module has wrong attribute values")
+ self.assertEqual(mod.b, 2, "module has wrong attribute values")
+
+ # On WinXP, just replacing the .py file wasn't enough to
+ # convince reload() to reparse it. Maybe the timestamp didn't
+ # move enough. We force it to get reparsed by removing the
+ # compiled file too.
+ remove_files(TESTFN)
+
+ # Now damage the module.
+ with open(source, "w") as f:
+ print >> f, "a = 10"
+ print >> f, "b = 20//0"
+
+ self.assertRaises(ZeroDivisionError, imp.reload, mod)
+
+ # But we still expect the module to be in sys.modules.
+ mod = sys.modules.get(TESTFN)
+ self.assertIsNot(mod, None, "expected module to be in sys.modules")
+
+ # We should have replaced a w/ 10, but the old b value should
+ # stick.
+ self.assertEqual(mod.a, 10, "module has wrong attribute values")
+ self.assertEqual(mod.b, 2, "module has wrong attribute values")
+
+ finally:
+ del sys.path[0]
+ remove_files(TESTFN)
+ unload(TESTFN)
+
+ def test_infinite_reload(self):
+ # http://bugs.python.org/issue742342 reports that Python segfaults
+ # (infinite recursion in C) when faced with self-recursive reload()ing.
+
+ sys.path.insert(0, os.path.dirname(__file__))
+ try:
+ import infinite_reload
+ finally:
+ del sys.path[0]
+
+ def test_import_name_binding(self):
+ # import x.y.z binds x in the current namespace.
+ import test as x
+ import test.test_support
+ self.assertIs(x, test, x.__name__)
+ self.assertTrue(hasattr(test.test_support, "__file__"))
+
+ # import x.y.z as w binds z as w.
+ import test.test_support as y
+ self.assertIs(y, test.test_support, y.__name__)
+
+ def test_import_initless_directory_warning(self):
+ # NOTE: to test this, we have to remove Jython's JavaImporter
+ # (bound to the string '__classpath__', which of course
+ # supports such directories as possible Java packages.
+ #
+ # For Jython 3.x we really need to rethink what it does, since
+ # it repeatedly causes questions on Jython forums, but too
+ # late to change for 2.7, except perhaps by some option.
+ classpath_entry = sys.path.index('__classpath__')
+ del sys.path[classpath_entry]
+ try:
+ with check_warnings(('', ImportWarning)):
+ # Just a random non-package directory we always expect to be
+ # somewhere in sys.path...
+ self.assertRaises(ImportError, __import__, "site-packages")
+ finally:
+ sys.path.insert(classpath_entry, '__classpath__')
+
+ def test_import_by_filename(self):
+ path = os.path.abspath(TESTFN)
+ with self.assertRaises(ImportError) as c:
+ __import__(path)
+ self.assertEqual("Import by filename is not supported.",
+ c.exception.args[0])
+
+ def test_import_in_del_does_not_crash(self):
+ # Issue 4236
+ testfn = script_helper.make_script('', TESTFN, textwrap.dedent("""\
+ import sys
+ class C:
+ def __del__(self):
+ import imp
+ sys.argv.insert(0, C())
+ """))
+ try:
+ script_helper.assert_python_ok(testfn)
+ finally:
+ unlink(testfn)
+
+ def test_bug7732(self):
+ source = TESTFN + '.py'
+ os.mkdir(source)
+ try:
+ self.assertRaises((ImportError, IOError),
+ imp.find_module, TESTFN, ["."])
+ finally:
+ os.rmdir(source)
+
+ def test_timestamp_overflow(self):
+ # A modification timestamp larger than 2**32 should not be a problem
+ # when importing a module (issue #11235).
+ sys.path.insert(0, os.curdir)
+ try:
+ source = TESTFN + ".py"
+ if is_jython:
+ compiled = TESTFN + "$py.class"
+ else:
+ compiled = source + ('c' if __debug__ else 'o')
+ with open(source, 'w') as f:
+ pass
+ try:
+ os.utime(source, (2 ** 33 - 5, 2 ** 33 - 5))
+ except OverflowError:
+ self.skipTest("cannot set modification time to large integer")
+ except OSError as e:
+ if e.errno != getattr(errno, 'EOVERFLOW', None):
+ raise
+ self.skipTest("cannot set modification time to large integer ({})".format(e))
+ __import__(TESTFN)
+ # The pyc file was created.
+ os.stat(compiled)
+ finally:
+ del sys.path[0]
+ remove_files(TESTFN)
+
+ def test_pyc_mtime(self):
+ # Test for issue #13863: .pyc timestamp sometimes incorrect on Windows.
+ sys.path.insert(0, os.curdir)
+ try:
+ # Jan 1, 2012; Jul 1, 2012.
+ mtimes = 1325376000, 1341100800
+
+ # Different names to avoid running into import caching.
+ tails = "spam", "eggs"
+ for mtime, tail in zip(mtimes, tails):
+ module = TESTFN + tail
+ source = module + ".py"
+ compiled = source + ('c' if __debug__ else 'o')
+
+ # Create a new Python file with the given mtime.
+ with open(source, 'w') as f:
+ f.write("# Just testing\nx=1, 2, 3\n")
+ os.utime(source, (mtime, mtime))
+
+ # Generate the .pyc/o file; if it couldn't be created
+ # for some reason, skip the test.
+ m = __import__(module)
+ if not os.path.exists(compiled):
+ unlink(source)
+ self.skipTest("Couldn't create .pyc/.pyo file.")
+
+ # Actual modification time of .py file.
+ mtime1 = int(os.stat(source).st_mtime) & 0xffffffff
+
+ # mtime that was encoded in the .pyc file.
+ with open(compiled, 'rb') as f:
+ mtime2 = struct.unpack('<L', f.read(8)[4:])[0]
+
+ unlink(compiled)
+ unlink(source)
+
+ self.assertEqual(mtime1, mtime2)
+ finally:
+ sys.path.pop(0)
+
+
+class PycRewritingTests(unittest.TestCase):
+ # Test that the `co_filename` attribute on code objects always points
+ # to the right file, even when various things happen (e.g. both the .py
+ # and the .pyc file are renamed).
+
+ module_name = "unlikely_module_name"
+ module_source = """
+import sys
+code_filename = sys._getframe().f_code.co_filename
+module_filename = __file__
+constant = 1
+def func():
+ pass
+func_filename = func.func_code.co_filename
+"""
+ dir_name = os.path.abspath(TESTFN)
+ file_name = os.path.join(dir_name, module_name) + os.extsep + "py"
+ if is_jython:
+ compiled_name = os.path.join(dir_name, module_name) + "$py.class"
+ else:
+ compiled_name = file_name + ("c" if __debug__ else "o")
+
+ def setUp(self):
+ self.sys_path = sys.path[:]
+ self.orig_module = sys.modules.pop(self.module_name, None)
+ os.mkdir(self.dir_name)
+ with open(self.file_name, "w") as f:
+ f.write(self.module_source)
+ sys.path.insert(0, self.dir_name)
+
+ def tearDown(self):
+ sys.path[:] = self.sys_path
+ if self.orig_module is not None:
+ sys.modules[self.module_name] = self.orig_module
+ else:
+ unload(self.module_name)
+ unlink(self.file_name)
+ unlink(self.compiled_name)
+ rmtree(self.dir_name)
+
+ def import_module(self):
+ ns = globals()
+ __import__(self.module_name, ns, ns)
+ return sys.modules[self.module_name]
+
+ def test_basics(self):
+ mod = self.import_module()
+ self.assertEqual(mod.module_filename, self.file_name)
+ self.assertEqual(mod.code_filename, self.file_name)
+ self.assertEqual(mod.func_filename, self.file_name)
+ del sys.modules[self.module_name]
+ mod = self.import_module()
+ self.assertEqual(mod.module_filename, self.compiled_name)
+ self.assertEqual(mod.code_filename, self.file_name)
+ self.assertEqual(mod.func_filename, self.file_name)
+
+ def test_incorrect_code_name(self):
+ py_compile.compile(self.file_name, dfile="another_module.py")
+ mod = self.import_module()
+ self.assertEqual(mod.module_filename, self.compiled_name)
+ self.assertEqual(mod.code_filename, self.file_name)
+ self.assertEqual(mod.func_filename, self.file_name)
+
+ def test_module_without_source(self):
+ target = "another_module.py"
+ py_compile.compile(self.file_name, dfile=target)
+ os.remove(self.file_name)
+ mod = self.import_module()
+ self.assertEqual(mod.module_filename, self.compiled_name)
+ self.assertEqual(mod.code_filename, target)
+ self.assertEqual(mod.func_filename, target)
+
+ @unittest.skipIf(is_jython, "Jython does not support compilation to Python bytecode (yet)")
+ def test_foreign_code(self):
+ py_compile.compile(self.file_name)
+ with open(self.compiled_name, "rb") as f:
+ header = f.read(8)
+ code = marshal.load(f)
+ constants = list(code.co_consts)
+ foreign_code = test_main.func_code
+ pos = constants.index(1)
+ constants[pos] = foreign_code
+ code = type(code)(code.co_argcount, code.co_nlocals, code.co_stacksize,
+ code.co_flags, code.co_code, tuple(constants),
+ code.co_names, code.co_varnames, code.co_filename,
+ code.co_name, code.co_firstlineno, code.co_lnotab,
+ code.co_freevars, code.co_cellvars)
+ with open(self.compiled_name, "wb") as f:
+ f.write(header)
+ marshal.dump(code, f)
+ mod = self.import_module()
+ self.assertEqual(mod.constant.co_filename, foreign_code.co_filename)
+
+
+class PathsTests(unittest.TestCase):
+ path = TESTFN
+
+ def setUp(self):
+ os.mkdir(self.path)
+ self.syspath = sys.path[:]
+
+ def tearDown(self):
+ rmtree(self.path)
+ sys.path[:] = self.syspath
+
+ # Regression test for http://bugs.python.org/issue1293.
+ def test_trailing_slash(self):
+ with open(os.path.join(self.path, 'test_trailing_slash.py'), 'w') as f:
+ f.write("testdata = 'test_trailing_slash'")
+ sys.path.append(self.path+'/')
+ mod = __import__("test_trailing_slash")
+ self.assertEqual(mod.testdata, 'test_trailing_slash')
+ unload("test_trailing_slash")
+
+ # Regression test for http://bugs.python.org/issue3677.
+ def _test_UNC_path(self):
+ with open(os.path.join(self.path, 'test_trailing_slash.py'), 'w') as f:
+ f.write("testdata = 'test_trailing_slash'")
+ # Create the UNC path, like \\myhost\c$\foo\bar.
+ path = os.path.abspath(self.path)
+ import socket
+ hn = socket.gethostname()
+ drive = path[0]
+ unc = "\\\\%s\\%s$"%(hn, drive)
+ unc += path[2:]
+ try:
+ os.listdir(unc)
+ except OSError as e:
+ if e.errno in (errno.EPERM, errno.EACCES):
+ # See issue #15338
+ self.skipTest("cannot access administrative share %r" % (unc,))
+ raise
+ sys.path.append(path)
+ mod = __import__("test_trailing_slash")
+ self.assertEqual(mod.testdata, 'test_trailing_slash')
+ unload("test_trailing_slash")
+
+ if sys.platform == "win32":
+ test_UNC_path = _test_UNC_path
+
+
+class RelativeImportTests(unittest.TestCase):
+
+ def tearDown(self):
+ unload("test.relimport")
+ setUp = tearDown
+
+ def test_relimport_star(self):
+ # This will import * from .test_import.
+ from . import relimport
+ self.assertTrue(hasattr(relimport, "RelativeImportTests"))
+
+ def test_issue3221(self):
+ # Regression test for http://bugs.python.org/issue3221.
+ def check_absolute():
+ exec "from os import path" in ns
+ def check_relative():
+ exec "from . import relimport" in ns
+
+ # Check both OK with __package__ and __name__ correct
+ ns = dict(__package__='test', __name__='test.notarealmodule')
+ check_absolute()
+ check_relative()
+
+ # Check both OK with only __name__ wrong
+ ns = dict(__package__='test', __name__='notarealpkg.notarealmodule')
+ check_absolute()
+ check_relative()
+
+ # Check relative fails with only __package__ wrong
+ ns = dict(__package__='foo', __name__='test.notarealmodule')
+ with check_warnings(('.+foo', RuntimeWarning)):
+ check_absolute()
+ self.assertRaises(SystemError, check_relative)
+
+ # Check relative fails with __package__ and __name__ wrong
+ ns = dict(__package__='foo', __name__='notarealpkg.notarealmodule')
+ with check_warnings(('.+foo', RuntimeWarning)):
+ check_absolute()
+ self.assertRaises(SystemError, check_relative)
+
+ # Check both fail with package set to a non-string
+ ns = dict(__package__=object())
+ self.assertRaises(ValueError, check_absolute)
+ self.assertRaises(ValueError, check_relative)
+
+ def test_absolute_import_without_future(self):
+ # If explicit relative import syntax is used, then do not try
+ # to perform an absolute import in the face of failure.
+ # Issue #7902.
+ with self.assertRaises(ImportError):
+ from .os import sep
+ self.fail("explicit relative import triggered an "
+ "implicit absolute import")
+
+
+class TestSymbolicallyLinkedPackage(unittest.TestCase):
+ package_name = 'sample'
+
+ def setUp(self):
+ if os.path.exists(self.tagged):
+ shutil.rmtree(self.tagged)
+ if os.path.exists(self.package_name):
+ symlink_support.remove_symlink(self.package_name)
+ self.orig_sys_path = sys.path[:]
+
+ # create a sample package; imagine you have a package with a tag and
+ # you want to symbolically link it from its untagged name.
+ os.mkdir(self.tagged)
+ init_file = os.path.join(self.tagged, '__init__.py')
+ open(init_file, 'w').close()
+ assert os.path.exists(init_file)
+
+ # now create a symlink to the tagged package
+ # sample -> sample-tagged
+ symlink_support.symlink(self.tagged, self.package_name)
+
+ assert os.path.isdir(self.package_name)
+ assert os.path.isfile(os.path.join(self.package_name, '__init__.py'))
+
+ @property
+ def tagged(self):
+ return self.package_name + '-tagged'
+
+ # regression test for issue6727
+ @unittest.skipUnless(
+ not hasattr(sys, 'getwindowsversion')
+ or sys.getwindowsversion() >= (6, 0),
+ "Windows Vista or later required")
+ @symlink_support.skip_unless_symlink
+ def test_symlinked_dir_importable(self):
+ # make sure sample can only be imported from the current directory.
+ sys.path[:] = ['.']
+
+ # and try to import the package
+ __import__(self.package_name)
+
+ def tearDown(self):
+ # now cleanup
+ if os.path.exists(self.package_name):
+ symlink_support.remove_symlink(self.package_name)
+ if os.path.exists(self.tagged):
+ shutil.rmtree(self.tagged)
+ sys.path[:] = self.orig_sys_path
+
+def test_main(verbose=None):
+ run_unittest(ImportTests, PycRewritingTests, PathsTests,
+ RelativeImportTests, TestSymbolicallyLinkedPackage)
+
+if __name__ == '__main__':
+ # Test needs to be a package, so we can do relative imports.
+ from test.test_import import test_main
+ test_main()
diff --git a/Lib/test/test_inspect.py b/Lib/test/test_inspect.py
--- a/Lib/test/test_inspect.py
+++ b/Lib/test/test_inspect.py
@@ -633,8 +633,15 @@
class TestGetcallargsFunctions(unittest.TestCase):
- # tuple parameters are named '.1', '.2', etc.
- is_tuplename = re.compile(r'^\.\d+$').match
+ # It's possible to get both the tuple parameters AND the unpacked
+ # parameters by using locals(). However, they are named
+ # differently in CPython and Jython:
+ #
+ # * For CPython, such tuple parameters are named '.1', '.2', etc.
+ # * For Jython, they are actually the formal parameter, eg '(d, (e, f))'
+ #
+ # In both cases, we ignore in testing - they are in fact unpacked
+ is_tuplename = re.compile(r'(?:^\.\d+$)|(?:^\()').match
def assertEqualCallArgs(self, func, call_params_string, locs=None):
locs = dict(locs or {}, func=func)
diff --git a/Lib/test/test_itertools.py b/Lib/test/test_itertools.py
--- a/Lib/test/test_itertools.py
+++ b/Lib/test/test_itertools.py
@@ -1,5 +1,6 @@
import unittest
from test import test_support
+from test.test_weakref import extra_collect
from itertools import *
from weakref import proxy
from decimal import Decimal
@@ -336,11 +337,8 @@
self.assertEqual(take(2, zip('abc',count(-3))), [('a', -3), ('b', -2)])
self.assertRaises(TypeError, count, 2, 3, 4)
self.assertRaises(TypeError, count, 'a')
-
- #FIXME: not working in Jython
- #self.assertEqual(list(islice(count(maxsize-5), 10)), range(maxsize-5, maxsize+5))
- #self.assertEqual(list(islice(count(-maxsize-5), 10)), range(-maxsize-5, -maxsize+5))
-
+ self.assertEqual(list(islice(count(maxsize-5), 10)), range(maxsize-5, maxsize+5))
+ self.assertEqual(list(islice(count(-maxsize-5), 10)), range(-maxsize-5, -maxsize+5))
c = count(3)
self.assertEqual(repr(c), 'count(3)')
c.next()
@@ -348,27 +346,20 @@
c = count(-9)
self.assertEqual(repr(c), 'count(-9)')
c.next()
+ self.assertEqual(repr(count(10.25)), 'count(10.25)')
+ self.assertEqual(c.next(), -8)
+ for i in (-sys.maxint-5, -sys.maxint+5 ,-10, -1, 0, 10, sys.maxint-5, sys.maxint+5):
+ # Test repr (ignoring the L in longs)
+ r1 = repr(count(i)).replace('L', '')
+ r2 = 'count(%r)'.__mod__(i).replace('L', '')
+ self.assertEqual(r1, r2)
- #FIXME: not working in Jython
- #self.assertEqual(repr(count(10.25)), 'count(10.25)')
- self.assertEqual(c.next(), -8)
-
- #FIXME: not working in Jython
- if not test_support.is_jython:
- for i in (-sys.maxint-5, -sys.maxint+5 ,-10, -1, 0, 10, sys.maxint-5, sys.maxint+5):
- # Test repr (ignoring the L in longs)
- r1 = repr(count(i)).replace('L', '')
- r2 = 'count(%r)'.__mod__(i).replace('L', '')
- self.assertEqual(r1, r2)
-
- #FIXME: not working in Jython
# check copy, deepcopy, pickle
- if not test_support.is_jython:
- for value in -3, 3, sys.maxint-5, sys.maxint+5:
- c = count(value)
- self.assertEqual(next(copy.copy(c)), value)
- self.assertEqual(next(copy.deepcopy(c)), value)
- self.assertEqual(next(pickle.loads(pickle.dumps(c))), value)
+ for value in -3, 3, sys.maxint-5, sys.maxint+5:
+ c = count(value)
+ self.assertEqual(next(copy.copy(c)), value)
+ self.assertEqual(next(copy.deepcopy(c)), value)
+ self.assertEqual(next(pickle.loads(pickle.dumps(c))), value)
def test_count_with_stride(self):
self.assertEqual(zip('abc',count(2,3)), [('a', 2), ('b', 5), ('c', 8)])
@@ -378,17 +369,14 @@
[('a', 0), ('b', -1), ('c', -2)])
self.assertEqual(zip('abc',count(2,0)), [('a', 2), ('b', 2), ('c', 2)])
self.assertEqual(zip('abc',count(2,1)), [('a', 2), ('b', 3), ('c', 4)])
-
- #FIXME: not working in Jython
- #self.assertEqual(take(20, count(maxsize-15, 3)), take(20, range(maxsize-15, maxsize+100, 3)))
- #self.assertEqual(take(20, count(-maxsize-15, 3)), take(20, range(-maxsize-15,-maxsize+100, 3)))
- #self.assertEqual(take(3, count(2, 3.25-4j)), [2, 5.25-4j, 8.5-8j])
- #self.assertEqual(take(3, count(Decimal('1.1'), Decimal('.1'))),
- # [Decimal('1.1'), Decimal('1.2'), Decimal('1.3')])
- #self.assertEqual(take(3, count(Fraction(2,3), Fraction(1,7))),
- # [Fraction(2,3), Fraction(17,21), Fraction(20,21)])
- #self.assertEqual(repr(take(3, count(10, 2.5))), repr([10, 12.5, 15.0]))
-
+ self.assertEqual(take(20, count(maxsize-15, 3)), take(20, range(maxsize-15, maxsize+100, 3)))
+ self.assertEqual(take(20, count(-maxsize-15, 3)), take(20, range(-maxsize-15,-maxsize+100, 3)))
+ self.assertEqual(take(3, count(2, 3.25-4j)), [2, 5.25-4j, 8.5-8j])
+ self.assertEqual(take(3, count(Decimal('1.1'), Decimal('.1'))),
+ [Decimal('1.1'), Decimal('1.2'), Decimal('1.3')])
+ self.assertEqual(take(3, count(Fraction(2,3), Fraction(1,7))),
+ [Fraction(2,3), Fraction(17,21), Fraction(20,21)])
+ self.assertEqual(repr(take(3, count(10, 2.5))), repr([10, 12.5, 15.0]))
c = count(3, 5)
self.assertEqual(repr(c), 'count(3, 5)')
c.next()
@@ -402,23 +390,18 @@
c.next()
self.assertEqual(repr(c), 'count(-12, -3)')
self.assertEqual(repr(c), 'count(-12, -3)')
-
- #FIXME: not working in Jython
- #self.assertEqual(repr(count(10.5, 1.25)), 'count(10.5, 1.25)')
- #self.assertEqual(repr(count(10.5, 1)), 'count(10.5)') # suppress step=1 when it's an int
- #self.assertEqual(repr(count(10.5, 1.00)), 'count(10.5, 1.0)') # do show float values lilke 1.0
-
- #FIXME: not working in Jython
- if not test_support.is_jython:
- for i in (-sys.maxint-5, -sys.maxint+5 ,-10, -1, 0, 10, sys.maxint-5, sys.maxint+5):
- for j in (-sys.maxint-5, -sys.maxint+5 ,-10, -1, 0, 1, 10, sys.maxint-5, sys.maxint+5):
- # Test repr (ignoring the L in longs)
- r1 = repr(count(i, j)).replace('L', '')
- if j == 1:
- r2 = ('count(%r)' % i).replace('L', '')
- else:
- r2 = ('count(%r, %r)' % (i, j)).replace('L', '')
- self.assertEqual(r1, r2)
+ self.assertEqual(repr(count(10.5, 1.25)), 'count(10.5, 1.25)')
+ self.assertEqual(repr(count(10.5, 1)), 'count(10.5)') # suppress step=1 when it's an int
+ self.assertEqual(repr(count(10.5, 1.00)), 'count(10.5, 1.0)') # do show float values lilke 1.0
+ for i in (-sys.maxint-5, -sys.maxint+5 ,-10, -1, 0, 10, sys.maxint-5, sys.maxint+5):
+ for j in (-sys.maxint-5, -sys.maxint+5 ,-10, -1, 0, 1, 10, sys.maxint-5, sys.maxint+5):
+ # Test repr (ignoring the L in longs)
+ r1 = repr(count(i, j)).replace('L', '')
+ if j == 1:
+ r2 = ('count(%r)' % i).replace('L', '')
+ else:
+ r2 = ('count(%r, %r)' % (i, j)).replace('L', '')
+ self.assertEqual(r1, r2)
def test_cycle(self):
self.assertEqual(take(10, cycle('abc')), list('abcabcabca'))
@@ -918,14 +901,25 @@
self.assertTrue(list(t1) == list(t2) == list(t3) == list('abc'))
# test that tee objects are weak referencable
- a, b = tee(xrange(10))
- p = proxy(a)
- self.assertEqual(getattr(p, '__class__'), type(b))
- del a
+ def delocalize():
+ # local variables in Jython cannot be deleted to see
+ # objects go out of scope immediately. Except for tests
+ # like this however this is not going to be observed!
+ a, b = tee(xrange(10))
+ return dict(a=a, b=b)
- #FIXME: not working in Jython
- if not test_support.is_jython:
- self.assertRaises(ReferenceError, getattr, p, '__class__')
+ d = delocalize()
+ p = proxy(d['a'])
+ self.assertEqual(getattr(p, '__class__'), type(d['b']))
+ del d['a']
+ extra_collect() # necessary for Jython to ensure ref queue is cleared out
+ self.assertRaises(ReferenceError, getattr, p, '__class__')
+
+ # Issue 13454: Crash when deleting backward iterator from tee()
+ def test_tee_del_backward(self):
+ forward, backward = tee(repeat(None, 20000000))
+ any(forward) # exhaust the iterator
+ del backward
def test_StopIteration(self):
self.assertRaises(StopIteration, izip().next)
diff --git a/Lib/test/test_json.py b/Lib/test/test_json.py
deleted file mode 100644
--- a/Lib/test/test_json.py
+++ /dev/null
@@ -1,20 +0,0 @@
-"""Tests for json.
-
-The tests for json are defined in the json.tests package;
-the test_suite() function there returns a test suite that's ready to
-be run.
-"""
-
-import json.tests
-import test.test_support
-
-from json.tests.test_unicode import TestUnicode
-
-def test_main():
- #FIXME: Investigate why test_bad_encoding isn't working in Jython.
- del TestUnicode.test_bad_encoding
- test.test_support.run_unittest(json.tests.test_suite())
-
-
-if __name__ == "__main__":
- test_main()
diff --git a/Lib/test/test_jy_internals.py b/Lib/test/test_jy_internals.py
--- a/Lib/test/test_jy_internals.py
+++ b/Lib/test/test_jy_internals.py
@@ -1,7 +1,6 @@
"""
test some jython internals
"""
-import gc
import unittest
import time
from test import test_support
@@ -18,7 +17,6 @@
class MemoryLeakTests(unittest.TestCase):
- @unittest.skip("FIXME: broken in 2.7.")
def test_class_to_test_weakness(self):
# regrtest for bug 1522, adapted from test code submitted by Matt Brinkley
@@ -28,16 +26,6 @@
# `type`!)
class_to_type_map = getField(type, 'class_to_type').get(None)
- def make_clean():
- # gc a few times just to be really sure, since in this
- # case we don't really care if it takes a few cycles of GC
- # for the garbage to be reached
- gc.collect()
- time.sleep(0.1)
- gc.collect()
- time.sleep(0.5)
- gc.collect()
-
def create_proxies():
pi = PythonInterpreter()
pi.exec("""
@@ -51,16 +39,20 @@
Dog().bark()
""")
- make_clean()
-
# get to steady state first, then verify we don't create new proxies
for i in xrange(2):
create_proxies()
- start_size = class_to_type_map.size()
+ # Ensure the reaper thread can run and clear out weak refs, so
+ # use this supporting function
+ test_support.gc_collect()
+ # Given that taking the len (or size()) of Guava weak maps is
+ # eventually consistent, we should instead take a len of its
+ # keys.
+ start_size = len(list(class_to_type_map))
for i in xrange(5):
create_proxies()
- make_clean()
- self.assertEqual(start_size, class_to_type_map.size())
+ test_support.gc_collect()
+ self.assertEqual(start_size, len(list(class_to_type_map)))
class WeakIdentityMapTests(unittest.TestCase):
diff --git a/Lib/test/test_list.py b/Lib/test/test_list.py
--- a/Lib/test/test_list.py
+++ b/Lib/test/test_list.py
@@ -2,18 +2,19 @@
from test import test_support, list_tests
class ListTest(list_tests.CommonTest):
+
type2test = list
def test_basic(self):
- self.assertEqual(list([]), [])
+ self.assertEqual(self.type2test([]), [])
l0_3 = [0, 1, 2, 3]
- l0_3_bis = list(l0_3)
+ l0_3_bis = self.type2test(l0_3)
self.assertEqual(l0_3, l0_3_bis)
self.assertTrue(l0_3 is not l0_3_bis)
- self.assertEqual(list(()), [])
- self.assertEqual(list((0, 1, 2, 3)), [0, 1, 2, 3])
- self.assertEqual(list(''), [])
- self.assertEqual(list('spam'), ['s', 'p', 'a', 'm'])
+ self.assertEqual(self.type2test(()), [])
+ self.assertEqual(self.type2test((0, 1, 2, 3)), [0, 1, 2, 3])
+ self.assertEqual(self.type2test(''), [])
+ self.assertEqual(self.type2test('spam'), ['s', 'p', 'a', 'm'])
#FIXME: too brutal for us ATM.
if not test_support.is_jython:
@@ -41,20 +42,21 @@
def test_truth(self):
super(ListTest, self).test_truth()
- self.assertTrue(not [])
- self.assertTrue([42])
+ self.assertTrue(not self.type2test([]))
+ self.assertTrue(self.type2test([42]))
def test_identity(self):
self.assertTrue([] is not [])
+ self.assertTrue(self.type2test([]) is not self.type2test([]))
def test_len(self):
super(ListTest, self).test_len()
- self.assertEqual(len([]), 0)
- self.assertEqual(len([0]), 1)
- self.assertEqual(len([0, 1, 2]), 3)
+ self.assertEqual(len(self.type2test([])), 0)
+ self.assertEqual(len(self.type2test([0])), 1)
+ self.assertEqual(len(self.type2test([0, 1, 2])), 3)
def test_overflow(self):
- lst = [4, 5, 6, 7]
+ lst = self.type2test([4, 5, 6, 7])
n = int((sys.maxint*2+2) // len(lst))
def mul(a, b): return a * b
def imul(a, b): a *= b
diff --git a/Lib/test/test_list_jy.py b/Lib/test/test_list_jy.py
--- a/Lib/test/test_list_jy.py
+++ b/Lib/test/test_list_jy.py
@@ -3,6 +3,7 @@
import threading
import time
from test import test_support
+import test_list
if test_support.is_jython:
from java.util import ArrayList
@@ -209,10 +210,37 @@
self.assertEqual(a, expected4)
+class JavaListTestCase(test_list.ListTest):
+
+ type2test = ArrayList
+
+ def test_init(self):
+ # Iterable arg is optional
+ self.assertEqual(self.type2test([]), self.type2test())
+
+ # Unlike with builtin types, we do not guarantee objects can
+ # be overwritten; see corresponding tests
+
+ # Mutables always return a new object
+ a = self.type2test([1, 2, 3])
+ b = self.type2test(a)
+ self.assertNotEqual(id(a), id(b))
+ self.assertEqual(a, b)
+
+
+ def test_extend_java_ArrayList(self):
+ jl = ArrayList([])
+ jl.extend([1,2])
+ self.assertEqual(jl, ArrayList([1,2]))
+ jl.extend(ArrayList([3,4]))
+ self.assertEqual(jl, [1,2,3,4])
+
+
def test_main():
test_support.run_unittest(ListTestCase,
ThreadSafetyTestCase,
- ExtendedSliceTestCase)
+ ExtendedSliceTestCase,
+ JavaListTestCase)
if __name__ == "__main__":
test_main()
diff --git a/Lib/test/test_module.py b/Lib/test/test_module.py
--- a/Lib/test/test_module.py
+++ b/Lib/test/test_module.py
@@ -1,77 +1,87 @@
# Test the module type
+import unittest
+from test.test_support import run_unittest, gc_collect
-from test.test_support import verify, vereq, verbose, TestFailed
-from types import ModuleType as module
+import StringIO # Jython: sub this for sys, given the special status of PySystemState
+ModuleType = type(StringIO)
-# An uninitialized module has no __dict__ or __name__, and __doc__ is None
-foo = module.__new__(module)
-verify(foo.__dict__ is None)
-try:
- s = foo.__name__
-except AttributeError:
- pass
-else:
- raise TestFailed, "__name__ = %s" % repr(s)
-# __doc__ is None by default in CPython but not in Jython.
-# We're not worrying about that now.
-#vereq(foo.__doc__, module.__doc__)
+class ModuleTests(unittest.TestCase):
+ def test_uninitialized(self):
+ # An uninitialized module has no __dict__ or __name__,
+ # and __doc__ is None
+ foo = ModuleType.__new__(ModuleType)
+ self.assertTrue(foo.__dict__ is None)
+ # CPython raises SystemError, but this is more consistent
+ # and doesn't seem worth special casing for dir() here
+ self.assertRaises(TypeError, dir, foo)
+ try:
+ s = foo.__name__
+ self.fail("__name__ = %s" % repr(s))
+ except AttributeError:
+ pass
+ self.assertEqual(foo.__doc__, ModuleType.__doc__)
-try:
- foo_dir = dir(foo)
-except TypeError:
- pass
-else:
- raise TestFailed, "__dict__ = %s" % repr(foo_dir)
+ def test_no_docstring(self):
+ # Regularly initialized module, no docstring
+ foo = ModuleType("foo")
+ self.assertEqual(foo.__name__, "foo")
+ self.assertEqual(foo.__doc__, None)
+ self.assertEqual(foo.__dict__, {"__name__": "foo", "__doc__": None})
-try:
- del foo.somename
-except AttributeError:
- pass
-else:
- raise TestFailed, "del foo.somename"
+ def test_ascii_docstring(self):
+ # ASCII docstring
+ foo = ModuleType("foo", "foodoc")
+ self.assertEqual(foo.__name__, "foo")
+ self.assertEqual(foo.__doc__, "foodoc")
+ self.assertEqual(foo.__dict__,
+ {"__name__": "foo", "__doc__": "foodoc"})
-try:
- del foo.__dict__
-except TypeError:
- pass
-else:
- raise TestFailed, "del foo.__dict__"
+ def test_unicode_docstring(self):
+ # Unicode docstring
+ foo = ModuleType("foo", u"foodoc\u1234")
+ self.assertEqual(foo.__name__, "foo")
+ self.assertEqual(foo.__doc__, u"foodoc\u1234")
+ self.assertEqual(foo.__dict__,
+ {"__name__": "foo", "__doc__": u"foodoc\u1234"})
-try:
- foo.__dict__ = {}
-except TypeError:
- pass
-else:
- raise TestFailed, "foo.__dict__ = {}"
-verify(foo.__dict__ is None)
+ def test_reinit(self):
+ # Reinitialization should not replace the __dict__
+ foo = ModuleType("foo", u"foodoc\u1234")
+ foo.bar = 42
+ d = foo.__dict__
+ foo.__init__("foo", "foodoc")
+ self.assertEqual(foo.__name__, "foo")
+ self.assertEqual(foo.__doc__, "foodoc")
+ self.assertEqual(foo.bar, 42)
+ self.assertEqual(foo.__dict__,
+ {"__name__": "foo", "__doc__": "foodoc", "bar": 42})
+ self.assertTrue(foo.__dict__ is d)
-# Regularly initialized module, no docstring
-foo = module("foo")
-vereq(foo.__name__, "foo")
-vereq(foo.__doc__, None)
-vereq(foo.__dict__, {"__name__": "foo", "__package__": None, "__doc__": None})
+ # @unittest.expectedFailure - works fine on Jython!
+ def test_dont_clear_dict(self):
+ # See issue 7140.
+ def f():
+ foo = ModuleType("foo")
+ foo.bar = 4
+ return foo
+ gc_collect()
+ self.assertEqual(f().__dict__["bar"], 4)
-# ASCII docstring
-foo = module("foo", "foodoc")
-vereq(foo.__name__, "foo")
-vereq(foo.__doc__, "foodoc")
-vereq(foo.__dict__, {"__name__": "foo", "__package__": None, "__doc__": "foodoc"})
+ def test_clear_dict_in_ref_cycle(self):
+ destroyed = []
+ m = ModuleType("foo")
+ m.destroyed = destroyed
+ s = """class A:
+ def __del__(self, destroyed=destroyed):
+ destroyed.append(1)
+a = A()"""
+ exec(s, m.__dict__)
+ del m
+ gc_collect()
+ self.assertEqual(destroyed, [1])
-# Unicode docstring
-foo = module("foo", u"foodoc\u1234")
-vereq(foo.__name__, "foo")
-vereq(foo.__doc__, u"foodoc\u1234")
-vereq(foo.__dict__, {"__name__": "foo", "__package__": None, "__doc__": u"foodoc\u1234"})
+def test_main():
+ run_unittest(ModuleTests)
-# Reinitialization should not replace the __dict__
-foo.bar = 42
-d = foo.__dict__
-foo.__init__("foo", "foodoc")
-vereq(foo.__name__, "foo")
-vereq(foo.__doc__, "foodoc")
-vereq(foo.bar, 42)
-vereq(foo.__dict__, {"__name__": "foo", "__package__": None, "__doc__": "foodoc", "bar": 42})
-verify(foo.__dict__ is d)
-
-if verbose:
- print "All OK"
+if __name__ == '__main__':
+ test_main()
diff --git a/Lib/test/test_pdb.py b/Lib/test/test_pdb.py
deleted file mode 100644
--- a/Lib/test/test_pdb.py
+++ /dev/null
@@ -1,316 +0,0 @@
-# A test suite for pdb; at the moment, this only validates skipping of
-# specified test modules (RFE #5142).
-
-import imp
-import sys
-import os
-import unittest
-import subprocess
-
-from test import test_support
-# This little helper class is essential for testing pdb under doctest.
-from test_doctest import _FakeInput
-
-
-class PdbTestInput(object):
- """Context manager that makes testing Pdb in doctests easier."""
-
- def __init__(self, input):
- self.input = input
-
- def __enter__(self):
- self.real_stdin = sys.stdin
- sys.stdin = _FakeInput(self.input)
-
- def __exit__(self, *exc):
- sys.stdin = self.real_stdin
-
-
-def write(x):
- print x
-
-def test_pdb_displayhook():
- """This tests the custom displayhook for pdb.
-
- >>> def test_function(foo, bar):
- ... import pdb; pdb.Pdb().set_trace()
- ... pass
-
- >>> with PdbTestInput([
- ... 'foo',
- ... 'bar',
- ... 'for i in range(5): write(i)',
- ... 'continue',
- ... ]):
- ... test_function(1, None)
- > <doctest test.test_pdb.test_pdb_displayhook[0]>(3)test_function()
- -> pass
- (Pdb) foo
- 1
- (Pdb) bar
- (Pdb) for i in range(5): write(i)
- 0
- 1
- 2
- 3
- 4
- (Pdb) continue
- """
-
-def test_pdb_breakpoint_commands():
- """Test basic commands related to breakpoints.
-
- >>> def test_function():
- ... import pdb; pdb.Pdb().set_trace()
- ... print(1)
- ... print(2)
- ... print(3)
- ... print(4)
-
- First, need to clear bdb state that might be left over from previous tests.
- Otherwise, the new breakpoints might get assigned different numbers.
-
- >>> from bdb import Breakpoint
- >>> Breakpoint.next = 1
- >>> Breakpoint.bplist = {}
- >>> Breakpoint.bpbynumber = [None]
-
- Now test the breakpoint commands. NORMALIZE_WHITESPACE is needed because
- the breakpoint list outputs a tab for the "stop only" and "ignore next"
- lines, which we don't want to put in here.
-
- >>> with PdbTestInput([ # doctest: +NORMALIZE_WHITESPACE
- ... 'break 3',
- ... 'disable 1',
- ... 'ignore 1 10',
- ... 'condition 1 1 < 2',
- ... 'break 4',
- ... 'break 4',
- ... 'break',
- ... 'clear 3',
- ... 'break',
- ... 'condition 1',
- ... 'enable 1',
- ... 'clear 1',
- ... 'commands 2',
- ... 'print 42',
- ... 'end',
- ... 'continue', # will stop at breakpoint 2 (line 4)
- ... 'clear', # clear all!
- ... 'y',
- ... 'tbreak 5',
- ... 'continue', # will stop at temporary breakpoint
- ... 'break', # make sure breakpoint is gone
- ... 'continue',
- ... ]):
- ... test_function()
- > <doctest test.test_pdb.test_pdb_breakpoint_commands[0]>(3)test_function()
- -> print(1)
- (Pdb) break 3
- Breakpoint 1 at <doctest test.test_pdb.test_pdb_breakpoint_commands[0]>:3
- (Pdb) disable 1
- (Pdb) ignore 1 10
- Will ignore next 10 crossings of breakpoint 1.
- (Pdb) condition 1 1 < 2
- (Pdb) break 4
- Breakpoint 2 at <doctest test.test_pdb.test_pdb_breakpoint_commands[0]>:4
- (Pdb) break 4
- Breakpoint 3 at <doctest test.test_pdb.test_pdb_breakpoint_commands[0]>:4
- (Pdb) break
- Num Type Disp Enb Where
- 1 breakpoint keep no at <doctest test.test_pdb.test_pdb_breakpoint_commands[0]>:3
- stop only if 1 < 2
- ignore next 10 hits
- 2 breakpoint keep yes at <doctest test.test_pdb.test_pdb_breakpoint_commands[0]>:4
- 3 breakpoint keep yes at <doctest test.test_pdb.test_pdb_breakpoint_commands[0]>:4
- (Pdb) clear 3
- Deleted breakpoint 3
- (Pdb) break
- Num Type Disp Enb Where
- 1 breakpoint keep no at <doctest test.test_pdb.test_pdb_breakpoint_commands[0]>:3
- stop only if 1 < 2
- ignore next 10 hits
- 2 breakpoint keep yes at <doctest test.test_pdb.test_pdb_breakpoint_commands[0]>:4
- (Pdb) condition 1
- Breakpoint 1 is now unconditional.
- (Pdb) enable 1
- (Pdb) clear 1
- Deleted breakpoint 1
- (Pdb) commands 2
- (com) print 42
- (com) end
- (Pdb) continue
- 1
- 42
- > <doctest test.test_pdb.test_pdb_breakpoint_commands[0]>(4)test_function()
- -> print(2)
- (Pdb) clear
- Clear all breaks? y
- (Pdb) tbreak 5
- Breakpoint 4 at <doctest test.test_pdb.test_pdb_breakpoint_commands[0]>:5
- (Pdb) continue
- 2
- Deleted breakpoint 4
- > <doctest test.test_pdb.test_pdb_breakpoint_commands[0]>(5)test_function()
- -> print(3)
- (Pdb) break
- (Pdb) continue
- 3
- 4
- """
-
-
-def test_pdb_skip_modules():
- """This illustrates the simple case of module skipping.
-
- >>> def skip_module():
- ... import string
- ... import pdb; pdb.Pdb(skip=['string*']).set_trace()
- ... string.lower('FOO')
-
- >>> with PdbTestInput([
- ... 'step',
- ... 'continue',
- ... ]):
- ... skip_module()
- > <doctest test.test_pdb.test_pdb_skip_modules[0]>(4)skip_module()
- -> string.lower('FOO')
- (Pdb) step
- --Return--
- > <doctest test.test_pdb.test_pdb_skip_modules[0]>(4)skip_module()->None
- -> string.lower('FOO')
- (Pdb) continue
- """
-
-
-# Module for testing skipping of module that makes a callback
-mod = imp.new_module('module_to_skip')
-exec 'def foo_pony(callback): x = 1; callback(); return None' in mod.__dict__
-
-
-def test_pdb_skip_modules_with_callback():
- """This illustrates skipping of modules that call into other code.
-
- >>> def skip_module():
- ... def callback():
- ... return None
- ... import pdb; pdb.Pdb(skip=['module_to_skip*']).set_trace()
- ... mod.foo_pony(callback)
-
- >>> with PdbTestInput([
- ... 'step',
- ... 'step',
- ... 'step',
- ... 'step',
- ... 'step',
- ... 'continue',
- ... ]):
- ... skip_module()
- ... pass # provides something to "step" to
- > <doctest test.test_pdb.test_pdb_skip_modules_with_callback[0]>(5)skip_module()
- -> mod.foo_pony(callback)
- (Pdb) step
- --Call--
- > <doctest test.test_pdb.test_pdb_skip_modules_with_callback[0]>(2)callback()
- -> def callback():
- (Pdb) step
- > <doctest test.test_pdb.test_pdb_skip_modules_with_callback[0]>(3)callback()
- -> return None
- (Pdb) step
- --Return--
- > <doctest test.test_pdb.test_pdb_skip_modules_with_callback[0]>(3)callback()->None
- -> return None
- (Pdb) step
- --Return--
- > <doctest test.test_pdb.test_pdb_skip_modules_with_callback[0]>(5)skip_module()->None
- -> mod.foo_pony(callback)
- (Pdb) step
- > <doctest test.test_pdb.test_pdb_skip_modules_with_callback[1]>(10)<module>()
- -> pass # provides something to "step" to
- (Pdb) continue
- """
-
-
-def test_pdb_continue_in_bottomframe():
- """Test that "continue" and "next" work properly in bottom frame (issue #5294).
-
- >>> def test_function():
- ... import pdb, sys; inst = pdb.Pdb()
- ... inst.set_trace()
- ... inst.botframe = sys._getframe() # hackery to get the right botframe
- ... print(1)
- ... print(2)
- ... print(3)
- ... print(4)
-
- First, need to clear bdb state that might be left over from previous tests.
- Otherwise, the new breakpoints might get assigned different numbers.
-
- >>> from bdb import Breakpoint
- >>> Breakpoint.next = 1
- >>> Breakpoint.bplist = {}
- >>> Breakpoint.bpbynumber = [None]
-
- >>> with PdbTestInput([
- ... 'next',
- ... 'break 7',
- ... 'continue',
- ... 'next',
- ... 'continue',
- ... 'continue',
- ... ]):
- ... test_function()
- > <doctest test.test_pdb.test_pdb_continue_in_bottomframe[0]>(4)test_function()
- -> inst.botframe = sys._getframe() # hackery to get the right botframe
- (Pdb) next
- > <doctest test.test_pdb.test_pdb_continue_in_bottomframe[0]>(5)test_function()
- -> print(1)
- (Pdb) break 7
- Breakpoint 1 at <doctest test.test_pdb.test_pdb_continue_in_bottomframe[0]>:7
- (Pdb) continue
- 1
- 2
- > <doctest test.test_pdb.test_pdb_continue_in_bottomframe[0]>(7)test_function()
- -> print(3)
- (Pdb) next
- 3
- > <doctest test.test_pdb.test_pdb_continue_in_bottomframe[0]>(8)test_function()
- -> print(4)
- (Pdb) continue
- 4
- """
-
-class ModuleInitTester(unittest.TestCase):
-
- @unittest.skipIf(test_support.is_jython, "FIXME: not working in Jython")
- def test_filename_correct(self):
- """
- In issue 7750, it was found that if the filename has a sequence that
- resolves to an escape character in a Python string (such as \t), it
- will be treated as the escaped character.
- """
- # the test_fn must contain something like \t
- # on Windows, this will create 'test_mod.py' in the current directory.
- # on Unix, this will create '.\test_mod.py' in the current directory.
- test_fn = '.\\test_mod.py'
- code = 'print("testing pdb")'
- with open(test_fn, 'w') as f:
- f.write(code)
- self.addCleanup(os.remove, test_fn)
- cmd = [sys.executable, '-m', 'pdb', test_fn,]
- proc = subprocess.Popen(cmd,
- stdout=subprocess.PIPE,
- stdin=subprocess.PIPE,
- stderr=subprocess.STDOUT,
- )
- stdout, stderr = proc.communicate('quit\n')
- self.assertIn(code, stdout, "pdb munged the filename")
-
-
-def test_main():
- from test import test_pdb
- test_support.run_doctest(test_pdb, verbosity=True)
- test_support.run_unittest(ModuleInitTester)
-
-if __name__ == '__main__':
- test_main()
diff --git a/Lib/test/test_pkgutil.py b/Lib/test/test_pkgutil.py
deleted file mode 100644
--- a/Lib/test/test_pkgutil.py
+++ /dev/null
@@ -1,142 +0,0 @@
-from test.test_support import run_unittest, is_jython
-import unittest
-import sys
-import imp
-import pkgutil
-import os
-import os.path
-import tempfile
-import shutil
-import zipfile
-
-
-
-class PkgutilTests(unittest.TestCase):
-
- def setUp(self):
- self.dirname = tempfile.mkdtemp()
- self.addCleanup(shutil.rmtree, self.dirname)
- sys.path.insert(0, self.dirname)
-
- def tearDown(self):
- del sys.path[0]
-
- def test_getdata_filesys(self):
- pkg = 'test_getdata_filesys'
-
- # Include a LF and a CRLF, to test that binary data is read back
- RESOURCE_DATA = 'Hello, world!\nSecond line\r\nThird line'
-
- # Make a package with some resources
- package_dir = os.path.join(self.dirname, pkg)
- os.mkdir(package_dir)
- # Empty init.py
- f = open(os.path.join(package_dir, '__init__.py'), "wb")
- f.close()
- # Resource files, res.txt, sub/res.txt
- f = open(os.path.join(package_dir, 'res.txt'), "wb")
- f.write(RESOURCE_DATA)
- f.close()
- os.mkdir(os.path.join(package_dir, 'sub'))
- f = open(os.path.join(package_dir, 'sub', 'res.txt'), "wb")
- f.write(RESOURCE_DATA)
- f.close()
-
- # Check we can read the resources
- res1 = pkgutil.get_data(pkg, 'res.txt')
- self.assertEqual(res1, RESOURCE_DATA)
- res2 = pkgutil.get_data(pkg, 'sub/res.txt')
- self.assertEqual(res2, RESOURCE_DATA)
-
- del sys.modules[pkg]
-
- def test_getdata_zipfile(self):
- zip = 'test_getdata_zipfile.zip'
- pkg = 'test_getdata_zipfile'
-
- # Include a LF and a CRLF, to test that binary data is read back
- RESOURCE_DATA = 'Hello, world!\nSecond line\r\nThird line'
-
- # Make a package with some resources
- zip_file = os.path.join(self.dirname, zip)
- z = zipfile.ZipFile(zip_file, 'w')
-
- # Empty init.py
- z.writestr(pkg + '/__init__.py', "")
- # Resource files, res.txt, sub/res.txt
- z.writestr(pkg + '/res.txt', RESOURCE_DATA)
- z.writestr(pkg + '/sub/res.txt', RESOURCE_DATA)
- z.close()
-
- # Check we can read the resources
- sys.path.insert(0, zip_file)
- res1 = pkgutil.get_data(pkg, 'res.txt')
- self.assertEqual(res1, RESOURCE_DATA)
- res2 = pkgutil.get_data(pkg, 'sub/res.txt')
- self.assertEqual(res2, RESOURCE_DATA)
- del sys.path[0]
-
- del sys.modules[pkg]
-
- @unittest.skipIf(is_jython, "FIXME: not working on Jython")
- def test_unreadable_dir_on_syspath(self):
- # issue7367 - walk_packages failed if unreadable dir on sys.path
- package_name = "unreadable_package"
- d = os.path.join(self.dirname, package_name)
- # this does not appear to create an unreadable dir on Windows
- # but the test should not fail anyway
- os.mkdir(d, 0)
- self.addCleanup(os.rmdir, d)
- for t in pkgutil.walk_packages(path=[self.dirname]):
- self.fail("unexpected package found")
-
-class PkgutilPEP302Tests(unittest.TestCase):
-
- class MyTestLoader(object):
- def load_module(self, fullname):
- # Create an empty module
- mod = sys.modules.setdefault(fullname, imp.new_module(fullname))
- mod.__file__ = "<%s>" % self.__class__.__name__
- mod.__loader__ = self
- # Make it a package
- mod.__path__ = []
- # Count how many times the module is reloaded
- mod.__dict__['loads'] = mod.__dict__.get('loads',0) + 1
- return mod
-
- def get_data(self, path):
- return "Hello, world!"
-
- class MyTestImporter(object):
- def find_module(self, fullname, path=None):
- return PkgutilPEP302Tests.MyTestLoader()
-
- def setUp(self):
- sys.meta_path.insert(0, self.MyTestImporter())
-
- def tearDown(self):
- del sys.meta_path[0]
-
- def test_getdata_pep302(self):
- # Use a dummy importer/loader
- self.assertEqual(pkgutil.get_data('foo', 'dummy'), "Hello, world!")
- del sys.modules['foo']
-
- def test_alreadyloaded(self):
- # Ensure that get_data works without reloading - the "loads" module
- # variable in the example loader should count how many times a reload
- # occurs.
- import foo
- self.assertEqual(foo.loads, 1)
- self.assertEqual(pkgutil.get_data('foo', 'dummy'), "Hello, world!")
- self.assertEqual(foo.loads, 1)
- del sys.modules['foo']
-
-def test_main():
- run_unittest(PkgutilTests, PkgutilPEP302Tests)
- # this is necessary if test is run repeated (like when finding leaks)
- import zipimport
- zipimport._zip_directory_cache.clear()
-
-if __name__ == '__main__':
- test_main()
diff --git a/Lib/test/test_profilehooks.py b/Lib/test/test_profilehooks.py
--- a/Lib/test/test_profilehooks.py
+++ b/Lib/test/test_profilehooks.py
@@ -17,11 +17,9 @@
def tearDown(self):
sys.setprofile(None)
- @unittest.skip("FIXME: broken")
def test_empty(self):
assert sys.getprofile() == None
- @unittest.skip("FIXME: broken")
def test_setget(self):
def fn(*args):
pass
diff --git a/Lib/test/test_runpy.py b/Lib/test/test_runpy.py
new file mode 100644
--- /dev/null
+++ b/Lib/test/test_runpy.py
@@ -0,0 +1,402 @@
+# Test the runpy module
+import unittest
+import os
+import os.path
+import sys
+import re
+import tempfile
+from test.test_support import verbose, run_unittest, forget
+from test.script_helper import (temp_dir, make_script, compile_script,
+ make_pkg, make_zip_script, make_zip_pkg)
+
+
+from runpy import _run_code, _run_module_code, run_module, run_path
+# Note: This module can't safely test _run_module_as_main as it
+# runs its tests in the current process, which would mess with the
+# real __main__ module (usually test.regrtest)
+# See test_cmd_line_script for a test that executes that code path
+
+# Set up the test code and expected results
+
+class RunModuleCodeTest(unittest.TestCase):
+ """Unit tests for runpy._run_code and runpy._run_module_code"""
+
+ expected_result = ["Top level assignment", "Lower level reference"]
+ test_source = (
+ "# Check basic code execution\n"
+ "result = ['Top level assignment']\n"
+ "def f():\n"
+ " result.append('Lower level reference')\n"
+ "f()\n"
+ "# Check the sys module\n"
+ "import sys\n"
+ "run_argv0 = sys.argv[0]\n"
+ "run_name_in_sys_modules = __name__ in sys.modules\n"
+ "if run_name_in_sys_modules:\n"
+ " module_in_sys_modules = globals() is sys.modules[__name__].__dict__\n"
+ "# Check nested operation\n"
+ "import runpy\n"
+ "nested = runpy._run_module_code('x=1\\n', mod_name='<run>')\n"
+ )
+
+ def test_run_code(self):
+ saved_argv0 = sys.argv[0]
+ d = _run_code(self.test_source, {})
+ self.assertEqual(d["result"], self.expected_result)
+ self.assertIs(d["__name__"], None)
+ self.assertIs(d["__file__"], None)
+ self.assertIs(d["__loader__"], None)
+ self.assertIs(d["__package__"], None)
+ self.assertIs(d["run_argv0"], saved_argv0)
+ self.assertNotIn("run_name", d)
+ self.assertIs(sys.argv[0], saved_argv0)
+
+ def test_run_module_code(self):
+ initial = object()
+ name = "<Nonsense>"
+ file = "Some other nonsense"
+ loader = "Now you're just being silly"
+ package = '' # Treat as a top level module
+ d1 = dict(initial=initial)
+ saved_argv0 = sys.argv[0]
+ d2 = _run_module_code(self.test_source,
+ d1,
+ name,
+ file,
+ loader,
+ package)
+ self.assertNotIn("result", d1)
+ self.assertIs(d2["initial"], initial)
+ self.assertEqual(d2["result"], self.expected_result)
+ self.assertEqual(d2["nested"]["x"], 1)
+ self.assertIs(d2["__name__"], name)
+ self.assertTrue(d2["run_name_in_sys_modules"])
+ self.assertTrue(d2["module_in_sys_modules"])
+ self.assertIs(d2["__file__"], file)
+ self.assertIs(d2["run_argv0"], file)
+ self.assertIs(d2["__loader__"], loader)
+ self.assertIs(d2["__package__"], package)
+ self.assertIs(sys.argv[0], saved_argv0)
+ self.assertNotIn(name, sys.modules)
+
+
+class RunModuleTest(unittest.TestCase):
+ """Unit tests for runpy.run_module"""
+
+ def expect_import_error(self, mod_name):
+ try:
+ run_module(mod_name)
+ except ImportError:
+ pass
+ else:
+ self.fail("Expected import error for " + mod_name)
+
+ def test_invalid_names(self):
+ # Builtin module
+ self.expect_import_error("sys")
+ # Non-existent modules
+ self.expect_import_error("sys.imp.eric")
+ self.expect_import_error("os.path.half")
+ self.expect_import_error("a.bee")
+ self.expect_import_error(".howard")
+ self.expect_import_error("..eaten")
+ # Package without __main__.py
+ self.expect_import_error("multiprocessing")
+
+ def test_library_module(self):
+ run_module("runpy")
+
+ def _add_pkg_dir(self, pkg_dir):
+ os.mkdir(pkg_dir)
+ pkg_fname = os.path.join(pkg_dir, "__init__"+os.extsep+"py")
+ pkg_file = open(pkg_fname, "w")
+ pkg_file.close()
+ return pkg_fname
+
+ def _make_pkg(self, source, depth, mod_base="runpy_test"):
+ pkg_name = "__runpy_pkg__"
+ test_fname = mod_base+os.extsep+"py"
+ pkg_dir = sub_dir = tempfile.mkdtemp()
+ if verbose: print " Package tree in:", sub_dir
+ sys.path.insert(0, pkg_dir)
+ if verbose: print " Updated sys.path:", sys.path[0]
+ for i in range(depth):
+ sub_dir = os.path.join(sub_dir, pkg_name)
+ pkg_fname = self._add_pkg_dir(sub_dir)
+ if verbose: print " Next level in:", sub_dir
+ if verbose: print " Created:", pkg_fname
+ mod_fname = os.path.join(sub_dir, test_fname)
+ mod_file = open(mod_fname, "w")
+ mod_file.write(source)
+ mod_file.close()
+ if verbose: print " Created:", mod_fname
+ mod_name = (pkg_name+".")*depth + mod_base
+ return pkg_dir, mod_fname, mod_name
+
+ def _del_pkg(self, top, depth, mod_name):
+ for entry in list(sys.modules):
+ if entry.startswith("__runpy_pkg__"):
+ del sys.modules[entry]
+ if verbose: print " Removed sys.modules entries"
+ del sys.path[0]
+ if verbose: print " Removed sys.path entry"
+ for root, dirs, files in os.walk(top, topdown=False):
+ for name in files:
+ try:
+ os.remove(os.path.join(root, name))
+ except OSError, ex:
+ if verbose: print ex # Persist with cleaning up
+ for name in dirs:
+ fullname = os.path.join(root, name)
+ try:
+ os.rmdir(fullname)
+ except OSError, ex:
+ if verbose: print ex # Persist with cleaning up
+ try:
+ os.rmdir(top)
+ if verbose: print " Removed package tree"
+ except OSError, ex:
+ if verbose: print ex # Persist with cleaning up
+
+ def _check_module(self, depth):
+ pkg_dir, mod_fname, mod_name = (
+ self._make_pkg("x=1\n", depth))
+ forget(mod_name)
+ try:
+ if verbose: print "Running from source:", mod_name
+ d1 = run_module(mod_name) # Read from source
+ self.assertIn("x", d1)
+ self.assertTrue(d1["x"] == 1)
+ del d1 # Ensure __loader__ entry doesn't keep file open
+ __import__(mod_name)
+ os.remove(mod_fname)
+ if verbose: print "Running from compiled:", mod_name
+ d2 = run_module(mod_name) # Read from bytecode
+ self.assertIn("x", d2)
+ self.assertTrue(d2["x"] == 1)
+ del d2 # Ensure __loader__ entry doesn't keep file open
+ finally:
+ self._del_pkg(pkg_dir, depth, mod_name)
+ if verbose: print "Module executed successfully"
+
+ def _check_package(self, depth):
+ pkg_dir, mod_fname, mod_name = (
+ self._make_pkg("x=1\n", depth, "__main__"))
+ pkg_name, _, _ = mod_name.rpartition(".")
+ forget(mod_name)
+ try:
+ if verbose: print "Running from source:", pkg_name
+ d1 = run_module(pkg_name) # Read from source
+ self.assertIn("x", d1)
+ self.assertTrue(d1["x"] == 1)
+ del d1 # Ensure __loader__ entry doesn't keep file open
+ __import__(mod_name)
+ os.remove(mod_fname)
+ if verbose: print "Running from compiled:", pkg_name
+ d2 = run_module(pkg_name) # Read from bytecode
+ self.assertIn("x", d2)
+ self.assertTrue(d2["x"] == 1)
+ del d2 # Ensure __loader__ entry doesn't keep file open
+ finally:
+ self._del_pkg(pkg_dir, depth, pkg_name)
+ if verbose: print "Package executed successfully"
+
+ def _add_relative_modules(self, base_dir, source, depth):
+ if depth <= 1:
+ raise ValueError("Relative module test needs depth > 1")
+ pkg_name = "__runpy_pkg__"
+ module_dir = base_dir
+ for i in range(depth):
+ parent_dir = module_dir
+ module_dir = os.path.join(module_dir, pkg_name)
+ # Add sibling module
+ sibling_fname = os.path.join(module_dir, "sibling"+os.extsep+"py")
+ sibling_file = open(sibling_fname, "w")
+ sibling_file.close()
+ if verbose: print " Added sibling module:", sibling_fname
+ # Add nephew module
+ uncle_dir = os.path.join(parent_dir, "uncle")
+ self._add_pkg_dir(uncle_dir)
+ if verbose: print " Added uncle package:", uncle_dir
+ cousin_dir = os.path.join(uncle_dir, "cousin")
+ self._add_pkg_dir(cousin_dir)
+ if verbose: print " Added cousin package:", cousin_dir
+ nephew_fname = os.path.join(cousin_dir, "nephew"+os.extsep+"py")
+ nephew_file = open(nephew_fname, "w")
+ nephew_file.close()
+ if verbose: print " Added nephew module:", nephew_fname
+
+ def _check_relative_imports(self, depth, run_name=None):
+ contents = r"""\
+from __future__ import absolute_import
+from . import sibling
+from ..uncle.cousin import nephew
+"""
+ pkg_dir, mod_fname, mod_name = (
+ self._make_pkg(contents, depth))
+ try:
+ self._add_relative_modules(pkg_dir, contents, depth)
+ pkg_name = mod_name.rpartition('.')[0]
+ if verbose: print "Running from source:", mod_name
+ d1 = run_module(mod_name, run_name=run_name) # Read from source
+ self.assertIn("__package__", d1)
+ self.assertTrue(d1["__package__"] == pkg_name)
+ self.assertIn("sibling", d1)
+ self.assertIn("nephew", d1)
+ del d1 # Ensure __loader__ entry doesn't keep file open
+ __import__(mod_name)
+ os.remove(mod_fname)
+ if verbose: print "Running from compiled:", mod_name
+ d2 = run_module(mod_name, run_name=run_name) # Read from bytecode
+ self.assertIn("__package__", d2)
+ self.assertTrue(d2["__package__"] == pkg_name)
+ self.assertIn("sibling", d2)
+ self.assertIn("nephew", d2)
+ del d2 # Ensure __loader__ entry doesn't keep file open
+ finally:
+ self._del_pkg(pkg_dir, depth, mod_name)
+ if verbose: print "Module executed successfully"
+
+ def test_run_module(self):
+ for depth in range(4):
+ if verbose: print "Testing package depth:", depth
+ self._check_module(depth)
+
+ def test_run_package(self):
+ for depth in range(1, 4):
+ if verbose: print "Testing package depth:", depth
+ self._check_package(depth)
+
+ def test_explicit_relative_import(self):
+ for depth in range(2, 5):
+ if verbose: print "Testing relative imports at depth:", depth
+ self._check_relative_imports(depth)
+
+ def test_main_relative_import(self):
+ for depth in range(2, 5):
+ if verbose: print "Testing main relative imports at depth:", depth
+ self._check_relative_imports(depth, "__main__")
+
+
+class RunPathTest(unittest.TestCase):
+ """Unit tests for runpy.run_path"""
+ # Based on corresponding tests in test_cmd_line_script
+
+ test_source = """\
+# Script may be run with optimisation enabled, so don't rely on assert
+# statements being executed
+def assertEqual(lhs, rhs):
+ if lhs != rhs:
+ raise AssertionError('%r != %r' % (lhs, rhs))
+def assertIs(lhs, rhs):
+ if lhs is not rhs:
+ raise AssertionError('%r is not %r' % (lhs, rhs))
+# Check basic code execution
+result = ['Top level assignment']
+def f():
+ result.append('Lower level reference')
+f()
+assertEqual(result, ['Top level assignment', 'Lower level reference'])
+# Check the sys module
+import sys
+assertIs(globals(), sys.modules[__name__].__dict__)
+argv0 = sys.argv[0]
+"""
+
+ def _make_test_script(self, script_dir, script_basename, source=None):
+ if source is None:
+ source = self.test_source
+ return make_script(script_dir, script_basename, source)
+
+ def _check_script(self, script_name, expected_name, expected_file,
+ expected_argv0, expected_package):
+ result = run_path(script_name)
+ self.assertEqual(result["__name__"], expected_name)
+ self.assertEqual(result["__file__"], expected_file)
+ self.assertIn("argv0", result)
+ self.assertEqual(result["argv0"], expected_argv0)
+ self.assertEqual(result["__package__"], expected_package)
+
+ def _check_import_error(self, script_name, msg):
+ msg = re.escape(msg)
+ self.assertRaisesRegexp(ImportError, msg, run_path, script_name)
+
+ def test_basic_script(self):
+ with temp_dir() as script_dir:
+ mod_name = 'script'
+ script_name = self._make_test_script(script_dir, mod_name)
+ self._check_script(script_name, "<run_path>", script_name,
+ script_name, None)
+
+ def test_script_compiled(self):
+ with temp_dir() as script_dir:
+ mod_name = 'script'
+ script_name = self._make_test_script(script_dir, mod_name)
+ compiled_name = compile_script(script_name)
+ os.remove(script_name)
+ self._check_script(compiled_name, "<run_path>", compiled_name,
+ compiled_name, None)
+
+ def test_directory(self):
+ with temp_dir() as script_dir:
+ mod_name = '__main__'
+ script_name = self._make_test_script(script_dir, mod_name)
+ self._check_script(script_dir, "<run_path>", script_name,
+ script_dir, '')
+
+ def test_directory_compiled(self):
+ with temp_dir() as script_dir:
+ mod_name = '__main__'
+ script_name = self._make_test_script(script_dir, mod_name)
+ compiled_name = compile_script(script_name)
+ os.remove(script_name)
+ self._check_script(script_dir, "<run_path>", compiled_name,
+ script_dir, '')
+
+ def test_directory_error(self):
+ with temp_dir() as script_dir:
+ mod_name = 'not_main'
+ script_name = self._make_test_script(script_dir, mod_name)
+ msg = "can't find '__main__' module in %r" % script_dir
+ self._check_import_error(script_dir, msg)
+
+ def test_zipfile(self):
+ with temp_dir() as script_dir:
+ mod_name = '__main__'
+ script_name = self._make_test_script(script_dir, mod_name)
+ zip_name, fname = make_zip_script(script_dir, 'test_zip', script_name)
+ self._check_script(zip_name, "<run_path>", fname, zip_name, '')
+
+ def test_zipfile_compiled(self):
+ with temp_dir() as script_dir:
+ mod_name = '__main__'
+ script_name = self._make_test_script(script_dir, mod_name)
+ compiled_name = compile_script(script_name)
+ zip_name, fname = make_zip_script(script_dir, 'test_zip', compiled_name)
+ self._check_script(zip_name, "<run_path>", fname, zip_name, '')
+
+ def test_zipfile_error(self):
+ with temp_dir() as script_dir:
+ mod_name = 'not_main'
+ script_name = self._make_test_script(script_dir, mod_name)
+ zip_name, fname = make_zip_script(script_dir, 'test_zip', script_name)
+ msg = "can't find '__main__' module in '%s'" % zip_name
+ self._check_import_error(zip_name, msg)
+
+ def test_main_recursion_error(self):
+ with temp_dir() as script_dir, temp_dir() as dummy_dir:
+ mod_name = '__main__'
+ source = ("import runpy\n"
+ "runpy.run_path(%r)\n") % dummy_dir
+ script_name = self._make_test_script(script_dir, mod_name, source)
+ zip_name, fname = make_zip_script(script_dir, 'test_zip', script_name)
+ msg = "recursion depth exceeded"
+ self.assertRaisesRegexp(RuntimeError, msg, run_path, zip_name)
+
+
+
+def test_main():
+ run_unittest(RunModuleCodeTest, RunModuleTest, RunPathTest)
+
+if __name__ == "__main__":
+ test_main()
diff --git a/Lib/test/test_select_new.py b/Lib/test/test_select_new.py
--- a/Lib/test/test_select_new.py
+++ b/Lib/test/test_select_new.py
@@ -16,18 +16,15 @@
DATA_CHUNK = "." * DATA_CHUNK_SIZE
#
-# The timing of these tests depends on the how the unerlying OS socket library
+# The timing of these tests depends on the how the underlying OS socket library
# handles buffering. These values may need tweaking for different platforms
#
# The fundamental problem is that there is no reliable way to fill a socket with bytes
-#
+# To address this for running on Netty, we arbitrarily send 10000 bytes
-if test_support.is_jython:
- SELECT_TIMEOUT = 0
-else:
- # zero select timeout fails these tests on cpython (on windows 2003 anyway)
- SELECT_TIMEOUT = 0.001
-
+# zero select timeout fails these tests on cpython (on windows 2003 anyway);
+# on Jython with Netty it will result in flaky test runs
+SELECT_TIMEOUT = 0.001
READ_TIMEOUT = 5
class AsynchronousServer:
@@ -86,6 +83,9 @@
if self.select_writable():
bytes_sent = self.socket.send(DATA_CHUNK)
total_bytes += bytes_sent
+ if test_support.is_jython and total_bytes > 10000:
+ # Netty will buffer indefinitely, so just pick an arbitrary cutoff
+ return total_bytes
else:
return total_bytes
except socket.error, se:
@@ -149,7 +149,7 @@
def start_connect(self):
result = self.socket.connect_ex(SERVER_ADDRESS)
if result == errno.EISCONN:
- self.connected = 1
+ self.connected = True
else:
assert result == errno.EINPROGRESS
diff --git a/Lib/test/test_set_jy.py b/Lib/test/test_set_jy.py
--- a/Lib/test/test_set_jy.py
+++ b/Lib/test/test_set_jy.py
@@ -1,13 +1,14 @@
import unittest
-from test import test_support
+from test import test_support, test_set
+import pickle
import threading
-if test_support.is_jython:
- from java.io import (ByteArrayInputStream, ByteArrayOutputStream,
- ObjectInputStream, ObjectOutputStream)
- from java.util import Random
- from javatests import PySetInJavaTest
+from java.io import (ByteArrayInputStream, ByteArrayOutputStream,
+ ObjectInputStream, ObjectOutputStream)
+from java.util import Random, HashSet, LinkedHashSet
+from javatests import PySetInJavaTest
+
class SetTestCase(unittest.TestCase):
@@ -81,10 +82,41 @@
unserializer = ObjectInputStream(input)
self.assertEqual(s, unserializer.readObject())
+
+class TestJavaSet(test_set.TestSet):
+ thetype = HashSet
+
+ def test_init(self):
+ # Instances of Java types cannot be re-initialized
+ pass
+
+ def test_cyclical_repr(self):
+ pass
+
+ def test_cyclical_print(self):
+ pass
+
+ def test_pickling(self):
+ for i in range(pickle.HIGHEST_PROTOCOL + 1):
+ p = pickle.dumps(self.s, i)
+ dup = pickle.loads(p)
+ self.assertEqual(self.s, dup, "%s != %s" % (self.s, dup))
+
+
+class TestJavaHashSet(TestJavaSet):
+ thetype = HashSet
+
+class TestJavaLinkedHashSet(TestJavaSet):
+ thetype = LinkedHashSet
+
+
def test_main():
- tests = [SetTestCase]
- if test_support.is_jython:
- tests.append(SetInJavaTestCase)
+ tests = [
+ SetTestCase,
+ SetInJavaTestCase,
+ TestJavaHashSet,
+ TestJavaLinkedHashSet,
+ ]
test_support.run_unittest(*tests)
diff --git a/Lib/test/test_slice.py b/Lib/test/test_slice.py
deleted file mode 100644
--- a/Lib/test/test_slice.py
+++ /dev/null
@@ -1,137 +0,0 @@
-# tests for slice objects; in particular the indices method.
-
-import unittest
-from test import test_support
-from cPickle import loads, dumps
-
-import sys
-
-class SliceTest(unittest.TestCase):
-
- def test_constructor(self):
- self.assertRaises(TypeError, slice)
- self.assertRaises(TypeError, slice, 1, 2, 3, 4)
-
- def test_repr(self):
- self.assertEqual(repr(slice(1, 2, 3)), "slice(1, 2, 3)")
-
- def test_hash(self):
- # Verify clearing of SF bug #800796
- self.assertRaises(TypeError, hash, slice(5))
- self.assertRaises(TypeError, slice(5).__hash__)
-
- def test_cmp(self):
- s1 = slice(1, 2, 3)
- s2 = slice(1, 2, 3)
- s3 = slice(1, 2, 4)
- self.assertEqual(s1, s2)
- self.assertNotEqual(s1, s3)
-
- class Exc(Exception):
- pass
-
- class BadCmp(object):
- def __eq__(self, other):
- raise Exc
- __hash__ = None # Silence Py3k warning
-
- s1 = slice(BadCmp())
- s2 = slice(BadCmp())
- self.assertRaises(Exc, cmp, s1, s2)
- self.assertEqual(s1, s1)
-
- s1 = slice(1, BadCmp())
- s2 = slice(1, BadCmp())
- self.assertEqual(s1, s1)
- self.assertRaises(Exc, cmp, s1, s2)
-
- s1 = slice(1, 2, BadCmp())
- s2 = slice(1, 2, BadCmp())
- self.assertEqual(s1, s1)
- self.assertRaises(Exc, cmp, s1, s2)
-
- def test_members(self):
- s = slice(1)
- self.assertEqual(s.start, None)
- self.assertEqual(s.stop, 1)
- self.assertEqual(s.step, None)
-
- s = slice(1, 2)
- self.assertEqual(s.start, 1)
- self.assertEqual(s.stop, 2)
- self.assertEqual(s.step, None)
-
- s = slice(1, 2, 3)
- self.assertEqual(s.start, 1)
- self.assertEqual(s.stop, 2)
- self.assertEqual(s.step, 3)
-
- class AnyClass:
- pass
-
- obj = AnyClass()
- s = slice(obj)
- self.assertTrue(s.stop is obj)
-
- def test_indices(self):
- self.assertEqual(slice(None ).indices(10), (0, 10, 1))
- self.assertEqual(slice(None, None, 2).indices(10), (0, 10, 2))
- self.assertEqual(slice(1, None, 2).indices(10), (1, 10, 2))
- self.assertEqual(slice(None, None, -1).indices(10), (9, -1, -1))
- self.assertEqual(slice(None, None, -2).indices(10), (9, -1, -2))
- self.assertEqual(slice(3, None, -2).indices(10), (3, -1, -2))
- # issue 3004 tests
- self.assertEqual(slice(None, -9).indices(10), (0, 1, 1))
- #FIXME: next two not correct on Jython
- #self.assertEqual(slice(None, -10).indices(10), (0, 0, 1))
- #self.assertEqual(slice(None, -11).indices(10), (0, 0, 1))
- self.assertEqual(slice(None, -10, -1).indices(10), (9, 0, -1))
- self.assertEqual(slice(None, -11, -1).indices(10), (9, -1, -1))
- self.assertEqual(slice(None, -12, -1).indices(10), (9, -1, -1))
- self.assertEqual(slice(None, 9).indices(10), (0, 9, 1))
- self.assertEqual(slice(None, 10).indices(10), (0, 10, 1))
- self.assertEqual(slice(None, 11).indices(10), (0, 10, 1))
- self.assertEqual(slice(None, 8, -1).indices(10), (9, 8, -1))
- self.assertEqual(slice(None, 9, -1).indices(10), (9, 9, -1))
- #FIXME: next not correct on Jython
- #self.assertEqual(slice(None, 10, -1).indices(10), (9, 9, -1))
-
- self.assertEqual(
- slice(-100, 100 ).indices(10),
- slice(None).indices(10)
- )
- self.assertEqual(
- slice(100, -100, -1).indices(10),
- slice(None, None, -1).indices(10)
- )
- self.assertEqual(slice(-100L, 100L, 2L).indices(10), (0, 10, 2))
-
- self.assertEqual(range(10)[::sys.maxint - 1], [0])
-
- self.assertRaises(OverflowError, slice(None).indices, 1L<<100)
-
- def test_setslice_without_getslice(self):
- tmp = []
- class X(object):
- def __setslice__(self, i, j, k):
- tmp.append((i, j, k))
-
- x = X()
- with test_support.check_py3k_warnings():
- x[1:2] = 42
- self.assertEqual(tmp, [(1, 2, 42)])
-
- @unittest.skipIf(test_support.is_jython, "FIXME: not working in Jython")
- def test_pickle(self):
- s = slice(10, 20, 3)
- for protocol in (0,1,2):
- t = loads(dumps(s, protocol))
- self.assertEqual(s, t)
- self.assertEqual(s.indices(15), t.indices(15))
- self.assertNotEqual(id(s), id(t))
-
-def test_main():
- test_support.run_unittest(SliceTest)
-
-if __name__ == "__main__":
- test_main()
diff --git a/Lib/test/test_socket.py b/Lib/test/test_socket.py
--- a/Lib/test/test_socket.py
+++ b/Lib/test/test_socket.py
@@ -138,9 +138,6 @@
# 1. It takes two collections for finalization to run.
# 2. gc.collect() is only advisory to the JVM, never mandatory. Still
# it usually seems to happen under light load.
- gc.collect()
- time.sleep(0.1)
- gc.collect()
# Wait up to one second for there not to be pending threads
@@ -148,10 +145,10 @@
pending_threads = _check_threadpool_for_pending_threads(group)
if len(pending_threads) == 0:
break
- time.sleep(0.1)
+ test_support.gc_collect()
if pending_threads:
- self.fail("Pending threads in Netty msg={} pool={}".format(msg, pprint.pformat(pending_threads)))
+ print "Pending threads in Netty msg={} pool={}".format(msg, pprint.pformat(pending_threads))
def _tearDown(self):
self.done.wait() # wait for the client to exit
@@ -1128,6 +1125,20 @@
self.serv_conn.send(MSG)
self.serv_conn.send('and ' + MSG)
+ def testSelect(self):
+ # http://bugs.jython.org/issue2242
+ self.assertIs(self.cli_conn.gettimeout(), None, "Server socket is not blocking")
+ start_time = time.time()
+ r, w, x = select.select([self.cli_conn], [], [], 10)
+ if (time.time() - start_time) > 1:
+ self.fail("Child socket was not immediately available for read when set to blocking")
+ self.assertEqual(r[0], self.cli_conn)
+ self.assertEqual(self.cli_conn.recv(1024), MSG)
+
+ def _testSelect(self):
+ self.serv_conn.send(MSG)
+
+
class UDPBindTest(unittest.TestCase):
HOST = HOST
@@ -1399,7 +1410,6 @@
def _testRecvData(self):
self.cli.connect((self.HOST, self.PORT))
self.cli.send(MSG)
- #time.sleep(0.5)
def testRecvNoData(self):
# Testing non-blocking recv
@@ -2071,6 +2081,14 @@
result = socket.getnameinfo(address, flags)
self.failUnlessEqual(result[0], expected)
+
+# TODO: consider re-enabling this set of tests, but for now
+# this set reliably does *not* work on Ubuntu (but does on
+# OSX). However the underlying internal method _get_jsockaddr
+# is exercised by nearly every socket usage, along with the
+# corresponding tests.
+
+ at unittest.skipIf(test_support.is_jython, "Skip internal tests for address lookup due to underlying OS issues")
class TestJython_get_jsockaddr(unittest.TestCase):
"These tests are specific to jython: they test a key internal routine"
@@ -2506,7 +2524,6 @@
except socket.error, se:
# FIXME Apparently Netty's doesn't set remoteAddress, even if connected, for datagram channels
# so we may have to shadow
- print "\n\n\ngetpeername()", self.s._sock.channel
self.fail("getpeername() on connected UDP socket should not have raised socket.error")
self.failUnlessEqual(self.s.getpeername(), self._udp_peer.getsockname())
finally:
diff --git a/Lib/test/test_sort.py b/Lib/test/test_sort.py
--- a/Lib/test/test_sort.py
+++ b/Lib/test/test_sort.py
@@ -2,6 +2,10 @@
import random
import sys
import unittest
+try:
+ import java
+except ImportError:
+ pass
verbose = test_support.verbose
nerrors = 0
@@ -39,8 +43,6 @@
return
class TestBase(unittest.TestCase):
- @unittest.skipIf(test_support.is_jython,
- "FIXME: find the part that is too much for Jython.")
def testStressfully(self):
# Try a variety of sizes at and around powers of 2, and at powers of 10.
sizes = [0]
@@ -102,8 +104,15 @@
print " Checking against an insane comparison function."
print " If the implementation isn't careful, this may segfault."
s = x[:]
- s.sort(lambda a, b: int(random.random() * 3) - 1)
- check("an insane function left some permutation", x, s)
+
+ if test_support.is_jython:
+ try:
+ s.sort(lambda a, b: int(random.random() * 3) - 1)
+ except java.lang.IllegalArgumentException:
+ pass
+ else:
+ s.sort(lambda a, b: int(random.random() * 3) - 1)
+ check("an insane function left some permutation", x, s)
x = [Complains(i) for i in x]
s = x[:]
diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py
new file mode 100644
--- /dev/null
+++ b/Lib/test/test_ssl.py
@@ -0,0 +1,1395 @@
+# Test the support for SSL and sockets
+
+import sys
+import unittest
+from test import test_support
+import asyncore
+import socket
+import select
+import time
+import gc
+import os
+import errno
+import pprint
+import urllib, urlparse
+import traceback
+import weakref
+import functools
+import platform
+
+from BaseHTTPServer import HTTPServer
+from SimpleHTTPServer import SimpleHTTPRequestHandler
+
+ssl = test_support.import_module("ssl")
+
+HOST = test_support.HOST
+CERTFILE = None
+SVN_PYTHON_ORG_ROOT_CERT = None
+
+def handle_error(prefix):
+ exc_format = ' '.join(traceback.format_exception(*sys.exc_info()))
+ if test_support.verbose:
+ sys.stdout.write(prefix + exc_format)
+
+
+class BasicTests(unittest.TestCase):
+
+ def test_sslwrap_simple(self):
+ # A crude test for the legacy API
+ try:
+ ssl.sslwrap_simple(socket.socket(socket.AF_INET))
+ except IOError, e:
+ if e.errno == 32: # broken pipe when ssl_sock.do_handshake(), this test doesn't care about that
+ pass
+ else:
+ raise
+ try:
+ ssl.sslwrap_simple(socket.socket(socket.AF_INET)._sock)
+ except IOError, e:
+ if e.errno == 32: # broken pipe when ssl_sock.do_handshake(), this test doesn't care about that
+ pass
+ else:
+ raise
+
+# Issue #9415: Ubuntu hijacks their OpenSSL and forcefully disables SSLv2
+def skip_if_broken_ubuntu_ssl(func):
+ if hasattr(ssl, 'PROTOCOL_SSLv2'):
+ # We need to access the lower-level wrapper in order to create an
+ # implicit SSL context without trying to connect or listen.
+ try:
+ import _ssl
+ except ImportError:
+ # The returned function won't get executed, just ignore the error
+ pass
+ @functools.wraps(func)
+ def f(*args, **kwargs):
+ try:
+ s = socket.socket(socket.AF_INET)
+ _ssl.sslwrap(s._sock, 0, None, None,
+ ssl.CERT_NONE, ssl.PROTOCOL_SSLv2, None, None)
+ except ssl.SSLError as e:
+ if (ssl.OPENSSL_VERSION_INFO == (0, 9, 8, 15, 15) and
+ platform.linux_distribution() == ('debian', 'squeeze/sid', '')
+ and 'Invalid SSL protocol variant specified' in str(e)):
+ raise unittest.SkipTest("Patched Ubuntu OpenSSL breaks behaviour")
+ return func(*args, **kwargs)
+ return f
+ else:
+ return func
+
+
+class BasicSocketTests(unittest.TestCase):
+
+ def test_constants(self):
+ #ssl.PROTOCOL_SSLv2
+ ssl.PROTOCOL_SSLv23
+ ssl.PROTOCOL_SSLv3
+ ssl.PROTOCOL_TLSv1
+ ssl.CERT_NONE
+ ssl.CERT_OPTIONAL
+ ssl.CERT_REQUIRED
+
+ def test_random(self):
+ v = ssl.RAND_status()
+ if test_support.verbose:
+ sys.stdout.write("\n RAND_status is %d (%s)\n"
+ % (v, (v and "sufficient randomness") or
+ "insufficient randomness"))
+ self.assertRaises(TypeError, ssl.RAND_egd, 1)
+ self.assertRaises(TypeError, ssl.RAND_egd, 'foo', 1)
+ ssl.RAND_add("this is a random string", 75.0)
+
+ @unittest.skipIf(test_support.is_jython, "Jython uses BouncyCastle")
+ def test_parse_cert(self):
+ # note that this uses an 'unofficial' function in _ssl.c,
+ # provided solely for this test, to exercise the certificate
+ # parsing code
+ p = ssl._ssl._test_decode_cert(CERTFILE, False)
+ if test_support.verbose:
+ sys.stdout.write("\n" + pprint.pformat(p) + "\n")
+ self.assertEqual(p['subject'],
+ ((('countryName', 'XY'),),
+ (('localityName', 'Castle Anthrax'),),
+ (('organizationName', 'Python Software Foundation'),),
+ (('commonName', 'localhost'),))
+ )
+ self.assertEqual(p['subjectAltName'], (('DNS', 'localhost'),))
+ # Issue #13034: the subjectAltName in some certificates
+ # (notably projects.developer.nokia.com:443) wasn't parsed
+ p = ssl._ssl._test_decode_cert(NOKIACERT)
+ if test_support.verbose:
+ sys.stdout.write("\n" + pprint.pformat(p) + "\n")
+ self.assertEqual(p['subjectAltName'],
+ (('DNS', 'projects.developer.nokia.com'),
+ ('DNS', 'projects.forum.nokia.com'))
+ )
+
+ def test_DER_to_PEM(self):
+ with open(SVN_PYTHON_ORG_ROOT_CERT, 'r') as f:
+ pem = f.read()
+ d1 = ssl.PEM_cert_to_DER_cert(pem)
+ p2 = ssl.DER_cert_to_PEM_cert(d1)
+ d2 = ssl.PEM_cert_to_DER_cert(p2)
+ self.assertEqual(d1, d2)
+ if not p2.startswith(ssl.PEM_HEADER + '\n'):
+ self.fail("DER-to-PEM didn't include correct header:\n%r\n" % p2)
+ if not p2.endswith('\n' + ssl.PEM_FOOTER + '\n'):
+ self.fail("DER-to-PEM didn't include correct footer:\n%r\n" % p2)
+
+ def test_openssl_version(self):
+ n = ssl.OPENSSL_VERSION_NUMBER
+ t = ssl.OPENSSL_VERSION_INFO
+ s = ssl.OPENSSL_VERSION
+ self.assertIsInstance(n, (int, long))
+ self.assertIsInstance(t, tuple)
+ self.assertIsInstance(s, str)
+ # Some sanity checks follow
+ # >= 0.9
+ self.assertGreaterEqual(n, 0x900000)
+ # < 2.0
+ self.assertLess(n, 0x20000000)
+ major, minor, fix, patch, status = t
+ self.assertGreaterEqual(major, 0)
+ self.assertLess(major, 2)
+ self.assertGreaterEqual(minor, 0)
+ self.assertLess(minor, 256)
+ self.assertGreaterEqual(fix, 0)
+ self.assertLess(fix, 256)
+ self.assertGreaterEqual(patch, 0)
+ self.assertLessEqual(patch, 26)
+ self.assertGreaterEqual(status, 0)
+ self.assertLessEqual(status, 15)
+ # Version string as returned by OpenSSL, the format might change
+ self.assertTrue(s.startswith("OpenSSL {:d}.{:d}.{:d}".format(major, minor, fix)),
+ (s, t))
+
+ def test_ciphers(self):
+ if not test_support.is_resource_enabled('network'):
+ return
+ remote = ("svn.python.org", 443)
+ with test_support.transient_internet(remote[0]):
+ s = ssl.wrap_socket(socket.socket(socket.AF_INET),
+ cert_reqs=ssl.CERT_NONE, ciphers="ALL")
+ s.connect(remote)
+ s = ssl.wrap_socket(socket.socket(socket.AF_INET),
+ cert_reqs=ssl.CERT_NONE, ciphers="DEFAULT")
+ s.connect(remote)
+ # Error checking occurs when connecting, because the SSL context
+ # isn't created before.
+ s = ssl.wrap_socket(socket.socket(socket.AF_INET),
+ cert_reqs=ssl.CERT_NONE, ciphers="^$:,;?*'dorothyx")
+ with self.assertRaisesRegexp(ssl.SSLError, "No cipher can be selected"):
+ s.connect(remote)
+
+ @test_support.cpython_only
+ def test_refcycle(self):
+ # Issue #7943: an SSL object doesn't create reference cycles with
+ # itself.
+ s = socket.socket(socket.AF_INET)
+ ss = ssl.wrap_socket(s)
+ wr = weakref.ref(ss)
+ del ss
+ self.assertEqual(wr(), None)
+
+ def test_wrapped_unconnected(self):
+ # The _delegate_methods in socket.py are correctly delegated to by an
+ # unconnected SSLSocket, so they will raise a socket.error rather than
+ # something unexpected like TypeError.
+ s = socket.socket(socket.AF_INET)
+ ss = ssl.wrap_socket(s)
+ self.assertRaises(socket.error, ss.recv, 1)
+ self.assertRaises(socket.error, ss.recv_into, bytearray(b'x'))
+ self.assertRaises(socket.error, ss.recvfrom, 1)
+ self.assertRaises(socket.error, ss.recvfrom_into, bytearray(b'x'), 1)
+ self.assertRaises(socket.error, ss.send, b'x')
+ self.assertRaises(socket.error, ss.sendto, b'x', ('0.0.0.0', 0))
+
+
+class NetworkedTests(unittest.TestCase):
+
+ def test_connect(self):
+ with test_support.transient_internet("svn.python.org"):
+ s = ssl.wrap_socket(socket.socket(socket.AF_INET),
+ cert_reqs=ssl.CERT_NONE)
+ s.connect(("svn.python.org", 443))
+ c = s.getpeercert()
+ if c:
+ self.fail("Peer cert %s shouldn't be here!")
+ s.close()
+
+ # this should fail because we have no verification certs
+ s = ssl.wrap_socket(socket.socket(socket.AF_INET),
+ cert_reqs=ssl.CERT_REQUIRED)
+ try:
+ s.connect(("svn.python.org", 443))
+ except ssl.SSLError:
+ pass
+ finally:
+ s.close()
+
+ # this should succeed because we specify the root cert
+ s = ssl.wrap_socket(socket.socket(socket.AF_INET),
+ cert_reqs=ssl.CERT_REQUIRED,
+ ca_certs=SVN_PYTHON_ORG_ROOT_CERT)
+ try:
+ s.connect(("svn.python.org", 443))
+ finally:
+ s.close()
+
+ def test_connect_ex(self):
+ # Issue #11326: check connect_ex() implementation
+ with test_support.transient_internet("svn.python.org"):
+ s = ssl.wrap_socket(socket.socket(socket.AF_INET),
+ cert_reqs=ssl.CERT_REQUIRED,
+ ca_certs=SVN_PYTHON_ORG_ROOT_CERT)
+ try:
+ self.assertEqual(0, s.connect_ex(("svn.python.org", 443)))
+ self.assertTrue(s.getpeercert())
+ finally:
+ s.close()
+
+ def test_non_blocking_connect_ex(self):
+ # Issue #11326: non-blocking connect_ex() should allow handshake
+ # to proceed after the socket gets ready.
+ with test_support.transient_internet("svn.python.org"):
+ s = ssl.wrap_socket(socket.socket(socket.AF_INET),
+ cert_reqs=ssl.CERT_REQUIRED,
+ ca_certs=SVN_PYTHON_ORG_ROOT_CERT,
+ do_handshake_on_connect=False)
+ try:
+ s.setblocking(False)
+ rc = s.connect_ex(('svn.python.org', 443))
+ # EWOULDBLOCK under Windows, EINPROGRESS elsewhere
+ self.assertIn(rc, (0, errno.EINPROGRESS, errno.EWOULDBLOCK))
+ # Wait for connect to finish
+ select.select([], [s], [], 5.0)
+ # Non-blocking handshake
+ while True:
+ try:
+ s.do_handshake()
+ break
+ except ssl.SSLError as err:
+ if err.args[0] == ssl.SSL_ERROR_WANT_READ:
+ select.select([s], [], [], 5.0)
+ elif err.args[0] == ssl.SSL_ERROR_WANT_WRITE:
+ select.select([], [s], [], 5.0)
+ else:
+ raise
+ # SSL established
+ self.assertTrue(s.getpeercert())
+ finally:
+ s.close()
+
+ def test_timeout_connect_ex(self):
+ # Issue #12065: on a timeout, connect_ex() should return the original
+ # errno (mimicking the behaviour of non-SSL sockets).
+ with test_support.transient_internet("svn.python.org"):
+ s = ssl.wrap_socket(socket.socket(socket.AF_INET),
+ cert_reqs=ssl.CERT_REQUIRED,
+ ca_certs=SVN_PYTHON_ORG_ROOT_CERT,
+ do_handshake_on_connect=False)
+ try:
+ s.settimeout(0.0000001)
+ rc = s.connect_ex(('svn.python.org', 443))
+ if rc == 0:
+ self.skipTest("svn.python.org responded too quickly")
+ self.assertIn(rc, (errno.EAGAIN, errno.EWOULDBLOCK))
+ finally:
+ s.close()
+
+ def test_connect_ex_error(self):
+ with test_support.transient_internet("svn.python.org"):
+ s = ssl.wrap_socket(socket.socket(socket.AF_INET),
+ cert_reqs=ssl.CERT_REQUIRED,
+ ca_certs=SVN_PYTHON_ORG_ROOT_CERT)
+ try:
+ self.assertEqual(errno.ECONNREFUSED,
+ s.connect_ex(("svn.python.org", 444)))
+ finally:
+ s.close()
+
+ @unittest.skipIf(os.name == "nt", "Can't use a socket as a file under Windows")
+ def test_makefile_close(self):
+ # Issue #5238: creating a file-like object with makefile() shouldn't
+ # delay closing the underlying "real socket" (here tested with its
+ # file descriptor, hence skipping the test under Windows).
+ with test_support.transient_internet("svn.python.org"):
+ ss = ssl.wrap_socket(socket.socket(socket.AF_INET))
+ ss.connect(("svn.python.org", 443))
+ fd = ss.fileno()
+ f = ss.makefile()
+ f.close()
+ # The fd is still open
+ os.read(fd, 0)
+ # Closing the SSL socket should close the fd too
+ ss.close()
+ gc.collect()
+ with self.assertRaises(OSError) as e:
+ os.read(fd, 0)
+ self.assertEqual(e.exception.errno, errno.EBADF)
+
+ def test_non_blocking_handshake(self):
+ with test_support.transient_internet("svn.python.org"):
+ s = socket.socket(socket.AF_INET)
+ s.connect(("svn.python.org", 443))
+ s.setblocking(False)
+ s = ssl.wrap_socket(s,
+ cert_reqs=ssl.CERT_NONE,
+ do_handshake_on_connect=False)
+ count = 0
+ while True:
+ try:
+ count += 1
+ s.do_handshake()
+ break
+ except ssl.SSLError, err:
+ if err.args[0] == ssl.SSL_ERROR_WANT_READ:
+ select.select([s], [], [])
+ elif err.args[0] == ssl.SSL_ERROR_WANT_WRITE:
+ select.select([], [s], [])
+ else:
+ raise
+ s.close()
+ if test_support.verbose:
+ sys.stdout.write("\nNeeded %d calls to do_handshake() to establish session.\n" % count)
+
+ def test_get_server_certificate(self):
+ with test_support.transient_internet("svn.python.org"):
+ pem = ssl.get_server_certificate(("svn.python.org", 443))
+ if not pem:
+ self.fail("No server certificate on svn.python.org:443!")
+
+ try:
+ pem = ssl.get_server_certificate(("svn.python.org", 443), ca_certs=CERTFILE)
+ except ssl.SSLError:
+ #should fail
+ pass
+ else:
+ self.fail("Got server certificate %s for svn.python.org!" % pem)
+
+ pem = ssl.get_server_certificate(("svn.python.org", 443), ca_certs=SVN_PYTHON_ORG_ROOT_CERT)
+ if not pem:
+ self.fail("No server certificate on svn.python.org:443!")
+ if test_support.verbose:
+ sys.stdout.write("\nVerified certificate for svn.python.org:443 is\n%s\n" % pem)
+
+ def test_algorithms(self):
+ # Issue #8484: all algorithms should be available when verifying a
+ # certificate.
+ # SHA256 was added in OpenSSL 0.9.8
+ if ssl.OPENSSL_VERSION_INFO < (0, 9, 8, 0, 15):
+ self.skipTest("SHA256 not available on %r" % ssl.OPENSSL_VERSION)
+ self.skipTest("remote host needs SNI, only available on Python 3.2+")
+ # NOTE: https://sha2.hboeck.de is another possible test host
+ remote = ("sha256.tbs-internet.com", 443)
+ sha256_cert = os.path.join(os.path.dirname(__file__), "sha256.pem")
+ with test_support.transient_internet("sha256.tbs-internet.com"):
+ s = ssl.wrap_socket(socket.socket(socket.AF_INET),
+ cert_reqs=ssl.CERT_REQUIRED,
+ ca_certs=sha256_cert,)
+ try:
+ s.connect(remote)
+ if test_support.verbose:
+ sys.stdout.write("\nCipher with %r is %r\n" %
+ (remote, s.cipher()))
+ sys.stdout.write("Certificate is:\n%s\n" %
+ pprint.pformat(s.getpeercert()))
+ finally:
+ s.close()
+
+
+try:
+ import threading
+except ImportError:
+ _have_threads = False
+else:
+ _have_threads = True
+
+ class ThreadedEchoServer(threading.Thread):
+
+ class ConnectionHandler(threading.Thread):
+
+ """A mildly complicated class, because we want it to work both
+ with and without the SSL wrapper around the socket connection, so
+ that we can test the STARTTLS functionality."""
+
+ def __init__(self, server, connsock):
+ self.server = server
+ self.running = False
+ self.sock = connsock
+ self.sock.setblocking(1)
+ self.sslconn = None
+ threading.Thread.__init__(self)
+ self.daemon = True
+
+ def show_conn_details(self):
+ if self.server.certreqs == ssl.CERT_REQUIRED:
+ cert = self.sslconn.getpeercert()
+ if test_support.verbose and self.server.chatty:
+ sys.stdout.write(" client cert is " + pprint.pformat(cert) + "\n")
+ cert_binary = self.sslconn.getpeercert(True)
+ if test_support.verbose and self.server.chatty:
+ sys.stdout.write(" cert binary is " + str(len(cert_binary)) + " bytes\n")
+ cipher = self.sslconn.cipher()
+ if test_support.verbose and self.server.chatty:
+ sys.stdout.write(" server: connection cipher is now " + str(cipher) + "\n")
+
+ def wrap_conn(self):
+ try:
+ self.sslconn = ssl.wrap_socket(self.sock, server_side=True,
+ certfile=self.server.certificate,
+ ssl_version=self.server.protocol,
+ ca_certs=self.server.cacerts,
+ cert_reqs=self.server.certreqs,
+ ciphers=self.server.ciphers)
+ except ssl.SSLError as e:
+ # XXX Various errors can have happened here, for example
+ # a mismatching protocol version, an invalid certificate,
+ # or a low-level bug. This should be made more discriminating.
+ self.server.conn_errors.append(e)
+ if self.server.chatty:
+ handle_error("\n server: bad connection attempt from " +
+ str(self.sock.getpeername()) + ":\n")
+ self.close()
+ self.running = False
+ self.server.stop()
+ return False
+ else:
+ return True
+
+ def read(self):
+ if self.sslconn:
+ return self.sslconn.read()
+ else:
+ return self.sock.recv(1024)
+
+ def write(self, bytes):
+ if self.sslconn:
+ return self.sslconn.write(bytes)
+ else:
+ return self.sock.send(bytes)
+
+ def close(self):
+ if self.sslconn:
+ self.sslconn.close()
+ else:
+ self.sock._sock.close()
+
+ def run(self):
+ self.running = True
+ if not self.server.starttls_server:
+ if isinstance(self.sock, ssl.SSLSocket):
+ self.sslconn = self.sock
+ elif not self.wrap_conn():
+ return
+ self.show_conn_details()
+ while self.running:
+ try:
+ msg = self.read()
+ if not msg:
+ # eof, so quit this handler
+ self.running = False
+ self.close()
+ elif msg.strip() == 'over':
+ if test_support.verbose and self.server.connectionchatty:
+ sys.stdout.write(" server: client closed connection\n")
+ self.close()
+ return
+ elif self.server.starttls_server and msg.strip() == 'STARTTLS':
+ if test_support.verbose and self.server.connectionchatty:
+ sys.stdout.write(" server: read STARTTLS from client, sending OK...\n")
+ self.write("OK\n")
+ if not self.wrap_conn():
+ return
+ elif self.server.starttls_server and self.sslconn and msg.strip() == 'ENDTLS':
+ if test_support.verbose and self.server.connectionchatty:
+ sys.stdout.write(" server: read ENDTLS from client, sending OK...\n")
+ self.write("OK\n")
+ self.sslconn.unwrap()
+ self.sslconn = None
+ if test_support.verbose and self.server.connectionchatty:
+ sys.stdout.write(" server: connection is now unencrypted...\n")
+ else:
+ if (test_support.verbose and
+ self.server.connectionchatty):
+ ctype = (self.sslconn and "encrypted") or "unencrypted"
+ sys.stdout.write(" server: read %s (%s), sending back %s (%s)...\n"
+ % (repr(msg), ctype, repr(msg.lower()), ctype))
+ self.write(msg.lower())
+ except ssl.SSLError:
+ if self.server.chatty:
+ handle_error("Test server failure:\n")
+ self.close()
+ self.running = False
+ # normally, we'd just stop here, but for the test
+ # harness, we want to stop the server
+ self.server.stop()
+
+ def __init__(self, certificate, ssl_version=None,
+ certreqs=None, cacerts=None,
+ chatty=True, connectionchatty=False, starttls_server=False,
+ wrap_accepting_socket=False, ciphers=None):
+
+ if ssl_version is None:
+ ssl_version = ssl.PROTOCOL_TLSv1
+ if certreqs is None:
+ certreqs = ssl.CERT_NONE
+ self.certificate = certificate
+ self.protocol = ssl_version
+ self.certreqs = certreqs
+ self.cacerts = cacerts
+ self.ciphers = ciphers
+ self.chatty = chatty
+ self.connectionchatty = connectionchatty
+ self.starttls_server = starttls_server
+ self.sock = socket.socket()
+ self.flag = None
+ if wrap_accepting_socket:
+ self.sock = ssl.wrap_socket(self.sock, server_side=True,
+ certfile=self.certificate,
+ cert_reqs = self.certreqs,
+ ca_certs = self.cacerts,
+ ssl_version = self.protocol,
+ ciphers = self.ciphers)
+ if test_support.verbose and self.chatty:
+ sys.stdout.write(' server: wrapped server socket as %s\n' % str(self.sock))
+ self.port = test_support.bind_port(self.sock)
+ self.active = False
+ self.conn_errors = []
+ threading.Thread.__init__(self)
+ self.daemon = True
+
+ def __enter__(self):
+ self.start(threading.Event())
+ self.flag.wait()
+ return self
+
+ def __exit__(self, *args):
+ self.stop()
+ self.join()
+
+ def start(self, flag=None):
+ self.flag = flag
+ threading.Thread.start(self)
+
+ def run(self):
+ self.sock.settimeout(0.05)
+ self.sock.listen(5)
+ self.active = True
+ if self.flag:
+ # signal an event
+ self.flag.set()
+ while self.active:
+ try:
+ newconn, connaddr = self.sock.accept()
+ if test_support.verbose and self.chatty:
+ sys.stdout.write(' server: new connection from '
+ + str(connaddr) + '\n')
+ handler = self.ConnectionHandler(self, newconn)
+ handler.start()
+ handler.join()
+ except socket.timeout:
+ pass
+ except KeyboardInterrupt:
+ self.stop()
+ self.sock.close()
+
+ def stop(self):
+ self.active = False
+
+ class AsyncoreEchoServer(threading.Thread):
+
+ class EchoServer(asyncore.dispatcher):
+
+ class ConnectionHandler(asyncore.dispatcher_with_send):
+
+ def __init__(self, conn, certfile):
+ asyncore.dispatcher_with_send.__init__(self, conn)
+ self.socket = ssl.wrap_socket(conn, server_side=True,
+ certfile=certfile,
+ do_handshake_on_connect=False)
+ self._ssl_accepting = True
+
+ def readable(self):
+ if isinstance(self.socket, ssl.SSLSocket):
+ while self.socket.pending() > 0:
+ self.handle_read_event()
+ return True
+
+ def _do_ssl_handshake(self):
+ try:
+ self.socket.do_handshake()
+ except ssl.SSLError, err:
+ if err.args[0] in (ssl.SSL_ERROR_WANT_READ,
+ ssl.SSL_ERROR_WANT_WRITE):
+ return
+ elif err.args[0] == ssl.SSL_ERROR_EOF:
+ return self.handle_close()
+ raise
+ except socket.error, err:
+ if err.args[0] == errno.ECONNABORTED:
+ return self.handle_close()
+ else:
+ self._ssl_accepting = False
+
+ def handle_read(self):
+ if self._ssl_accepting:
+ self._do_ssl_handshake()
+ else:
+ data = self.recv(1024)
+ if data and data.strip() != 'over':
+ self.send(data.lower())
+
+ def handle_close(self):
+ self.close()
+ if test_support.verbose:
+ sys.stdout.write(" server: closed connection %s\n" % self.socket)
+
+ def handle_error(self):
+ raise
+
+ def __init__(self, certfile):
+ self.certfile = certfile
+ asyncore.dispatcher.__init__(self)
+ self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
+ self.port = test_support.bind_port(self.socket)
+ self.listen(5)
+
+ def handle_accept(self):
+ sock_obj, addr = self.accept()
+ if test_support.verbose:
+ sys.stdout.write(" server: new connection from %s:%s\n" %addr)
+ self.ConnectionHandler(sock_obj, self.certfile)
+
+ def handle_error(self):
+ raise
+
+ def __init__(self, certfile):
+ self.flag = None
+ self.active = False
+ self.server = self.EchoServer(certfile)
+ self.port = self.server.port
+ threading.Thread.__init__(self)
+ self.daemon = True
+
+ def __str__(self):
+ return "<%s %s>" % (self.__class__.__name__, self.server)
+
+ def __enter__(self):
+ self.start(threading.Event())
+ self.flag.wait()
+ return self
+
+ def __exit__(self, *args):
+ if test_support.verbose:
+ sys.stdout.write(" cleanup: stopping server.\n")
+ self.stop()
+ if test_support.verbose:
+ sys.stdout.write(" cleanup: joining server thread.\n")
+ self.join()
+ if test_support.verbose:
+ sys.stdout.write(" cleanup: successfully joined.\n")
+
+ def start(self, flag=None):
+ self.flag = flag
+ threading.Thread.start(self)
+
+ def run(self):
+ self.active = True
+ if self.flag:
+ self.flag.set()
+ while self.active:
+ asyncore.loop(0.05)
+
+ def stop(self):
+ self.active = False
+ self.server.close()
+
+ class SocketServerHTTPSServer(threading.Thread):
+
+ class HTTPSServer(HTTPServer):
+
+ def __init__(self, server_address, RequestHandlerClass, certfile):
+ HTTPServer.__init__(self, server_address, RequestHandlerClass)
+ # we assume the certfile contains both private key and certificate
+ self.certfile = certfile
+ self.allow_reuse_address = True
+
+ def __str__(self):
+ return ('<%s %s:%s>' %
+ (self.__class__.__name__,
+ self.server_name,
+ self.server_port))
+
+ def get_request(self):
+ # override this to wrap socket with SSL
+ sock, addr = self.socket.accept()
+ sslconn = ssl.wrap_socket(sock, server_side=True,
+ certfile=self.certfile)
+ return sslconn, addr
+
+ class RootedHTTPRequestHandler(SimpleHTTPRequestHandler):
+ # need to override translate_path to get a known root,
+ # instead of using os.curdir, since the test could be
+ # run from anywhere
+
+ server_version = "TestHTTPS/1.0"
+
+ root = None
+
+ def translate_path(self, path):
+ """Translate a /-separated PATH to the local filename syntax.
+
+ Components that mean special things to the local file system
+ (e.g. drive or directory names) are ignored. (XXX They should
+ probably be diagnosed.)
+
+ """
+ # abandon query parameters
+ path = urlparse.urlparse(path)[2]
+ path = os.path.normpath(urllib.unquote(path))
+ words = path.split('/')
+ words = filter(None, words)
+ path = self.root
+ for word in words:
+ drive, word = os.path.splitdrive(word)
+ head, word = os.path.split(word)
+ if word in self.root: continue
+ path = os.path.join(path, word)
+ return path
+
+ def log_message(self, format, *args):
+
+ # we override this to suppress logging unless "verbose"
+
+ if test_support.verbose:
+ sys.stdout.write(" server (%s:%d %s):\n [%s] %s\n" %
+ (self.server.server_address,
+ self.server.server_port,
+ self.request.cipher(),
+ self.log_date_time_string(),
+ format%args))
+
+
+ def __init__(self, certfile):
+ self.flag = None
+ self.RootedHTTPRequestHandler.root = os.path.split(CERTFILE)[0]
+ self.server = self.HTTPSServer(
+ (HOST, 0), self.RootedHTTPRequestHandler, certfile)
+ self.port = self.server.server_port
+ threading.Thread.__init__(self)
+ self.daemon = True
+
+ def __str__(self):
+ return "<%s %s>" % (self.__class__.__name__, self.server)
+
+ def start(self, flag=None):
+ self.flag = flag
+ threading.Thread.start(self)
+
+ def run(self):
+ if self.flag:
+ self.flag.set()
+ self.server.serve_forever(0.05)
+
+ def stop(self):
+ self.server.shutdown()
+
+
+ def bad_cert_test(certfile):
+ """
+ Launch a server with CERT_REQUIRED, and check that trying to
+ connect to it with the given client certificate fails.
+ """
+ server = ThreadedEchoServer(CERTFILE,
+ certreqs=ssl.CERT_REQUIRED,
+ cacerts=CERTFILE, chatty=False)
+ with server:
+ try:
+ s = ssl.wrap_socket(socket.socket(),
+ certfile=certfile,
+ ssl_version=ssl.PROTOCOL_TLSv1)
+ s.connect((HOST, server.port))
+ except ssl.SSLError, x:
+ if test_support.verbose:
+ sys.stdout.write("\nSSLError is %s\n" % x[1])
+ except socket.error, x:
+ if test_support.verbose:
+ sys.stdout.write("\nsocket.error is %s\n" % x[1])
+ else:
+ raise AssertionError("Use of invalid cert should have failed!")
+
+ def server_params_test(certfile, protocol, certreqs, cacertsfile,
+ client_certfile, client_protocol=None, indata="FOO\n",
+ ciphers=None, chatty=True, connectionchatty=False,
+ wrap_accepting_socket=False):
+ """
+ Launch a server, connect a client to it and try various reads
+ and writes.
+ """
+ server = ThreadedEchoServer(certfile,
+ certreqs=certreqs,
+ ssl_version=protocol,
+ cacerts=cacertsfile,
+ ciphers=ciphers,
+ chatty=chatty,
+ connectionchatty=connectionchatty,
+ wrap_accepting_socket=wrap_accepting_socket)
+ with server:
+ # try to connect
+ if client_protocol is None:
+ client_protocol = protocol
+ s = ssl.wrap_socket(socket.socket(),
+ certfile=client_certfile,
+ ca_certs=cacertsfile,
+ ciphers=ciphers,
+ cert_reqs=certreqs,
+ ssl_version=client_protocol)
+ s.connect((HOST, server.port))
+ for arg in [indata, bytearray(indata), memoryview(indata)]:
+ if connectionchatty:
+ if test_support.verbose:
+ sys.stdout.write(
+ " client: sending %s...\n" % (repr(arg)))
+ s.write(arg)
+ outdata = s.read()
+ if connectionchatty:
+ if test_support.verbose:
+ sys.stdout.write(" client: read %s\n" % repr(outdata))
+ if outdata != indata.lower():
+ raise AssertionError(
+ "bad data <<%s>> (%d) received; expected <<%s>> (%d)\n"
+ % (outdata[:min(len(outdata),20)], len(outdata),
+ indata[:min(len(indata),20)].lower(), len(indata)))
+ s.write("over\n")
+ if connectionchatty:
+ if test_support.verbose:
+ sys.stdout.write(" client: closing connection.\n")
+ s.close()
+
+ def try_protocol_combo(server_protocol,
+ client_protocol,
+ expect_success,
+ certsreqs=None):
+ if certsreqs is None:
+ certsreqs = ssl.CERT_NONE
+ certtype = {
+ ssl.CERT_NONE: "CERT_NONE",
+ ssl.CERT_OPTIONAL: "CERT_OPTIONAL",
+ ssl.CERT_REQUIRED: "CERT_REQUIRED",
+ }[certsreqs]
+ if test_support.verbose:
+ formatstr = (expect_success and " %s->%s %s\n") or " {%s->%s} %s\n"
+ sys.stdout.write(formatstr %
+ (ssl.get_protocol_name(client_protocol),
+ ssl.get_protocol_name(server_protocol),
+ certtype))
+ try:
+ # NOTE: we must enable "ALL" ciphers, otherwise an SSLv23 client
+ # will send an SSLv3 hello (rather than SSLv2) starting from
+ # OpenSSL 1.0.0 (see issue #8322).
+ server_params_test(CERTFILE, server_protocol, certsreqs,
+ CERTFILE, CERTFILE, client_protocol,
+ ciphers="ALL", chatty=False)
+ # Protocol mismatch can result in either an SSLError, or a
+ # "Connection reset by peer" error.
+ except ssl.SSLError:
+ if expect_success:
+ raise
+ except socket.error as e:
+ if expect_success or e.errno != errno.ECONNRESET:
+ raise
+ else:
+ if not expect_success:
+ raise AssertionError(
+ "Client protocol %s succeeded with server protocol %s!"
+ % (ssl.get_protocol_name(client_protocol),
+ ssl.get_protocol_name(server_protocol)))
+
+
+ class ThreadedTests(unittest.TestCase):
+
+ def test_rude_shutdown(self):
+ """A brutal shutdown of an SSL server should raise an IOError
+ in the client when attempting handshake.
+ """
+ listener_ready = threading.Event()
+ listener_gone = threading.Event()
+
+ s = socket.socket()
+ port = test_support.bind_port(s, HOST)
+
+ # `listener` runs in a thread. It sits in an accept() until
+ # the main thread connects. Then it rudely closes the socket,
+ # and sets Event `listener_gone` to let the main thread know
+ # the socket is gone.
+ def listener():
+ s.listen(5)
+ listener_ready.set()
+ s.accept()
+ s.close()
+ listener_gone.set()
+
+ def connector():
+ listener_ready.wait()
+ c = socket.socket()
+ c.connect((HOST, port))
+ listener_gone.wait()
+ try:
+ ssl_sock = ssl.wrap_socket(c)
+ except IOError:
+ pass
+ else:
+ self.fail('connecting to closed SSL socket should have failed')
+
+ t = threading.Thread(target=listener)
+ t.start()
+ try:
+ connector()
+ finally:
+ t.join()
+
+ @skip_if_broken_ubuntu_ssl
+ def test_echo(self):
+ """Basic test of an SSL client connecting to a server"""
+ if test_support.verbose:
+ sys.stdout.write("\n")
+ server_params_test(CERTFILE, ssl.PROTOCOL_TLSv1, ssl.CERT_NONE,
+ CERTFILE, CERTFILE, ssl.PROTOCOL_TLSv1,
+ chatty=True, connectionchatty=True)
+
+ def test_getpeercert(self):
+ if test_support.verbose:
+ sys.stdout.write("\n")
+ s2 = socket.socket()
+ server = ThreadedEchoServer(CERTFILE,
+ certreqs=ssl.CERT_NONE,
+ ssl_version=ssl.PROTOCOL_SSLv23,
+ cacerts=CERTFILE,
+ chatty=False)
+ with server:
+ s = ssl.wrap_socket(socket.socket(),
+ certfile=CERTFILE,
+ ca_certs=CERTFILE,
+ cert_reqs=ssl.CERT_REQUIRED,
+ ssl_version=ssl.PROTOCOL_SSLv23)
+ s.connect((HOST, server.port))
+ cert = s.getpeercert()
+ self.assertTrue(cert, "Can't get peer certificate.")
+ cipher = s.cipher()
+ if test_support.verbose:
+ sys.stdout.write(pprint.pformat(cert) + '\n')
+ sys.stdout.write("Connection cipher is " + str(cipher) + '.\n')
+ if 'subject' not in cert:
+ self.fail("No subject field in certificate: %s." %
+ pprint.pformat(cert))
+ if ((('organizationName', 'Python Software Foundation'),)
+ not in cert['subject']):
+ self.fail(
+ "Missing or invalid 'organizationName' field in certificate subject; "
+ "should be 'Python Software Foundation'.")
+ s.close()
+
+ def test_empty_cert(self):
+ """Connecting with an empty cert file"""
+ bad_cert_test(os.path.join(os.path.dirname(__file__) or os.curdir,
+ "nullcert.pem"))
+ def test_malformed_cert(self):
+ """Connecting with a badly formatted certificate (syntax error)"""
+ bad_cert_test(os.path.join(os.path.dirname(__file__) or os.curdir,
+ "badcert.pem"))
+ def test_nonexisting_cert(self):
+ """Connecting with a non-existing cert file"""
+ bad_cert_test(os.path.join(os.path.dirname(__file__) or os.curdir,
+ "wrongcert.pem"))
+ def test_malformed_key(self):
+ """Connecting with a badly formatted key (syntax error)"""
+ bad_cert_test(os.path.join(os.path.dirname(__file__) or os.curdir,
+ "badkey.pem"))
+
+ @skip_if_broken_ubuntu_ssl
+ def test_protocol_sslv2(self):
+ """Connecting to an SSLv2 server with various client options"""
+ if test_support.verbose:
+ sys.stdout.write("\n")
+ if not hasattr(ssl, 'PROTOCOL_SSLv2'):
+ self.skipTest("PROTOCOL_SSLv2 needed")
+ try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True)
+ try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True, ssl.CERT_OPTIONAL)
+ try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True, ssl.CERT_REQUIRED)
+ try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, True)
+ try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv3, False)
+ try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLSv1, False)
+
+ @skip_if_broken_ubuntu_ssl
+ def test_protocol_sslv23(self):
+ """Connecting to an SSLv23 server with various client options"""
+ if test_support.verbose:
+ sys.stdout.write("\n")
+ try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, True)
+ try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True)
+ try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, True)
+
+ try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, True, ssl.CERT_OPTIONAL)
+ try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True, ssl.CERT_OPTIONAL)
+ try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, True, ssl.CERT_OPTIONAL)
+
+ try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, True, ssl.CERT_REQUIRED)
+ try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True, ssl.CERT_REQUIRED)
+ try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, True, ssl.CERT_REQUIRED)
+
+ @skip_if_broken_ubuntu_ssl
+ def test_protocol_sslv3(self):
+ """Connecting to an SSLv3 server with various client options"""
+ if test_support.verbose:
+ sys.stdout.write("\n")
+ try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, True)
+ try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, True, ssl.CERT_OPTIONAL)
+ try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, True, ssl.CERT_REQUIRED)
+ if hasattr(ssl, 'PROTOCOL_SSLv2'):
+ try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv2, False)
+ try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_TLSv1, False)
+
+ @skip_if_broken_ubuntu_ssl
+ def test_protocol_tlsv1(self):
+ """Connecting to a TLSv1 server with various client options"""
+ if test_support.verbose:
+ sys.stdout.write("\n")
+ try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, True)
+ try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, True, ssl.CERT_OPTIONAL)
+ try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, True, ssl.CERT_REQUIRED)
+ if hasattr(ssl, 'PROTOCOL_SSLv2'):
+ try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv2, False)
+ try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv3, False)
+
+ def test_starttls(self):
+ """Switching from clear text to encrypted and back again."""
+ msgs = ("msg 1", "MSG 2", "STARTTLS", "MSG 3", "msg 4", "ENDTLS", "msg 5", "msg 6")
+
+ server = ThreadedEchoServer(CERTFILE,
+ ssl_version=ssl.PROTOCOL_TLSv1,
+ starttls_server=True,
+ chatty=True,
+ connectionchatty=True)
+ wrapped = False
+ with server:
+ s = socket.socket()
+ s.setblocking(1)
+ s.connect((HOST, server.port))
+ if test_support.verbose:
+ sys.stdout.write("\n")
+ for indata in msgs:
+ if test_support.verbose:
+ sys.stdout.write(
+ " client: sending %s...\n" % repr(indata))
+ if wrapped:
+ conn.write(indata)
+ outdata = conn.read()
+ else:
+ s.send(indata)
+ outdata = s.recv(1024)
+ if (indata == "STARTTLS" and
+ outdata.strip().lower().startswith("ok")):
+ # STARTTLS ok, switch to secure mode
+ if test_support.verbose:
+ sys.stdout.write(
+ " client: read %s from server, starting TLS...\n"
+ % repr(outdata))
+ conn = ssl.wrap_socket(s, ssl_version=ssl.PROTOCOL_TLSv1)
+ wrapped = True
+ elif (indata == "ENDTLS" and
+ outdata.strip().lower().startswith("ok")):
+ # ENDTLS ok, switch back to clear text
+ if test_support.verbose:
+ sys.stdout.write(
+ " client: read %s from server, ending TLS...\n"
+ % repr(outdata))
+ s = conn.unwrap()
+ wrapped = False
+ else:
+ if test_support.verbose:
+ sys.stdout.write(
+ " client: read %s from server\n" % repr(outdata))
+ if test_support.verbose:
+ sys.stdout.write(" client: closing connection.\n")
+ if wrapped:
+ conn.write("over\n")
+ else:
+ s.send("over\n")
+ s.close()
+
+ def test_socketserver(self):
+ """Using a SocketServer to create and manage SSL connections."""
+ server = SocketServerHTTPSServer(CERTFILE)
+ flag = threading.Event()
+ server.start(flag)
+ # wait for it to start
+ flag.wait()
+ # try to connect
+ try:
+ if test_support.verbose:
+ sys.stdout.write('\n')
+ with open(CERTFILE, 'rb') as f:
+ d1 = f.read()
+ d2 = ''
+ # now fetch the same data from the HTTPS server
+ url = 'https://127.0.0.1:%d/%s' % (
+ server.port, os.path.split(CERTFILE)[1])
+ with test_support.check_py3k_warnings():
+ f = urllib.urlopen(url)
+ dlen = f.info().getheader("content-length")
+ if dlen and (int(dlen) > 0):
+ d2 = f.read(int(dlen))
+ if test_support.verbose:
+ sys.stdout.write(
+ " client: read %d bytes from remote server '%s'\n"
+ % (len(d2), server))
+ f.close()
+ self.assertEqual(d1, d2)
+ finally:
+ server.stop()
+ server.join()
+
+ def test_wrapped_accept(self):
+ """Check the accept() method on SSL sockets."""
+ if test_support.verbose:
+ sys.stdout.write("\n")
+ server_params_test(CERTFILE, ssl.PROTOCOL_SSLv23, ssl.CERT_REQUIRED,
+ CERTFILE, CERTFILE, ssl.PROTOCOL_SSLv23,
+ chatty=True, connectionchatty=True,
+ wrap_accepting_socket=True)
+
+ def test_asyncore_server(self):
+ """Check the example asyncore integration."""
+ indata = "TEST MESSAGE of mixed case\n"
+
+ if test_support.verbose:
+ sys.stdout.write("\n")
+ server = AsyncoreEchoServer(CERTFILE)
+ with server:
+ s = ssl.wrap_socket(socket.socket())
+ s.connect(('127.0.0.1', server.port))
+ if test_support.verbose:
+ sys.stdout.write(
+ " client: sending %s...\n" % (repr(indata)))
+ s.write(indata)
+ outdata = s.read()
+ if test_support.verbose:
+ sys.stdout.write(" client: read %s\n" % repr(outdata))
+ if outdata != indata.lower():
+ self.fail(
+ "bad data <<%s>> (%d) received; expected <<%s>> (%d)\n"
+ % (outdata[:min(len(outdata),20)], len(outdata),
+ indata[:min(len(indata),20)].lower(), len(indata)))
+ s.write("over\n")
+ if test_support.verbose:
+ sys.stdout.write(" client: closing connection.\n")
+ s.close()
+
+ def test_recv_send(self):
+ """Test recv(), send() and friends."""
+ if test_support.verbose:
+ sys.stdout.write("\n")
+
+ server = ThreadedEchoServer(CERTFILE,
+ certreqs=ssl.CERT_NONE,
+ ssl_version=ssl.PROTOCOL_TLSv1,
+ cacerts=CERTFILE,
+ chatty=True,
+ connectionchatty=False)
+ with server:
+ s = ssl.wrap_socket(socket.socket(),
+ server_side=False,
+ certfile=CERTFILE,
+ ca_certs=CERTFILE,
+ cert_reqs=ssl.CERT_NONE,
+ ssl_version=ssl.PROTOCOL_TLSv1)
+ s.connect((HOST, server.port))
+ # helper methods for standardising recv* method signatures
+ def _recv_into():
+ b = bytearray("\0"*100)
+ count = s.recv_into(b)
+ return b[:count]
+
+ def _recvfrom_into():
+ b = bytearray("\0"*100)
+ count, addr = s.recvfrom_into(b)
+ return b[:count]
+
+ # (name, method, whether to expect success, *args)
+ send_methods = [
+ ('send', s.send, True, []),
+ ('sendto', s.sendto, False, ["some.address"]),
+ ('sendall', s.sendall, True, []),
+ ]
+ recv_methods = [
+ ('recv', s.recv, True, []),
+ ('recvfrom', s.recvfrom, False, ["some.address"]),
+ ('recv_into', _recv_into, True, []),
+ ('recvfrom_into', _recvfrom_into, False, []),
+ ]
+ data_prefix = u"PREFIX_"
+
+ for meth_name, send_meth, expect_success, args in send_methods:
+ indata = data_prefix + meth_name
+ try:
+ send_meth(indata.encode('ASCII', 'strict'), *args)
+ outdata = s.read()
+ outdata = outdata.decode('ASCII', 'strict')
+ if outdata != indata.lower():
+ self.fail(
+ "While sending with <<%s>> bad data "
+ "<<%r>> (%d) received; "
+ "expected <<%r>> (%d)\n" % (
+ meth_name, outdata[:20], len(outdata),
+ indata[:20], len(indata)
+ )
+ )
+ except ValueError as e:
+ if expect_success:
+ self.fail(
+ "Failed to send with method <<%s>>; "
+ "expected to succeed.\n" % (meth_name,)
+ )
+ if not str(e).startswith(meth_name):
+ self.fail(
+ "Method <<%s>> failed with unexpected "
+ "exception message: %s\n" % (
+ meth_name, e
+ )
+ )
+
+ for meth_name, recv_meth, expect_success, args in recv_methods:
+ indata = data_prefix + meth_name
+ try:
+ s.send(indata.encode('ASCII', 'strict'))
+ outdata = recv_meth(*args)
+ outdata = outdata.decode('ASCII', 'strict')
+ if outdata != indata.lower():
+ self.fail(
+ "While receiving with <<%s>> bad data "
+ "<<%r>> (%d) received; "
+ "expected <<%r>> (%d)\n" % (
+ meth_name, outdata[:20], len(outdata),
+ indata[:20], len(indata)
+ )
+ )
+ except ValueError as e:
+ if expect_success:
+ self.fail(
+ "Failed to receive with method <<%s>>; "
+ "expected to succeed.\n" % (meth_name,)
+ )
+ if not str(e).startswith(meth_name):
+ self.fail(
+ "Method <<%s>> failed with unexpected "
+ "exception message: %s\n" % (
+ meth_name, e
+ )
+ )
+ # consume data
+ s.read()
+
+ s.write("over\n".encode("ASCII", "strict"))
+ s.close()
+
+ def test_handshake_timeout(self):
+ # Issue #5103: SSL handshake must respect the socket timeout
+ server = socket.socket(socket.AF_INET)
+ host = "127.0.0.1"
+ port = test_support.bind_port(server)
+ started = threading.Event()
+ finish = False
+
+ def serve():
+ server.listen(5)
+ started.set()
+ conns = []
+ while not finish:
+ r, w, e = select.select([server], [], [], 0.1)
+ if server in r:
+ # Let the socket hang around rather than having
+ # it closed by garbage collection.
+ conns.append(server.accept()[0])
+
+ t = threading.Thread(target=serve)
+ t.start()
+ started.wait()
+
+ try:
+ try:
+ c = socket.socket(socket.AF_INET)
+ c.settimeout(0.2)
+ c.connect((host, port))
+ # Will attempt handshake and time out
+ self.assertRaisesRegexp(ssl.SSLError, "timed out",
+ ssl.wrap_socket, c)
+ finally:
+ c.close()
+ try:
+ c = socket.socket(socket.AF_INET)
+ c.settimeout(0.2)
+ c = ssl.wrap_socket(c)
+ # Will attempt handshake and time out
+ self.assertRaisesRegexp(ssl.SSLError, "timed out",
+ c.connect, (host, port))
+ finally:
+ c.close()
+ finally:
+ finish = True
+ t.join()
+ server.close()
+
+ def test_default_ciphers(self):
+ with ThreadedEchoServer(CERTFILE,
+ ssl_version=ssl.PROTOCOL_SSLv23,
+ chatty=False) as server:
+ sock = socket.socket()
+ try:
+ # Force a set of weak ciphers on our client socket
+ try:
+ s = ssl.wrap_socket(sock,
+ ssl_version=ssl.PROTOCOL_SSLv23,
+ ciphers="DES")
+ except ssl.SSLError:
+ self.skipTest("no DES cipher available")
+ with self.assertRaises((OSError, ssl.SSLError)):
+ s.connect((HOST, server.port))
+ finally:
+ sock.close()
+ self.assertIn("no shared cipher", str(server.conn_errors[0]))
+
+
+def test_main(verbose=False):
+ global CERTFILE, SVN_PYTHON_ORG_ROOT_CERT, NOKIACERT
+ CERTFILE = os.path.join(os.path.dirname(__file__) or os.curdir,
+ "keycert.pem")
+ SVN_PYTHON_ORG_ROOT_CERT = os.path.join(
+ os.path.dirname(__file__) or os.curdir,
+ "https_svn_python_org_root.pem")
+ NOKIACERT = os.path.join(os.path.dirname(__file__) or os.curdir,
+ "nokia.pem")
+
+ if (not os.path.exists(CERTFILE) or
+ not os.path.exists(SVN_PYTHON_ORG_ROOT_CERT) or
+ not os.path.exists(NOKIACERT)):
+ raise test_support.TestFailed("Can't read certificate files!")
+
+ tests = [BasicTests, BasicSocketTests]
+
+ if test_support.is_resource_enabled('network'):
+ tests.append(NetworkedTests)
+
+ if _have_threads:
+ thread_info = test_support.threading_setup()
+ if thread_info and test_support.is_resource_enabled('network'):
+ tests.append(ThreadedTests)
+
+ try:
+ test_support.run_unittest(*tests)
+ finally:
+ if _have_threads:
+ test_support.threading_cleanup(*thread_info)
+
+if __name__ == "__main__":
+ test_main()
diff --git a/Lib/test/test_support.py b/Lib/test/test_support.py
--- a/Lib/test/test_support.py
+++ b/Lib/test/test_support.py
@@ -969,6 +969,22 @@
def captured_stdin():
return captured_output("stdin")
+def gc_collect():
+ """Force as many objects as possible to be collected.
+
+ In non-CPython implementations of Python, this is needed because timely
+ deallocation is not guaranteed by the garbage collector. (Even in CPython
+ this can be the case in case of reference cycles.) This means that __del__
+ methods may be called later than expected and weakrefs may remain alive for
+ longer than expected. This function tries its best to force all garbage
+ objects to disappear.
+ """
+ gc.collect()
+ if is_jython:
+ time.sleep(0.1)
+ gc.collect()
+ gc.collect()
+
_header = '2P'
if hasattr(sys, "gettotalrefcount"):
diff --git a/Lib/test/test_tarfile.py b/Lib/test/test_tarfile.py
deleted file mode 100644
--- a/Lib/test/test_tarfile.py
+++ /dev/null
@@ -1,1568 +0,0 @@
-# -*- coding: iso-8859-15 -*-
-
-import sys
-import os
-import shutil
-import StringIO
-from hashlib import md5
-import errno
-import gc
-
-import unittest
-import tarfile
-
-from test import test_support
-
-# Check for our compression modules.
-try:
- import gzip
- gzip.GzipFile
-except (ImportError, AttributeError):
- gzip = None
-try:
- import bz2
-except ImportError:
- bz2 = None
-
-def md5sum(data):
- return md5(data).hexdigest()
-
-TEMPDIR = os.path.abspath(test_support.TESTFN)
-tarname = test_support.findfile("testtar.tar")
-gzipname = os.path.join(TEMPDIR, "testtar.tar.gz")
-bz2name = os.path.join(TEMPDIR, "testtar.tar.bz2")
-tmpname = os.path.join(TEMPDIR, "tmp.tar")
-
-md5_regtype = "65f477c818ad9e15f7feab0c6d37742f"
-md5_sparse = "a54fbc4ca4f4399a90e1b27164012fc6"
-
-
-class ReadTest(unittest.TestCase):
-
- tarname = tarname
- mode = "r:"
-
- def setUp(self):
- self.tar = tarfile.open(self.tarname, mode=self.mode, encoding="iso8859-1")
-
- def tearDown(self):
- self.tar.close()
- # Not all files are currently being closed in these tests,
- # so to ensure something similar to CPython's deterministic cleanup,
- # call gc and have finalization happen
- gc.collect()
-
-
-class UstarReadTest(ReadTest):
-
- def test_fileobj_regular_file(self):
- tarinfo = self.tar.getmember("ustar/regtype")
- fobj = self.tar.extractfile(tarinfo)
- data = fobj.read()
- self.assertTrue((len(data), md5sum(data)) == (tarinfo.size, md5_regtype),
- "regular file extraction failed")
-
- def test_fileobj_readlines(self):
- self.tar.extract("ustar/regtype", TEMPDIR)
- tarinfo = self.tar.getmember("ustar/regtype")
- fobj1 = open(os.path.join(TEMPDIR, "ustar/regtype"), "rU")
- fobj2 = self.tar.extractfile(tarinfo)
-
- lines1 = fobj1.readlines()
- lines2 = fobj2.readlines()
- self.assertTrue(lines1 == lines2,
- "fileobj.readlines() failed")
- self.assertTrue(len(lines2) == 114,
- "fileobj.readlines() failed")
- self.assertTrue(lines2[83] ==
- "I will gladly admit that Python is not the fastest running scripting language.\n",
- "fileobj.readlines() failed")
-
- def test_fileobj_iter(self):
- self.tar.extract("ustar/regtype", TEMPDIR)
- tarinfo = self.tar.getmember("ustar/regtype")
- fobj1 = open(os.path.join(TEMPDIR, "ustar/regtype"), "rU")
- fobj2 = self.tar.extractfile(tarinfo)
- lines1 = fobj1.readlines()
- lines2 = [line for line in fobj2]
- self.assertTrue(lines1 == lines2,
- "fileobj.__iter__() failed")
-
- def test_fileobj_seek(self):
- self.tar.extract("ustar/regtype", TEMPDIR)
- fobj = open(os.path.join(TEMPDIR, "ustar/regtype"), "rb")
- data = fobj.read()
- fobj.close()
-
- tarinfo = self.tar.getmember("ustar/regtype")
- fobj = self.tar.extractfile(tarinfo)
-
- text = fobj.read()
- fobj.seek(0)
- self.assertTrue(0 == fobj.tell(),
- "seek() to file's start failed")
- fobj.seek(2048, 0)
- self.assertTrue(2048 == fobj.tell(),
- "seek() to absolute position failed")
- fobj.seek(-1024, 1)
- self.assertTrue(1024 == fobj.tell(),
- "seek() to negative relative position failed")
- fobj.seek(1024, 1)
- self.assertTrue(2048 == fobj.tell(),
- "seek() to positive relative position failed")
- s = fobj.read(10)
- self.assertTrue(s == data[2048:2058],
- "read() after seek failed")
- fobj.seek(0, 2)
- self.assertTrue(tarinfo.size == fobj.tell(),
- "seek() to file's end failed")
- self.assertTrue(fobj.read() == "",
- "read() at file's end did not return empty string")
- fobj.seek(-tarinfo.size, 2)
- self.assertTrue(0 == fobj.tell(),
- "relative seek() to file's start failed")
- fobj.seek(512)
- s1 = fobj.readlines()
- fobj.seek(512)
- s2 = fobj.readlines()
- self.assertTrue(s1 == s2,
- "readlines() after seek failed")
- fobj.seek(0)
- self.assertTrue(len(fobj.readline()) == fobj.tell(),
- "tell() after readline() failed")
- fobj.seek(512)
- self.assertTrue(len(fobj.readline()) + 512 == fobj.tell(),
- "tell() after seek() and readline() failed")
- fobj.seek(0)
- line = fobj.readline()
- self.assertTrue(fobj.read() == data[len(line):],
- "read() after readline() failed")
- fobj.close()
-
- # Test if symbolic and hard links are resolved by extractfile(). The
- # test link members each point to a regular member whose data is
- # supposed to be exported.
- def _test_fileobj_link(self, lnktype, regtype):
- a = self.tar.extractfile(lnktype)
- b = self.tar.extractfile(regtype)
- self.assertEqual(a.name, b.name)
-
- def test_fileobj_link1(self):
- self._test_fileobj_link("ustar/lnktype", "ustar/regtype")
-
- def test_fileobj_link2(self):
- self._test_fileobj_link("./ustar/linktest2/lnktype", "ustar/linktest1/regtype")
-
- def test_fileobj_symlink1(self):
- self._test_fileobj_link("ustar/symtype", "ustar/regtype")
-
- def test_fileobj_symlink2(self):
- self._test_fileobj_link("./ustar/linktest2/symtype", "ustar/linktest1/regtype")
-
-
-class CommonReadTest(ReadTest):
-
- def test_empty_tarfile(self):
- # Test for issue6123: Allow opening empty archives.
- # This test checks if tarfile.open() is able to open an empty tar
- # archive successfully. Note that an empty tar archive is not the
- # same as an empty file!
- tarfile.open(tmpname, self.mode.replace("r", "w")).close()
- try:
- tar = tarfile.open(tmpname, self.mode)
- tar.getnames()
- except tarfile.ReadError:
- self.fail("tarfile.open() failed on empty archive")
- self.assertListEqual(tar.getmembers(), [])
-
- def test_null_tarfile(self):
- # Test for issue6123: Allow opening empty archives.
- # This test guarantees that tarfile.open() does not treat an empty
- # file as an empty tar archive.
- open(tmpname, "wb").close()
- self.assertRaises(tarfile.ReadError, tarfile.open, tmpname, self.mode)
- self.assertRaises(tarfile.ReadError, tarfile.open, tmpname)
-
- def test_ignore_zeros(self):
- # Test TarFile's ignore_zeros option.
- if self.mode.endswith(":gz"):
- _open = gzip.GzipFile
- elif self.mode.endswith(":bz2"):
- _open = bz2.BZ2File
- else:
- _open = open
-
- for char in ('\0', 'a'):
- # Test if EOFHeaderError ('\0') and InvalidHeaderError ('a')
- # are ignored correctly.
- fobj = _open(tmpname, "wb")
- fobj.write(char * 1024)
- fobj.write(tarfile.TarInfo("foo").tobuf())
- fobj.close()
-
- tar = tarfile.open(tmpname, mode="r", ignore_zeros=True)
- self.assertListEqual(tar.getnames(), ["foo"],
- "ignore_zeros=True should have skipped the %r-blocks" % char)
- tar.close()
-
-
-class MiscReadTest(CommonReadTest):
-
- def test_no_name_argument(self):
- fobj = open(self.tarname, "rb")
- self.tar.close()
- self.tar = tarfile.open(fileobj=fobj, mode="r")
- self.assertEqual(self.tar.name, os.path.abspath(fobj.name))
- fobj.close()
-
- def test_no_name_attribute(self):
- fp = open(self.tarname, "rb")
- data = fp.read()
- fp.close()
- fobj = StringIO.StringIO(data)
- self.assertRaises(AttributeError, getattr, fobj, "name")
- self.tar.close()
- self.tar = tarfile.open(fileobj=fobj, mode="r")
- self.assertEqual(self.tar.name, None)
-
- def test_empty_name_attribute(self):
- fp = open(self.tarname, "rb")
- data = fp.read()
- fp.close()
- fobj = StringIO.StringIO(data)
- fobj.name = ""
- self.tar.close()
- self.tar = tarfile.open(fileobj=fobj, mode="r")
- self.assertEqual(self.tar.name, None)
-
- def test_fileobj_with_offset(self):
- # Skip the first member and store values from the second member
- # of the testtar.
- tar = tarfile.open(self.tarname, mode=self.mode)
- tar.next()
- t = tar.next()
- name = t.name
- offset = t.offset
- data = tar.extractfile(t).read()
- tar.close()
-
- # Open the testtar and seek to the offset of the second member.
- if self.mode.endswith(":gz"):
- _open = gzip.GzipFile
- elif self.mode.endswith(":bz2"):
- _open = bz2.BZ2File
- else:
- _open = open
- fobj = _open(self.tarname, "rb")
- fobj.seek(offset)
-
- # Test if the tarfile starts with the second member.
- tar.close()
- tar = tar.open(self.tarname, mode="r:", fileobj=fobj)
- t = tar.next()
- self.assertEqual(t.name, name)
- # Read to the end of fileobj and test if seeking back to the
- # beginning works.
- tar.getmembers()
- self.assertEqual(tar.extractfile(t).read(), data,
- "seek back did not work")
- tar.close()
- fobj.close()
-
- def test_fail_comp(self):
- # For Gzip and Bz2 Tests: fail with a ReadError on an uncompressed file.
- if self.mode == "r:":
- return
- self.assertRaises(tarfile.ReadError, tarfile.open, tarname, self.mode)
- fobj = open(tarname, "rb")
- self.assertRaises(tarfile.ReadError, tarfile.open, fileobj=fobj, mode=self.mode)
-
- def test_v7_dirtype(self):
- # Test old style dirtype member (bug #1336623):
- # Old V7 tars create directory members using an AREGTYPE
- # header with a "/" appended to the filename field.
- tarinfo = self.tar.getmember("misc/dirtype-old-v7")
- self.assertTrue(tarinfo.type == tarfile.DIRTYPE,
- "v7 dirtype failed")
-
- def test_xstar_type(self):
- # The xstar format stores extra atime and ctime fields inside the
- # space reserved for the prefix field. The prefix field must be
- # ignored in this case, otherwise it will mess up the name.
- try:
- self.tar.getmember("misc/regtype-xstar")
- except KeyError:
- self.fail("failed to find misc/regtype-xstar (mangled prefix?)")
-
- def test_check_members(self):
- for tarinfo in self.tar:
- self.assertTrue(int(tarinfo.mtime) == 07606136617,
- "wrong mtime for %s" % tarinfo.name)
- if not tarinfo.name.startswith("ustar/"):
- continue
- self.assertTrue(tarinfo.uname == "tarfile",
- "wrong uname for %s" % tarinfo.name)
-
- def test_find_members(self):
- self.assertTrue(self.tar.getmembers()[-1].name == "misc/eof",
- "could not find all members")
-
- def test_extract_hardlink(self):
- # Test hardlink extraction (e.g. bug #857297).
- tar = tarfile.open(tarname, errorlevel=1, encoding="iso8859-1")
-
- tar.extract("ustar/regtype", TEMPDIR)
- try:
- tar.extract("ustar/lnktype", TEMPDIR)
- except EnvironmentError, e:
- if e.errno == errno.ENOENT:
- self.fail("hardlink not extracted properly")
-
- data = open(os.path.join(TEMPDIR, "ustar/lnktype"), "rb").read()
- self.assertEqual(md5sum(data), md5_regtype)
-
- try:
- tar.extract("ustar/symtype", TEMPDIR)
- except EnvironmentError, e:
- if e.errno == errno.ENOENT:
- self.fail("symlink not extracted properly")
-
- data = open(os.path.join(TEMPDIR, "ustar/symtype"), "rb").read()
- self.assertEqual(md5sum(data), md5_regtype)
-
- def test_extractall(self):
- # Test if extractall() correctly restores directory permissions
- # and times (see issue1735).
- tar = tarfile.open(tarname, encoding="iso8859-1")
- directories = [t for t in tar if t.isdir()]
- tar.extractall(TEMPDIR, directories)
- for tarinfo in directories:
- path = os.path.join(TEMPDIR, tarinfo.name)
- if (sys.platform == "win32" or
- test_support.is_jython and os._name == 'nt'):
- # Win32 has no support for fine grained permissions.
- self.assertEqual(tarinfo.mode & 0777, os.stat(path).st_mode & 0777)
- self.assertEqual(tarinfo.mtime, os.path.getmtime(path))
- tar.close()
-
- def test_init_close_fobj(self):
- # Issue #7341: Close the internal file object in the TarFile
- # constructor in case of an error. For the test we rely on
- # the fact that opening an empty file raises a ReadError.
- empty = os.path.join(TEMPDIR, "empty")
- open(empty, "wb").write("")
-
- try:
- tar = object.__new__(tarfile.TarFile)
- try:
- tar.__init__(empty)
- except tarfile.ReadError:
- self.assertTrue(tar.fileobj.closed)
- else:
- self.fail("ReadError not raised")
- finally:
- os.remove(empty)
-
-
-class StreamReadTest(CommonReadTest):
-
- mode="r|"
-
- def test_fileobj_regular_file(self):
- tarinfo = self.tar.next() # get "regtype" (can't use getmember)
- fobj = self.tar.extractfile(tarinfo)
- data = fobj.read()
- self.assertTrue((len(data), md5sum(data)) == (tarinfo.size, md5_regtype),
- "regular file extraction failed")
-
- def test_provoke_stream_error(self):
- tarinfos = self.tar.getmembers()
- f = self.tar.extractfile(tarinfos[0]) # read the first member
- self.assertRaises(tarfile.StreamError, f.read)
-
- def test_compare_members(self):
- tar1 = tarfile.open(tarname, encoding="iso8859-1")
- tar2 = self.tar
-
- while True:
- t1 = tar1.next()
- t2 = tar2.next()
- if t1 is None:
- break
- self.assertTrue(t2 is not None, "stream.next() failed.")
-
- if t2.islnk() or t2.issym():
- self.assertRaises(tarfile.StreamError, tar2.extractfile, t2)
- continue
-
- v1 = tar1.extractfile(t1)
- v2 = tar2.extractfile(t2)
- if v1 is None:
- continue
- self.assertTrue(v2 is not None, "stream.extractfile() failed")
- self.assertTrue(v1.read() == v2.read(), "stream extraction failed")
-
- tar1.close()
-
-
-class DetectReadTest(unittest.TestCase):
-
- def _testfunc_file(self, name, mode):
- try:
- tarfile.open(name, mode)
- except tarfile.ReadError:
- self.fail()
-
- def _testfunc_fileobj(self, name, mode):
- try:
- tarfile.open(name, mode, fileobj=open(name, "rb"))
- except tarfile.ReadError:
- self.fail()
-
- def _test_modes(self, testfunc):
- testfunc(tarname, "r")
- testfunc(tarname, "r:")
- testfunc(tarname, "r:*")
- testfunc(tarname, "r|")
- testfunc(tarname, "r|*")
-
- if gzip:
- self.assertRaises(tarfile.ReadError, tarfile.open, tarname, mode="r:gz")
- self.assertRaises(tarfile.ReadError, tarfile.open, tarname, mode="r|gz")
- self.assertRaises(tarfile.ReadError, tarfile.open, gzipname, mode="r:")
- self.assertRaises(tarfile.ReadError, tarfile.open, gzipname, mode="r|")
-
- testfunc(gzipname, "r")
- testfunc(gzipname, "r:*")
- testfunc(gzipname, "r:gz")
- testfunc(gzipname, "r|*")
- testfunc(gzipname, "r|gz")
-
- if bz2:
- self.assertRaises(tarfile.ReadError, tarfile.open, tarname, mode="r:bz2")
- self.assertRaises(tarfile.ReadError, tarfile.open, tarname, mode="r|bz2")
- self.assertRaises(tarfile.ReadError, tarfile.open, bz2name, mode="r:")
- self.assertRaises(tarfile.ReadError, tarfile.open, bz2name, mode="r|")
-
- testfunc(bz2name, "r")
- testfunc(bz2name, "r:*")
- testfunc(bz2name, "r:bz2")
- testfunc(bz2name, "r|*")
- testfunc(bz2name, "r|bz2")
-
- def test_detect_file(self):
- self._test_modes(self._testfunc_file)
-
- def test_detect_fileobj(self):
- self._test_modes(self._testfunc_fileobj)
-
-
-class MemberReadTest(ReadTest):
-
- def _test_member(self, tarinfo, chksum=None, **kwargs):
- if chksum is not None:
- self.assertTrue(md5sum(self.tar.extractfile(tarinfo).read()) == chksum,
- "wrong md5sum for %s" % tarinfo.name)
-
- kwargs["mtime"] = 07606136617
- kwargs["uid"] = 1000
- kwargs["gid"] = 100
- if "old-v7" not in tarinfo.name:
- # V7 tar can't handle alphabetic owners.
- kwargs["uname"] = "tarfile"
- kwargs["gname"] = "tarfile"
- for k, v in kwargs.iteritems():
- self.assertTrue(getattr(tarinfo, k) == v,
- "wrong value in %s field of %s" % (k, tarinfo.name))
-
- def test_find_regtype(self):
- tarinfo = self.tar.getmember("ustar/regtype")
- self._test_member(tarinfo, size=7011, chksum=md5_regtype)
-
- def test_find_conttype(self):
- tarinfo = self.tar.getmember("ustar/conttype")
- self._test_member(tarinfo, size=7011, chksum=md5_regtype)
-
- def test_find_dirtype(self):
- tarinfo = self.tar.getmember("ustar/dirtype")
- self._test_member(tarinfo, size=0)
-
- def test_find_dirtype_with_size(self):
- tarinfo = self.tar.getmember("ustar/dirtype-with-size")
- self._test_member(tarinfo, size=255)
-
- def test_find_lnktype(self):
- tarinfo = self.tar.getmember("ustar/lnktype")
- self._test_member(tarinfo, size=0, linkname="ustar/regtype")
-
- def test_find_symtype(self):
- tarinfo = self.tar.getmember("ustar/symtype")
- self._test_member(tarinfo, size=0, linkname="regtype")
-
- def test_find_blktype(self):
- tarinfo = self.tar.getmember("ustar/blktype")
- self._test_member(tarinfo, size=0, devmajor=3, devminor=0)
-
- def test_find_chrtype(self):
- tarinfo = self.tar.getmember("ustar/chrtype")
- self._test_member(tarinfo, size=0, devmajor=1, devminor=3)
-
- def test_find_fifotype(self):
- tarinfo = self.tar.getmember("ustar/fifotype")
- self._test_member(tarinfo, size=0)
-
- def test_find_sparse(self):
- tarinfo = self.tar.getmember("ustar/sparse")
- self._test_member(tarinfo, size=86016, chksum=md5_sparse)
-
- def test_find_umlauts(self):
- tarinfo = self.tar.getmember("ustar/umlauts-ÄÖÜäöüß")
- self._test_member(tarinfo, size=7011, chksum=md5_regtype)
-
- def test_find_ustar_longname(self):
- name = "ustar/" + "12345/" * 39 + "1234567/longname"
- self.assertIn(name, self.tar.getnames())
-
- def test_find_regtype_oldv7(self):
- tarinfo = self.tar.getmember("misc/regtype-old-v7")
- self._test_member(tarinfo, size=7011, chksum=md5_regtype)
-
- def test_find_pax_umlauts(self):
- self.tar = tarfile.open(self.tarname, mode=self.mode, encoding="iso8859-1")
- tarinfo = self.tar.getmember("pax/umlauts-ÄÖÜäöüß")
- self._test_member(tarinfo, size=7011, chksum=md5_regtype)
-
-
-class LongnameTest(ReadTest):
-
- def test_read_longname(self):
- # Test reading of longname (bug #1471427).
- longname = self.subdir + "/" + "123/" * 125 + "longname"
- try:
- tarinfo = self.tar.getmember(longname)
- except KeyError:
- self.fail("longname not found")
- self.assertTrue(tarinfo.type != tarfile.DIRTYPE, "read longname as dirtype")
-
- def test_read_longlink(self):
- longname = self.subdir + "/" + "123/" * 125 + "longname"
- longlink = self.subdir + "/" + "123/" * 125 + "longlink"
- try:
- tarinfo = self.tar.getmember(longlink)
- except KeyError:
- self.fail("longlink not found")
- self.assertTrue(tarinfo.linkname == longname, "linkname wrong")
-
- def test_truncated_longname(self):
- longname = self.subdir + "/" + "123/" * 125 + "longname"
- tarinfo = self.tar.getmember(longname)
- offset = tarinfo.offset
- self.tar.fileobj.seek(offset)
- fobj = StringIO.StringIO(self.tar.fileobj.read(3 * 512))
- self.assertRaises(tarfile.ReadError, tarfile.open, name="foo.tar", fileobj=fobj)
-
- def test_header_offset(self):
- # Test if the start offset of the TarInfo object includes
- # the preceding extended header.
- longname = self.subdir + "/" + "123/" * 125 + "longname"
- offset = self.tar.getmember(longname).offset
- fobj = open(tarname)
- fobj.seek(offset)
- tarinfo = tarfile.TarInfo.frombuf(fobj.read(512))
- self.assertEqual(tarinfo.type, self.longnametype)
-
-
-class GNUReadTest(LongnameTest):
-
- subdir = "gnu"
- longnametype = tarfile.GNUTYPE_LONGNAME
-
- def test_sparse_file(self):
- tarinfo1 = self.tar.getmember("ustar/sparse")
- fobj1 = self.tar.extractfile(tarinfo1)
- tarinfo2 = self.tar.getmember("gnu/sparse")
- fobj2 = self.tar.extractfile(tarinfo2)
- self.assertTrue(fobj1.read() == fobj2.read(),
- "sparse file extraction failed")
-
-
-class PaxReadTest(LongnameTest):
-
- subdir = "pax"
- longnametype = tarfile.XHDTYPE
-
- def test_pax_global_headers(self):
- tar = tarfile.open(tarname, encoding="iso8859-1")
-
- tarinfo = tar.getmember("pax/regtype1")
- self.assertEqual(tarinfo.uname, "foo")
- self.assertEqual(tarinfo.gname, "bar")
- self.assertEqual(tarinfo.pax_headers.get("VENDOR.umlauts"), u"ÄÖÜäöüß")
-
- tarinfo = tar.getmember("pax/regtype2")
- self.assertEqual(tarinfo.uname, "")
- self.assertEqual(tarinfo.gname, "bar")
- self.assertEqual(tarinfo.pax_headers.get("VENDOR.umlauts"), u"ÄÖÜäöüß")
-
- tarinfo = tar.getmember("pax/regtype3")
- self.assertEqual(tarinfo.uname, "tarfile")
- self.assertEqual(tarinfo.gname, "tarfile")
- self.assertEqual(tarinfo.pax_headers.get("VENDOR.umlauts"), u"ÄÖÜäöüß")
-
- def test_pax_number_fields(self):
- # All following number fields are read from the pax header.
- tar = tarfile.open(tarname, encoding="iso8859-1")
- tarinfo = tar.getmember("pax/regtype4")
- self.assertEqual(tarinfo.size, 7011)
- self.assertEqual(tarinfo.uid, 123)
- self.assertEqual(tarinfo.gid, 123)
- self.assertEqual(tarinfo.mtime, 1041808783.0)
- self.assertEqual(type(tarinfo.mtime), float)
- self.assertEqual(float(tarinfo.pax_headers["atime"]), 1041808783.0)
- self.assertEqual(float(tarinfo.pax_headers["ctime"]), 1041808783.0)
-
-
-class WriteTestBase(unittest.TestCase):
- # Put all write tests in here that are supposed to be tested
- # in all possible mode combinations.
-
- def test_fileobj_no_close(self):
- fobj = StringIO.StringIO()
- tar = tarfile.open(fileobj=fobj, mode=self.mode)
- tar.addfile(tarfile.TarInfo("foo"))
- tar.close()
- self.assertTrue(fobj.closed is False, "external fileobjs must never closed")
-
-
-class WriteTest(WriteTestBase):
-
- mode = "w:"
-
- def test_100_char_name(self):
- # The name field in a tar header stores strings of at most 100 chars.
- # If a string is shorter than 100 chars it has to be padded with '\0',
- # which implies that a string of exactly 100 chars is stored without
- # a trailing '\0'.
- name = "0123456789" * 10
- tar = tarfile.open(tmpname, self.mode)
- t = tarfile.TarInfo(name)
- tar.addfile(t)
- tar.close()
-
- tar = tarfile.open(tmpname)
- self.assertTrue(tar.getnames()[0] == name,
- "failed to store 100 char filename")
- tar.close()
-
- def test_tar_size(self):
- # Test for bug #1013882.
- tar = tarfile.open(tmpname, self.mode)
- path = os.path.join(TEMPDIR, "file")
- fobj = open(path, "wb")
- fobj.write("aaa")
- fobj.close()
- tar.add(path)
- tar.close()
- self.assertTrue(os.path.getsize(tmpname) > 0,
- "tarfile is empty")
-
- # The test_*_size tests test for bug #1167128.
- def test_file_size(self):
- tar = tarfile.open(tmpname, self.mode)
-
- path = os.path.join(TEMPDIR, "file")
- fobj = open(path, "wb")
- fobj.close()
- tarinfo = tar.gettarinfo(path)
- self.assertEqual(tarinfo.size, 0)
-
- fobj = open(path, "wb")
- fobj.write("aaa")
- fobj.close()
- tarinfo = tar.gettarinfo(path)
- self.assertEqual(tarinfo.size, 3)
-
- tar.close()
-
- def test_directory_size(self):
- path = os.path.join(TEMPDIR, "directory")
- os.mkdir(path)
- try:
- tar = tarfile.open(tmpname, self.mode)
- tarinfo = tar.gettarinfo(path)
- self.assertEqual(tarinfo.size, 0)
- finally:
- os.rmdir(path)
-
- def test_link_size(self):
- if hasattr(os, "link"):
- link = os.path.join(TEMPDIR, "link")
- target = os.path.join(TEMPDIR, "link_target")
- fobj = open(target, "wb")
- fobj.write("aaa")
- fobj.close()
- os.link(target, link)
- try:
- tar = tarfile.open(tmpname, self.mode)
- # Record the link target in the inodes list.
- tar.gettarinfo(target)
- tarinfo = tar.gettarinfo(link)
- self.assertEqual(tarinfo.size, 0)
- finally:
- os.remove(target)
- os.remove(link)
-
- def test_symlink_size(self):
- if hasattr(os, "symlink"):
- path = os.path.join(TEMPDIR, "symlink")
- os.symlink("link_target", path)
- try:
- tar = tarfile.open(tmpname, self.mode)
- tarinfo = tar.gettarinfo(path)
- self.assertEqual(tarinfo.size, 0)
- finally:
- os.remove(path)
-
- def test_add_self(self):
- # Test for #1257255.
- dstname = os.path.abspath(tmpname)
-
- tar = tarfile.open(tmpname, self.mode)
- self.assertTrue(tar.name == dstname, "archive name must be absolute")
-
- tar.add(dstname)
- self.assertTrue(tar.getnames() == [], "added the archive to itself")
-
- cwd = os.getcwd()
- os.chdir(TEMPDIR)
- tar.add(dstname)
- os.chdir(cwd)
- self.assertTrue(tar.getnames() == [], "added the archive to itself")
-
- def test_exclude(self):
- tempdir = os.path.join(TEMPDIR, "exclude")
- os.mkdir(tempdir)
- try:
- for name in ("foo", "bar", "baz"):
- name = os.path.join(tempdir, name)
- open(name, "wb").close()
-
- exclude = os.path.isfile
-
- tar = tarfile.open(tmpname, self.mode, encoding="iso8859-1")
- with test_support.check_warnings(("use the filter argument",
- DeprecationWarning)):
- tar.add(tempdir, arcname="empty_dir", exclude=exclude)
- tar.close()
-
- tar = tarfile.open(tmpname, "r")
- self.assertEqual(len(tar.getmembers()), 1)
- self.assertEqual(tar.getnames()[0], "empty_dir")
- finally:
- shutil.rmtree(tempdir)
-
- def test_filter(self):
- tempdir = os.path.join(TEMPDIR, "filter")
- os.mkdir(tempdir)
- try:
- for name in ("foo", "bar", "baz"):
- name = os.path.join(tempdir, name)
- open(name, "wb").close()
-
- def filter(tarinfo):
- if os.path.basename(tarinfo.name) == "bar":
- return
- tarinfo.uid = 123
- tarinfo.uname = "foo"
- return tarinfo
-
- tar = tarfile.open(tmpname, self.mode, encoding="iso8859-1")
- tar.add(tempdir, arcname="empty_dir", filter=filter)
- tar.close()
-
- tar = tarfile.open(tmpname, "r")
- for tarinfo in tar:
- self.assertEqual(tarinfo.uid, 123)
- self.assertEqual(tarinfo.uname, "foo")
- self.assertEqual(len(tar.getmembers()), 3)
- tar.close()
- finally:
- shutil.rmtree(tempdir)
-
- # Guarantee that stored pathnames are not modified. Don't
- # remove ./ or ../ or double slashes. Still make absolute
- # pathnames relative.
- # For details see bug #6054.
- def _test_pathname(self, path, cmp_path=None, dir=False):
- # Create a tarfile with an empty member named path
- # and compare the stored name with the original.
- foo = os.path.join(TEMPDIR, "foo")
- if not dir:
- open(foo, "w").close()
- else:
- os.mkdir(foo)
-
- tar = tarfile.open(tmpname, self.mode)
- tar.add(foo, arcname=path)
- tar.close()
-
- tar = tarfile.open(tmpname, "r")
- t = tar.next()
- tar.close()
-
- if not dir:
- os.remove(foo)
- else:
- os.rmdir(foo)
-
- self.assertEqual(t.name, cmp_path or path.replace(os.sep, "/"))
-
- def test_pathnames(self):
- self._test_pathname("foo")
- self._test_pathname(os.path.join("foo", ".", "bar"))
- self._test_pathname(os.path.join("foo", "..", "bar"))
- self._test_pathname(os.path.join(".", "foo"))
- self._test_pathname(os.path.join(".", "foo", "."))
- self._test_pathname(os.path.join(".", "foo", ".", "bar"))
- self._test_pathname(os.path.join(".", "foo", "..", "bar"))
- self._test_pathname(os.path.join(".", "foo", "..", "bar"))
- self._test_pathname(os.path.join("..", "foo"))
- self._test_pathname(os.path.join("..", "foo", ".."))
- self._test_pathname(os.path.join("..", "foo", ".", "bar"))
- self._test_pathname(os.path.join("..", "foo", "..", "bar"))
-
- self._test_pathname("foo" + os.sep + os.sep + "bar")
- self._test_pathname("foo" + os.sep + os.sep, "foo", dir=True)
-
- def test_abs_pathnames(self):
- if sys.platform == "win32":
- self._test_pathname("C:\\foo", "foo")
- else:
- self._test_pathname("/foo", "foo")
- self._test_pathname("///foo", "foo")
-
- def test_cwd(self):
- # Test adding the current working directory.
- cwd = os.getcwd()
- os.chdir(TEMPDIR)
- try:
- open("foo", "w").close()
-
- tar = tarfile.open(tmpname, self.mode)
- tar.add(".")
- tar.close()
-
- tar = tarfile.open(tmpname, "r")
- for t in tar:
- self.assert_(t.name == "." or t.name.startswith("./"))
- tar.close()
- finally:
- os.chdir(cwd)
-
-
-class StreamWriteTest(WriteTestBase):
-
- mode = "w|"
-
- def test_stream_padding(self):
- # Test for bug #1543303.
- tar = tarfile.open(tmpname, self.mode)
- tar.close()
-
- if self.mode.endswith("gz"):
- fobj = gzip.GzipFile(tmpname)
- data = fobj.read()
- fobj.close()
- elif self.mode.endswith("bz2"):
- dec = bz2.BZ2Decompressor()
- data = open(tmpname, "rb").read()
- data = dec.decompress(data)
- self.assertTrue(len(dec.unused_data) == 0,
- "found trailing data")
- else:
- fobj = open(tmpname, "rb")
- data = fobj.read()
- fobj.close()
-
- self.assertTrue(data.count("\0") == tarfile.RECORDSIZE,
- "incorrect zero padding")
-
- def test_file_mode(self):
- # Test for issue #8464: Create files with correct
- # permissions.
- if sys.platform == "win32" or not hasattr(os, "umask"):
- return
-
- if os.path.exists(tmpname):
- os.remove(tmpname)
-
- original_umask = os.umask(0022)
- try:
- tar = tarfile.open(tmpname, self.mode)
- tar.close()
- mode = os.stat(tmpname).st_mode & 0777
- self.assertEqual(mode, 0644, "wrong file permissions")
- finally:
- os.umask(original_umask)
-
-
-class GNUWriteTest(unittest.TestCase):
- # This testcase checks for correct creation of GNU Longname
- # and Longlink extended headers (cp. bug #812325).
-
- def _length(self, s):
- blocks, remainder = divmod(len(s) + 1, 512)
- if remainder:
- blocks += 1
- return blocks * 512
-
- def _calc_size(self, name, link=None):
- # Initial tar header
- count = 512
-
- if len(name) > tarfile.LENGTH_NAME:
- # GNU longname extended header + longname
- count += 512
- count += self._length(name)
- if link is not None and len(link) > tarfile.LENGTH_LINK:
- # GNU longlink extended header + longlink
- count += 512
- count += self._length(link)
- return count
-
- def _test(self, name, link=None):
- tarinfo = tarfile.TarInfo(name)
- if link:
- tarinfo.linkname = link
- tarinfo.type = tarfile.LNKTYPE
-
- tar = tarfile.open(tmpname, "w")
- tar.format = tarfile.GNU_FORMAT
- tar.addfile(tarinfo)
-
- v1 = self._calc_size(name, link)
- v2 = tar.offset
- self.assertTrue(v1 == v2, "GNU longname/longlink creation failed")
-
- tar.close()
-
- tar = tarfile.open(tmpname)
- member = tar.next()
- self.assertIsNotNone(member,
- "unable to read longname member")
- self.assertEqual(tarinfo.name, member.name,
- "unable to read longname member")
- self.assertEqual(tarinfo.linkname, member.linkname,
- "unable to read longname member")
- tar.close()
-
- def test_longname_1023(self):
- self._test(("longnam/" * 127) + "longnam")
-
- def test_longname_1024(self):
- self._test(("longnam/" * 127) + "longname")
-
- def test_longname_1025(self):
- self._test(("longnam/" * 127) + "longname_")
-
- def test_longlink_1023(self):
- self._test("name", ("longlnk/" * 127) + "longlnk")
-
- def test_longlink_1024(self):
- self._test("name", ("longlnk/" * 127) + "longlink")
-
- def test_longlink_1025(self):
- self._test("name", ("longlnk/" * 127) + "longlink_")
-
- def test_longnamelink_1023(self):
- self._test(("longnam/" * 127) + "longnam",
- ("longlnk/" * 127) + "longlnk")
-
- def test_longnamelink_1024(self):
- self._test(("longnam/" * 127) + "longname",
- ("longlnk/" * 127) + "longlink")
-
- def test_longnamelink_1025(self):
- self._test(("longnam/" * 127) + "longname_",
- ("longlnk/" * 127) + "longlink_")
-
-
-class HardlinkTest(unittest.TestCase):
- # Test the creation of LNKTYPE (hardlink) members in an archive.
-
- def setUp(self):
- self.foo = os.path.join(TEMPDIR, "foo")
- self.bar = os.path.join(TEMPDIR, "bar")
-
- fobj = open(self.foo, "wb")
- fobj.write("foo")
- fobj.close()
-
- os.link(self.foo, self.bar)
-
- self.tar = tarfile.open(tmpname, "w")
- self.tar.add(self.foo)
-
- def tearDown(self):
- self.tar.close()
- os.remove(self.foo)
- os.remove(self.bar)
-
- def test_add_twice(self):
- # The same name will be added as a REGTYPE every
- # time regardless of st_nlink.
- tarinfo = self.tar.gettarinfo(self.foo)
- self.assertTrue(tarinfo.type == tarfile.REGTYPE,
- "add file as regular failed")
-
- def test_add_hardlink(self):
- tarinfo = self.tar.gettarinfo(self.bar)
- self.assertTrue(tarinfo.type == tarfile.LNKTYPE,
- "add file as hardlink failed")
-
- def test_dereference_hardlink(self):
- self.tar.dereference = True
- tarinfo = self.tar.gettarinfo(self.bar)
- self.assertTrue(tarinfo.type == tarfile.REGTYPE,
- "dereferencing hardlink failed")
-
-
-class PaxWriteTest(GNUWriteTest):
-
- def _test(self, name, link=None):
- # See GNUWriteTest.
- tarinfo = tarfile.TarInfo(name)
- if link:
- tarinfo.linkname = link
- tarinfo.type = tarfile.LNKTYPE
-
- tar = tarfile.open(tmpname, "w", format=tarfile.PAX_FORMAT)
- tar.addfile(tarinfo)
- tar.close()
-
- tar = tarfile.open(tmpname)
- if link:
- l = tar.getmembers()[0].linkname
- self.assertTrue(link == l, "PAX longlink creation failed")
- else:
- n = tar.getmembers()[0].name
- self.assertTrue(name == n, "PAX longname creation failed")
-
- def test_pax_global_header(self):
- pax_headers = {
- u"foo": u"bar",
- u"uid": u"0",
- u"mtime": u"1.23",
- u"test": u"äöü",
- u"äöü": u"test"}
-
- tar = tarfile.open(tmpname, "w", format=tarfile.PAX_FORMAT,
- pax_headers=pax_headers)
- tar.addfile(tarfile.TarInfo("test"))
- tar.close()
-
- # Test if the global header was written correctly.
- tar = tarfile.open(tmpname, encoding="iso8859-1")
- self.assertEqual(tar.pax_headers, pax_headers)
- self.assertEqual(tar.getmembers()[0].pax_headers, pax_headers)
-
- # Test if all the fields are unicode.
- for key, val in tar.pax_headers.iteritems():
- self.assertTrue(type(key) is unicode)
- self.assertTrue(type(val) is unicode)
- if key in tarfile.PAX_NUMBER_FIELDS:
- try:
- tarfile.PAX_NUMBER_FIELDS[key](val)
- except (TypeError, ValueError):
- self.fail("unable to convert pax header field")
-
- def test_pax_extended_header(self):
- # The fields from the pax header have priority over the
- # TarInfo.
- pax_headers = {u"path": u"foo", u"uid": u"123"}
-
- tar = tarfile.open(tmpname, "w", format=tarfile.PAX_FORMAT, encoding="iso8859-1")
- t = tarfile.TarInfo()
- t.name = u"äöü" # non-ASCII
- t.uid = 8**8 # too large
- t.pax_headers = pax_headers
- tar.addfile(t)
- tar.close()
-
- tar = tarfile.open(tmpname, encoding="iso8859-1")
- t = tar.getmembers()[0]
- self.assertEqual(t.pax_headers, pax_headers)
- self.assertEqual(t.name, "foo")
- self.assertEqual(t.uid, 123)
-
-
-class UstarUnicodeTest(unittest.TestCase):
- # All *UnicodeTests FIXME
-
- format = tarfile.USTAR_FORMAT
-
- def test_iso8859_1_filename(self):
- self._test_unicode_filename("iso8859-1")
-
- @unittest.skipIf(test_support.is_jython, "FIXME: not working in Jython")
- def test_utf7_filename(self):
- self._test_unicode_filename("utf7")
-
- def test_utf8_filename(self):
- self._test_unicode_filename("utf8")
-
- def _test_unicode_filename(self, encoding):
- tar = tarfile.open(tmpname, "w", format=self.format, encoding=encoding, errors="strict")
- name = u"äöü"
- tar.addfile(tarfile.TarInfo(name))
- tar.close()
-
- tar = tarfile.open(tmpname, encoding=encoding)
- self.assertTrue(type(tar.getnames()[0]) is not unicode)
- self.assertEqual(tar.getmembers()[0].name, name.encode(encoding))
- tar.close()
-
- def test_unicode_filename_error(self):
- tar = tarfile.open(tmpname, "w", format=self.format, encoding="ascii", errors="strict")
- tarinfo = tarfile.TarInfo()
-
- tarinfo.name = "äöü"
- if self.format == tarfile.PAX_FORMAT:
- self.assertRaises(UnicodeError, tar.addfile, tarinfo)
- else:
- tar.addfile(tarinfo)
-
- tarinfo.name = u"äöü"
- self.assertRaises(UnicodeError, tar.addfile, tarinfo)
-
- tarinfo.name = "foo"
- tarinfo.uname = u"äöü"
- self.assertRaises(UnicodeError, tar.addfile, tarinfo)
-
- def test_unicode_argument(self):
- tar = tarfile.open(tarname, "r", encoding="iso8859-1", errors="strict")
- for t in tar:
- self.assertTrue(type(t.name) is str)
- self.assertTrue(type(t.linkname) is str)
- self.assertTrue(type(t.uname) is str)
- self.assertTrue(type(t.gname) is str)
- tar.close()
-
- def test_uname_unicode(self):
- for name in (u"äöü", "äöü"):
- t = tarfile.TarInfo("foo")
- t.uname = name
- t.gname = name
-
- fobj = StringIO.StringIO()
- tar = tarfile.open("foo.tar", mode="w", fileobj=fobj, format=self.format, encoding="iso8859-1")
- tar.addfile(t)
- tar.close()
- fobj.seek(0)
-
- tar = tarfile.open("foo.tar", fileobj=fobj, encoding="iso8859-1")
- t = tar.getmember("foo")
- self.assertEqual(t.uname, "äöü")
- self.assertEqual(t.gname, "äöü")
-
-
-class GNUUnicodeTest(UstarUnicodeTest):
-
- format = tarfile.GNU_FORMAT
-
-
-class PaxUnicodeTest(UstarUnicodeTest):
-
- format = tarfile.PAX_FORMAT
-
- def _create_unicode_name(self, name):
- tar = tarfile.open(tmpname, "w", format=self.format)
- t = tarfile.TarInfo()
- t.pax_headers["path"] = name
- tar.addfile(t)
- tar.close()
-
- @unittest.skipIf(test_support.is_jython, "FIXME: not working in Jython")
- def test_error_handlers(self):
- # Test if the unicode error handlers work correctly for characters
- # that cannot be expressed in a given encoding.
- self._create_unicode_name(u"äöü")
-
- for handler, name in (("utf-8", u"äöü".encode("utf8")),
- ("replace", "???"), ("ignore", "")):
- tar = tarfile.open(tmpname, format=self.format, encoding="ascii",
- errors=handler)
- self.assertEqual(tar.getnames()[0], name)
-
- self.assertRaises(UnicodeError, tarfile.open, tmpname,
- encoding="ascii", errors="strict")
-
- def test_error_handler_utf8(self):
- # Create a pathname that has one component representable using
- # iso8859-1 and the other only in iso8859-15.
- self._create_unicode_name(u"äöü/¤")
-
- tar = tarfile.open(tmpname, format=self.format, encoding="iso8859-1",
- errors="utf-8")
- self.assertEqual(tar.getnames()[0], "äöü/" + u"¤".encode("utf8"))
-
-
-class AppendTest(unittest.TestCase):
- # Test append mode (cp. patch #1652681).
-
- def setUp(self):
- self.tarname = tmpname
- if os.path.exists(self.tarname):
- os.remove(self.tarname)
-
- def _add_testfile(self, fileobj=None):
- tar = tarfile.open(self.tarname, "a", fileobj=fileobj)
- tar.addfile(tarfile.TarInfo("bar"))
- tar.close()
-
- def _create_testtar(self, mode="w:"):
- src = tarfile.open(tarname, encoding="iso8859-1")
- t = src.getmember("ustar/regtype")
- t.name = "foo"
- f = src.extractfile(t)
- tar = tarfile.open(self.tarname, mode)
- tar.addfile(t, f)
- tar.close()
-
- def _test(self, names=["bar"], fileobj=None):
- tar = tarfile.open(self.tarname, fileobj=fileobj)
- self.assertEqual(tar.getnames(), names)
-
- def test_non_existing(self):
- self._add_testfile()
- self._test()
-
- def test_empty(self):
- tarfile.open(self.tarname, "w:").close()
- self._add_testfile()
- self._test()
-
- def test_empty_fileobj(self):
- fobj = StringIO.StringIO("\0" * 1024)
- self._add_testfile(fobj)
- fobj.seek(0)
- self._test(fileobj=fobj)
-
- def test_fileobj(self):
- self._create_testtar()
- data = open(self.tarname).read()
- fobj = StringIO.StringIO(data)
- self._add_testfile(fobj)
- fobj.seek(0)
- self._test(names=["foo", "bar"], fileobj=fobj)
-
- def test_existing(self):
- self._create_testtar()
- self._add_testfile()
- self._test(names=["foo", "bar"])
-
- def test_append_gz(self):
- if gzip is None:
- return
- self._create_testtar("w:gz")
- self.assertRaises(tarfile.ReadError, tarfile.open, tmpname, "a")
-
- def test_append_bz2(self):
- if bz2 is None:
- return
- self._create_testtar("w:bz2")
- self.assertRaises(tarfile.ReadError, tarfile.open, tmpname, "a")
-
- # Append mode is supposed to fail if the tarfile to append to
- # does not end with a zero block.
- def _test_error(self, data):
- open(self.tarname, "wb").write(data)
- self.assertRaises(tarfile.ReadError, self._add_testfile)
-
- def test_null(self):
- self._test_error("")
-
- def test_incomplete(self):
- self._test_error("\0" * 13)
-
- def test_premature_eof(self):
- data = tarfile.TarInfo("foo").tobuf()
- self._test_error(data)
-
- def test_trailing_garbage(self):
- data = tarfile.TarInfo("foo").tobuf()
- self._test_error(data + "\0" * 13)
-
- def test_invalid(self):
- self._test_error("a" * 512)
-
-
-class LimitsTest(unittest.TestCase):
-
- def test_ustar_limits(self):
- # 100 char name
- tarinfo = tarfile.TarInfo("0123456789" * 10)
- tarinfo.tobuf(tarfile.USTAR_FORMAT)
-
- # 101 char name that cannot be stored
- tarinfo = tarfile.TarInfo("0123456789" * 10 + "0")
- self.assertRaises(ValueError, tarinfo.tobuf, tarfile.USTAR_FORMAT)
-
- # 256 char name with a slash at pos 156
- tarinfo = tarfile.TarInfo("123/" * 62 + "longname")
- tarinfo.tobuf(tarfile.USTAR_FORMAT)
-
- # 256 char name that cannot be stored
- tarinfo = tarfile.TarInfo("1234567/" * 31 + "longname")
- self.assertRaises(ValueError, tarinfo.tobuf, tarfile.USTAR_FORMAT)
-
- # 512 char name
- tarinfo = tarfile.TarInfo("123/" * 126 + "longname")
- self.assertRaises(ValueError, tarinfo.tobuf, tarfile.USTAR_FORMAT)
-
- # 512 char linkname
- tarinfo = tarfile.TarInfo("longlink")
- tarinfo.linkname = "123/" * 126 + "longname"
- self.assertRaises(ValueError, tarinfo.tobuf, tarfile.USTAR_FORMAT)
-
- # uid > 8 digits
- tarinfo = tarfile.TarInfo("name")
- tarinfo.uid = 010000000
- self.assertRaises(ValueError, tarinfo.tobuf, tarfile.USTAR_FORMAT)
-
- def test_gnu_limits(self):
- tarinfo = tarfile.TarInfo("123/" * 126 + "longname")
- tarinfo.tobuf(tarfile.GNU_FORMAT)
-
- tarinfo = tarfile.TarInfo("longlink")
- tarinfo.linkname = "123/" * 126 + "longname"
- tarinfo.tobuf(tarfile.GNU_FORMAT)
-
- # uid >= 256 ** 7
- tarinfo = tarfile.TarInfo("name")
- tarinfo.uid = 04000000000000000000L
- self.assertRaises(ValueError, tarinfo.tobuf, tarfile.GNU_FORMAT)
-
- def test_pax_limits(self):
- tarinfo = tarfile.TarInfo("123/" * 126 + "longname")
- tarinfo.tobuf(tarfile.PAX_FORMAT)
-
- tarinfo = tarfile.TarInfo("longlink")
- tarinfo.linkname = "123/" * 126 + "longname"
- tarinfo.tobuf(tarfile.PAX_FORMAT)
-
- tarinfo = tarfile.TarInfo("name")
- tarinfo.uid = 04000000000000000000L
- tarinfo.tobuf(tarfile.PAX_FORMAT)
-
-
-class ContextManagerTest(unittest.TestCase):
-
- def test_basic(self):
- with tarfile.open(tarname) as tar:
- self.assertFalse(tar.closed, "closed inside runtime context")
- self.assertTrue(tar.closed, "context manager failed")
-
- def test_closed(self):
- # The __enter__() method is supposed to raise IOError
- # if the TarFile object is already closed.
- tar = tarfile.open(tarname)
- tar.close()
- with self.assertRaises(IOError):
- with tar:
- pass
-
- def test_exception(self):
- # Test if the IOError exception is passed through properly.
- with self.assertRaises(Exception) as exc:
- with tarfile.open(tarname) as tar:
- raise IOError
- self.assertIsInstance(exc.exception, IOError,
- "wrong exception raised in context manager")
- self.assertTrue(tar.closed, "context manager failed")
-
- def test_no_eof(self):
- # __exit__() must not write end-of-archive blocks if an
- # exception was raised.
- try:
- with tarfile.open(tmpname, "w") as tar:
- raise Exception
- except:
- pass
- self.assertEqual(os.path.getsize(tmpname), 0,
- "context manager wrote an end-of-archive block")
- self.assertTrue(tar.closed, "context manager failed")
-
- def test_eof(self):
- # __exit__() must write end-of-archive blocks, i.e. call
- # TarFile.close() if there was no error.
- with tarfile.open(tmpname, "w"):
- pass
- self.assertNotEqual(os.path.getsize(tmpname), 0,
- "context manager wrote no end-of-archive block")
-
- def test_fileobj(self):
- # Test that __exit__() did not close the external file
- # object.
- fobj = open(tmpname, "wb")
- try:
- with tarfile.open(fileobj=fobj, mode="w") as tar:
- raise Exception
- except:
- pass
- self.assertFalse(fobj.closed, "external file object was closed")
- self.assertTrue(tar.closed, "context manager failed")
- fobj.close()
-
-
-class LinkEmulationTest(ReadTest):
-
- # Test for issue #8741 regression. On platforms that do not support
- # symbolic or hard links tarfile tries to extract these types of members as
- # the regular files they point to.
- def _test_link_extraction(self, name):
- self.tar.extract(name, TEMPDIR)
- data = open(os.path.join(TEMPDIR, name), "rb").read()
- self.assertEqual(md5sum(data), md5_regtype)
-
- def test_hardlink_extraction1(self):
- self._test_link_extraction("ustar/lnktype")
-
- def test_hardlink_extraction2(self):
- self._test_link_extraction("./ustar/linktest2/lnktype")
-
- def test_symlink_extraction1(self):
- self._test_link_extraction("ustar/symtype")
-
- def test_symlink_extraction2(self):
- self._test_link_extraction("./ustar/linktest2/symtype")
-
-
-class GzipMiscReadTest(MiscReadTest):
- tarname = gzipname
- mode = "r:gz"
-class GzipUstarReadTest(UstarReadTest):
- tarname = gzipname
- mode = "r:gz"
-class GzipStreamReadTest(StreamReadTest):
- tarname = gzipname
- mode = "r|gz"
-class GzipWriteTest(WriteTest):
- mode = "w:gz"
-class GzipStreamWriteTest(StreamWriteTest):
- mode = "w|gz"
-
-
-class Bz2MiscReadTest(MiscReadTest):
- tarname = bz2name
- mode = "r:bz2"
-class Bz2UstarReadTest(UstarReadTest):
- tarname = bz2name
- mode = "r:bz2"
-class Bz2StreamReadTest(StreamReadTest):
- tarname = bz2name
- mode = "r|bz2"
-class Bz2WriteTest(WriteTest):
- mode = "w:bz2"
-class Bz2StreamWriteTest(StreamWriteTest):
- mode = "w|bz2"
-
-class Bz2PartialReadTest(unittest.TestCase):
- # Issue5068: The _BZ2Proxy.read() method loops forever
- # on an empty or partial bzipped file.
-
- def _test_partial_input(self, mode):
- class MyStringIO(StringIO.StringIO):
- hit_eof = False
- def read(self, n):
- if self.hit_eof:
- raise AssertionError("infinite loop detected in tarfile.open()")
- self.hit_eof = self.pos == self.len
- return StringIO.StringIO.read(self, n)
- def seek(self, *args):
- self.hit_eof = False
- return StringIO.StringIO.seek(self, *args)
-
- data = bz2.compress(tarfile.TarInfo("foo").tobuf())
- for x in range(len(data) + 1):
- try:
- tarfile.open(fileobj=MyStringIO(data[:x]), mode=mode)
- except tarfile.ReadError:
- pass # we have no interest in ReadErrors
-
- def test_partial_input(self):
- self._test_partial_input("r")
-
- def test_partial_input_bz2(self):
- self._test_partial_input("r:bz2")
-
-
-def test_main():
- os.makedirs(TEMPDIR)
-
- tests = [
- UstarReadTest,
- MiscReadTest,
- StreamReadTest,
- DetectReadTest,
- MemberReadTest,
- GNUReadTest,
- PaxReadTest,
- WriteTest,
- StreamWriteTest,
- GNUWriteTest,
- PaxWriteTest,
- UstarUnicodeTest,
- GNUUnicodeTest,
- PaxUnicodeTest,
- AppendTest,
- LimitsTest,
- ContextManagerTest,
- ]
-
- if hasattr(os, "link"):
- tests.append(HardlinkTest)
- else:
- tests.append(LinkEmulationTest)
-
- fobj = open(tarname, "rb")
- data = fobj.read()
- fobj.close()
-
- if gzip:
- # Create testtar.tar.gz and add gzip-specific tests.
- tar = gzip.open(gzipname, "wb")
- tar.write(data)
- tar.close()
-
- tests += [
- GzipMiscReadTest,
- GzipUstarReadTest,
- GzipStreamReadTest,
- GzipWriteTest,
- GzipStreamWriteTest,
- ]
-
- if bz2:
- # Create testtar.tar.bz2 and add bz2-specific tests.
- tar = bz2.BZ2File(bz2name, "wb")
- tar.write(data)
- tar.close()
-
- tests += [
- Bz2MiscReadTest,
- Bz2UstarReadTest,
- Bz2StreamReadTest,
- Bz2WriteTest,
- Bz2StreamWriteTest,
- Bz2PartialReadTest,
- ]
-
- try:
- test_support.run_unittest(*tests)
- finally:
- if os.path.exists(TEMPDIR):
- shutil.rmtree(TEMPDIR)
-
-if __name__ == "__main__":
- test_main()
diff --git a/Lib/test/test_time.py b/Lib/test/test_time.py
--- a/Lib/test/test_time.py
+++ b/Lib/test/test_time.py
@@ -125,7 +125,6 @@
except ValueError:
self.fail('strptime failed on empty args.')
- @unittest.skip("FIXME: broken")
def test_asctime(self):
time.asctime(time.gmtime(self.t))
self.assertRaises(TypeError, time.asctime, 0)
diff --git a/Lib/test/test_types.py b/Lib/test/test_types.py
--- a/Lib/test/test_types.py
+++ b/Lib/test/test_types.py
@@ -741,7 +741,7 @@
for code in 'xXobns':
self.assertRaises(ValueError, format, 0, ',' + code)
- @unittest.skipIf(is_jython, "FIXME: not working")
+ @unittest.skipIf(is_jython, "Java does not allow access to object sizes")
def test_internal_sizes(self):
self.assertGreater(object.__basicsize__, 0)
self.assertGreater(tuple.__itemsize__, 0)
diff --git a/Lib/test/test_weakref.py b/Lib/test/test_weakref.py
--- a/Lib/test/test_weakref.py
+++ b/Lib/test/test_weakref.py
@@ -153,7 +153,7 @@
o2 = C()
ref3 = weakref.proxy(o2)
del o2
- gc.collect()
+ extra_collect()
self.assertRaises(weakref.ReferenceError, bool, ref3)
self.assertTrue(self.cbcalled == 2)
@@ -638,7 +638,7 @@
del c1, c2, C # make them all trash
self.assertEqual(alist, []) # del isn't enough to reclaim anything
- gc.collect()
+ extra_collect()
# c1.wr and c2.wr were part of the cyclic trash, so should have
# been cleared without their callbacks executing. OTOH, the weakref
# to C is bound to a function local (wr), and wasn't trash, so that
@@ -682,7 +682,7 @@
del callback, c, d, C
self.assertEqual(alist, []) # del isn't enough to clean up cycles
- gc.collect()
+ extra_collect()
self.assertEqual(alist, ["safe_callback called"])
self.assertEqual(external_wr(), None)
@@ -755,12 +755,12 @@
weakref.ref(int)
a = weakref.ref(A, l.append)
A = None
- gc.collect()
+ extra_collect()
self.assertEqual(a(), None)
self.assertEqual(l, [a])
b = weakref.ref(B, l.append)
B = None
- gc.collect()
+ extra_collect()
self.assertEqual(b(), None)
self.assertEqual(l, [a, b])
@@ -850,7 +850,7 @@
self.assertTrue(mr.called)
self.assertEqual(mr.value, 24)
del o
- gc.collect()
+ extra_collect()
self.assertTrue(mr() is None)
self.assertTrue(mr.called)
diff --git a/Lib/test/test_weakset.py b/Lib/test/test_weakset.py
--- a/Lib/test/test_weakset.py
+++ b/Lib/test/test_weakset.py
@@ -84,7 +84,11 @@
self.assertEqual(len(self.fs), 1)
del self.obj
gc.collect()
- self.assertEqual(len(self.fs), 0)
+ # len of weak collections is eventually consistent on
+ # Jython. In practice this does not matter because of the
+ # nature of weaksets - we cannot rely on what happens in the
+ # reaper thread and how it interacts with gc
+ self.assertIn(len(self.fs), (0, 1))
def test_contains(self):
for c in self.letters:
@@ -391,7 +395,7 @@
# We have removed either the first consumed items, or another one
self.assertIn(len(list(it)), [len(items), len(items) - 1])
del it
- gc.collect()
+ extra_collect()
# The removal has been committed
self.assertEqual(len(s), len(items))
diff --git a/Lib/test/test_xmlrpc.py b/Lib/test/test_xmlrpc.py
--- a/Lib/test/test_xmlrpc.py
+++ b/Lib/test/test_xmlrpc.py
@@ -304,6 +304,15 @@
s.setblocking(True)
return s, port
+ def handle_error(self, request, client_address):
+ # test_partial_post causes a close error (as might be
+ # expected), apparently because the timing is different
+ # between CPython and Jython. So ignore so that the
+ # default SocketServer.handle_error logging does not cause
+ # issues in unexpected text output in the overall
+ # regrtest.
+ pass
+
if not requestHandler:
requestHandler = SimpleXMLRPCServer.SimpleXMLRPCRequestHandler
serv = MyXMLRPCServer(("localhost", 0), requestHandler,
@@ -605,7 +614,10 @@
# Check that a partial POST doesn't make the server loop: issue #14001.
conn = httplib.HTTPConnection(ADDR, PORT)
conn.request('POST', '/RPC2 HTTP/1.0\r\nContent-Length: 100\r\n\r\nbye')
- conn.close()
+ try:
+ conn.close()
+ except Exception, e:
+ print "Got this exception", type(e), e
class MultiPathServerTestCase(BaseServerTestCase):
threadFunc = staticmethod(http_multi_server)
diff --git a/src/org/python/compiler/ClassFile.java b/src/org/python/compiler/ClassFile.java
--- a/src/org/python/compiler/ClassFile.java
+++ b/src/org/python/compiler/ClassFile.java
@@ -189,7 +189,7 @@
av.visitEnd();
}
}
-
+
public void write(OutputStream stream) throws IOException {
cw.visit(Opcodes.V1_5, Opcodes.ACC_PUBLIC + Opcodes.ACC_SUPER, this.name, null, this.superclass, interfaces);
AnnotationVisitor av = cw.visitAnnotation("Lorg/python/compiler/APIVersion;", true);
@@ -203,6 +203,9 @@
av.visitEnd();
if (sfilename != null) {
+ av = cw.visitAnnotation("Lorg/python/compiler/Filename;", true);
+ av.visit("value", sfilename);
+ av.visitEnd();
cw.visitSource(sfilename, null);
}
endClassAnnotations();
diff --git a/src/org/python/compiler/Filename.java b/src/org/python/compiler/Filename.java
new file mode 100644
--- /dev/null
+++ b/src/org/python/compiler/Filename.java
@@ -0,0 +1,9 @@
+package org.python.compiler;
+
+import java.lang.annotation.Retention;
+import java.lang.annotation.RetentionPolicy;
+
+ at Retention(RetentionPolicy.RUNTIME)
+public @interface Filename {
+ String value();
+}
diff --git a/src/org/python/core/AnnotationReader.java b/src/org/python/core/AnnotationReader.java
--- a/src/org/python/core/AnnotationReader.java
+++ b/src/org/python/core/AnnotationReader.java
@@ -24,9 +24,11 @@
private boolean nextVisitIsVersion = false;
private boolean nextVisitIsMTime = false;
+ private boolean nextVisitIsFilename = false;
private int version = -1;
private long mtime = -1;
+ private String filename = null;
/**
* Reads the classfile bytecode in data and to extract the version.
@@ -50,6 +52,7 @@
public AnnotationVisitor visitAnnotation(String desc, boolean visible) {
nextVisitIsVersion = desc.equals("Lorg/python/compiler/APIVersion;");
nextVisitIsMTime = desc.equals("Lorg/python/compiler/MTime;");
+ nextVisitIsFilename = desc.equals("Lorg/python/compiler/Filename;");
return new AnnotationVisitor(Opcodes.ASM4) {
public void visit(String name, Object value) {
@@ -58,8 +61,11 @@
nextVisitIsVersion = false;
} else if (nextVisitIsMTime) {
mtime = (Long)value;
- nextVisitIsVersion = false;
- }
+ nextVisitIsMTime = false;
+ } else if (nextVisitIsFilename) {
+ filename = (String)value;
+ nextVisitIsFilename = false;
+ }
}
};
}
@@ -71,4 +77,8 @@
public long getMTime() {
return mtime;
}
+
+ public String getFilename() {
+ return filename;
+ }
}
diff --git a/src/org/python/core/BaseSet.java b/src/org/python/core/BaseSet.java
--- a/src/org/python/core/BaseSet.java
+++ b/src/org/python/core/BaseSet.java
@@ -20,6 +20,10 @@
_set = set;
}
+ public Set<PyObject> getSet() {
+ return _set;
+ }
+
protected void _update(PyObject data) {
_update(_set, data);
}
diff --git a/src/org/python/core/JavaIterator.java b/src/org/python/core/JavaIterator.java
new file mode 100644
--- /dev/null
+++ b/src/org/python/core/JavaIterator.java
@@ -0,0 +1,20 @@
+package org.python.core;
+
+import java.util.Iterator;
+
+public class JavaIterator extends PyIterator {
+
+ final private Iterator<Object> proxy;
+
+ public JavaIterator(Iterable<Object> proxy) {
+ this(proxy.iterator());
+ }
+
+ public JavaIterator(Iterator<Object> proxy) {
+ this.proxy = proxy;
+ }
+
+ public PyObject __iternext__() {
+ return proxy.hasNext() ? Py.java2py(proxy.next()) : null;
+ }
+}
diff --git a/src/org/python/core/JavaProxyList.java b/src/org/python/core/JavaProxyList.java
new file mode 100644
--- /dev/null
+++ b/src/org/python/core/JavaProxyList.java
@@ -0,0 +1,637 @@
+package org.python.core;
+
+/**
+ * Proxy Java objects implementing java.util.List with Python methods
+ * corresponding to the standard list type
+ */
+
+import org.python.util.Generic;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.ConcurrentModificationException;
+import java.util.Iterator;
+import java.util.List;
+import java.util.ListIterator;
+
+
+class JavaProxyList {
+
+ private static class ListMethod extends PyBuiltinMethodNarrow {
+ protected ListMethod(String name, int numArgs) {
+ super(name, numArgs);
+ }
+
+ protected ListMethod(String name, int minArgs, int maxArgs) {
+ super(name, minArgs, maxArgs);
+ }
+
+ protected List asList() {
+ return (List) self.getJavaProxy();
+ }
+
+ protected List newList() {
+ try {
+ return (List) asList().getClass().newInstance();
+ } catch (IllegalAccessException e) {
+ throw Py.JavaError(e);
+ } catch (InstantiationException e) {
+ throw Py.JavaError(e);
+ }
+ }
+ }
+
+ protected static class ListIndexDelegate extends SequenceIndexDelegate {
+
+ private final List list;
+
+ public ListIndexDelegate(List list) {
+ this.list = list;
+ }
+
+ @Override
+ public void delItem(int idx) {
+ list.remove(idx);
+ }
+
+ @Override
+ public PyObject getItem(int idx) {
+ return Py.java2py(list.get(idx));
+ }
+
+ @Override
+ public PyObject getSlice(int start, int stop, int step) {
+ if (step > 0 && stop < start) {
+ stop = start;
+ }
+ int n = PySequence.sliceLength(start, stop, step);
+ List newList;
+ try {
+ newList = list.getClass().newInstance();
+ } catch (Exception e) {
+ throw Py.JavaError(e);
+ }
+ int j = 0;
+ for (int i = start; j < n; i += step) {
+ newList.add(list.get(i));
+ j++;
+ }
+ return Py.java2py(newList);
+ }
+
+ @Override
+ public String getTypeName() {
+ return list.getClass().getName();
+ }
+
+ @Override
+ public int len() {
+ return list.size();
+ }
+
+ protected int fixBoundIndex(PyObject index) {
+ PyInteger length = Py.newInteger(len());
+ if (index._lt(Py.Zero).__nonzero__()) {
+ index = index._add(length);
+ if (index._lt(Py.Zero).__nonzero__()) {
+ index = Py.Zero;
+ }
+ } else if (index._gt(length).__nonzero__()) {
+ index = length;
+ }
+ int i = index.asIndex();
+ assert i >= 0;
+ return i;
+ }
+
+ @Override
+ public void setItem(int idx, PyObject value) {
+ list.set(idx, value.__tojava__(Object.class));
+ }
+
+ @Override
+ public void setSlice(int start, int stop, int step, PyObject value) {
+ if (stop < start) {
+ stop = start;
+ }
+ if (value.javaProxy == this.list) {
+ List xs = Generic.list();
+ xs.addAll(this.list);
+ setsliceList(start, stop, step, xs);
+ } else if (value instanceof PyList) {
+ setslicePyList(start, stop, step, (PyList) value);
+ } else {
+ Object valueList = value.__tojava__(List.class);
+ if (valueList != null && valueList != Py.NoConversion) {
+ setsliceList(start, stop, step, (List) valueList);
+ } else {
+ setsliceIterator(start, stop, step, value.asIterable().iterator());
+ }
+ }
+ }
+
+ final private void setsliceList(int start, int stop, int step, List value) {
+ if (step == 1) {
+ list.subList(start, stop).clear();
+ list.addAll(start, value);
+ } else {
+ int size = list.size();
+ Iterator<Object> iter = value.listIterator();
+ for (int j = start; iter.hasNext(); j += step) {
+ Object item = iter.next();
+ if (j >= size) {
+ list.add(item);
+ } else {
+ list.set(j, item);
+ }
+ }
+ }
+ }
+
+ final private void setsliceIterator(int start, int stop, int step, Iterator<PyObject> iter) {
+ if (step == 1) {
+ List insertion = new ArrayList();
+ if (iter != null) {
+ while (iter.hasNext()) {
+ insertion.add(iter.next().__tojava__(Object.class));
+ }
+ }
+ list.subList(start, stop).clear();
+ list.addAll(start, insertion);
+ } else {
+ int size = list.size();
+ for (int j = start; iter.hasNext(); j += step) {
+ Object item = iter.next().__tojava__(Object.class);
+ if (j >= size) {
+ list.add(item);
+ } else {
+ list.set(j, item);
+ }
+ }
+ }
+ }
+
+ final private void setslicePyList(int start, int stop, int step, PyList value) {
+ if (step == 1) {
+ list.subList(start, stop).clear();
+ int n = value.getList().size();
+ for (int i = 0, j = start; i < n; i++, j++) {
+ Object item = value.getList().get(i).__tojava__(Object.class);
+ list.add(j, item);
+ }
+ } else {
+ int size = list.size();
+ Iterator<PyObject> iter = value.getList().listIterator();
+ for (int j = start; iter.hasNext(); j += step) {
+ Object item = iter.next().__tojava__(Object.class);
+ if (j >= size) {
+ list.add(item);
+ } else {
+ list.set(j, item);
+ }
+ }
+ }
+ }
+
+
+ @Override
+ public void delItems(int start, int stop) {
+ int n = stop - start;
+ while (n-- > 0) {
+ delItem(start);
+ }
+ }
+ }
+
+
+ private static class ListMulProxyClass extends ListMethod {
+ protected ListMulProxyClass(String name, int numArgs) {
+ super(name, numArgs);
+ }
+
+ @Override
+ public PyObject __call__(PyObject obj) {
+ List jList = asList();
+ int mult = obj.asInt();
+ List newList = null;
+ // anything below 0 multiplier, we return an empty list
+ if (mult > 0) {
+ try {
+ newList = new ArrayList(jList.size() * mult);
+ // otherwise, extend it x times, where x is int-cast from obj
+ for (; mult > 0; mult--) {
+ for (Object entry : jList) {
+ newList.add(entry);
+ }
+ }
+ } catch (OutOfMemoryError t) {
+ throw Py.MemoryError("");
+ }
+ } else {
+ newList = Collections.EMPTY_LIST;
+ }
+ return Py.java2py(newList);
+ }
+ }
+
+
+ private static class KV {
+
+ private final PyObject key;
+ private final Object value;
+
+ KV(PyObject key, Object value) {
+ this.key = key;
+ this.value = value;
+ }
+ }
+
+ private static class KVComparator implements Comparator<KV> {
+
+ private final PyObject cmp;
+
+ KVComparator(PyObject cmp) {
+ this.cmp = cmp;
+ }
+
+ public int compare(KV o1, KV o2) {
+ int result;
+ if (cmp != null && cmp != Py.None) {
+ PyObject pyresult = cmp.__call__(o1.key, o2.key);
+ if (pyresult instanceof PyInteger || pyresult instanceof PyLong) {
+ return pyresult.asInt();
+ } else {
+ throw Py.TypeError(
+ String.format("comparison function must return int, not %.200s",
+ pyresult.getType().fastGetName()));
+ }
+ } else {
+ result = o1.key._cmp(o2.key);
+ }
+ return result;
+ }
+
+ public boolean equals(Object o) {
+ if (o == this) {
+ return true;
+ }
+
+ if (o instanceof KVComparator) {
+ return cmp.equals(((KVComparator) o).cmp);
+ }
+ return false;
+ }
+ }
+
+ private synchronized static void list_sort(List list, PyObject cmp, PyObject key, boolean reverse) {
+ int size = list.size();
+ final ArrayList<KV> decorated = new ArrayList(size);
+ for (Object value : list) {
+ PyObject pyvalue = Py.java2py(value);
+ if (key == null || key == Py.None) {
+ decorated.add(new KV(pyvalue, value));
+ } else {
+ decorated.add(new KV(key.__call__(pyvalue), value));
+ }
+ }
+ // we will rebuild the list from the decorated version
+ list.clear();
+ KVComparator c = new KVComparator(cmp);
+ if (reverse) {
+ Collections.reverse(decorated); // maintain stability of sort by reversing first
+ }
+ Collections.sort(decorated, c);
+ if (reverse) {
+ Collections.reverse(decorated);
+ }
+ boolean modified = list.size() > 0;
+ for (KV kv : decorated) {
+ list.add(kv.value);
+ }
+ if (modified) {
+ throw Py.ValueError("list modified during sort");
+ }
+ }
+
+ private static final PyBuiltinMethodNarrow listGetProxy = new ListMethod("__getitem__", 1) {
+ @Override
+ public PyObject __call__(PyObject key) {
+ return new ListIndexDelegate(asList()).checkIdxAndGetItem(key);
+ }
+ };
+ private static final PyBuiltinMethodNarrow listSetProxy = new ListMethod("__setitem__", 2) {
+ @Override
+ public PyObject __call__(PyObject key, PyObject value) {
+ new ListIndexDelegate(asList()).checkIdxAndSetItem(key, value);
+ return Py.None;
+ }
+ };
+ private static final PyBuiltinMethodNarrow listRemoveProxy = new ListMethod("__delitem__", 1) {
+ @Override
+ public PyObject __call__(PyObject key) {
+ new ListIndexDelegate(asList()).checkIdxAndDelItem(key);
+ return Py.None;
+ }
+ };
+ private static final PyBuiltinMethodNarrow listEqProxy = new ListMethod("__eq__", 1) {
+ @Override
+ public PyObject __call__(PyObject other) {
+ List jList = asList();
+ if (other.getType().isSubType(PyList.TYPE)) {
+ PyList oList = (PyList) other;
+ if (jList.size() != oList.size()) {
+ return Py.False;
+ }
+ for (int i = 0; i < jList.size(); i++) {
+ if (!Py.java2py(jList.get(i))._eq(oList.pyget(i)).__nonzero__()) {
+ return Py.False;
+ }
+ }
+ return Py.True;
+ } else {
+ Object oj = other.getJavaProxy();
+ if (oj instanceof List) {
+ List oList = (List) oj;
+ if (jList.size() != oList.size()) {
+ return Py.False;
+ }
+ for (int i = 0; i < jList.size(); i++) {
+ if (!Py.java2py(jList.get(i))._eq(
+ Py.java2py(oList.get(i))).__nonzero__()) {
+ return Py.False;
+ }
+ }
+ return Py.True;
+ } else {
+ return null;
+ }
+ }
+ }
+ };
+ private static final PyBuiltinMethodNarrow listAppendProxy = new ListMethod("append", 1) {
+ @Override
+ public PyObject __call__(PyObject value) {
+ asList().add(value);
+ return Py.None;
+ }
+ };
+ private static final PyBuiltinMethodNarrow listExtendProxy = new ListMethod("extend", 1) {
+ @Override
+ public PyObject __call__(PyObject obj) {
+ List jList = asList();
+ List extension = new ArrayList();
+
+ // Extra step to build the extension list is necessary
+ // in case of adding to oneself
+ for (PyObject item : obj.asIterable()) {
+ extension.add(item);
+ }
+ jList.addAll(extension);
+ return Py.None;
+ }
+ };
+ private static final PyBuiltinMethodNarrow listInsertProxy = new ListMethod("insert", 2) {
+ @Override
+ public PyObject __call__(PyObject index, PyObject object) {
+ List jlist = asList();
+ ListIndexDelegate lid = new ListIndexDelegate(jlist);
+ int idx = lid.fixBoundIndex(index);
+ jlist.add(idx, object);
+ return Py.None;
+ }
+ };
+ private static final PyBuiltinMethodNarrow listPopProxy = new ListMethod("pop", 0, 1) {
+ @Override
+ public PyObject __call__() {
+ return __call__(Py.newInteger(-1));
+ }
+
+ @Override
+ public PyObject __call__(PyObject index) {
+ List jlist = asList();
+ if (jlist.isEmpty()) {
+ throw Py.IndexError("pop from empty list");
+ }
+ ListIndexDelegate ldel = new ListIndexDelegate(jlist);
+ PyObject item = ldel.checkIdxAndFindItem(index.asInt());
+ if (item == null) {
+ throw Py.IndexError("pop index out of range");
+ } else {
+ ldel.checkIdxAndDelItem(index);
+ return item;
+ }
+ }
+ };
+ private static final PyBuiltinMethodNarrow listIndexProxy = new ListMethod("index", 1, 3) {
+ @Override
+ public PyObject __call__(PyObject object) {
+ return __call__(object, Py.newInteger(0), Py.newInteger(asList().size()));
+ }
+
+ @Override
+ public PyObject __call__(PyObject object, PyObject start) {
+ return __call__(object, start, Py.newInteger(asList().size()));
+ }
+
+ @Override
+ public PyObject __call__(PyObject object, PyObject start, PyObject end) {
+ List jlist = asList();
+ ListIndexDelegate lid = new ListIndexDelegate(jlist);
+ int start_index = lid.fixBoundIndex(start);
+ int end_index = lid.fixBoundIndex(end);
+ int i = start_index;
+ try {
+ for (ListIterator it = jlist.listIterator(start_index); it.hasNext(); i++) {
+ if (i == end_index) {
+ break;
+ }
+ Object jobj = it.next();
+ if (Py.java2py(jobj)._eq(object).__nonzero__()) {
+ return Py.newInteger(i);
+ }
+ }
+ } catch (ConcurrentModificationException e) {
+ throw Py.ValueError(object.toString() + " is not in list");
+ }
+ throw Py.ValueError(object.toString() + " is not in list");
+ }
+ };
+ private static final PyBuiltinMethodNarrow listCountProxy = new ListMethod("count", 1) {
+ @Override
+ public PyObject __call__(PyObject object) {
+ int count = 0;
+ List jlist = asList();
+ for (int i = 0; i < jlist.size(); i++) {
+ Object jobj = jlist.get(i);
+ if (Py.java2py(jobj)._eq(object).__nonzero__()) {
+ ++count;
+ }
+ }
+ return Py.newInteger(count);
+ }
+ };
+ private static final PyBuiltinMethodNarrow listReverseProxy = new ListMethod("reverse", 0) {
+ @Override
+ public PyObject __call__() {
+ List jlist = asList();
+ Collections.reverse(jlist);
+ return Py.None;
+ }
+ };
+ private static final PyBuiltinMethodNarrow listRemoveOverrideProxy = new ListMethod("remove", 1) {
+ @Override
+ public PyObject __call__(PyObject object) {
+ List jlist = asList();
+ for (int i = 0; i < jlist.size(); i++) {
+ Object jobj = jlist.get(i);
+ if (Py.java2py(jobj)._eq(object).__nonzero__()) {
+ jlist.remove(i);
+ return Py.None;
+ }
+ }
+ throw Py.ValueError(object.toString() + " is not in list");
+ }
+ };
+ private static final PyBuiltinMethodNarrow listRAddProxy = new ListMethod("__radd__", 1) {
+ @Override
+ public PyObject __call__(PyObject obj) {
+ // first, clone the self list
+ List jList = asList();
+ List jClone;
+ try {
+ jClone = (List) jList.getClass().newInstance();
+ } catch (IllegalAccessException e) {
+ throw Py.JavaError(e);
+ } catch (InstantiationException e) {
+ throw Py.JavaError(e);
+ }
+ for (Object entry : jList) {
+ jClone.add(entry);
+ }
+
+ // then, extend it with elements from the other list
+ // (but, since this is reverse add, we are technically
+ // pre-pending the clone with elements from the other list)
+ if (obj instanceof Collection) {
+ jClone.addAll(0, (Collection) obj);
+ } else {
+ int i = 0;
+ for (PyObject item : obj.asIterable()) {
+ jClone.add(i, item);
+ i++;
+ }
+ }
+
+ return Py.java2py(jClone);
+ }
+ };
+ private static final PyBuiltinMethodNarrow listIAddProxy = new ListMethod("__iadd__", 1) {
+ @Override
+ public PyObject __call__(PyObject obj) {
+ List jList = asList();
+ if (obj instanceof Collection) {
+ jList.addAll((Collection) obj);
+ } else {
+ for (PyObject item : obj.asIterable()) {
+ jList.add(item);
+ }
+ }
+ return self;
+ }
+ };
+ private static final PyBuiltinMethodNarrow listIMulProxy = new ListMethod("__imul__", 1) {
+ @Override
+ public PyObject __call__(PyObject obj) {
+ List jList = asList();
+ int mult = obj.asInt();
+
+ // anything below 0 multiplier, we clear the list
+ if (mult <= 0) {
+ jList.clear();
+ } else {
+ try {
+ if (jList instanceof ArrayList) {
+ ((ArrayList) jList).ensureCapacity(jList.size() * (mult - 1));
+ }
+ // otherwise, extend it (in-place) x times, where x is int-cast from obj
+ int originalSize = jList.size();
+ for (mult = mult - 1; mult > 0; mult--) {
+ for (int i = 0; i < originalSize; i++) {
+ jList.add(jList.get(i));
+ }
+ }
+ } catch (OutOfMemoryError t) {
+ throw Py.MemoryError("");
+ }
+ }
+ return self;
+ }
+ };
+ private static final PyBuiltinMethodNarrow listSortProxy = new ListMethod("sort", 0, 3) {
+ @Override
+ public PyObject __call__() {
+ list_sort(asList(), Py.None, Py.None, false);
+ return Py.None;
+ }
+
+ @Override
+ public PyObject __call__(PyObject cmp) {
+ list_sort(asList(), cmp, Py.None, false);
+ return Py.None;
+ }
+
+ @Override
+ public PyObject __call__(PyObject cmp, PyObject key) {
+ list_sort(asList(), cmp, key, false);
+ return Py.None;
+ }
+
+ @Override
+ public PyObject __call__(PyObject cmp, PyObject key, PyObject reverse) {
+ list_sort(asList(), cmp, key, reverse.__nonzero__());
+ return Py.None;
+ }
+
+ @Override
+ public PyObject __call__(PyObject[] args, String[] kwds) {
+ ArgParser ap = new ArgParser("list", args, kwds, new String[]{
+ "cmp", "key", "reverse"}, 0);
+ PyObject cmp = ap.getPyObject(0, Py.None);
+ PyObject key = ap.getPyObject(1, Py.None);
+ PyObject reverse = ap.getPyObject(2, Py.False);
+ list_sort(asList(), cmp, key, reverse.__nonzero__());
+ return Py.None;
+ }
+ };
+
+ static PyBuiltinMethod[] getProxyMethods() {
+ return new PyBuiltinMethod[]{
+ listGetProxy,
+ listSetProxy,
+ listEqProxy,
+ listRemoveProxy,
+ listAppendProxy,
+ listExtendProxy,
+ listInsertProxy,
+ listPopProxy,
+ listIndexProxy,
+ listCountProxy,
+ listReverseProxy,
+ listRAddProxy,
+ listIAddProxy,
+ new ListMulProxyClass("__mul__", 1),
+ new ListMulProxyClass("__rmul__", 1),
+ listIMulProxy,
+ listSortProxy,
+ };
+ }
+
+ static PyBuiltinMethod[] getPostProxyMethods() {
+ return new PyBuiltinMethod[]{
+ listRemoveOverrideProxy
+ };
+ }
+
+}
diff --git a/src/org/python/core/JavaProxyMap.java b/src/org/python/core/JavaProxyMap.java
new file mode 100644
--- /dev/null
+++ b/src/org/python/core/JavaProxyMap.java
@@ -0,0 +1,478 @@
+package org.python.core;
+
+import java.util.Iterator;
+import java.util.Map;
+import java.util.Set;
+
+/**
+ * Proxy Java objects implementing java.util.List with Python methods
+ * corresponding to the standard list type
+ */
+
+
+class JavaProxyMap {
+
+ private static class MapMethod extends PyBuiltinMethodNarrow {
+ protected MapMethod(String name, int numArgs) {
+ super(name, numArgs);
+ }
+
+ protected MapMethod(String name, int minArgs, int maxArgs) {
+ super(name, minArgs, maxArgs);
+ }
+
+ protected Map<Object, Object> asMap() {
+ return (Map<Object, Object>) self.getJavaProxy();
+ }
+ }
+
+ private static class MapClassMethod extends PyBuiltinClassMethodNarrow {
+ protected MapClassMethod(String name, int minArgs, int maxArgs) {
+ super(name, minArgs, maxArgs);
+ }
+
+ protected Class<?> asClass() {
+ return (Class<?>) self.getJavaProxy();
+ }
+ }
+
+ private static PyObject mapEq(PyObject self, PyObject other) {
+ Map<Object, Object> selfMap = ((Map<Object, Object>) self.getJavaProxy());
+ if (other.getType().isSubType(PyDictionary.TYPE)) {
+ PyDictionary oDict = (PyDictionary) other;
+ if (selfMap.size() != oDict.size()) {
+ return Py.False;
+ }
+ for (Object jkey : selfMap.keySet()) {
+ Object jval = selfMap.get(jkey);
+ PyObject oVal = oDict.__finditem__(Py.java2py(jkey));
+ if (oVal == null) {
+ return Py.False;
+ }
+ if (!Py.java2py(jval)._eq(oVal).__nonzero__()) {
+ return Py.False;
+ }
+ }
+ return Py.True;
+ } else {
+ Object oj = other.getJavaProxy();
+ if (oj instanceof Map) {
+ Map<Object, Object> oMap = (Map<Object, Object>) oj;
+ return Py.newBoolean(selfMap.equals(oMap));
+ } else {
+ return null;
+ }
+ }
+ }
+
+ // Map ordering comparisons (lt, le, gt, ge) are based on the key sets;
+ // we just define mapLe + mapEq for total ordering of such key sets
+ private static PyObject mapLe(PyObject self, PyObject other) {
+ Set<Object> selfKeys = ((Map<Object, Object>) self.getJavaProxy()).keySet();
+ if (other.getType().isSubType(PyDictionary.TYPE)) {
+ PyDictionary oDict = (PyDictionary) other;
+ for (Object jkey : selfKeys) {
+ if (!oDict.__contains__(Py.java2py(jkey))) {
+ return Py.False;
+ }
+ }
+ return Py.True;
+ } else {
+ Object oj = other.getJavaProxy();
+ if (oj instanceof Map) {
+ Map<Object, Object> oMap = (Map<Object, Object>) oj;
+ return Py.newBoolean(oMap.keySet().containsAll(selfKeys));
+ } else {
+ return null;
+ }
+ }
+ }
+
+ // Map doesn't extend Collection, so it needs its own version of len, iter and contains
+ private static final PyBuiltinMethodNarrow mapLenProxy = new MapMethod("__len__", 0) {
+ @Override
+ public PyObject __call__() {
+ return Py.java2py(asMap().size());
+ }
+ };
+ private static final PyBuiltinMethodNarrow mapReprProxy = new MapMethod("__repr__", 0) {
+ @Override
+ public PyObject __call__() {
+ StringBuilder repr = new StringBuilder("{");
+ for (Map.Entry<Object, Object> entry : asMap().entrySet()) {
+ Object jkey = entry.getKey();
+ Object jval = entry.getValue();
+ repr.append(jkey.toString());
+ repr.append(": ");
+ repr.append(jval == asMap() ? "{...}" : (jval == null ? "None" : jval.toString()));
+ repr.append(", ");
+ }
+ int lastindex = repr.lastIndexOf(", ");
+ if (lastindex > -1) {
+ repr.delete(lastindex, lastindex + 2);
+ }
+ repr.append("}");
+ return new PyString(repr.toString());
+ }
+ };
+ private static final PyBuiltinMethodNarrow mapEqProxy = new MapMethod("__eq__", 1) {
+ @Override
+ public PyObject __call__(PyObject other) {
+ return mapEq(self, other);
+ }
+ };
+ private static final PyBuiltinMethodNarrow mapLeProxy = new MapMethod("__le__", 1) {
+ @Override
+ public PyObject __call__(PyObject other) {
+ return mapLe(self, other);
+ }
+ };
+ private static final PyBuiltinMethodNarrow mapGeProxy = new MapMethod("__ge__", 1) {
+ @Override
+ public PyObject __call__(PyObject other) {
+ return (mapLe(self, other).__not__()).__or__(mapEq(self, other));
+ }
+ };
+ private static final PyBuiltinMethodNarrow mapLtProxy = new MapMethod("__lt__", 1) {
+ @Override
+ public PyObject __call__(PyObject other) {
+ return mapLe(self, other).__and__(mapEq(self, other).__not__());
+ }
+ };
+ private static final PyBuiltinMethodNarrow mapGtProxy = new MapMethod("__gt__", 1) {
+ @Override
+ public PyObject __call__(PyObject other) {
+ return mapLe(self, other).__not__();
+ }
+ };
+ private static final PyBuiltinMethodNarrow mapIterProxy = new MapMethod("__iter__", 0) {
+ @Override
+ public PyObject __call__() {
+ return new JavaIterator(asMap().keySet());
+ }
+ };
+ private static final PyBuiltinMethodNarrow mapContainsProxy = new MapMethod("__contains__", 1) {
+ @Override
+ public PyObject __call__(PyObject obj) {
+ Object other = obj.__tojava__(Object.class);
+ return asMap().containsKey(other) ? Py.True : Py.False;
+ }
+ };
+ // "get" needs to override java.util.Map#get() in its subclasses, too, so this needs to be injected last
+ // (i.e. when HashMap is loaded not when it is recursively loading its super-type Map)
+ private static final PyBuiltinMethodNarrow mapGetProxy = new MapMethod("get", 1, 2) {
+ @Override
+ public PyObject __call__(PyObject key) {
+ return __call__(key, Py.None);
+ }
+
+ @Override
+ public PyObject __call__(PyObject key, PyObject _default) {
+ Object jkey = Py.tojava(key, Object.class);
+ if (asMap().containsKey(jkey)) {
+ return Py.java2py(asMap().get(jkey));
+ } else {
+ return _default;
+ }
+ }
+ };
+ private static final PyBuiltinMethodNarrow mapGetItemProxy = new MapMethod("__getitem__", 1) {
+ @Override
+ public PyObject __call__(PyObject key) {
+ Object jkey = Py.tojava(key, Object.class);
+ if (asMap().containsKey(jkey)) {
+ return Py.java2py(asMap().get(jkey));
+ } else {
+ throw Py.KeyError(key);
+ }
+ }
+ };
+ private static final PyBuiltinMethodNarrow mapPutProxy = new MapMethod("__setitem__", 2) {
+ @Override
+ public PyObject __call__(PyObject key, PyObject value) {
+ asMap().put(Py.tojava(key, Object.class),
+ value == Py.None ? Py.None : Py.tojava(value, Object.class));
+ return Py.None;
+ }
+ };
+ private static final PyBuiltinMethodNarrow mapRemoveProxy = new MapMethod("__delitem__", 1) {
+ @Override
+ public PyObject __call__(PyObject key) {
+ Object jkey = Py.tojava(key, Object.class);
+ if (asMap().remove(jkey) == null) {
+ throw Py.KeyError(key);
+ }
+ return Py.None;
+ }
+ };
+ private static final PyBuiltinMethodNarrow mapIterItemsProxy = new MapMethod("iteritems", 0) {
+ @Override
+ public PyObject __call__() {
+ final Iterator<Map.Entry<Object, Object>> entrySetIterator = asMap().entrySet().iterator();
+ return new PyIterator() {
+ @Override
+ public PyObject __iternext__() {
+ if (entrySetIterator.hasNext()) {
+ Map.Entry<Object, Object> nextEntry = entrySetIterator.next();
+ // yield a Python tuple object (key, value)
+ return new PyTuple(Py.java2py(nextEntry.getKey()),
+ Py.java2py(nextEntry.getValue()));
+ }
+ return null;
+ }
+ };
+ }
+ };
+ private static final PyBuiltinMethodNarrow mapHasKeyProxy = new MapMethod("has_key", 1) {
+ @Override
+ public PyObject __call__(PyObject key) {
+ return asMap().containsKey(Py.tojava(key, Object.class)) ? Py.True : Py.False;
+ }
+ };
+ private static final PyBuiltinMethodNarrow mapKeysProxy = new MapMethod("keys", 0) {
+ @Override
+ public PyObject __call__() {
+ PyList keys = new PyList();
+ for (Object key : asMap().keySet()) {
+ keys.add(Py.java2py(key));
+ }
+ return keys;
+ }
+ };
+ private static final PyBuiltinMethod mapValuesProxy = new MapMethod("values", 0) {
+ @Override
+ public PyObject __call__() {
+ PyList values = new PyList();
+ for (Object value : asMap().values()) {
+ values.add(Py.java2py(value));
+ }
+ return values;
+ }
+ };
+ private static final PyBuiltinMethodNarrow mapSetDefaultProxy = new MapMethod("setdefault", 1, 2) {
+ @Override
+ public PyObject __call__(PyObject key) {
+ return __call__(key, Py.None);
+ }
+
+ @Override
+ public PyObject __call__(PyObject key, PyObject _default) {
+ Object jkey = Py.tojava(key, Object.class);
+ Object jval = asMap().get(jkey);
+ if (jval == null) {
+ asMap().put(jkey, _default == Py.None ? Py.None : Py.tojava(_default, Object.class));
+ return _default;
+ }
+ return Py.java2py(jval);
+ }
+ };
+ private static final PyBuiltinMethodNarrow mapPopProxy = new MapMethod("pop", 1, 2) {
+ @Override
+ public PyObject __call__(PyObject key) {
+ return __call__(key, null);
+ }
+
+ @Override
+ public PyObject __call__(PyObject key, PyObject _default) {
+ Object jkey = Py.tojava(key, Object.class);
+ if (asMap().containsKey(jkey)) {
+ PyObject value = Py.java2py(asMap().remove(jkey));
+ assert (value != null);
+ return Py.java2py(value);
+ } else {
+ if (_default == null) {
+ throw Py.KeyError(key);
+ }
+ return _default;
+ }
+ }
+ };
+ private static final PyBuiltinMethodNarrow mapPopItemProxy = new MapMethod("popitem", 0) {
+ @Override
+ public PyObject __call__() {
+ if (asMap().size() == 0) {
+ throw Py.KeyError("popitem(): map is empty");
+ }
+ Object key = asMap().keySet().toArray()[0];
+ Object val = asMap().remove(key);
+ return Py.java2py(val);
+ }
+ };
+ private static final PyBuiltinMethodNarrow mapItemsProxy = new MapMethod("items", 0) {
+ @Override
+ public PyObject __call__() {
+ PyList items = new PyList();
+ for (Map.Entry<Object, Object> entry : asMap().entrySet()) {
+ items.add(new PyTuple(Py.java2py(entry.getKey()),
+ Py.java2py(entry.getValue())));
+ }
+ return items;
+ }
+ };
+ private static final PyBuiltinMethodNarrow mapCopyProxy = new MapMethod("copy", 0) {
+ @Override
+ public PyObject __call__() {
+ Map<Object, Object> jmap = asMap();
+ Map<Object, Object> jclone;
+ try {
+ jclone = (Map<Object, Object>) jmap.getClass().newInstance();
+ } catch (IllegalAccessException e) {
+ throw Py.JavaError(e);
+ } catch (InstantiationException e) {
+ throw Py.JavaError(e);
+ }
+ for (Map.Entry<Object, Object> entry : jmap.entrySet()) {
+ jclone.put(entry.getKey(), entry.getValue());
+ }
+ return Py.java2py(jclone);
+ }
+ };
+ private static final PyBuiltinMethodNarrow mapUpdateProxy = new MapMethod("update", 0, 1) {
+ private Map<Object, Object> jmap;
+
+ @Override
+ public PyObject __call__() {
+ return Py.None;
+ }
+
+ @Override
+ public PyObject __call__(PyObject other) {
+ // `other` is either another dict-like object, or an iterable of key/value pairs (as tuples
+ // or other iterables of length two)
+ return __call__(new PyObject[]{other}, new String[]{});
+ }
+
+ @Override
+ public PyObject __call__(PyObject[] args, String[] keywords) {
+ if ((args.length - keywords.length) != 1) {
+ throw info.unexpectedCall(args.length, false);
+ }
+ jmap = asMap();
+ PyObject other = args[0];
+ // update with entries from `other` (adapted from their equivalent in PyDictionary#update)
+ Object proxy = other.getJavaProxy();
+ if (proxy instanceof Map) {
+ merge((Map<Object, Object>) proxy);
+ } else if (other.__findattr__("keys") != null) {
+ merge(other);
+ } else {
+ mergeFromSeq(other);
+ }
+ // update with entries from keyword arguments
+ for (int i = 0; i < keywords.length; i++) {
+ String jkey = keywords[i];
+ PyObject value = args[1 + i];
+ jmap.put(jkey, Py.tojava(value, Object.class));
+ }
+ return Py.None;
+ }
+
+ private void merge(Map<Object, Object> other) {
+ for (Map.Entry<Object, Object> entry : other.entrySet()) {
+ jmap.put(entry.getKey(), entry.getValue());
+ }
+ }
+
+ private void merge(PyObject other) {
+ if (other instanceof PyDictionary) {
+ jmap.putAll(((PyDictionary) other).getMap());
+ } else if (other instanceof PyStringMap) {
+ mergeFromKeys(other, ((PyStringMap) other).keys());
+ } else {
+ mergeFromKeys(other, other.invoke("keys"));
+ }
+ }
+
+ private void mergeFromKeys(PyObject other, PyObject keys) {
+ for (PyObject key : keys.asIterable()) {
+ jmap.put(Py.tojava(key, Object.class),
+ Py.tojava(other.__getitem__(key), Object.class));
+ }
+ }
+
+ private void mergeFromSeq(PyObject other) {
+ PyObject pairs = other.__iter__();
+ PyObject pair;
+
+ for (int i = 0; (pair = pairs.__iternext__()) != null; i++) {
+ try {
+ pair = PySequence.fastSequence(pair, "");
+ } catch (PyException pye) {
+ if (pye.match(Py.TypeError)) {
+ throw Py.TypeError(String.format("cannot convert dictionary update sequence "
+ + "element #%d to a sequence", i));
+ }
+ throw pye;
+ }
+ int n;
+ if ((n = pair.__len__()) != 2) {
+ throw Py.ValueError(String.format("dictionary update sequence element #%d "
+ + "has length %d; 2 is required", i, n));
+ }
+ jmap.put(Py.tojava(pair.__getitem__(0), Object.class),
+ Py.tojava(pair.__getitem__(1), Object.class));
+ }
+ }
+ };
+ private static final PyBuiltinClassMethodNarrow mapFromKeysProxy = new MapClassMethod("fromkeys", 1, 2) {
+ @Override
+ public PyObject __call__(PyObject keys) {
+ return __call__(keys, null);
+ }
+
+ @Override
+ public PyObject __call__(PyObject keys, PyObject _default) {
+ Object defobj = _default == null ? Py.None : Py.tojava(_default, Object.class);
+ Class<?> theClass = asClass();
+ try {
+ // always injected to java.util.Map, so we know the class object we get from asClass is subtype of java.util.Map
+ Map<Object, Object> theMap = (Map<Object, Object>) theClass.newInstance();
+ for (PyObject key : keys.asIterable()) {
+ theMap.put(Py.tojava(key, Object.class), defobj);
+ }
+ return Py.java2py(theMap);
+ } catch (InstantiationException e) {
+ throw Py.JavaError(e);
+ } catch (IllegalAccessException e) {
+ throw Py.JavaError(e);
+ }
+ }
+ };
+
+ static PyBuiltinMethod[] getProxyMethods() {
+ return new PyBuiltinMethod[]{
+ mapLenProxy,
+ // map IterProxy can conflict with Iterable.class;
+ // fix after the fact in handleMroError
+ mapIterProxy,
+ mapReprProxy,
+ mapEqProxy,
+ mapLeProxy,
+ mapLtProxy,
+ mapGeProxy,
+ mapGtProxy,
+ mapContainsProxy,
+ mapGetItemProxy,
+ mapPutProxy,
+ mapRemoveProxy,
+ mapIterItemsProxy,
+ mapHasKeyProxy,
+ mapKeysProxy,
+ mapSetDefaultProxy,
+ mapPopProxy,
+ mapPopItemProxy,
+ mapItemsProxy,
+ mapCopyProxy,
+ mapUpdateProxy,
+ mapFromKeysProxy // class method
+
+ };
+ }
+
+ static PyBuiltinMethod[] getPostProxyMethods() {
+ return new PyBuiltinMethod[]{
+ mapGetProxy,
+ mapValuesProxy
+ };
+ }
+}
diff --git a/src/org/python/core/JavaProxySet.java b/src/org/python/core/JavaProxySet.java
new file mode 100644
--- /dev/null
+++ b/src/org/python/core/JavaProxySet.java
@@ -0,0 +1,574 @@
+package org.python.core;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.NavigableSet;
+import java.util.NoSuchElementException;
+import java.util.Set;
+
+/** Proxy objects implementing java.util.Set */
+
+class JavaProxySet {
+
+ private static class SetMethod extends PyBuiltinMethodNarrow {
+
+ protected SetMethod(String name, int numArgs) {
+ super(name, numArgs);
+ }
+
+ protected SetMethod(String name, int minArgs, int maxArgs) {
+ super(name, minArgs, maxArgs);
+ }
+
+ @SuppressWarnings("unchecked")
+ protected Set<Object> asSet() {
+ return (Set<Object>) self.getJavaProxy();
+ }
+
+ // Unlike list and dict, set maintains the derived type for the set
+ // so we replicate this behavior
+ protected PyObject makePySet(Set newSet) {
+ PyObject newPySet = self.getType().__call__();
+ @SuppressWarnings("unchecked")
+ Set<Object> jSet = ((Set<Object>) newPySet.getJavaProxy());
+ jSet.addAll(newSet);
+ return newPySet;
+ }
+
+ public boolean isEqual(PyObject other) {
+ Set<Object> selfSet = asSet();
+ Object oj = other.getJavaProxy();
+ if (oj != null && oj instanceof Set) {
+ @SuppressWarnings("unchecked")
+ Set<Object> otherSet = (Set<Object>) oj;
+ if (selfSet.size() != otherSet.size()) {
+ return false;
+ }
+ return selfSet.containsAll(otherSet);
+ } else if (isPySet(other)) {
+ Set<PyObject> otherPySet = ((BaseSet) other).getSet();
+ if (selfSet.size() != otherPySet.size()) {
+ return false;
+ }
+ for (PyObject pyobj : otherPySet) {
+ if (!selfSet.contains(pyobj.__tojava__(Object.class))) {
+ return false;
+ }
+ }
+ return true;
+ }
+ return false;
+ }
+
+ public boolean isSuperset(PyObject other) {
+ Set<Object> selfSet = asSet();
+ Object oj = other.getJavaProxy();
+ if (oj != null && oj instanceof Set) {
+ Set otherSet = (Set) oj;
+ return selfSet.containsAll(otherSet);
+ } else if (isPySet(other)) {
+ Set<PyObject> otherPySet = ((BaseSet) other).getSet();
+ for (PyObject pyobj : otherPySet) {
+ if (!selfSet.contains(pyobj.__tojava__(Object.class))) {
+ return false;
+ }
+ }
+ return true;
+ }
+ return false;
+ }
+
+ public boolean isSubset(PyObject other) {
+ Set<Object> selfSet = asSet();
+ Object oj = other.getJavaProxy();
+ if (oj != null && oj instanceof Set) {
+ @SuppressWarnings("unchecked")
+ Set<Object> otherSet = (Set<Object>) oj;
+ return otherSet.containsAll(selfSet);
+ } else if (isPySet(other)) {
+ Set<PyObject> otherPySet = ((BaseSet) other).getSet();
+ for (Object obj : selfSet) {
+ if (!otherPySet.contains(Py.java2py(obj))) {
+ return false;
+ }
+ }
+ return true;
+ }
+ return false;
+ }
+
+ protected Set difference(Collection<Object> other) {
+ Set<Object> selfSet = asSet();
+ Set<Object> diff = new HashSet<>(selfSet);
+ diff.removeAll(other);
+ return diff;
+ }
+ protected void differenceUpdate(Collection other) {
+ asSet().removeAll(other);
+ }
+
+ protected Set intersect(Collection[] others) {
+ Set<Object> selfSet = asSet();
+ Set<Object> intersection = new HashSet<>(selfSet);
+ for (Collection other : others) {
+ intersection.retainAll(other);
+ }
+ return intersection;
+ }
+ protected void intersectUpdate(Collection[] others) {
+ Set<Object> selfSet = asSet();
+ for (Collection other : others) {
+ selfSet.retainAll(other);
+ }
+ }
+
+ protected Set union(Collection<Object> other) {
+ Set<Object> selfSet = asSet();
+ Set<Object> u = new HashSet<>(selfSet);
+ u.addAll(other);
+ return u;
+ }
+ protected void update(Collection<Object> other) {
+ asSet().addAll(other);
+ }
+
+ protected Set symDiff(Collection<Object> other) {
+ Set<Object> selfSet = asSet();
+ Set<Object> symDiff = new HashSet<>(selfSet);
+ symDiff.addAll(other);
+ Set<Object> intersection = new HashSet<>(selfSet);
+ intersection.retainAll(other);
+ symDiff.removeAll(intersection);
+ return symDiff;
+ }
+ protected void symDiffUpdate(Collection<Object> other) {
+ Set<Object> selfSet = asSet();
+ Set<Object> intersection = new HashSet<>(selfSet);
+ intersection.retainAll(other);
+ selfSet.addAll(other);
+ selfSet.removeAll(intersection);
+ }
+ }
+
+ private static class SetMethodVarargs extends SetMethod {
+ protected SetMethodVarargs(String name) {
+ super(name, 0, -1);
+ }
+
+ public PyObject __call__() {
+ return __call__(Py.EmptyObjects);
+ }
+
+ public PyObject __call__(PyObject obj) {
+ return __call__(new PyObject[]{obj});
+ }
+
+ public PyObject __call__(PyObject obj1, PyObject obj2) {
+ return __call__(new PyObject[]{obj1, obj2});
+ }
+
+ public PyObject __call__(PyObject obj1, PyObject obj2, PyObject obj3) {
+ return __call__(new PyObject[]{obj1, obj2, obj3});
+ }
+
+ public PyObject __call__(PyObject obj1, PyObject obj2, PyObject obj3, PyObject obj4) {
+ return __call__(new PyObject[]{obj1, obj2, obj3, obj4});
+ }
+ }
+
+ private static boolean isPySet(PyObject obj) {
+ PyType type = obj.getType();
+ return type.isSubType(PySet.TYPE) || type.isSubType(PyFrozenSet.TYPE);
+ }
+
+ private static Collection<Object> getJavaSet(PyObject self, String op, PyObject obj) {
+ Collection<Object> items;
+ if (isPySet(obj)) {
+ Set<PyObject> otherPySet = ((BaseSet)obj).getSet();
+ items = new ArrayList<>(otherPySet.size());
+ for (PyObject pyobj : otherPySet) {
+ items.add(pyobj.__tojava__(Object.class));
+ }
+ } else {
+ Object oj = obj.getJavaProxy();
+ if (oj instanceof Set) {
+ @SuppressWarnings("unchecked")
+ Set<Object> jSet = (Set<Object>) oj;
+ items = jSet;
+ } else {
+ throw Py.TypeError(String.format(
+ "unsupported operand type(s) for %s: '%.200s' and '%.200s'",
+ op, self.getType().fastGetName(), obj.getType().fastGetName()));
+ }
+ }
+ return items;
+ }
+
+ private static Collection<Object> getJavaCollection(PyObject obj) {
+ Collection<Object> items;
+ Object oj = obj.getJavaProxy();
+ if (oj != null) {
+ if (oj instanceof Collection) {
+ @SuppressWarnings("unchecked")
+ Collection<Object> jCollection = (Collection<Object>) oj;
+ items = jCollection;
+ } else if (oj instanceof Iterable) {
+ items = new HashSet<>();
+ for (Object item: (Iterable) oj) {
+ items.add(item);
+ }
+ } else {
+ throw Py.TypeError(String.format("unsupported operand type(s): '%.200s'",
+ obj.getType().fastGetName()));
+ }
+ } else {
+ // This step verifies objects are hashable
+ items = new HashSet<>();
+ for (PyObject pyobj : obj.asIterable()) {
+ items.add(pyobj.__tojava__(Object.class));
+ }
+ }
+ return items;
+ }
+
+ private static Collection<Object>[] getJavaCollections(PyObject[] objs) {
+ Collection[] collections = new Collection[objs.length];
+ int i = 0;
+ for (PyObject obj : objs) {
+ collections[i++] = getJavaCollection(obj);
+ }
+ return collections;
+ }
+
+ private static Collection<Object> getCombinedJavaCollections(PyObject[] objs) {
+ if (objs.length == 0) {
+ return Collections.emptyList();
+ }
+ if (objs.length == 1) {
+ return getJavaCollection(objs[0]);
+ }
+ Set<Object> items = new HashSet<>();
+ for (PyObject obj : objs) {
+ Object oj = obj.getJavaProxy();
+ if (oj != null) {
+ if (oj instanceof Iterable) {
+ for (Object item : (Iterable) oj) {
+ items.add(item);
+ }
+ } else {
+ throw Py.TypeError(String.format("unsupported operand type(s): '%.200s'",
+ obj.getType().fastGetName()));
+ }
+ } else {
+ for (PyObject pyobj : obj.asIterable()) {
+ items.add(pyobj.__tojava__(Object.class));
+ }
+ }
+ }
+ return items;
+ }
+
+ private static final SetMethod cmpProxy = new SetMethod("__cmp__", 1) {
+ @Override
+ public PyObject __call__(PyObject value) {
+ throw Py.TypeError("cannot compare sets using cmp()");
+ }
+ };
+ private static final SetMethod eqProxy = new SetMethod("__eq__", 1) {
+ @Override
+ public PyObject __call__(PyObject other) {
+ return Py.newBoolean(isEqual(other));
+ }
+ };
+ private static final SetMethod ltProxy = new SetMethod("__lt__", 1) {
+ @Override
+ public PyObject __call__(PyObject other) {
+ return Py.newBoolean(!isEqual(other) && isSubset(other));
+ }
+ };
+ private static class IsSubsetMethod extends SetMethod {
+ // __le__, issubset
+
+ protected IsSubsetMethod(String name) {
+ super(name, 1);
+ }
+
+ @Override
+ public PyObject __call__(PyObject other) {
+ return Py.newBoolean(isSubset(other));
+ }
+ }
+ private static class IsSupersetMethod extends SetMethod {
+ // __ge__, issuperset
+
+ protected IsSupersetMethod(String name) {
+ super(name, 1);
+ }
+
+ @Override
+ public PyObject __call__(PyObject other) {
+ return Py.newBoolean(isSuperset(other));
+ }
+ }
+ private static final SetMethod gtProxy = new SetMethod("__gt__", 1) {
+ @Override
+ public PyObject __call__(PyObject other) {
+ return Py.newBoolean(!isEqual(other) && isSuperset(other));
+ }
+ };
+ private static final SetMethod isDisjointProxy = new SetMethod("isdisjoint", 1) {
+ @Override
+ public PyObject __call__(PyObject other) {
+ return Py.newBoolean(intersect(new Collection[]{getJavaCollection(other)}).size() == 0);
+ }
+ };
+
+ private static final SetMethod differenceProxy = new SetMethodVarargs("difference") {
+ @Override
+ public PyObject __call__(PyObject[] others) {
+ return makePySet(difference(getCombinedJavaCollections(others)));
+ }
+ };
+ private static final SetMethod differenceUpdateProxy = new SetMethodVarargs("difference_update") {
+ @Override
+ public PyObject __call__(PyObject[] others) {
+ differenceUpdate(getCombinedJavaCollections(others));
+ return Py.None;
+ }
+ };
+ private static final SetMethod subProxy = new SetMethod("__sub__", 1) {
+ @Override
+ public PyObject __call__(PyObject other) {
+ return makePySet(difference(getJavaSet(self, "-", other)));
+ }
+ };
+ private static final SetMethod isubProxy = new SetMethod("__isub__", 1) {
+ @Override
+ public PyObject __call__(PyObject other) {
+ differenceUpdate(getJavaSet(self, "-=", other));
+ return self;
+ }
+ };
+
+ private static final SetMethod intersectionProxy = new SetMethodVarargs("intersection") {
+ @Override
+ public PyObject __call__(PyObject[] others) {
+ return makePySet(intersect(getJavaCollections(others)));
+ }
+ };
+ private static final SetMethod intersectionUpdateProxy = new SetMethodVarargs("intersection_update") {
+ @Override
+ public PyObject __call__(PyObject[] others) {
+ intersectUpdate(getJavaCollections(others));
+ return Py.None;
+ }
+ };
+ private static final SetMethod andProxy = new SetMethod("__and__", 1) {
+ @Override
+ public PyObject __call__(PyObject other) {
+ return makePySet(intersect(new Collection[]{getJavaSet(self, "&", other)}));
+ }
+ };
+ private static final SetMethod iandProxy = new SetMethod("__iand__", 1) {
+ @Override
+ public PyObject __call__(PyObject other) {
+ intersectUpdate(new Collection[]{getJavaSet(self, "&=", other)});
+ return self;
+ }
+ };
+
+ private static final SetMethod symDiffProxy = new SetMethod("symmetric_difference", 1) {
+ @Override
+ public PyObject __call__(PyObject other) {
+ return makePySet(symDiff(getJavaCollection(other)));
+ }
+ };
+ private static final SetMethod symDiffUpdateProxy = new SetMethod("symmetric_difference_update", 1) {
+ @Override
+ public PyObject __call__(PyObject other) {
+ symDiffUpdate(getJavaCollection(other));
+ return Py.None;
+ }
+ };
+ private static final SetMethod xorProxy = new SetMethod("__xor__", 1) {
+ @Override
+ public PyObject __call__(PyObject other) {
+ return makePySet(symDiff(getJavaSet(self, "^", other)));
+ }
+ };
+ private static final SetMethod ixorProxy = new SetMethod("__ixor__", 1) {
+ @Override
+ public PyObject __call__(PyObject other) {
+ symDiffUpdate(getJavaSet(self, "^=", other));
+ return self;
+ }
+ };
+
+ private static final SetMethod unionProxy = new SetMethodVarargs("union") {
+ @Override
+ public PyObject __call__(PyObject[] others) {
+ return makePySet(union(getCombinedJavaCollections(others)));
+ }
+ };
+ private static final SetMethod updateProxy = new SetMethodVarargs("update") {
+ @Override
+ public PyObject __call__(PyObject[] others) {
+ update(getCombinedJavaCollections(others));
+ return Py.None;
+ }
+ };
+ private static final SetMethod orProxy = new SetMethod("__or__", 1) {
+ @Override
+ public PyObject __call__(PyObject other) {
+ return makePySet(union(getJavaSet(self, "|", other)));
+ }
+ };
+ private static final SetMethod iorProxy = new SetMethod("__ior__", 1) {
+ @Override
+ public PyObject __call__(PyObject other) {
+ update(getJavaSet(self, "|=", other));
+ return self;
+ }
+ };
+
+ private static class CopyMethod extends SetMethod {
+ protected CopyMethod(String name) {
+ super(name, 0);
+ }
+ @Override
+ public PyObject __call__() {
+ return makePySet(asSet());
+ }
+ }
+
+ private static final SetMethod deepcopyOverrideProxy = new SetMethod("__deepcopy__", 1) {
+ @Override
+ public PyObject __call__(PyObject memo) {
+ Set<Object> newSet = new HashSet<>();
+ for (Object obj : asSet()) {
+ PyObject pyobj = Py.java2py(obj);
+ PyObject newobj = pyobj.invoke("__deepcopy__", memo);
+ newSet.add(newobj.__tojava__(Object.class));
+ }
+ return makePySet(newSet);
+ }
+ };
+
+ private static final SetMethod reduceProxy = new SetMethod("__reduce__", 0) {
+ @Override
+ public PyObject __call__() {
+ PyObject args = new PyTuple(new PyList(new JavaIterator(asSet())));
+ PyObject dict = __findattr__("__dict__");
+ if (dict == null) {
+ dict = Py.None;
+ }
+ return new PyTuple(self.getType(), args, dict);
+ }
+ };
+
+ private static final SetMethod containsProxy = new SetMethod("__contains__", 1) {
+ @Override
+ public PyObject __call__(PyObject value) {
+ return Py.newBoolean(asSet().contains(value.__tojava__(Object.class)));
+ }
+ };
+ private static final SetMethod hashProxy = new SetMethod("__hash__", 0) {
+ // in general, we don't know if this is really true or not
+ @Override
+ public PyObject __call__(PyObject value) {
+ throw Py.TypeError(String.format("unhashable type: '%.200s'", self.getType().fastGetName()));
+ }
+ };
+
+ private static final SetMethod discardProxy = new SetMethod("discard", 1) {
+ @Override
+ public PyObject __call__(PyObject value) {
+ asSet().remove(value.__tojava__(Object.class));
+ return Py.None;
+ }
+ };
+ private static final SetMethod popProxy = new SetMethod("pop", 0) {
+ @Override
+ public PyObject __call__() {
+ Set selfSet = asSet();
+ Iterator it;
+ if (selfSet instanceof NavigableSet) {
+ it = ((NavigableSet) selfSet).descendingIterator();
+ } else {
+ it = selfSet.iterator();
+ }
+ try {
+ PyObject value = Py.java2py(it.next());
+ it.remove();
+ return value;
+ } catch (NoSuchElementException ex) {
+ throw Py.KeyError("pop from an empty set");
+ }
+ }
+ };
+ private static final SetMethod removeOverrideProxy = new SetMethod("remove", 1) {
+ @Override
+ public PyObject __call__(PyObject value) {
+ boolean removed = asSet().remove(value.__tojava__(Object.class));
+ if (!removed) {
+ throw Py.KeyError(value);
+ }
+ return Py.None;
+ }
+ };
+
+ static PyBuiltinMethod[] getProxyMethods() {
+ return new PyBuiltinMethod[]{
+ cmpProxy,
+ eqProxy,
+ ltProxy,
+ new IsSubsetMethod("__le__"),
+ new IsSubsetMethod("issubset"),
+ new IsSupersetMethod("__ge__"),
+ new IsSupersetMethod("issuperset"),
+ gtProxy,
+ isDisjointProxy,
+
+ differenceProxy,
+ differenceUpdateProxy,
+ subProxy,
+ isubProxy,
+
+ intersectionProxy,
+ intersectionUpdateProxy,
+ andProxy,
+ iandProxy,
+
+ symDiffProxy,
+ symDiffUpdateProxy,
+ xorProxy,
+ ixorProxy,
+
+ unionProxy,
+ updateProxy,
+ orProxy,
+ iorProxy,
+
+ new CopyMethod("copy"),
+ new CopyMethod("__copy__"),
+ reduceProxy,
+
+ containsProxy,
+ hashProxy,
+
+ discardProxy,
+ popProxy
+ };
+ }
+
+ static PyBuiltinMethod[] getPostProxyMethods() {
+ return new PyBuiltinMethod[]{
+ deepcopyOverrideProxy,
+ removeOverrideProxy
+ };
+ }
+
+}
diff --git a/src/org/python/core/Py.java b/src/org/python/core/Py.java
--- a/src/org/python/core/Py.java
+++ b/src/org/python/core/Py.java
@@ -1316,6 +1316,19 @@
if (globals == null || globals == Py.None) {
globals = ts.frame.f_globals;
+ } else if (globals.__finditem__("__builtins__") == null) {
+ // Apply side effect of copying into globals,
+ // per documentation of eval and observed behavior of exec
+ try {
+ globals.__setitem__("__builtins__", Py.getSystemState().modules.__finditem__("__builtin__").__getattr__("__dict__"));
+ } catch (PyException e) {
+ // Quietly ignore if cannot set __builtins__ - Jython previously allowed a much wider range of
+ // mappable objects for the globals mapping than CPython, do not want to break existing code
+ // as we try to get better CPython compliance
+ if (!e.match(AttributeError)) {
+ throw e;
+ }
+ }
}
PyBaseCode baseCode = null;
diff --git a/src/org/python/core/PyBaseCode.java b/src/org/python/core/PyBaseCode.java
--- a/src/org/python/core/PyBaseCode.java
+++ b/src/org/python/core/PyBaseCode.java
@@ -5,6 +5,7 @@
package org.python.core;
import org.python.modules._systemrestart;
+import com.google.common.base.CharMatcher;
public abstract class PyBaseCode extends PyCode {
@@ -199,14 +200,14 @@
public PyObject call(ThreadState state, PyObject args[], String kws[], PyObject globals,
PyObject[] defs, PyObject closure) {
- PyFrame frame = new PyFrame(this, globals);
- int argcount = args.length - kws.length;
+ final PyFrame frame = new PyFrame(this, globals);
+ final int argcount = args.length - kws.length;
- if (co_argcount > 0 || (varargs || varkwargs)) {
+ if ((co_argcount > 0) || varargs || varkwargs) {
int i;
int n = argcount;
PyObject kwdict = null;
- PyObject[] fastlocals = frame.f_fastlocals;
+ final PyObject[] fastlocals = frame.f_fastlocals;
if (varkwargs) {
kwdict = new PyDictionary();
i = co_argcount;
@@ -222,9 +223,9 @@
co_name,
defcount > 0 ? "at most" : "exactly",
co_argcount,
- kws.length > 0 ? "non-keyword " : "",
+ kws.length > 0 ? "" : "",
co_argcount == 1 ? "" : "s",
- argcount);
+ args.length);
throw Py.TypeError(msg);
}
n = co_argcount;
@@ -242,11 +243,6 @@
String keyword = kws[i];
PyObject value = args[i + argcount];
int j;
- // XXX: keywords aren't PyObjects, can't ensure strings
- //if (keyword == null || keyword.getClass() != PyString.class) {
- // throw Py.TypeError(String.format("%.200s() keywords must be strings",
- // co_name));
- //}
for (j = 0; j < co_argcount; j++) {
if (co_varnames[j].equals(keyword)) {
break;
@@ -254,11 +250,16 @@
}
if (j >= co_argcount) {
if (kwdict == null) {
- throw Py.TypeError(String.format("%.200s() got an unexpected keyword "
- + "argument '%.400s'",
- co_name, keyword));
+ throw Py.TypeError(String.format(
+ "%.200s() got an unexpected keyword argument '%.400s'",
+ co_name,
+ Py.newUnicode(keyword).encode("ascii", "replace")));
}
- kwdict.__setitem__(keyword, value);
+ if (CharMatcher.ASCII.matchesAllOf(keyword)) {
+ kwdict.__setitem__(keyword, value);
+ } else {
+ kwdict.__setitem__(Py.newUnicode(keyword), value);
+ }
} else {
if (fastlocals[j] != null) {
throw Py.TypeError(String.format("%.200s() got multiple values for "
@@ -269,16 +270,18 @@
}
}
if (argcount < co_argcount) {
- int defcount = defs != null ? defs.length : 0;
- int m = co_argcount - defcount;
+ final int defcount = defs != null ? defs.length : 0;
+ final int m = co_argcount - defcount;
for (i = argcount; i < m; i++) {
if (fastlocals[i] == null) {
String msg =
String.format("%.200s() takes %s %d %sargument%s (%d given)",
- co_name, (varargs || defcount > 0) ?
- "at least" : "exactly",
- m, kws.length > 0 ? "non-keyword " : "",
- m == 1 ? "" : "s", i);
+ co_name,
+ (varargs || defcount > 0) ? "at least" : "exactly",
+ m,
+ kws.length > 0 ? "" : "",
+ m == 1 ? "" : "s",
+ args.length);
throw Py.TypeError(msg);
}
}
@@ -293,9 +296,9 @@
}
}
}
- } else if (argcount > 0) {
+ } else if ((argcount > 0) || (args.length > 0 && (co_argcount == 0 && !varargs && !varkwargs))) {
throw Py.TypeError(String.format("%.200s() takes no arguments (%d given)",
- co_name, argcount));
+ co_name, args.length));
}
if (co_flags.isFlagSet(CodeFlag.CO_GENERATOR)) {
diff --git a/src/org/python/core/PyByteArray.java b/src/org/python/core/PyByteArray.java
--- a/src/org/python/core/PyByteArray.java
+++ b/src/org/python/core/PyByteArray.java
@@ -151,7 +151,7 @@
*
* @param storage pre-initialised with desired value: the caller should not keep a reference
*/
- PyByteArray(byte[] storage) {
+ public PyByteArray(byte[] storage) {
super(TYPE);
setStorage(storage);
}
@@ -165,7 +165,7 @@
* @throws IllegalArgumentException if the range [0:size] is not within the array bounds of the
* storage.
*/
- PyByteArray(byte[] storage, int size) {
+ public PyByteArray(byte[] storage, int size) {
super(TYPE);
setStorage(storage, size);
}
diff --git a/src/org/python/core/PyDictionary.java b/src/org/python/core/PyDictionary.java
--- a/src/org/python/core/PyDictionary.java
+++ b/src/org/python/core/PyDictionary.java
@@ -259,7 +259,7 @@
PyType thisType = getType();
PyType otherType = otherObj.getType();
if (otherType != thisType && !thisType.isSubType(otherType)
- && !otherType.isSubType(thisType)) {
+ && !otherType.isSubType(thisType) || otherType == PyObject.TYPE) {
return null;
}
PyDictionary other = (PyDictionary)otherObj;
@@ -344,7 +344,7 @@
PyType thisType = getType();
PyType otherType = otherObj.getType();
if (otherType != thisType && !thisType.isSubType(otherType)
- && !otherType.isSubType(thisType)) {
+ && !otherType.isSubType(thisType) || otherType == PyObject.TYPE) {
return -2;
}
PyDictionary other = (PyDictionary)otherObj;
@@ -618,7 +618,7 @@
final PyObject dict_pop(PyObject key, PyObject defaultValue) {
if (!getMap().containsKey(key)) {
if (defaultValue == null) {
- throw Py.KeyError("popitem(): dictionary is empty");
+ throw Py.KeyError(key.toString());
}
return defaultValue;
}
diff --git a/src/org/python/core/PyIterator.java b/src/org/python/core/PyIterator.java
--- a/src/org/python/core/PyIterator.java
+++ b/src/org/python/core/PyIterator.java
@@ -1,7 +1,10 @@
// Copyright 2000 Finn Bock
package org.python.core;
+import java.util.ArrayList;
+import java.util.Collection;
import java.util.Iterator;
+import java.util.List;
/**
* An abstract helper class useful when implementing an iterator object. This implementation supply
@@ -62,4 +65,22 @@
}
};
}
+
+ @Override
+ public Object __tojava__(Class<?> c) {
+ if (c.isAssignableFrom(Iterable.class)) {
+ return this;
+ }
+ if (c.isAssignableFrom(Iterator.class)) {
+ return iterator();
+ }
+ if (c.isAssignableFrom(Collection.class)) {
+ List<Object> list = new ArrayList();
+ for (Object obj : this) {
+ list.add(obj);
+ }
+ return list;
+ }
+ return super.__tojava__(c);
+ }
}
diff --git a/src/org/python/core/PyJavaType.java b/src/org/python/core/PyJavaType.java
--- a/src/org/python/core/PyJavaType.java
+++ b/src/org/python/core/PyJavaType.java
@@ -15,20 +15,22 @@
import java.lang.reflect.Member;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
-import java.util.ArrayList;
import java.util.Collection;
+import java.util.Collections;
import java.util.Enumeration;
import java.util.EventListener;
+import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
+import java.util.Queue;
import java.util.Set;
-import java.util.Queue;
import org.python.core.util.StringUtil;
import org.python.util.Generic;
+
public class PyJavaType extends PyType {
private final static Class<?>[] OO = {PyObject.class, PyObject.class};
@@ -66,8 +68,7 @@
java.net.URI.class,
java.util.concurrent.TimeUnit.class);
- private static Map<Class<?>, PyBuiltinMethod[]> collectionProxies;
- private static Map<Class<?>, PyBuiltinMethod[]> postCollectionProxies;
+
/**
* Other Java classes this type has MRO conflicts with. This doesn't matter for Java method
@@ -538,7 +539,7 @@
addMethod(meth);
}
}
- // allow for some methods to override the Java type's as a late injection
+ // allow for some methods to override the Java type's methods as a late injection
for (Class<?> type : getPostCollectionProxies().keySet()) {
if (type.isAssignableFrom(forClass)) {
for (PyBuiltinMethod meth : getPostCollectionProxies().get(type)) {
@@ -876,57 +877,6 @@
}
}
- private static class IteratorIter extends PyIterator {
-
- private Iterator<Object> proxy;
-
- public IteratorIter(Iterable<Object> proxy) {
- this(proxy.iterator());
- }
-
- public IteratorIter(Iterator<Object> proxy) {
- this.proxy = proxy;
- }
-
- public PyObject __iternext__() {
- return proxy.hasNext() ? Py.java2py(proxy.next()) : null;
- }
- }
-
- private static class ListMethod extends PyBuiltinMethodNarrow {
- protected ListMethod(String name, int numArgs) {
- super(name, numArgs);
- }
-
- protected List<Object> asList(){
- return (List<Object>)self.getJavaProxy();
- }
- }
-
- private static class MapMethod extends PyBuiltinMethodNarrow {
- protected MapMethod(String name, int numArgs) {
- super(name, numArgs);
- }
-
- protected MapMethod(String name, int minArgs, int maxArgs) {
- super(name, minArgs, maxArgs);
- }
-
- protected Map<Object, Object> asMap(){
- return (Map<Object, Object>)self.getJavaProxy();
- }
- }
-
- private static class MapClassMethod extends PyBuiltinClassMethodNarrow {
- protected MapClassMethod(String name, int minArgs, int maxArgs) {
- super(name, minArgs, maxArgs);
- }
-
- protected Class<?> asClass() {
- return (Class<?>) self.getJavaProxy();
- }
- }
-
private static abstract class ComparableMethod extends PyBuiltinMethodNarrow {
protected ComparableMethod(String name, int numArgs) {
super(name, numArgs);
@@ -946,6 +896,28 @@
protected abstract boolean getResult(int comparison);
}
+ private static class CollectionProxies {
+ final Map<Class<?>, PyBuiltinMethod[]> proxies;
+ final Map<Class<?>, PyBuiltinMethod[]> postProxies;
+
+ CollectionProxies() {
+ proxies = buildCollectionProxies();
+ postProxies = buildPostCollectionProxies();
+ }
+ }
+
+ private static class CollectionsProxiesHolder {
+ static final CollectionProxies proxies = new CollectionProxies();
+ }
+
+ private static Map<Class<?>, PyBuiltinMethod[]> getCollectionProxies() {
+ return CollectionsProxiesHolder.proxies.proxies;
+ }
+
+ private static Map<Class<?>, PyBuiltinMethod[]> getPostCollectionProxies() {
+ return CollectionsProxiesHolder.proxies.postProxies;
+ }
+
/**
* Build a map of common Java collection base types (Map, Iterable, etc) that need to be
* injected with Python's equivalent types' builtin methods (__len__, __iter__, iteritems, etc).
@@ -953,596 +925,73 @@
* @return A map whose key is the base Java collection types and whose entry is a list of
* injected methods.
*/
- private static Map<Class<?>, PyBuiltinMethod[]> getCollectionProxies() {
- if (collectionProxies == null) {
- collectionProxies = Generic.map();
- postCollectionProxies = Generic.map();
+ private static Map<Class<?>, PyBuiltinMethod[]> buildCollectionProxies() {
+ final Map<Class<?>, PyBuiltinMethod[]> proxies = new HashMap();
- PyBuiltinMethodNarrow iterableProxy = new PyBuiltinMethodNarrow("__iter__") {
- @Override
- public PyObject __call__() {
- return new IteratorIter(((Iterable)self.getJavaProxy()));
- }
- };
- collectionProxies.put(Iterable.class, new PyBuiltinMethod[] {iterableProxy});
+ PyBuiltinMethodNarrow iterableProxy = new PyBuiltinMethodNarrow("__iter__") {
+ @Override
+ public PyObject __call__() {
+ return new JavaIterator(((Iterable) self.getJavaProxy()));
+ }
+ };
+ proxies.put(Iterable.class, new PyBuiltinMethod[]{iterableProxy});
- PyBuiltinMethodNarrow lenProxy = new PyBuiltinMethodNarrow("__len__") {
- @Override
- public PyObject __call__() {
- return Py.newInteger(((Collection<?>)self.getJavaProxy()).size());
- }
- };
+ PyBuiltinMethodNarrow lenProxy = new PyBuiltinMethodNarrow("__len__") {
+ @Override
+ public PyObject __call__() {
+ return Py.newInteger(((Collection<?>) self.getJavaProxy()).size());
+ }
+ };
- PyBuiltinMethodNarrow containsProxy = new PyBuiltinMethodNarrow("__contains__", 1) {
- @Override
- public PyObject __call__(PyObject obj) {
- Object other = obj.__tojava__(Object.class);
- boolean contained = ((Collection<?>)self.getJavaProxy()).contains(other);
- return contained ? Py.True : Py.False;
- }
- };
- collectionProxies.put(Collection.class, new PyBuiltinMethod[] {lenProxy,
- containsProxy});
-
- PyBuiltinMethodNarrow iteratorProxy = new PyBuiltinMethodNarrow("__iter__") {
- @Override
- public PyObject __call__() {
- return new IteratorIter(((Iterator)self.getJavaProxy()));
- }
- };
- collectionProxies.put(Iterator.class, new PyBuiltinMethod[] {iteratorProxy});
-
- PyBuiltinMethodNarrow enumerationProxy = new PyBuiltinMethodNarrow("__iter__") {
- @Override
- public PyObject __call__() {
- return new EnumerationIter(((Enumeration)self.getJavaProxy()));
- }
- };
- collectionProxies.put(Enumeration.class, new PyBuiltinMethod[] {enumerationProxy});
-
- // Map doesn't extend Collection, so it needs its own version of len, iter and contains
- PyBuiltinMethodNarrow mapLenProxy = new MapMethod("__len__", 0) {
- @Override
- public PyObject __call__() {
- return Py.java2py(asMap().size());
- }
- };
- PyBuiltinMethodNarrow mapReprProxy = new MapMethod("__repr__", 0) {
- @Override
- public PyObject __call__() {
- StringBuilder repr = new StringBuilder("{");
- for (Map.Entry<Object, Object> entry : asMap().entrySet()) {
- Object jkey = entry.getKey();
- Object jval = entry.getValue();
- repr.append(jkey.toString());
- repr.append(": ");
- repr.append(jval == asMap() ? "{...}" : (jval == null ? "None" : jval.toString()));
- repr.append(", ");
- }
- int lastindex = repr.lastIndexOf(", ");
- if (lastindex > -1) {
- repr.delete(lastindex, lastindex + 2);
- }
- repr.append("}");
- return new PyString(repr.toString());
- }
- };
- PyBuiltinMethodNarrow mapEqProxy = new MapMethod("__eq__", 1) {
- @Override
- public PyObject __call__(PyObject other) {
- if (other.getType().isSubType(PyDictionary.TYPE)) {
- PyDictionary oDict = (PyDictionary) other;
- if (asMap().size() != oDict.size()) {
- return Py.False;
- }
- for (Object jkey : asMap().keySet()) {
- Object jval = asMap().get(jkey);
- PyObject oVal = oDict.__finditem__(Py.java2py(jkey));
- if (oVal == null) {
- return Py.False;
- }
- if (!Py.java2py(jval)._eq(oVal).__nonzero__()) {
- return Py.False;
- }
- }
- return Py.True;
- } else {
- Object oj = other.getJavaProxy();
- if (oj instanceof Map) {
- Map<Object, Object> oMap = (Map<Object, Object>) oj;
- return asMap().equals(oMap) ? Py.True : Py.False;
- } else {
- return null;
+ PyBuiltinMethodNarrow containsProxy = new PyBuiltinMethodNarrow("__contains__", 1) {
+ @Override
+ public PyObject __call__(PyObject obj) {
+ boolean contained = false;
+ Object proxy = obj.getJavaProxy();
+ if (proxy == null) {
+ for (Object item : (Collection<?>) self.getJavaProxy()) {
+ if (Py.java2py(item)._eq(obj).__nonzero__()) {
+ contained = true;
+ break;
}
}
+ } else {
+ Object other = obj.__tojava__(Object.class);
+ contained = ((Collection<?>) self.getJavaProxy()).contains(other);
+
}
- };
- PyBuiltinMethodNarrow mapIterProxy = new MapMethod("__iter__", 0) {
- @Override
- public PyObject __call__() {
- return new IteratorIter(asMap().keySet());
- }
- };
- PyBuiltinMethodNarrow mapContainsProxy = new MapMethod("__contains__", 1) {
- @Override
- public PyObject __call__(PyObject obj) {
- Object other = obj.__tojava__(Object.class);
- return asMap().containsKey(other) ? Py.True : Py.False;
- }
- };
- // "get" needs to override java.util.Map#get() in its subclasses, too, so this needs to be injected last
- // (i.e. when HashMap is loaded not when it is recursively loading its super-type Map)
- PyBuiltinMethodNarrow mapGetProxy = new MapMethod("get", 1, 2) {
- @Override
- public PyObject __call__(PyObject key) {
- return __call__(key, Py.None);
- }
- @Override
- public PyObject __call__(PyObject key, PyObject _default) {
- Object jkey = Py.tojava(key, Object.class);
- if (asMap().containsKey(jkey)) {
- return Py.java2py(asMap().get(jkey));
- } else {
- return _default;
- }
- }
- };
- PyBuiltinMethodNarrow mapGetItemProxy = new MapMethod("__getitem__", 1) {
- @Override
- public PyObject __call__(PyObject key) {
- Object jkey = Py.tojava(key, Object.class);
- if (asMap().containsKey(jkey)) {
- return Py.java2py(asMap().get(jkey));
- } else {
- throw Py.KeyError(key);
- }
- }
- };
- PyBuiltinMethodNarrow mapPutProxy = new MapMethod("__setitem__", 2) {
- @Override
- public PyObject __call__(PyObject key, PyObject value) {
- asMap().put(Py.tojava(key, Object.class),
- value == Py.None ? Py.None : Py.tojava(value, Object.class));
- return Py.None;
- }
- };
- PyBuiltinMethodNarrow mapRemoveProxy = new MapMethod("__delitem__", 1) {
- @Override
- public PyObject __call__(PyObject key) {
- Object jkey = Py.tojava(key, Object.class);
- if (asMap().remove(jkey) == null) {
- throw Py.KeyError(key);
- }
- return Py.None;
- }
- };
- PyBuiltinMethodNarrow mapIterItemsProxy = new MapMethod("iteritems", 0) {
- @Override
- public PyObject __call__() {
- final Iterator<Map.Entry<Object, Object>> entrySetIterator = asMap().entrySet().iterator();
- return new PyIterator() {
- @Override
- public PyObject __iternext__() {
- if (entrySetIterator.hasNext()) {
- Map.Entry<Object, Object> nextEntry = entrySetIterator.next();
- // yield a Python tuple object (key, value)
- return new PyTuple(Py.java2py(nextEntry.getKey()),
- Py.java2py(nextEntry.getValue()));
- }
- return null;
- }
- };
- }
- };
- PyBuiltinMethodNarrow mapHasKeyProxy = new MapMethod("has_key", 1) {
- @Override
- public PyObject __call__(PyObject key) {
- return asMap().containsKey(Py.tojava(key, Object.class)) ? Py.True : Py.False;
- }
- };
- PyBuiltinMethodNarrow mapKeysProxy = new MapMethod("keys", 0) {
- @Override
- public PyObject __call__() {
- PyList keys = new PyList();
- for (Object key : asMap().keySet()) {
- keys.add(Py.java2py(key));
- }
- return keys;
- }
- };
- PyBuiltinMethod mapValuesProxy = new MapMethod("values", 0) {
- @Override
- public PyObject __call__() {
- PyList values = new PyList();
- for (Object value : asMap().values()) {
- values.add(Py.java2py(value));
- }
- return values;
- }
- };
- PyBuiltinMethodNarrow mapSetDefaultProxy = new MapMethod("setdefault", 1, 2) {
- @Override
- public PyObject __call__(PyObject key) {
- return __call__(key, Py.None);
- }
- @Override
- public PyObject __call__(PyObject key, PyObject _default) {
- Object jkey = Py.tojava(key, Object.class);
- Object jval = asMap().get(jkey);
- if (jval == null) {
- asMap().put(jkey, _default == Py.None? Py.None : Py.tojava(_default, Object.class));
- return _default;
- }
- return Py.java2py(jval);
- }
- };
- PyBuiltinMethodNarrow mapPopProxy = new MapMethod("pop", 1, 2) {
- @Override
- public PyObject __call__(PyObject key) {
- return __call__(key, null);
- }
- @Override
- public PyObject __call__(PyObject key, PyObject _default) {
- Object jkey = Py.tojava(key, Object.class);
- if (asMap().containsKey(jkey)) {
- PyObject value = Py.java2py(asMap().remove(jkey));
- assert (value != null);
- return Py.java2py(value);
- } else {
- if (_default == null) {
- throw Py.KeyError(key);
- }
- return _default;
- }
- }
- };
- PyBuiltinMethodNarrow mapPopItemProxy = new MapMethod("popitem", 0) {
- @Override
- public PyObject __call__() {
- if (asMap().size() == 0) {
- throw Py.KeyError("popitem(): map is empty");
- }
- Object key = asMap().keySet().toArray()[0];
- Object val = asMap().remove(key);
- return Py.java2py(val);
- }
- };
- PyBuiltinMethodNarrow mapItemsProxy = new MapMethod("items", 0) {
- @Override
- public PyObject __call__() {
- PyList items = new PyList();
- for (Map.Entry<Object, Object> entry : asMap().entrySet()) {
- items.add(new PyTuple(Py.java2py(entry.getKey()),
- Py.java2py(entry.getValue())));
- }
- return items;
- }
- };
- PyBuiltinMethodNarrow mapCopyProxy = new MapMethod("copy", 0) {
- @Override
- public PyObject __call__() {
- Map<Object, Object> jmap = asMap();
- Map<Object, Object> jclone;
- try {
- jclone = (Map<Object, Object>) jmap.getClass().newInstance();
- } catch (IllegalAccessException e) {
- throw Py.JavaError(e);
- } catch (InstantiationException e) {
- throw Py.JavaError(e);
- }
- for (Map.Entry<Object, Object> entry : jmap.entrySet()) {
- jclone.put(entry.getKey(), entry.getValue());
- }
- return Py.java2py(jclone);
- }
- };
- PyBuiltinMethodNarrow mapUpdateProxy = new MapMethod("update", 0, 1) {
- private Map<Object, Object> jmap;
- @Override
- public PyObject __call__() {
- return Py.None;
- }
- @Override
- public PyObject __call__(PyObject other) {
- // `other` is either another dict-like object, or an iterable of key/value pairs (as tuples
- // or other iterables of length two)
- return __call__(new PyObject[]{other}, new String[]{});
- }
- @Override
- public PyObject __call__(PyObject[] args, String[] keywords) {
- if ((args.length - keywords.length) != 1) {
- throw info.unexpectedCall(args.length, false);
- }
- jmap = asMap();
- PyObject other = args[0];
- // update with entries from `other` (adapted from their equivalent in PyDictionary#update)
- Object proxy = other.getJavaProxy();
- if (proxy instanceof Map) {
- merge((Map<Object, Object>)proxy);
- }
- else if (other.__findattr__("keys") != null) {
- merge(other);
- } else {
- mergeFromSeq(other);
- }
- // update with entries from keyword arguments
- for (int i = 0; i < keywords.length; i++) {
- String jkey = keywords[i];
- PyObject value = args[1+i];
- jmap.put(jkey, Py.tojava(value, Object.class));
- }
- return Py.None;
- }
- private void merge(Map<Object, Object> other) {
- for (Map.Entry<Object, Object> entry : other.entrySet()) {
- jmap.put(entry.getKey(), entry.getValue());
- }
- }
- private void merge(PyObject other) {
- if (other instanceof PyDictionary) {
- jmap.putAll(((PyDictionary) other).getMap());
- } else if (other instanceof PyStringMap) {
- mergeFromKeys(other, ((PyStringMap)other).keys());
- } else {
- mergeFromKeys(other, other.invoke("keys"));
- }
- }
- private void mergeFromKeys(PyObject other, PyObject keys) {
- for (PyObject key : keys.asIterable()) {
- jmap.put(Py.tojava(key, Object.class),
- Py.tojava(other.__getitem__(key), Object.class));
- }
- }
- private void mergeFromSeq(PyObject other) {
- PyObject pairs = other.__iter__();
- PyObject pair;
+ return contained ? Py.True : Py.False;
+ }
+ };
+ proxies.put(Collection.class, new PyBuiltinMethod[]{lenProxy,
+ containsProxy});
- for (int i = 0; (pair = pairs.__iternext__()) != null; i++) {
- try {
- pair = PySequence.fastSequence(pair, "");
- } catch(PyException pye) {
- if (pye.match(Py.TypeError)) {
- throw Py.TypeError(String.format("cannot convert dictionary update sequence "
- + "element #%d to a sequence", i));
- }
- throw pye;
- }
- int n;
- if ((n = pair.__len__()) != 2) {
- throw Py.ValueError(String.format("dictionary update sequence element #%d "
- + "has length %d; 2 is required", i, n));
- }
- jmap.put(Py.tojava(pair.__getitem__(0), Object.class),
- Py.tojava(pair.__getitem__(1), Object.class));
- }
- }
- };
- PyBuiltinClassMethodNarrow mapFromKeysProxy = new MapClassMethod("fromkeys", 1, 2) {
- @Override
- public PyObject __call__(PyObject keys) {
- return __call__(keys, null);
- }
- @Override
- public PyObject __call__(PyObject keys, PyObject _default) {
- Object defobj = _default == null ? Py.None : Py.tojava(_default, Object.class);
- Class<?> theClass = asClass();
- try {
- // always injected to java.util.Map, so we know the class object we get from asClass is subtype of java.util.Map
- Map<Object, Object> theMap = (Map<Object, Object>) theClass.newInstance();
- for (PyObject key : keys.asIterable()) {
- theMap.put(Py.tojava(key, Object.class), defobj);
- }
- return Py.java2py(theMap);
- } catch (InstantiationException e) {
- throw Py.JavaError(e);
- } catch (IllegalAccessException e) {
- throw Py.JavaError(e);
- }
- }
- };
- collectionProxies.put(Map.class, new PyBuiltinMethod[] {mapLenProxy,
- // map IterProxy can conflict with Iterable.class; fix after the fact in handleMroError
- mapIterProxy,
- mapReprProxy,
- mapEqProxy,
- mapContainsProxy,
- mapGetItemProxy,
- //mapGetProxy,
- mapPutProxy,
- mapRemoveProxy,
- mapIterItemsProxy,
- mapHasKeyProxy,
- mapKeysProxy,
- //mapValuesProxy,
- mapSetDefaultProxy,
- mapPopProxy,
- mapPopItemProxy,
- mapItemsProxy,
- mapCopyProxy,
- mapUpdateProxy,
- mapFromKeysProxy}); // class method
- postCollectionProxies.put(Map.class, new PyBuiltinMethod[] {mapGetProxy,
- mapValuesProxy});
+ PyBuiltinMethodNarrow iteratorProxy = new PyBuiltinMethodNarrow("__iter__") {
+ @Override
+ public PyObject __call__() {
+ return new JavaIterator(((Iterator) self.getJavaProxy()));
+ }
+ };
+ proxies.put(Iterator.class, new PyBuiltinMethod[]{iteratorProxy});
- PyBuiltinMethodNarrow listGetProxy = new ListMethod("__getitem__", 1) {
- @Override
- public PyObject __call__(PyObject key) {
- return new ListIndexDelegate(asList()).checkIdxAndGetItem(key);
- }
- };
- PyBuiltinMethodNarrow listSetProxy = new ListMethod("__setitem__", 2) {
- @Override
- public PyObject __call__(PyObject key, PyObject value) {
- new ListIndexDelegate(asList()).checkIdxAndSetItem(key, value);
- return Py.None;
- }
- };
- PyBuiltinMethodNarrow listRemoveProxy = new ListMethod("__delitem__", 1) {
- @Override
- public PyObject __call__(PyObject key) {
- new ListIndexDelegate(asList()).checkIdxAndDelItem(key);
- return Py.None;
- }
- };
- collectionProxies.put(List.class, new PyBuiltinMethod[] {listGetProxy,
- listSetProxy,
- listRemoveProxy});
- }
- return collectionProxies;
+ PyBuiltinMethodNarrow enumerationProxy = new PyBuiltinMethodNarrow("__iter__") {
+ @Override
+ public PyObject __call__() {
+ return new EnumerationIter(((Enumeration) self.getJavaProxy()));
+ }
+ };
+ proxies.put(Enumeration.class, new PyBuiltinMethod[]{enumerationProxy});
+ proxies.put(List.class, JavaProxyList.getProxyMethods());
+ proxies.put(Map.class, JavaProxyMap.getProxyMethods());
+ proxies.put(Set.class, JavaProxySet.getProxyMethods());
+ return Collections.unmodifiableMap(proxies);
}
- private static Map<Class<?>, PyBuiltinMethod[]> getPostCollectionProxies() {
- getCollectionProxies();
- assert (postCollectionProxies != null);
- return postCollectionProxies;
- }
-
-
- protected static class ListIndexDelegate extends SequenceIndexDelegate {
-
- private final List list;
-
- public ListIndexDelegate(List list) {
- this.list = list;
- }
- @Override
- public void delItem(int idx) {
- list.remove(idx);
- }
-
- @Override
- public PyObject getItem(int idx) {
- return Py.java2py(list.get(idx));
- }
-
- @Override
- public PyObject getSlice(int start, int stop, int step) {
- if (step > 0 && stop < start) {
- stop = start;
- }
- int n = PySequence.sliceLength(start, stop, step);
- List<Object> newList;
- try {
- newList = list.getClass().newInstance();
- } catch (Exception e) {
- throw Py.JavaError(e);
- }
- int j = 0;
- for (int i = start; j < n; i += step) {
- newList.add(list.get(i));
- j++;
- }
- return Py.java2py(newList);
- }
-
- @Override
- public String getTypeName() {
- return list.getClass().getName();
- }
-
- @Override
- public int len() {
- return list.size();
- }
-
- @Override
- public void setItem(int idx, PyObject value) {
- list.set(idx, value.__tojava__(Object.class));
- }
-
- @Override
- public void setSlice(int start, int stop, int step, PyObject value) {
- if (stop < start) {
- stop = start;
- }
- if (value.javaProxy == this.list) {
- List<Object> xs = Generic.list();
- xs.addAll(this.list);
- setsliceList(start, stop, step, xs);
- } else if (value instanceof PyList) {
- setslicePyList(start, stop, step, (PyList)value);
- } else {
- Object valueList = value.__tojava__(List.class);
- if (valueList != null && valueList != Py.NoConversion) {
- setsliceList(start, stop, step, (List)valueList);
- } else {
- setsliceIterator(start, stop, step, value.asIterable().iterator());
- }
- }
- }
-
-
-
- final private void setsliceList(int start, int stop, int step, List<Object> value) {
- if (step == 1) {
- list.subList(start, stop).clear();
- list.addAll(start, value);
- } else {
- int size = list.size();
- Iterator<Object> iter = value.listIterator();
- for (int j = start; iter.hasNext(); j += step) {
- Object item =iter.next();
- if (j >= size) {
- list.add(item);
- } else {
- list.set(j, item);
- }
- }
- }
- }
-
- final private void setsliceIterator(int start, int stop, int step, Iterator<PyObject> iter) {
- if (step == 1) {
- List<Object> insertion = new ArrayList<Object>();
- if (iter != null) {
- while (iter.hasNext()) {
- insertion.add(iter.next().__tojava__(Object.class));
- }
- }
- list.subList(start, stop).clear();
- list.addAll(start, insertion);
- } else {
- int size = list.size();
- for (int j = start; iter.hasNext(); j += step) {
- Object item = iter.next().__tojava__(Object.class);
- if (j >= size) {
- list.add(item);
- } else {
- list.set(j, item);
- }
- }
- }
- }
-
- final private void setslicePyList(int start, int stop, int step, PyList value) {
- if (step == 1) {
- list.subList(start, stop).clear();
- int n = value.getList().size();
- for (int i=0, j=start; i<n; i++, j++) {
- Object item = value.getList().get(i).__tojava__(Object.class);
- list.add(j, item);
- }
- } else {
- int size = list.size();
- Iterator<PyObject> iter = value.getList().listIterator();
- for (int j = start; iter.hasNext(); j += step) {
- Object item = iter.next().__tojava__(Object.class);
- if (j >= size) {
- list.add(item);
- } else {
- list.set(j, item);
- }
- }
- }
- }
-
-
- @Override
- public void delItems(int start, int stop) {
- int n = stop - start;
- while (n-- > 0) {
- delItem(start);
- }
- }
+ private static Map<Class<?>, PyBuiltinMethod[]> buildPostCollectionProxies() {
+ final Map<Class<?>, PyBuiltinMethod[]> postProxies = new HashMap();
+ postProxies.put(List.class, JavaProxyList.getPostProxyMethods());
+ postProxies.put(Map.class, JavaProxyMap.getPostProxyMethods());
+ postProxies.put(Set.class, JavaProxySet.getPostProxyMethods());
+ return Collections.unmodifiableMap(postProxies);
}
}
diff --git a/src/org/python/core/PyList.java b/src/org/python/core/PyList.java
--- a/src/org/python/core/PyList.java
+++ b/src/org/python/core/PyList.java
@@ -14,11 +14,13 @@
import java.util.Collections;
import java.util.Comparator;
import java.util.ConcurrentModificationException;
+import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.ListIterator;
import java.lang.reflect.Array;
+import java.util.Map;
@ExposedType(name = "list", base = PyObject.class, doc = BuiltinDocs.list_doc)
public class PyList extends PySequenceList implements List {
@@ -101,6 +103,21 @@
this(TYPE, listify(iter));
}
+ // refactor and put in Py presumably;
+ // presumably we can consume an arbitrary iterable too!
+ private static void addCollection(List<PyObject> list, Collection<Object> seq) {
+ Map<Long, PyObject> seen = new HashMap();
+ for (Object item : seq) {
+ long id = Py.java_obj_id(item);
+ PyObject seen_obj = seen.get(id);
+ if (seen_obj != null) {
+ seen_obj = Py.java2py(item);
+ seen.put(id, seen_obj);
+ }
+ list.add(seen_obj);
+ }
+ }
+
@ExposedNew
@ExposedMethod(doc = BuiltinDocs.list___init___doc)
final void list___init__(PyObject[] args, String[] kwds) {
@@ -114,6 +131,9 @@
list.addAll(((PyList) seq).list); // don't convert
} else if (seq instanceof PyTuple) {
list.addAll(((PyTuple) seq).getList());
+ } else if (seq.getClass().isAssignableFrom(Collection.class)) {
+ System.err.println("Adding from collection");
+ addCollection(list, (Collection)seq);
} else {
for (PyObject item : seq.asIterable()) {
append(item);
diff --git a/src/org/python/core/PyModule.java b/src/org/python/core/PyModule.java
--- a/src/org/python/core/PyModule.java
+++ b/src/org/python/core/PyModule.java
@@ -60,7 +60,9 @@
ensureDict();
__dict__.__setitem__("__name__", name);
__dict__.__setitem__("__doc__", doc);
- __dict__.__setitem__("__package__", Py.None);
+ if (name.equals(new PyString("__main__"))) {
+ __dict__.__setitem__("__builtins__", Py.getSystemState().modules.__finditem__("__builtin__"));
+ }
}
public PyObject fastGetDict() {
@@ -166,10 +168,24 @@
}
public PyObject __dir__() {
- if (__dict__ == null) {
- throw Py.TypeError("module.__dict__ is not a dictionary");
+ // Some special casing to ensure that classes deriving from PyModule
+ // can use their own __dict__. Although it would be nice to do this in
+ // PyModuleDerived, current templating in gderived.py does not support
+ // including from object, then overriding a specific method.
+ PyObject d;
+ if (this instanceof PyModuleDerived) {
+ d = __findattr_ex__("__dict__");
+ } else {
+ d = __dict__;
}
- return __dict__.invoke("keys");
+ if (d == null ||
+ !(d instanceof PyDictionary ||
+ d instanceof PyStringMap ||
+ d instanceof PyDictProxy)) {
+ throw Py.TypeError(String.format("%.200s.__dict__ is not a dictionary",
+ getType().fastGetName().toLowerCase()));
+ }
+ return d.invoke("keys");
}
private void ensureDict() {
diff --git a/src/org/python/core/PySequence.java b/src/org/python/core/PySequence.java
--- a/src/org/python/core/PySequence.java
+++ b/src/org/python/core/PySequence.java
@@ -156,7 +156,7 @@
}
final PyObject seq___eq__(PyObject o) {
- if (!isSubType(o)) {
+ if (!isSubType(o) || o.getType() == PyObject.TYPE) {
return null;
}
int tl = __len__();
@@ -174,7 +174,7 @@
}
final PyObject seq___ne__(PyObject o) {
- if (!isSubType(o)) {
+ if (!isSubType(o) || o.getType() == PyObject.TYPE) {
return null;
}
int tl = __len__();
@@ -192,7 +192,7 @@
}
final PyObject seq___lt__(PyObject o) {
- if (!isSubType(o)) {
+ if (!isSubType(o) || o.getType() == PyObject.TYPE) {
return null;
}
int i = cmp(this, -1, o, -1);
@@ -208,7 +208,7 @@
}
final PyObject seq___le__(PyObject o) {
- if (!isSubType(o)) {
+ if (!isSubType(o) || o.getType() == PyObject.TYPE) {
return null;
}
int i = cmp(this, -1, o, -1);
@@ -224,7 +224,7 @@
}
final PyObject seq___gt__(PyObject o) {
- if (!isSubType(o)) {
+ if (!isSubType(o) || o.getType() == PyObject.TYPE) {
return null;
}
int i = cmp(this, -1, o, -1);
@@ -240,7 +240,7 @@
}
final PyObject seq___ge__(PyObject o) {
- if (!isSubType(o)) {
+ if (!isSubType(o) || o.getType() == PyObject.TYPE) {
return null;
}
int i = cmp(this, -1, o, -1);
diff --git a/src/org/python/core/PySlice.java b/src/org/python/core/PySlice.java
--- a/src/org/python/core/PySlice.java
+++ b/src/org/python/core/PySlice.java
@@ -113,62 +113,64 @@
* @return an array with the start at index 0, stop at index 1, step at index 2 and
* slicelength at index 3
*/
- public int[] indicesEx(int len) {
- int start;
- int stop;
- int step;
- int slicelength;
+ public int[] indicesEx(int length) {
+ /* The corresponding C code (PySlice_GetIndicesEx) states:
+ * "this is harder to get right than you might think"
+ * As a consequence, I have chosen to copy the code and translate to Java.
+ * Note *rstart, etc., become result_start - the usual changes we need
+ * when going from pointers to corresponding Java.
+ */
- if (getStep() == Py.None) {
- step = 1;
+ int defstart, defstop;
+ int result_start, result_stop, result_step, result_slicelength;
+
+ if (step == Py.None) {
+ result_step = 1;
} else {
- step = calculateSliceIndex(getStep());
- if (step == 0) {
+ result_step = calculateSliceIndex(step);
+ if (result_step == 0) {
throw Py.ValueError("slice step cannot be zero");
}
}
- if (getStart() == Py.None) {
- start = step < 0 ? len - 1 : 0;
+ defstart = result_step < 0 ? length - 1 : 0;
+ defstop = result_step < 0 ? -1 : length;
+
+ if (start == Py.None) {
+ result_start = defstart;
} else {
- start = calculateSliceIndex(getStart());
- if (start < 0) {
- start += len;
- }
- if (start < 0) {
- start = step < 0 ? -1 : 0;
- }
- if (start >= len) {
- start = step < 0 ? len - 1 : len;
+ result_start = calculateSliceIndex(start);
+ if (result_start < 0) result_start += length;
+ if (result_start < 0) result_start = (result_step < 0) ? -1 : 0;
+ if (result_start >= length) {
+ result_start = (result_step < 0) ? length - 1 : length;
}
}
- if (getStop() == Py.None) {
- stop = step < 0 ? -1 : len;
+ if (stop == Py.None) {
+ result_stop = defstop;
} else {
- stop = calculateSliceIndex(getStop());
- if (stop < 0) {
- stop += len;
- }
- if (stop < 0) {
- stop = -1;
- }
- if (stop > len) {
- stop = len;
+ result_stop = calculateSliceIndex(stop);
+ if (result_stop < 0) result_stop += length;
+ if (result_stop < 0) result_stop = (result_step < 0) ? -1 : 0;
+ if (result_stop >= length) {
+ result_stop = (result_step < 0) ? length - 1 : length;
}
}
- if ((step < 0 && stop >= start) || (step > 0 && start >= stop)) {
- slicelength = 0;
- } else if (step < 0) {
- slicelength = (stop - start + 1) / (step) + 1;
+ if ((result_step < 0 && result_stop >= result_start)
+ || (result_step > 0 && result_start >= result_stop)) {
+ result_slicelength = 0;
+ } else if (result_step < 0) {
+ result_slicelength = (result_stop - result_start + 1) / (result_step) + 1;
} else {
- slicelength = (stop - start - 1) / (step) + 1;
+ result_slicelength = (result_stop - result_start - 1) / (result_step) + 1;
}
- return new int[] {start, stop, step, slicelength};
+ return new int[]{result_start, result_stop, result_step, result_slicelength};
}
+
/**
* Calculate indices for the deprecated __get/set/delslice__ methods.
*
@@ -230,4 +232,14 @@
public final PyObject getStep() {
return step;
}
+
+ @ExposedMethod
+ final PyObject slice___reduce__() {
+ return new PyTuple(getType(), new PyTuple(start, stop, step));
+ }
+
+ @ExposedMethod(defaults = "Py.None")
+ final PyObject slice___reduce_ex__(PyObject protocol) {
+ return new PyTuple(getType(), new PyTuple(start, stop, step));
+ }
}
diff --git a/src/org/python/core/PyString.java b/src/org/python/core/PyString.java
--- a/src/org/python/core/PyString.java
+++ b/src/org/python/core/PyString.java
@@ -4,8 +4,11 @@
import java.lang.ref.Reference;
import java.lang.ref.SoftReference;
import java.math.BigInteger;
+import java.util.ArrayList;
+import java.util.Collection;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
+import java.util.List;
import org.python.core.buffer.BaseBuffer;
import org.python.core.buffer.SimpleStringBuffer;
@@ -720,6 +723,14 @@
}
}
+ if (c.isAssignableFrom(Collection.class)) {
+ List<Object> list = new ArrayList();
+ for (int i = 0; i < __len__(); i++) {
+ list.add(pyget(i).__tojava__(String.class));
+ }
+ return list;
+ }
+
if (c.isInstance(this)) {
return this;
}
@@ -733,6 +744,10 @@
return Py.makeCharacter(string.charAt(i));
}
+ public int getInt(int i) {
+ return string.charAt(i);
+ }
+
@Override
protected PyObject getslice(int start, int stop, int step) {
if (step > 0 && stop < start) {
diff --git a/src/org/python/core/PySystemState.java b/src/org/python/core/PySystemState.java
--- a/src/org/python/core/PySystemState.java
+++ b/src/org/python/core/PySystemState.java
@@ -19,6 +19,7 @@
import java.nio.charset.Charset;
import java.nio.charset.UnsupportedCharsetException;
import java.security.AccessControlException;
+import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.LinkedList;
@@ -29,6 +30,7 @@
import java.util.StringTokenizer;
import java.util.concurrent.Callable;
import java.util.concurrent.ConcurrentMap;
+import java.util.concurrent.locks.ReentrantLock;
import java.util.jar.JarEntry;
import java.util.jar.JarFile;
@@ -130,6 +132,9 @@
public PyList argv = new PyList();
public PyObject modules;
+ public Map<String, PyModule> modules_reloading;
+ private ReentrantLock importLock;
+ private ClassLoader syspathJavaLoader;
public PyList path;
public PyList warnoptions = new PyList();
@@ -190,6 +195,9 @@
initialize();
closer = new PySystemStateCloser(this);
modules = new PyStringMap();
+ modules_reloading = new HashMap<String, PyModule>();
+ importLock = new ReentrantLock();
+ syspathJavaLoader = new SyspathJavaLoader(imp.getParentClassLoader());
argv = (PyList)defaultArgv.repeat(1);
path = (PyList)defaultPath.repeat(1);
@@ -337,6 +345,14 @@
return codecState;
}
+ public ReentrantLock getImportLock() {
+ return importLock;
+ }
+
+ public ClassLoader getSyspathJavaLoader() {
+ return syspathJavaLoader;
+ }
+
// xxx fix this accessors
@Override
public PyObject __findattr_ex__(String name) {
@@ -428,6 +444,15 @@
this.recursionlimit = recursionlimit;
}
+ public PyObject gettrace() {
+ ThreadState ts = Py.getThreadState();
+ if (ts.tracefunc == null) {
+ return Py.None;
+ } else {
+ return ((PythonTraceFunction)ts.tracefunc).tracefunc;
+ }
+ }
+
public void settrace(PyObject tracefunc) {
ThreadState ts = Py.getThreadState();
if (tracefunc == Py.None) {
@@ -437,6 +462,15 @@
}
}
+ public PyObject getprofile() {
+ ThreadState ts = Py.getThreadState();
+ if (ts.profilefunc == null) {
+ return Py.None;
+ } else {
+ return ((PythonTraceFunction)ts.profilefunc).tracefunc;
+ }
+ }
+
public void setprofile(PyObject profilefunc) {
ThreadState ts = Py.getThreadState();
if (profilefunc == Py.None) {
diff --git a/src/org/python/core/PyType.java b/src/org/python/core/PyType.java
--- a/src/org/python/core/PyType.java
+++ b/src/org/python/core/PyType.java
@@ -203,6 +203,8 @@
type.bases = tmpBases.length == 0 ? new PyObject[] {PyObject.TYPE} : tmpBases;
type.dict = dict;
type.tp_flags = Py.TPFLAGS_HEAPTYPE | Py.TPFLAGS_BASETYPE;
+ // Enable defining a custom __dict__ via a property, method, or other descriptor
+ boolean defines_dict = dict.__finditem__("__dict__") != null;
// immediately setup the javaProxy if applicable. may modify bases
List<Class<?>> interfaces = Generic.list();
@@ -215,7 +217,7 @@
base.name));
}
- type.createAllSlots(!base.needs_userdict, !base.needs_weakref);
+ type.createAllSlots(!(base.needs_userdict || defines_dict), !base.needs_weakref);
type.ensureAttributes();
type.invalidateMethodCache();
diff --git a/src/org/python/core/PyUnicode.java b/src/org/python/core/PyUnicode.java
--- a/src/org/python/core/PyUnicode.java
+++ b/src/org/python/core/PyUnicode.java
@@ -7,6 +7,7 @@
import java.util.List;
import java.util.Set;
+import com.google.common.base.CharMatcher;
import org.python.core.stringlib.FieldNameIterator;
import org.python.core.stringlib.MarkupIterator;
import org.python.expose.ExposedMethod;
@@ -578,6 +579,11 @@
return string.length() - translator.suppCount();
}
+ private static String checkEncoding(String s) {
+ if (s == null || CharMatcher.ASCII.matchesAllOf(s)) { return s; }
+ return codecs.PyUnicode_EncodeASCII(s, s.length(), null);
+ }
+
@ExposedNew
final static PyObject unicode_new(PyNewWrapper new_, boolean init, PyType subtype,
PyObject[] args, String[] keywords) {
@@ -585,8 +591,8 @@
new ArgParser("unicode", args, keywords, new String[] {"string", "encoding",
"errors"}, 0);
PyObject S = ap.getPyObject(0, null);
- String encoding = ap.getString(1, null);
- String errors = ap.getString(2, null);
+ String encoding = checkEncoding(ap.getString(1, null));
+ String errors = checkEncoding(ap.getString(2, null));
if (new_.for_type == subtype) {
if (S == null) {
return new PyUnicode("");
@@ -731,6 +737,10 @@
return Py.makeCharacter(codepoint, true);
}
+ public int getInt(int i) {
+ return getString().codePointAt(translator.utf16Index(i));
+ }
+
private class SubsequenceIteratorImpl implements Iterator {
private int current, k, stop, step;
diff --git a/src/org/python/core/PyXRange.java b/src/org/python/core/PyXRange.java
--- a/src/org/python/core/PyXRange.java
+++ b/src/org/python/core/PyXRange.java
@@ -5,6 +5,11 @@
import org.python.expose.ExposedNew;
import org.python.expose.ExposedType;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Iterator;
+import java.util.List;
+
/**
* The builtin xrange type.
*/
@@ -188,4 +193,22 @@
return String.format("xrange(%d, %d, %d)", start, stop, step);
}
}
+
+ @Override
+ public Object __tojava__(Class<?> c) {
+ if (c.isAssignableFrom(Iterable.class)) {
+ return new JavaIterator(range_iter());
+ }
+ if (c.isAssignableFrom(Iterator.class)) {
+ return (new JavaIterator(range_iter())).iterator();
+ }
+ if (c.isAssignableFrom(Collection.class)) {
+ List<Object> list = new ArrayList();
+ for (Object obj : new JavaIterator(range_iter())) {
+ list.add(obj);
+ }
+ return list;
+ }
+ return super.__tojava__(c);
+ }
}
diff --git a/src/org/python/core/__builtin__.java b/src/org/python/core/__builtin__.java
--- a/src/org/python/core/__builtin__.java
+++ b/src/org/python/core/__builtin__.java
@@ -14,8 +14,6 @@
import org.python.antlr.base.mod;
import org.python.core.stringlib.IntegerFormatter;
-import org.python.core.stringlib.InternalFormat;
-import org.python.core.stringlib.InternalFormat.Spec;
import org.python.core.util.ExtraMath;
import org.python.core.util.RelativeFile;
import org.python.modules._functools._functools;
@@ -70,8 +68,7 @@
case 5:
return __builtin__.hash(arg1);
case 6:
- return Py.newUnicode(__builtin__.unichr(Py.py2int(arg1, "unichr(): 1st arg can't "
- + "be coerced to int")));
+ return Py.newUnicode(__builtin__.unichr(arg1));
case 7:
return __builtin__.abs(arg1);
case 9:
@@ -401,6 +398,16 @@
return obj.isCallable();
}
+ public static int unichr(PyObject obj) {
+ long l = obj.asLong();
+ if (l < PySystemState.minint) {
+ throw Py.OverflowError("signed integer is less than minimum");
+ } else if (l > PySystemState.maxint) {
+ throw Py.OverflowError("signed integer is greater than maximum");
+ }
+ return unichr((int)l);
+ }
+
public static int unichr(int i) {
if (i < 0 || i > PySystemState.maxunicode) {
throw Py.ValueError("unichr() arg not in range(0x110000)");
@@ -435,8 +442,11 @@
}
public static PyObject dir(PyObject o) {
- PyList ret = (PyList) o.__dir__();
- ret.sort();
+ PyObject ret = o.__dir__();
+ if (!Py.isInstance(ret, PyList.TYPE)) {
+ throw Py.TypeError("__dir__() must return a list, not " + ret.getType().fastGetName());
+ }
+ ((PyList)ret).sort();
return ret;
}
@@ -884,39 +894,6 @@
y.getType().fastGetName(), z.getType().fastGetName()));
}
- public static PyObject range(PyObject start, PyObject stop, PyObject step) {
- int ilow = 0;
- int ihigh = 0;
- int istep = 1;
- int n;
-
- try {
- ilow = start.asInt();
- ihigh = stop.asInt();
- istep = step.asInt();
- } catch (PyException pye) {
- return handleRangeLongs(start, stop, step);
- }
-
- if (istep == 0) {
- throw Py.ValueError("range() step argument must not be zero");
- }
- if (istep > 0) {
- n = PyXRange.getLenOfRange(ilow, ihigh, istep);
- } else {
- n = PyXRange.getLenOfRange(ihigh, ilow, -istep);
- }
- if (n < 0) {
- throw Py.OverflowError("range() result has too many items");
- }
-
- PyObject[] range = new PyObject[n];
- for (int i = 0; i < n; i++, ilow += istep) {
- range[i] = Py.newInteger(ilow);
- }
- return new PyList(range);
- }
-
public static PyObject range(PyObject n) {
return range(Py.Zero, n, Py.One);
}
@@ -925,10 +902,7 @@
return range(start, stop, Py.One);
}
- /**
- * Handle range() when PyLong arguments (that OverFlow ints) are given.
- */
- private static PyObject handleRangeLongs(PyObject ilow, PyObject ihigh, PyObject istep) {
+ public static PyObject range(PyObject ilow, PyObject ihigh, PyObject istep) {
ilow = getRangeLongArgument(ilow, "start");
ihigh = getRangeLongArgument(ihigh, "end");
istep = getRangeLongArgument(istep, "step");
@@ -949,8 +923,8 @@
PyObject[] range = new PyObject[n];
for (int i = 0; i < n; i++) {
- range[i] = ilow.__long__();
- ilow = ilow.__add__(istep);
+ range[i] = ilow;
+ ilow = ilow._add(istep);
}
return new PyList(range);
}
@@ -972,7 +946,7 @@
}
try {
// See PyXRange.getLenOfRange for the primitive version
- PyObject diff = hi.__sub__(lo).__sub__(Py.One);
+ PyObject diff = hi._sub(lo)._sub(Py.One);
PyObject n = diff.__floordiv__(step).__add__(Py.One);
return n.asInt();
} catch (PyException pye) {
@@ -1400,7 +1374,11 @@
@Override
public PyObject __call__(PyObject arg1, PyObject arg2) {
- return arg1.__format__(arg2);
+ PyObject formatted = arg1.__format__(arg2);
+ if (!Py.isInstance(formatted, PyString.TYPE) && !Py.isInstance(formatted, PyUnicode.TYPE) ) {
+ throw Py.TypeError("instance.__format__ must return string or unicode, not " + formatted.getType().fastGetName());
+ }
+ return formatted;
}
}
diff --git a/src/org/python/core/imp.java b/src/org/python/core/imp.java
--- a/src/org/python/core/imp.java
+++ b/src/org/python/core/imp.java
@@ -7,6 +7,7 @@
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
+import java.util.Map;
import java.util.concurrent.locks.ReentrantLock;
import org.python.compiler.Module;
@@ -26,30 +27,47 @@
private static final String UNKNOWN_SOURCEFILE = "<unknown>";
- private static final int APIVersion = 34;
+ private static final int APIVersion = 35;
public static final int NO_MTIME = -1;
- // This should change to 0 for Python 2.7 and 3.0 see PEP 328
+ // This should change to Python 3.x; note that 2.7 allows relative
+ // imports unless `from __future__ import absolute_import`
public static final int DEFAULT_LEVEL = -1;
+ public static class CodeData {
+ private final byte[] bytes;
+ private final long mtime;
+ private final String filename;
+
+ public CodeData(byte[] bytes, long mtime, String filename) {
+ this.bytes = bytes;
+ this.mtime = mtime;
+ this.filename = filename;
+ }
+
+ public byte[] getBytes() {
+ return bytes;
+ }
+
+ public long getMTime() {
+ return mtime;
+ }
+
+ public String getFilename() {
+ return filename;
+ }
+ }
+
+ public static enum CodeImport {
+ source, compiled_only;
+ }
+
/** A non-empty fromlist for __import__'ing sub-modules. */
private static final PyObject nonEmptyFromlist = new PyTuple(Py.newString("__doc__"));
- /** Synchronizes import operations */
- public static final ReentrantLock importLock = new ReentrantLock();
-
- private static Object syspathJavaLoaderLock = new Object();
-
- private static ClassLoader syspathJavaLoader = null;
-
public static ClassLoader getSyspathJavaLoader() {
- synchronized (syspathJavaLoaderLock) {
- if (syspathJavaLoader == null) {
- syspathJavaLoader = new SyspathJavaLoader(getParentClassLoader());
- }
- }
- return syspathJavaLoader;
+ return Py.getSystemState().getSyspathJavaLoader();
}
/**
@@ -127,6 +145,10 @@
return module;
}
module = new PyModule(name, null);
+ PyModule __builtin__ = (PyModule)modules.__finditem__("__builtin__");
+ PyObject __dict__ = module.__getattr__("__dict__");
+ __dict__.__setitem__("__builtins__", __builtin__.__getattr__("__dict__"));
+ __dict__.__setitem__("__package__", Py.None);
modules.__setitem__(name, module);
return module;
}
@@ -180,9 +202,14 @@
static PyObject createFromPyClass(String name, InputStream fp, boolean testing,
String sourceName, String compiledName, long mtime) {
- byte[] data = null;
+ return createFromPyClass(name, fp, testing, sourceName, compiledName, mtime, CodeImport.source);
+ }
+
+ static PyObject createFromPyClass(String name, InputStream fp, boolean testing,
+ String sourceName, String compiledName, long mtime, CodeImport source) {
+ CodeData data = null;
try {
- data = readCode(name, fp, testing, mtime);
+ data = readCodeData(name, fp, testing, mtime);
} catch (IOException ioe) {
if (!testing) {
throw Py.ImportError(ioe.getMessage() + "[name=" + name + ", source=" + sourceName
@@ -194,7 +221,8 @@
}
PyCode code;
try {
- code = BytecodeLoader.makeCode(name + "$py", data, sourceName);
+ code = BytecodeLoader.makeCode(name + "$py", data.getBytes(),
+ source == CodeImport.compiled_only ? data.getFilename() : sourceName);
} catch (Throwable t) {
if (testing) {
return null;
@@ -205,7 +233,6 @@
Py.writeComment(IMPORT_LOG, String.format("import %s # precompiled from %s", name,
compiledName));
-
return createFromCode(name, code, compiledName);
}
@@ -214,6 +241,14 @@
}
public static byte[] readCode(String name, InputStream fp, boolean testing, long mtime) throws IOException {
+ return readCodeData(name, fp, testing, mtime).getBytes();
+ }
+
+ public static CodeData readCodeData(String name, InputStream fp, boolean testing) throws IOException {
+ return readCodeData(name, fp, testing, NO_MTIME);
+ }
+
+ public static CodeData readCodeData(String name, InputStream fp, boolean testing, long mtime) throws IOException {
byte[] data = readBytes(fp);
int api;
AnnotationReader ar = new AnnotationReader(data);
@@ -232,7 +267,7 @@
return null;
}
}
- return data;
+ return new CodeData(data, mtime, ar.getFilename());
}
public static byte[] compileSource(String name, File file) {
@@ -369,7 +404,7 @@
* running c. Sets __file__ on the module to be moduleLocation unless
* moduleLocation is null. If c comes from a local .py file or compiled
* $py.class class moduleLocation should be the result of running new
- * File(moduleLocation).getAbsoultePath(). If c comes from a remote file or
+ * File(moduleLocation).getAbsolutePath(). If c comes from a remote file or
* is a jar moduleLocation should be the full uri for c.
*/
public static PyObject createFromCode(String name, PyCode c, String moduleLocation) {
@@ -388,14 +423,18 @@
Py.writeDebug(IMPORT_LOG, String.format("Warning: %s __file__ is unknown", name));
}
+ ReentrantLock importLock = Py.getSystemState().getImportLock();
+ importLock.lock();
try {
PyFrame f = new PyFrame(code, module.__dict__, module.__dict__, null);
code.call(Py.getThreadState(), f);
+ return module;
} catch (RuntimeException t) {
removeModule(name);
throw t;
+ } finally {
+ importLock.unlock();
}
- return module;
}
static PyObject createFromClass(String name, Class<?> c) {
@@ -517,7 +556,13 @@
static PyObject loadFromLoader(PyObject importer, String name) {
PyObject load_module = importer.__getattr__("load_module");
- return load_module.__call__(new PyObject[] { new PyString(name) });
+ ReentrantLock importLock = Py.getSystemState().getImportLock();
+ importLock.lock();
+ try {
+ return load_module.__call__(new PyObject[]{new PyString(name)});
+ } finally {
+ importLock.unlock();
+ }
}
public static PyObject loadFromCompiled(String name, InputStream stream, String sourceName,
@@ -544,8 +589,15 @@
boolean pkg = false;
try {
- pkg = dir.isDirectory() && caseok(dir, name)
- && (sourceFile.isFile() || compiledFile.isFile());
+ if (dir.isDirectory()) {
+ if (caseok(dir, name) && (sourceFile.isFile() || compiledFile.isFile())) {
+ pkg = true;
+ } else {
+ Py.warning(Py.ImportWarning, String.format(
+ "Not importing directory '%s': missing __init__.py",
+ dirName));
+ }
+ }
} catch (SecurityException e) {
// ok
}
@@ -571,8 +623,9 @@
Py.writeDebug(IMPORT_LOG, "trying precompiled " + compiledFile.getPath());
long classTime = compiledFile.lastModified();
if (classTime >= pyTime) {
- PyObject ret = createFromPyClass(modName, makeStream(compiledFile), true,
- displaySourceName, displayCompiledName, pyTime);
+ PyObject ret = createFromPyClass(
+ modName, makeStream(compiledFile), true,
+ displaySourceName, displayCompiledName, pyTime);
if (ret != null) {
return ret;
}
@@ -587,8 +640,9 @@
// If no source, try loading precompiled
Py.writeDebug(IMPORT_LOG, "trying precompiled with no source " + compiledFile.getPath());
if (compiledFile.isFile() && caseok(compiledFile, compiledName)) {
- return createFromPyClass(modName, makeStream(compiledFile), true, displaySourceName,
- displayCompiledName);
+ return createFromPyClass(
+ modName, makeStream(compiledFile), true, displaySourceName,
+ displayCompiledName, NO_MTIME, CodeImport.compiled_only);
}
} catch (SecurityException e) {
// ok
@@ -630,7 +684,13 @@
* @return the loaded module
*/
public static PyObject load(String name) {
- return import_first(name, new StringBuilder());
+ ReentrantLock importLock = Py.getSystemState().getImportLock();
+ importLock.lock();
+ try {
+ return import_first(name, new StringBuilder());
+ } finally {
+ importLock.unlock();
+ }
}
/**
@@ -659,7 +719,9 @@
*/
private static String get_parent(PyObject dict, int level) {
String modname;
- if (dict == null || level == 0) {
+ int orig_level = level;
+
+ if ((dict == null && level == -1) || level == 0) {
// try an absolute import
return null;
}
@@ -709,6 +771,21 @@
}
modname = modname.substring(0, dot);
}
+
+ if (Py.getSystemState().modules.__finditem__(modname) == null) {
+ if (orig_level < 1) {
+ if (modname.length() > 0) {
+ Py.warning(Py.RuntimeWarning,
+ String.format(
+ "Parent module '%.200s' not found " +
+ "while handling absolute import", modname));
+ }
+ } else {
+ throw Py.SystemError(String.format(
+ "Parent module '%.200s' not loaded, " +
+ "cannot perform relative import", modname));
+ }
+ }
return modname.intern();
}
@@ -734,7 +811,7 @@
return ret;
}
if (mod == null) {
- ret = find_module(fullName.intern(), name, null);
+ ret = find_module(fullName, name, null);
} else {
ret = mod.impAttr(name.intern());
}
@@ -863,7 +940,11 @@
}
parentNameBuffer = new StringBuilder("");
// could throw ImportError
- topMod = import_first(firstName, parentNameBuffer, name, fromlist);
+ if (level > 0) {
+ topMod = import_first(pkgName + "." + firstName, parentNameBuffer, name, fromlist);
+ } else {
+ topMod = import_first(firstName, parentNameBuffer, name, fromlist);
+ }
}
PyObject mod = topMod;
if (dot != -1) {
@@ -927,7 +1008,13 @@
* @return an imported module (Java or Python)
*/
public static PyObject importName(String name, boolean top) {
- return import_module_level(name, top, null, null, DEFAULT_LEVEL);
+ ReentrantLock importLock = Py.getSystemState().getImportLock();
+ importLock.lock();
+ try {
+ return import_module_level(name, top, null, null, DEFAULT_LEVEL);
+ } finally {
+ importLock.unlock();
+ }
}
/**
@@ -941,6 +1028,7 @@
*/
public static PyObject importName(String name, boolean top,
PyObject modDict, PyObject fromlist, int level) {
+ ReentrantLock importLock = Py.getSystemState().getImportLock();
importLock.lock();
try {
return import_module_level(name, top, modDict, fromlist, level);
@@ -955,7 +1043,7 @@
*/
@Deprecated
public static PyObject importOne(String mod, PyFrame frame) {
- return importOne(mod, frame, imp.DEFAULT_LEVEL);
+ return importOne(mod, frame, imp.DEFAULT_LEVEL);
}
/**
* Called from jython generated code when a statement like "import spam" is
@@ -1130,15 +1218,34 @@
}
static PyObject reload(PyModule m) {
+ PySystemState sys = Py.getSystemState();
+ PyObject modules = sys.modules;
+ Map<String, PyModule> modules_reloading = sys.modules_reloading;
+ ReentrantLock importLock = Py.getSystemState().getImportLock();
+ importLock.lock();
+ try {
+ return _reload(m, modules, modules_reloading);
+ } finally {
+ modules_reloading.clear();
+ importLock.unlock();
+ }
+ }
+
+ private static PyObject _reload(PyModule m, PyObject modules, Map<String, PyModule> modules_reloading) {
String name = m.__getattr__("__name__").toString().intern();
-
- PyObject modules = Py.getSystemState().modules;
PyModule nm = (PyModule) modules.__finditem__(name);
-
if (nm == null || !nm.__getattr__("__name__").toString().equals(name)) {
throw Py.ImportError("reload(): module " + name
+ " not in sys.modules");
}
+ PyModule existing_module = modules_reloading.get(name);
+ if (existing_module != null) {
+ // Due to a recursive reload, this module is already being reloaded.
+ return existing_module;
+ }
+ // Since we are already in a re-entrant lock,
+ // this test & set is guaranteed to be atomic
+ modules_reloading.put(name, nm);
PyList path = Py.getSystemState().path;
String modName = name;
@@ -1153,10 +1260,17 @@
name = name.substring(dot + 1, name.length()).intern();
}
- nm.__setattr__("__name__", new PyString(modName));
- PyObject ret = find_module(name, modName, path);
- modules.__setitem__(modName, ret);
- return ret;
+ nm.__setattr__("__name__", new PyString(modName)); // FIXME necessary?!
+ try {
+ PyObject ret = find_module(name, modName, path);
+ modules.__setitem__(modName, ret);
+ return ret;
+ } catch (RuntimeException t) {
+ // Need to restore module, due to the semantics of addModule, which removed it
+ // Fortunately we are in a module import lock
+ modules.__setitem__(modName, nm);
+ throw t;
+ }
}
public static int getAPIVersion() {
diff --git a/src/org/python/modules/Setup.java b/src/org/python/modules/Setup.java
--- a/src/org/python/modules/Setup.java
+++ b/src/org/python/modules/Setup.java
@@ -34,6 +34,7 @@
"_functools:org.python.modules._functools._functools",
"_hashlib",
"_io:org.python.modules._io._io",
+ "_json:org.python.modules._json._json",
"_jythonlib:org.python.modules._jythonlib._jythonlib",
"_marshal",
"_py_compile",
diff --git a/src/org/python/modules/_codecs.java b/src/org/python/modules/_codecs.java
--- a/src/org/python/modules/_codecs.java
+++ b/src/org/python/modules/_codecs.java
@@ -36,12 +36,28 @@
codecs.register(search_function);
}
- public static PyTuple lookup(String encoding) {
- return codecs.lookup(encoding);
+ private static String _castString(PyString pystr) {
+ // Jython used to treat String as equivalent to PyString, or maybe PyUnicode, as
+ // it made sense. We need to be more careful now! Insert this cast check as necessary
+ // to ensure the appropriate compliance.
+ if (pystr == null) {
+ return null;
+ }
+ String s = pystr.toString();
+ if (pystr instanceof PyUnicode) {
+ return s;
+ } else {
+ // May throw UnicodeEncodeError, per CPython behavior
+ return codecs.PyUnicode_EncodeASCII(s, s.length(), null);
+ }
}
- public static PyObject lookup_error(String handlerName) {
- return codecs.lookup_error(handlerName);
+ public static PyTuple lookup(PyString encoding) {
+ return codecs.lookup(_castString(encoding));
+ }
+
+ public static PyObject lookup_error(PyString handlerName) {
+ return codecs.lookup_error(_castString(handlerName));
}
public static void register_error(String name, PyObject errorHandler) {
@@ -68,7 +84,7 @@
* @param encoding name of encoding (to look up in codec registry)
* @return Unicode string decoded from <code>bytes</code>
*/
- public static PyObject decode(PyString bytes, String encoding) {
+ public static PyObject decode(PyString bytes, PyString encoding) {
return decode(bytes, encoding, null);
}
@@ -85,8 +101,8 @@
* @param errors error policy name (e.g. "ignore")
* @return Unicode string decoded from <code>bytes</code>
*/
- public static PyObject decode(PyString bytes, String encoding, String errors) {
- return codecs.decode(bytes, encoding, errors);
+ public static PyObject decode(PyString bytes, PyString encoding, PyString errors) {
+ return codecs.decode(bytes, _castString(encoding), _castString(errors));
}
/**
@@ -109,7 +125,7 @@
* @param encoding name of encoding (to look up in codec registry)
* @return bytes object encoding <code>unicode</code>
*/
- public static PyString encode(PyUnicode unicode, String encoding) {
+ public static PyString encode(PyUnicode unicode, PyString encoding) {
return encode(unicode, encoding, null);
}
@@ -126,8 +142,8 @@
* @param errors error policy name (e.g. "ignore")
* @return bytes object encoding <code>unicode</code>
*/
- public static PyString encode(PyUnicode unicode, String encoding, String errors) {
- return Py.newString(codecs.encode(unicode, encoding, errors));
+ public static PyString encode(PyUnicode unicode, PyString encoding, PyString errors) {
+ return Py.newString(codecs.encode(unicode, _castString(encoding), _castString(errors)));
}
/* --- Some codec support methods -------------------------------------------- */
@@ -222,6 +238,10 @@
return utf_8_decode(str, errors, false);
}
+ public static PyTuple utf_8_decode(String str, String errors, PyObject final_) {
+ return utf_8_decode(str, errors, final_.__nonzero__());
+ }
+
public static PyTuple utf_8_decode(String str, String errors, boolean final_) {
int[] consumed = final_ ? null : new int[1];
return decode_tuple(codecs.PyUnicode_DecodeUTF8Stateful(str, errors, consumed), final_
diff --git a/src/org/python/modules/_imp.java b/src/org/python/modules/_imp.java
--- a/src/org/python/modules/_imp.java
+++ b/src/org/python/modules/_imp.java
@@ -299,7 +299,7 @@
*
*/
public static void acquire_lock() {
- org.python.core.imp.importLock.lock();
+ Py.getSystemState().getImportLock().lock();
}
/**
@@ -308,7 +308,7 @@
*/
public static void release_lock() {
try{
- org.python.core.imp.importLock.unlock();
+ Py.getSystemState().getImportLock().unlock();
}catch(IllegalMonitorStateException e){
throw Py.RuntimeError("not holding the import lock");
}
@@ -320,6 +320,6 @@
* @return true if the import lock is currently held, else false.
*/
public static boolean lock_held() {
- return org.python.core.imp.importLock.isHeldByCurrentThread();
+ return Py.getSystemState().getImportLock().isHeldByCurrentThread();
}
}
diff --git a/src/org/python/modules/_json/Encoder.java b/src/org/python/modules/_json/Encoder.java
new file mode 100644
--- /dev/null
+++ b/src/org/python/modules/_json/Encoder.java
@@ -0,0 +1,204 @@
+package org.python.modules._json;
+
+import org.python.core.ArgParser;
+import org.python.core.Py;
+import org.python.core.PyDictionary;
+import org.python.core.PyException;
+import org.python.core.PyFloat;
+import org.python.core.PyInteger;
+import org.python.core.PyLong;
+import org.python.core.PyList;
+import org.python.core.PyObject;
+import org.python.core.PyString;
+import org.python.core.PyTuple;
+import org.python.core.PyType;
+import org.python.core.PyUnicode;
+import org.python.expose.ExposedGet;
+import org.python.expose.ExposedType;
+
+ at ExposedType(name = "_json.encoder", base = PyObject.class)
+public class Encoder extends PyObject {
+
+ public static final PyType TYPE = PyType.fromClass(Encoder.class);
+
+ @ExposedGet
+ public final String __module__ = "_json";
+
+ final PyDictionary markers;
+ final PyObject defaultfn;
+ final PyObject encoder;
+ final PyObject indent;
+ final PyObject key_separator;
+ final PyObject item_separator;
+ final PyObject sort_keys;
+ final boolean skipkeys;
+ final boolean allow_nan;
+
+ public Encoder(PyObject[] args, String[] kwds) {
+ super();
+ ArgParser ap = new ArgParser("encoder", args, kwds,
+ new String[]{"markers", "default", "encoder", "indent",
+ "key_separator", "item_separator", "sort_keys", "skipkeys", "allow_nan"});
+ ap.noKeywords();
+ PyObject m = ap.getPyObject(0);
+ markers = m == Py.None ? null : (PyDictionary) m;
+ defaultfn = ap.getPyObject(1);
+ encoder = ap.getPyObject(2);
+ indent = ap.getPyObject(3);
+ key_separator = ap.getPyObject(4);
+ item_separator = ap.getPyObject(5);
+ sort_keys = ap.getPyObject(6);
+ skipkeys = ap.getPyObject(7).__nonzero__();
+ allow_nan = ap.getPyObject(8).__nonzero__();
+ }
+
+ public PyObject __call__(PyObject obj) {
+ return __call__(obj, Py.Zero);
+ }
+
+ public PyObject __call__(PyObject obj, PyObject indent_level) {
+ PyList rval = new PyList();
+ encode_obj(rval, obj, 0);
+ return rval;
+ }
+
+ private PyString encode_float(PyObject obj) {
+ /* Return the JSON representation of a PyFloat */
+ double i = obj.asDouble();
+ if (Double.isInfinite(i) || Double.isNaN(i)) {
+ if (!allow_nan) {
+ throw Py.ValueError("Out of range float values are not JSON compliant");
+ }
+ if (i == Double.POSITIVE_INFINITY) {
+ return new PyString("Infinity");
+ } else if (i == Double.NEGATIVE_INFINITY) {
+ return new PyString("-Infinity");
+ } else {
+ return new PyString("NaN");
+ }
+ }
+ /* Use a better float format here? */
+ return obj.__repr__();
+ }
+
+ private PyString encode_string(PyObject obj) {
+ /* Return the JSON representation of a string */
+ return (PyString) encoder.__call__(obj);
+ }
+
+ private PyObject checkCircularReference(PyObject obj) {
+ PyObject ident = null;
+ if (markers != null) {
+ ident = Py.newInteger(Py.id(obj));
+ if (markers.__contains__(ident)) {
+ throw Py.ValueError("Circular reference detected");
+ }
+ markers.__setitem__(ident, obj);
+ }
+ return ident;
+ }
+
+ private void encode_obj(PyList rval, PyObject obj, int indent_level) {
+ /* Encode Python object obj to a JSON term, rval is a PyList */
+ if (obj == Py.None) {
+ rval.append(new PyString("null"));
+ } else if (obj == Py.True) {
+ rval.append(new PyString("true"));
+ } else if (obj == Py.False) {
+ rval.append(new PyString("false"));
+ } else if (obj instanceof PyString) {
+ rval.append(encode_string(obj));
+ } else if (obj instanceof PyInteger || obj instanceof PyLong) {
+ rval.append(obj.__str__());
+ } else if (obj instanceof PyFloat) {
+ rval.append(encode_float(obj));
+ } else if (obj instanceof PyList || obj instanceof PyTuple) {
+ encode_list(rval, obj, indent_level);
+ } else if (obj instanceof PyDictionary) {
+ encode_dict(rval, (PyDictionary) obj, indent_level);
+ } else {
+ PyObject ident = checkCircularReference(obj);
+ if (defaultfn == Py.None) {
+ throw Py.TypeError(String.format(".80s is not JSON serializable", obj.__repr__()));
+ }
+
+ PyObject newobj = defaultfn.__call__(obj);
+ encode_obj(rval, newobj, indent_level);
+ if (ident != null) {
+ markers.__delitem__(ident);
+ }
+ }
+ }
+
+ private void encode_dict(PyList rval, PyDictionary dct, int indent_level) {
+ /* Encode Python dict dct a JSON term */
+ if (dct.__len__() == 0) {
+ rval.append(new PyString("{}"));
+ return;
+ }
+
+ PyObject ident = checkCircularReference(dct);
+ rval.append(new PyString("{"));
+
+ /* TODO: C speedup not implemented for sort_keys */
+
+ int idx = 0;
+ for (PyObject key : dct.asIterable()) {
+ PyString kstr;
+
+ if (key instanceof PyString || key instanceof PyUnicode) {
+ kstr = (PyString) key;
+ } else if (key instanceof PyFloat) {
+ kstr = encode_float(key);
+ } else if (key instanceof PyInteger || key instanceof PyLong) {
+ kstr = key.__str__();
+ } else if (key == Py.True) {
+ kstr = new PyString("true");
+ } else if (key == Py.False) {
+ kstr = new PyString("false");
+ } else if (key == Py.None) {
+ kstr = new PyString("null");
+ } else if (skipkeys) {
+ continue;
+ } else {
+ throw Py.TypeError(String.format("keys must be a string: %.80s", key.__repr__()));
+ }
+
+ if (idx > 0) {
+ rval.append(item_separator);
+ }
+
+ PyObject value = dct.__getitem__(key);
+ PyString encoded = encode_string(kstr);
+ rval.append(encoded);
+ rval.append(key_separator);
+ encode_obj(rval, value, indent_level);
+ idx += 1;
+ }
+
+ if (ident != null) {
+ markers.__delitem__(ident);
+ }
+ rval.append(new PyString("}"));
+ }
+
+
+ private void encode_list(PyList rval, PyObject seq, int indent_level) {
+ PyObject ident = checkCircularReference(seq);
+ rval.append(new PyString("["));
+
+ int i = 0;
+ for (PyObject obj : seq.asIterable()) {
+ if (i > 0) {
+ rval.append(item_separator);
+ }
+ encode_obj(rval, obj, indent_level);
+ i++;
+ }
+
+ if (ident != null) {
+ markers.__delitem__(ident);
+ }
+ rval.append(new PyString("]"));
+ }
+}
diff --git a/src/org/python/modules/_json/Scanner.java b/src/org/python/modules/_json/Scanner.java
new file mode 100644
--- /dev/null
+++ b/src/org/python/modules/_json/Scanner.java
@@ -0,0 +1,343 @@
+package org.python.modules._json;
+
+import org.python.core.Py;
+import org.python.core.PyDictionary;
+import org.python.core.PyList;
+import org.python.core.PyObject;
+import org.python.core.PyString;
+import org.python.core.PyTuple;
+import org.python.core.PyType;
+import org.python.core.PyUnicode;
+import org.python.core.codecs;
+import org.python.expose.ExposedGet;
+import org.python.expose.ExposedType;
+
+
+ at ExposedType(name = "_json.Scanner", base = PyObject.class)
+public class Scanner extends PyObject {
+
+ public static final PyType TYPE = PyType.fromClass(Scanner.class);
+
+ @ExposedGet
+ public final String __module__ = "_json";
+
+ final String encoding;
+ final boolean strict;
+ final PyObject object_hook;
+ final PyObject pairs_hook;
+ final PyObject parse_float;
+ final PyObject parse_int;
+ final PyObject parse_constant;
+
+ public Scanner(PyObject context) {
+ super();
+ encoding = _castString(context.__getattr__("encoding"), "utf-8");
+ strict = context.__getattr__("strict").__nonzero__();
+ object_hook = context.__getattr__("object_hook");
+ pairs_hook = context.__getattr__("object_pairs_hook");
+ parse_float = context.__getattr__("parse_float");
+ parse_int = context.__getattr__("parse_int");
+ parse_constant = context.__getattr__("parse_constant");
+ }
+
+ public PyObject __call__(PyObject string, PyObject idx) {
+ return _scan_once((PyString)string, idx.asInt());
+ }
+
+ private static boolean IS_WHITESPACE(int c) {
+ return (c == ' ') || (c == '\t') || (c == '\n') || (c == '\r');
+ }
+
+ private static String _castString(PyObject pystr, String defaultValue) {
+ // Jython used to treat String as equivalent to PyString, or maybe PyUnicode, as
+ // it made sense. We need to be more careful now! Insert this cast check as necessary
+ // to ensure the appropriate compliance.
+ if (pystr == Py.None) {
+ return defaultValue;
+ }
+ if (!(pystr instanceof PyString)) {
+ throw Py.TypeError("encoding is not a string");
+ }
+ String s = pystr.toString();
+ return codecs.PyUnicode_EncodeASCII(s, s.length(), null);
+ }
+
+ static PyTuple valIndex(PyObject obj, int i) {
+ return new PyTuple(obj, Py.newInteger(i));
+ }
+
+ public PyTuple _parse_object(PyString pystr, int idx) { // }, Py_ssize_t *next_idx_ptr) {
+ /* Read a JSON object from PyString pystr.
+ idx is the index of the first character after the opening curly brace.
+
+ Returns a new PyTuple of a PyObject (usually a dict, but object_hook can change that)
+ and the next_idx to the first character after
+ the closing curly brace.
+ */
+ PyString str = pystr;
+ int end_idx = pystr.__len__() - 1;
+ PyList pairs = new PyList();
+ PyObject item;
+ PyObject key;
+ PyObject val;
+
+ /* skip whitespace after { */
+ while (idx <= end_idx && IS_WHITESPACE(str.getInt(idx))) idx++;
+
+ /* only loop if the object is non-empty */
+ if (idx <= end_idx && str.getInt(idx) != '}') {
+ while (idx <= end_idx) {
+ /* read key */
+ if (str.getInt(idx) != '"') {
+ _json.raise_errmsg("Expecting property name", pystr, idx);
+ }
+ PyTuple key_idx = _json.scanstring(pystr, idx + 1, encoding, strict);
+ key = key_idx.pyget(0);
+ idx = key_idx.pyget(1).asInt();
+
+ /* skip whitespace between key and : delimiter, read :, skip whitespace */
+ while (idx <= end_idx && IS_WHITESPACE(str.getInt(idx))) idx++;
+ if (idx > end_idx || str.getInt(idx) != ':') {
+ _json.raise_errmsg("Expecting : delimiter", pystr, idx);
+ }
+ idx++;
+ while (idx <= end_idx && IS_WHITESPACE(str.getInt(idx))) idx++;
+
+ /* read any JSON data type */
+ PyTuple val_idx = _scan_once(pystr, idx);
+ val = val_idx.pyget(0);
+ idx = val_idx.pyget(1).asInt();
+ pairs.append(new PyTuple(key, val));
+
+ /* skip whitespace before } or , */
+ while (idx <= end_idx && IS_WHITESPACE(str.getInt(idx))) idx++;
+
+ /* bail if the object is closed or we didn't get the , delimiter */
+ if (idx > end_idx) break;
+ if (str.getInt(idx) == '}') {
+ break;
+ } else if (str.getInt(idx) != ',') {
+ _json.raise_errmsg("Expecting , delimiter", pystr, idx);
+ }
+ idx++;
+
+ /* skip whitespace after , delimiter */
+ while (idx <= end_idx && IS_WHITESPACE(str.getInt(idx))) idx++;
+ }
+ }
+ /* verify that idx < end_idx, str[idx] should be '}' */
+ if (idx > end_idx || str.getInt(idx) != '}') {
+ _json.raise_errmsg("Expecting object", pystr, end_idx);
+ }
+
+ /* if pairs_hook is not None: rval = object_pairs_hook(pairs) */
+ if (pairs_hook != Py.None) {
+ return valIndex(pairs_hook.__call__(pairs), idx + 1);
+ }
+
+ PyObject rval = new PyDictionary();
+ ((PyDictionary)rval).update(pairs);
+
+ /* if object_hook is not None: rval = object_hook(rval) */
+ if (object_hook != Py.None) {
+ rval = object_hook.__call__(rval);
+ }
+
+ return valIndex(rval, idx + 1);
+ }
+
+ public PyTuple _parse_array(PyString pystr, int idx) {
+ /* Read a JSON array from PyString pystr.
+
+
+ Returns a new PyTuple of a PyList and next_idx (first character after
+ the closing brace.)
+ */
+ PyString str = pystr;
+ int end_idx = pystr.__len__() - 1;
+ PyList rval = new PyList();
+ int next_idx;
+
+ /* skip whitespace after [ */
+ while (idx <= end_idx && IS_WHITESPACE(str.getInt(idx))) idx++;
+
+ /* only loop if the array is non-empty */
+ if (idx <= end_idx && str.getInt(idx) != ']') {
+ while (idx <= end_idx) {
+
+ /* read any JSON term and de-tuplefy the (rval, idx) */
+ PyTuple val_idx = _scan_once(pystr, idx);
+ PyObject val = val_idx.pyget(0);
+ idx = val_idx.pyget(1).asInt();
+ rval.append(val);
+
+ /* skip whitespace between term and , */
+ while (idx <= end_idx && IS_WHITESPACE(str.getInt(idx))) idx++;
+
+ /* bail if the array is closed or we didn't get the , delimiter */
+ if (idx > end_idx) break;
+ if (str.getInt(idx) == ']') {
+ break;
+ } else if (str.getInt(idx) != ',') {
+ _json.raise_errmsg("Expecting , delimiter", pystr, idx);
+ }
+ idx++;
+
+ /* skip whitespace after , */
+ while (idx <= end_idx && IS_WHITESPACE(str.getInt(idx))) idx++;
+ }
+ }
+
+ /* verify that idx < end_idx, str[idx] should be ']' */
+ if (idx > end_idx || str.getInt(idx) != ']') {
+ _json.raise_errmsg("Expecting object", pystr, end_idx);
+ }
+ return valIndex(rval, idx + 1);
+ }
+
+
+ public PyTuple _scan_once(PyString pystr, int idx) {
+ /* Read one JSON term (of any kind) from PyString pystr.
+ idx is the index of the first character of the term
+
+ Returns a new PyTuple of a PyObject representation of the term along
+ with the next_idx
+ */
+ PyString str = pystr;
+ int length = pystr.__len__();
+ if (idx >= length) {
+ throw Py.StopIteration("");
+ }
+ switch (str.getInt(idx)) {
+ case '"':
+ /* string */
+ return _json.scanstring(pystr, idx + 1, encoding, strict);
+ case '{':
+ /* object */
+ return _parse_object(pystr, idx + 1);
+ case '[':
+ /* array */
+ return _parse_array(pystr, idx + 1);
+ case 'n':
+ /* null */
+ if ((idx + 3 < length) && str.getInt(idx + 1) == 'u' && str.getInt(idx + 2) == 'l' && str.getInt(idx + 3) == 'l') {
+ return valIndex(Py.None, idx + 4);
+ }
+ break;
+ case 't':
+ /* true */
+ if ((idx + 3 < length) && str.getInt(idx + 1) == 'r' && str.getInt(idx + 2) == 'u' && str.getInt(idx + 3) == 'e') {
+ return valIndex(Py.True, idx + 4);
+ }
+ break;
+ case 'f':
+ /* false */
+ if ((idx + 4 < length) && str.getInt(idx + 1) == 'a' && str.getInt(idx + 2) == 'l' && str.getInt(idx + 3) == 's' && str.getInt(idx + 4) == 'e') {
+ return valIndex(Py.False, idx + 5);
+ }
+ break;
+ case 'N':
+ /* NaN */
+ if ((idx + 2 < length) && str.getInt(idx + 1) == 'a' && str.getInt(idx + 2) == 'N') {
+ return _parse_constant("NaN", idx + 3);
+ }
+ break;
+ case 'I':
+ /* Infinity */
+ if ((idx + 7 < length) && str.getInt(idx + 1) == 'n' && str.getInt(idx + 2) == 'f' && str.getInt(idx + 3) == 'i' && str.getInt(idx + 4) == 'n' && str.getInt(idx + 5) == 'i' && str.getInt(idx + 6) == 't' && str.getInt(idx + 7) == 'y') {
+ return _parse_constant("Infinity", idx + 8);
+ }
+ break;
+ case '-':
+ /* -Infinity */
+ if ((idx + 8 < length) && str.getInt(idx + 1) == 'I' && str.getInt(idx + 2) == 'n' && str.getInt(idx + 3) == 'f' && str.getInt(idx + 4) == 'i' && str.getInt(idx + 5) == 'n' && str.getInt(idx + 6) == 'i' && str.getInt(idx + 7) == 't' && str.getInt(idx + 8) == 'y') {
+ return _parse_constant("-Infinity", idx + 9);
+ }
+ break;
+ }
+ /* Didn't find a string, object, array, or named constant. Look for a number. */
+ return _match_number(pystr, idx);
+ }
+
+ public PyTuple _parse_constant(String constant, int idx) {
+ return valIndex(parse_constant.__call__(Py.newString(constant)), idx);
+ }
+
+ public PyTuple _match_number(PyString pystr, int start) {
+ /* Read a JSON number from PyString pystr.
+ idx is the index of the first character of the number
+
+ Returns a new PyObject representation of that number:
+ PyInt, PyLong, or PyFloat.
+ May return other types if parse_int or parse_float are set
+ along with index to the first character after
+ the number.
+ */
+ PyString str = pystr;
+ int end_idx = pystr.__len__() - 1;
+ int idx = start;
+ boolean is_float = false;
+
+ /* read a sign if it's there, make sure it's not the end of the string */
+ if (str.getInt(idx) == '-') {
+ idx++;
+ if (idx > end_idx) {
+ throw Py.StopIteration("");
+ }
+ }
+
+ /* read as many integer digits as we find as long as it doesn't start with 0 */
+ if (str.getInt(idx) >= '1' && str.getInt(idx) <= '9') {
+ idx++;
+ while (idx <= end_idx && str.getInt(idx) >= '0' && str.getInt(idx) <= '9') idx++;
+ }
+ /* if it starts with 0 we only expect one integer digit */
+ else if (str.getInt(idx) == '0') {
+ idx++;
+ }
+ /* no integer digits, error */
+ else {
+ throw Py.StopIteration("");
+ }
+
+ /* if the next char is '.' followed by a digit then read all float digits */
+ if (idx < end_idx && str.getInt(idx) == '.' && str.getInt(idx + 1) >= '0' && str.getInt(idx + 1) <= '9') {
+ is_float = true;
+ idx += 2;
+ while (idx <= end_idx && str.getInt(idx) >= '0' && str.getInt(idx) <= '9') idx++;
+ }
+
+ /* if the next char is 'e' or 'E' then maybe read the exponent (or backtrack) */
+ if (idx < end_idx && (str.getInt(idx) == 'e' || str.getInt(idx) == 'E')) {
+
+ /* save the index of the 'e' or 'E' just in case we need to backtrack */
+ int e_start = idx;
+ idx++;
+
+ /* read an exponent sign if present */
+ if (idx < end_idx && (str.getInt(idx) == '-' || str.getInt(idx) == '+')) idx++;
+
+ /* read all digits */
+ while (idx <= end_idx && str.getInt(idx) >= '0' && str.getInt(idx) <= '9') idx++;
+
+ /* if we got a digit, then parse as float. if not, backtrack */
+ if (str.getInt(idx - 1) >= '0' && str.getInt(idx - 1) <= '9') {
+ is_float = true;
+ } else {
+ idx = e_start;
+ }
+ }
+
+ /* copy the section we determined to be a number */
+ PyString numstr = (PyString) str.__getslice__(Py.newInteger(start), Py.newInteger(idx));
+ if (is_float) {
+ /* parse as a float using a fast path if available, otherwise call user defined method */
+ return valIndex(parse_float.__call__(numstr), idx);
+ } else {
+ /* parse as an int using a fast path if available, otherwise call user defined method */
+ return valIndex(parse_int.__call__(numstr), idx);
+ }
+ }
+
+
+}
diff --git a/src/org/python/modules/_json/_json.java b/src/org/python/modules/_json/_json.java
new file mode 100644
--- /dev/null
+++ b/src/org/python/modules/_json/_json.java
@@ -0,0 +1,422 @@
+/* Copyright (c) Jython Developers */
+package org.python.modules._json;
+
+import org.python.core.ArgParser;
+import org.python.core.ClassDictInit;
+import org.python.core.Py;
+import org.python.core.PyBuiltinFunctionNarrow;
+import org.python.core.PyList;
+import org.python.core.PyObject;
+import org.python.core.PyString;
+import org.python.core.PyTuple;
+import org.python.core.PyUnicode;
+import org.python.core.codecs;
+import org.python.expose.ExposedGet;
+
+import java.util.Iterator;
+
+/**
+ * This module is a nearly exact line by line port of _json.c to Java. Names and comments are retained
+ * to make it easy to follow, but classes and methods are modified to following Java calling conventions.
+ *
+ * (Retained comments use the standard commenting convention for C.)
+ */
+public class _json implements ClassDictInit {
+
+ public static final PyString __doc__ = new PyString("Port of _json C module.");
+
+ public static void classDictInit(PyObject dict) {
+ dict.__setitem__("__name__", new PyString("_json"));
+ dict.__setitem__("__doc__", __doc__);
+ dict.__setitem__("encode_basestring_ascii", new EncodeBasestringAsciiFunction());
+ dict.__setitem__("make_encoder", Encoder.TYPE);
+ dict.__setitem__("make_scanner", Scanner.TYPE);
+ dict.__setitem__("scanstring", new ScanstringFunction());
+ dict.__setitem__("__module__", new PyString("_json"));
+
+ // ensure __module__ is set properly in these modules,
+ // based on how the module name lookups are chained
+ Encoder.TYPE.setName("_json.Encoder");
+ Scanner.TYPE.setName("_json.Scanner");
+
+ // Hide from Python
+ dict.__setitem__("classDictInit", null);
+ }
+
+ private static PyObject errmsg_fn;
+
+ private static synchronized PyObject get_errmsg_fn() {
+ if (errmsg_fn == null) {
+ PyObject json = org.python.core.__builtin__.__import__("json");
+ if (json != null) {
+ PyObject decoder = json.__findattr__("decoder");
+ if (decoder != null) {
+ errmsg_fn = decoder.__findattr__("errmsg");
+ }
+ }
+ }
+ return errmsg_fn;
+ }
+
+ static void raise_errmsg(String msg, PyObject s) {
+ raise_errmsg(msg, s, Py.None, Py.None);
+ }
+
+ static void raise_errmsg(String msg, PyObject s, int pos) {
+ raise_errmsg(msg, s, Py.newInteger(pos), Py.None);
+ }
+
+ static void raise_errmsg(String msg, PyObject s, PyObject pos, PyObject end) {
+ /* Use the Python function json.decoder.errmsg to raise a nice
+ looking ValueError exception */
+ final PyObject errmsg_fn = get_errmsg_fn();
+ if (errmsg_fn != null) {
+ throw Py.ValueError(errmsg_fn.__call__(Py.newString(msg), s, pos, end).asString());
+ } else {
+ throw Py.ValueError(msg);
+ }
+ }
+
+ static class ScanstringFunction extends PyBuiltinFunctionNarrow {
+ ScanstringFunction() {
+ super("scanstring", 2, 4, "scanstring");
+ }
+
+ @Override
+ @ExposedGet(name = "__module__")
+ public PyObject getModule() {
+ return new PyString("_json");
+ }
+
+
+ @Override
+ public PyObject __call__(PyObject s, PyObject end) {
+ return __call__(s, end, new PyString("utf-8"), Py.True);
+ }
+
+ @Override
+ public PyObject __call__(PyObject s, PyObject end, PyObject encoding) {
+ return __call__(s, end, encoding, Py.True);
+ }
+
+ @Override
+ public PyObject __call__(PyObject[] args, String[] kwds) {
+ ArgParser ap = new ArgParser("scanstring", args, kwds, new String[]{
+ "s", "end", "encoding", "strict"}, 2);
+ return __call__(
+ ap.getPyObject(0),
+ ap.getPyObject(1),
+ ap.getPyObject(2, new PyString("utf-8")),
+ ap.getPyObject(3, Py.True));
+ }
+
+ @Override
+ public PyObject __call__(PyObject s, PyObject end, PyObject encoding, PyObject strict) {
+ // but rethrow in case it does work - see the test case for issue 362
+ int end_idx = end.asIndex(Py.OverflowError);
+ boolean is_strict = strict.__nonzero__();
+ if (s instanceof PyString) {
+ return scanstring((PyString) s, end_idx,
+ encoding == Py.None ? null : encoding.toString(), is_strict);
+ } else {
+ throw Py.TypeError(String.format(
+ "first argument must be a string, not %.80s",
+ s.getType().fastGetName()));
+ }
+ }
+
+ }
+
+ static PyTuple scanstring(PyString pystr, int end, String encoding, boolean strict) {
+ int len = pystr.__len__();
+ int begin = end - 1;
+ if (end < 0 || len <= end) {
+ throw Py.ValueError("end is out of bounds");
+ }
+ int next;
+ final PyList chunks = new PyList();
+ while (true) {
+ /* Find the end of the string or the next escape */
+ int c = 0;
+
+ for (next = end; next < len; next++) {
+ c = pystr.getInt(next);
+ if (c == '"' || c == '\\') {
+ break;
+ } else if (strict && c <= 0x1f) {
+ raise_errmsg("Invalid control character at", pystr, next);
+ }
+ }
+ if (!(c == '"' || c == '\\')) {
+ raise_errmsg("Unterminated string starting at", pystr, begin);
+ }
+
+ /* Pick up this chunk if it's not zero length */
+ if (next != end) {
+ PyString strchunk = (PyString) pystr.__getslice__(Py.newInteger(end), Py.newInteger(next));
+ if (strchunk instanceof PyUnicode) {
+ chunks.append(strchunk);
+ } else {
+ chunks.append(codecs.decode(strchunk, encoding, null));
+ }
+ }
+ next++;
+ if (c == '"') {
+ end = next;
+ break;
+ }
+ if (next == len) {
+ raise_errmsg("Unterminated string starting at", pystr, begin);
+ }
+ c = pystr.getInt(next);
+ if (c != 'u') {
+ /* Non-unicode backslash escapes */
+ end = next + 1;
+ switch (c) {
+ case '"':
+ break;
+ case '\\':
+ break;
+ case '/':
+ break;
+ case 'b':
+ c = '\b';
+ break;
+ case 'f':
+ c = '\f';
+ break;
+ case 'n':
+ c = '\n';
+ break;
+ case 'r':
+ c = '\r';
+ break;
+ case 't':
+ c = '\t';
+ break;
+ default:
+ c = 0;
+ }
+ if (c == 0) {
+ raise_errmsg("Invalid \\escape", pystr, end - 2);
+ }
+ } else {
+ c = 0;
+ next++;
+ end = next + 4;
+ if (end >= len) {
+ raise_errmsg("Invalid \\uXXXX escape", pystr, next - 1);
+ }
+ /* Decode 4 hex digits */
+ for (; next < end; next++) {
+ int digit = pystr.getInt(next);
+ c <<= 4;
+ switch (digit) {
+ case '0':
+ case '1':
+ case '2':
+ case '3':
+ case '4':
+ case '5':
+ case '6':
+ case '7':
+ case '8':
+ case '9':
+ c |= (digit - '0');
+ break;
+ case 'a':
+ case 'b':
+ case 'c':
+ case 'd':
+ case 'e':
+ case 'f':
+ c |= (digit - 'a' + 10);
+ break;
+ case 'A':
+ case 'B':
+ case 'C':
+ case 'D':
+ case 'E':
+ case 'F':
+ c |= (digit - 'A' + 10);
+ break;
+ default:
+ raise_errmsg("Invalid \\uXXXX escape", pystr, end - 5);
+ }
+ }
+ /* Surrogate pair */
+ if ((c & 0xfc00) == 0xd800) {
+ int c2 = 0;
+ if (end + 6 >= len) {
+ raise_errmsg("Unpaired high surrogate", pystr, end - 5);
+ }
+ if (pystr.getInt(next++) != '\\' || pystr.getInt(next++) != 'u') {
+ raise_errmsg("Unpaired high surrogate", pystr, end - 5);
+ }
+ end += 6;
+ /* Decode 4 hex digits */
+ for (; next < end; next++) {
+ int digit = pystr.getInt(next);
+ c2 <<= 4;
+ switch (digit) {
+ case '0':
+ case '1':
+ case '2':
+ case '3':
+ case '4':
+ case '5':
+ case '6':
+ case '7':
+ case '8':
+ case '9':
+ c2 |= (digit - '0');
+ break;
+ case 'a':
+ case 'b':
+ case 'c':
+ case 'd':
+ case 'e':
+ case 'f':
+ c2 |= (digit - 'a' + 10);
+ break;
+ case 'A':
+ case 'B':
+ case 'C':
+ case 'D':
+ case 'E':
+ case 'F':
+ c2 |= (digit - 'A' + 10);
+ break;
+ default:
+ raise_errmsg("Invalid \\uXXXX escape", pystr, end - 5);
+ }
+ }
+ if ((c2 & 0xfc00) != 0xdc00) {
+ raise_errmsg("Unpaired high surrogate", pystr, end - 5);
+ }
+ c = 0x10000 + (((c - 0xd800) << 10) | (c2 - 0xdc00));
+ } else if ((c & 0xfc00) == 0xdc00) {
+ raise_errmsg("Unpaired low surrogate", pystr, end - 5);
+ }
+ }
+ chunks.append(new PyUnicode(c));
+ }
+
+ return new PyTuple(Py.EmptyUnicode.join(chunks), Py.newInteger(end));
+ }
+
+ static class EncodeBasestringAsciiFunction extends PyBuiltinFunctionNarrow {
+ EncodeBasestringAsciiFunction() {
+ super("encode_basestring_ascii", 1, 1, "encode_basestring_ascii");
+ }
+
+ @Override
+ @ExposedGet(name = "__module__")
+ public PyObject getModule() {
+ return new PyString("_json");
+ }
+
+ @Override
+ public PyObject __call__(PyObject pystr) {
+ return encode_basestring_ascii(pystr);
+ }
+ }
+
+ static PyString encode_basestring_ascii(PyObject pystr) {
+ if (pystr instanceof PyUnicode) {
+ return ascii_escape((PyUnicode) pystr);
+ } else if (pystr instanceof PyString) {
+ return ascii_escape((PyString) pystr);
+ } else {
+ throw Py.TypeError(String.format(
+ "first argument must be a string, not %.80s",
+ pystr.getType().fastGetName()));
+ }
+ }
+
+ private static PyString ascii_escape(PyUnicode pystr) {
+ StringBuilder rval = new StringBuilder(pystr.__len__());
+ rval.append("\"");
+ for (Iterator<Integer> iter = pystr.newSubsequenceIterator(); iter.hasNext(); ) {
+ _write_char(rval, iter.next());
+ }
+ rval.append("\"");
+ return new PyString(rval.toString());
+ }
+
+ private static PyString ascii_escape(PyString pystr) {
+ int len = pystr.__len__();
+ String s = pystr.getString();
+ StringBuilder rval = new StringBuilder(len);
+ rval.append("\"");
+ for (int i = 0; i < len; i++) {
+ int c = s.charAt(i);
+ if (c > 127) {
+ return ascii_escape(new PyUnicode(codecs.PyUnicode_DecodeUTF8(s, null)));
+ }
+ _write_char(rval, c);
+ }
+ rval.append("\"");
+ return new PyString(rval.toString());
+ }
+
+ private static void _write_char(StringBuilder builder, int c) {
+ /* Escape unicode code point c to ASCII escape sequences
+ in char *output. output must have at least 12 bytes unused to
+ accommodate an escaped surrogate pair "\ u XXXX \ u XXXX" */
+ if (c >= ' ' && c <= '~' && c != '\\' & c != '"') {
+ builder.append((char) c);
+ } else {
+ _ascii_escape_char(builder, c);
+ }
+ }
+
+ private static void _write_hexchar(StringBuilder builder, int c) {
+ builder.append("0123456789abcdef".charAt(c & 0xf));
+ }
+
+ private static void _ascii_escape_char(StringBuilder builder, int c) {
+ builder.append('\\');
+ switch (c) {
+ case '\\':
+ builder.append((char) c);
+ break;
+ case '"':
+ builder.append((char) c);
+ break;
+ case '\b':
+ builder.append('b');
+ break;
+ case '\f':
+ builder.append('f');
+ break;
+ case '\n':
+ builder.append('n');
+ break;
+ case '\r':
+ builder.append('r');
+ break;
+ case '\t':
+ builder.append('t');
+ break;
+ default:
+ if (c >= 0x10000) {
+ /* UTF-16 surrogate pair */
+ int v = c - 0x10000;
+ c = 0xd800 | ((v >> 10) & 0x3ff);
+ builder.append('u');
+ _write_hexchar(builder, c >> 12);
+ _write_hexchar(builder, c >> 8);
+ _write_hexchar(builder, c >> 4);
+ _write_hexchar(builder, c);
+ c = 0xdc00 | (v & 0x3ff);
+ builder.append('\\');
+ }
+ builder.append('u');
+ _write_hexchar(builder, c >> 12);
+ _write_hexchar(builder, c >> 8);
+ _write_hexchar(builder, c >> 4);
+ _write_hexchar(builder, c);
+ }
+ }
+}
diff --git a/src/org/python/modules/bz2/PyBZ2Decompressor.java b/src/org/python/modules/bz2/PyBZ2Decompressor.java
--- a/src/org/python/modules/bz2/PyBZ2Decompressor.java
+++ b/src/org/python/modules/bz2/PyBZ2Decompressor.java
@@ -8,6 +8,7 @@
import org.apache.commons.compress.compressors.bzip2.BZip2CompressorInputStream;
import org.python.core.ArgParser;
import org.python.core.Py;
+import org.python.core.PyByteArray;
import org.python.core.PyObject;
import org.python.core.PyString;
import org.python.core.PyType;
@@ -89,19 +90,17 @@
return Py.EmptyString;
}
- ByteArrayOutputStream databuf = new ByteArrayOutputStream();
+ PyByteArray databuf = new PyByteArray();
int currentByte = -1;
try {
while ((currentByte = decompressStream.read()) != -1) {
- databuf.write(currentByte);
+ databuf.append((byte)currentByte);
}
- returnData = new PyString(new String(databuf.toByteArray()));
+ returnData = databuf.__str__();
if (compressedData.available() > 0) {
byte[] unusedbuf = new byte[compressedData.available()];
compressedData.read(unusedbuf);
-
- unused_data = (PyString) unused_data.__add__(new PyString(
- new String(unusedbuf)));
+ unused_data = (PyString)unused_data.__add__((new PyByteArray(unusedbuf)).__str__());
}
eofReached = true;
} catch (IOException e) {
diff --git a/src/org/python/modules/itertools/PyTeeIterator.java b/src/org/python/modules/itertools/PyTeeIterator.java
--- a/src/org/python/modules/itertools/PyTeeIterator.java
+++ b/src/org/python/modules/itertools/PyTeeIterator.java
@@ -94,7 +94,7 @@
throw Py.ValueError("n must be >= 0");
}
- PyObject[] tees = new PyTeeIterator[n];
+ PyObject[] tees = new PyObject[n];
if (n == 0) {
return tees;
diff --git a/src/org/python/modules/itertools/count.java b/src/org/python/modules/itertools/count.java
--- a/src/org/python/modules/itertools/count.java
+++ b/src/org/python/modules/itertools/count.java
@@ -3,12 +3,14 @@
import org.python.core.ArgParser;
import org.python.core.Py;
+import org.python.core.PyException;
import org.python.core.PyInteger;
import org.python.core.PyIterator;
import org.python.core.PyObject;
import org.python.core.PyString;
import org.python.core.PyTuple;
import org.python.core.PyType;
+import org.python.core.__builtin__;
import org.python.expose.ExposedNew;
import org.python.expose.ExposedMethod;
import org.python.expose.ExposedType;
@@ -18,8 +20,16 @@
public static final PyType TYPE = PyType.fromClass(count.class);
private PyIterator iter;
- private int counter;
- private int stepper;
+ private PyObject counter;
+ private PyObject stepper;
+
+ private static PyObject NumberClass;
+ private static synchronized PyObject getNumberClass() {
+ if (NumberClass == null) {
+ NumberClass = __builtin__.__import__("numbers").__getattr__("Number");
+ }
+ return NumberClass;
+ }
public static final String count_doc =
"count(start=0, step=1) --> count object\n\n" +
@@ -37,62 +47,104 @@
}
/**
- * Creates an iterator that returns consecutive integers starting at 0.
+ * Creates an iterator that returns consecutive numbers starting at 0.
*/
public count() {
super();
- count___init__(0, 1);
+ count___init__(Py.Zero, Py.One);
}
/**
- * Creates an iterator that returns consecutive integers starting at <code>start</code>.
+ * Creates an iterator that returns consecutive numbers starting at <code>start</code>.
*/
- public count(final int start) {
+ public count(final PyObject start) {
super();
- count___init__(start, 1);
+ count___init__(start, Py.One);
}
/**
- * Creates an iterator that returns consecutive integers starting at <code>start</code> with <code>step</code> step.
+ * Creates an iterator that returns consecutive numbers starting at <code>start</code> with <code>step</code> step.
*/
- public count(final int start, final int step) {
+ public count(final PyObject start, final PyObject step) {
super();
count___init__(start, step);
}
+ // TODO: move into Py, although NumberClass import time resolution becomes
+ // TODO: a bit trickier
+ private static PyObject getNumber(PyObject obj) {
+ if (Py.isInstance(obj, getNumberClass())) {
+ return obj;
+ }
+ try {
+ PyObject intObj = obj.__int__();
+ if (Py.isInstance(obj, getNumberClass())) {
+ return intObj;
+ }
+ throw Py.TypeError("a number is required");
+ } catch (PyException exc) {
+ if (exc.match(Py.ValueError)) {
+ throw Py.TypeError("a number is required");
+ }
+ throw exc;
+ }
+ }
+
@ExposedNew
@ExposedMethod
final void count___init__(final PyObject[] args, String[] kwds) {
ArgParser ap = new ArgParser("count", args, kwds, new String[] {"start", "step"}, 0);
-
- int start = ap.getInt(0, 0);
- int step = ap.getInt(1, 1);
+ PyObject start = getNumber(ap.getPyObject(0, Py.Zero));
+ PyObject step = getNumber(ap.getPyObject(1, Py.One));
count___init__(start, step);
}
- private void count___init__(final int start, final int step) {
+ private void count___init__(final PyObject start, final PyObject step) {
counter = start;
stepper = step;
iter = new PyIterator() {
public PyObject __iternext__() {
- int result = counter;
- counter += stepper;
- return new PyInteger(result);
+ PyObject result = counter;
+ counter = counter._add(stepper);
+ return result;
}
};
}
@ExposedMethod
+ public PyObject count___copy__() {
+ return new count(counter, stepper);
+ }
+
+ @ExposedMethod
+ final PyObject count___reduce_ex__(PyObject protocol) {
+ return __reduce_ex__(protocol);
+ }
+
+ @ExposedMethod
+ final PyObject count___reduce__() {
+ return __reduce_ex__(Py.Zero);
+ }
+
+
+ public PyObject __reduce_ex__(PyObject protocol) {
+ if (stepper == Py.One) {
+ return new PyTuple(getType(), new PyTuple(counter));
+ } else {
+ return new PyTuple(getType(), new PyTuple(counter, stepper));
+ }
+ }
+
+ @ExposedMethod
public PyString __repr__() {
- if (stepper == 1) {
- return (PyString)(Py.newString("count(%d)").__mod__(Py.newInteger(counter)));
+ if (stepper instanceof PyInteger && stepper._cmp(Py.One) == 0) {
+ return Py.newString(String.format("count(%s)", counter));
}
else {
- return (PyString)(Py.newString("count(%d, %d)").__mod__(new PyTuple(
- Py.newInteger(counter), Py.newInteger(stepper))));
+ return Py.newString(String.format("count(%s, %s)", counter, stepper));
}
}
diff --git a/src/org/python/modules/itertools/repeat.java b/src/org/python/modules/itertools/repeat.java
--- a/src/org/python/modules/itertools/repeat.java
+++ b/src/org/python/modules/itertools/repeat.java
@@ -97,6 +97,11 @@
}
@ExposedMethod
+ final PyObject __copy__() {
+ return new repeat(object, counter);
+ }
+
+ @ExposedMethod
public int __len__() {
if (counter < 0) {
throw Py.TypeError("object of type 'itertools.repeat' has no len()");
diff --git a/src/org/python/modules/time/Time.java b/src/org/python/modules/time/Time.java
--- a/src/org/python/modules/time/Time.java
+++ b/src/org/python/modules/time/Time.java
@@ -31,6 +31,7 @@
import org.python.core.PyException;
import org.python.core.PyInteger;
import org.python.core.PyObject;
+import org.python.core.PySequence;
import org.python.core.PyString;
import org.python.core.PyTuple;
import org.python.core.__builtin__;
@@ -412,7 +413,19 @@
return asctime(localtime());
}
- public static PyString asctime(PyTuple tup) {
+ public static PyString asctime(PyObject obj) {
+ PyTuple tup;
+ if (obj instanceof PyTuple) {
+ tup = (PyTuple)obj;
+ } else {
+ tup = PyTuple.fromIterable(obj);
+ }
+ int len = tup.__len__();
+ if (len != 9) {
+ throw Py.TypeError(
+ String.format("argument must be sequence of length 9, not %d", len));
+ }
+
StringBuilder buf = new StringBuilder(25);
buf.append(enshortdays[item(tup, 6)]).append(' ');
buf.append(enshortmonths[item(tup, 1)]).append(' ');
--
Repository URL: https://hg.python.org/jython
More information about the Jython-checkins
mailing list