[Jython-checkins] jython (merge default -> default): Merged trunk
jim.baker
jython-checkins at python.org
Wed Mar 14 22:04:40 CET 2012
http://hg.python.org/jython/rev/2538e8ea568f
changeset: 6374:2538e8ea568f
parent: 6373:f777b02ce694
parent: 6372:940c47216bdb
user: Jim Baker <jbaker at zyasoft.com>
date: Wed Mar 14 14:04:05 2012 -0700
summary:
Merged trunk
files:
.hgignore | 3 +
.idea/misc.xml | 150 -
.idea/vcs.xml | 2 +-
.project | 16 +
Lib/decimal.py | 2032 ++++++--
Lib/distutils/ccompiler.py | 473 +-
Lib/distutils/command/bdist.py | 10 +-
Lib/distutils/command/bdist_dumb.py | 43 +-
Lib/distutils/command/install_scripts.py | 4 +-
Lib/distutils/file_util.py | 158 +-
Lib/distutils/spawn.py | 92 +-
Lib/distutils/sysconfig.py | 157 +-
Lib/distutils/tests/test_build_py.py | 19 +-
Lib/distutils/util.py | 89 +-
Lib/filecmp.py | 7 +-
Lib/fileinput.py | 2 +-
Lib/gettext.py | 2 +-
Lib/mailbox.py | 156 +-
Lib/netrc.py | 4 +
Lib/new.py | 4 +
Lib/py_compile.py | 8 +-
Lib/robotparser.py | 7 +-
Lib/test/list_tests.py | 36 +-
Lib/test/pickletester.py | 361 +-
Lib/test/test_array.py | 183 +-
Lib/test/test_code.py | 47 +-
Lib/test/test_codeccallbacks.py | 94 +-
Lib/test/test_compile.py | 70 +-
Lib/test/test_copy.py | 5 +-
Lib/test/test_descrtut.py | 7 +-
Lib/test/test_dumbdbm.py | 24 +
Lib/test/test_genexps.py | 2 +-
Lib/test/test_hashlib.py | 180 +-
Lib/test/test_hmac.py | 51 +-
Lib/test/test_iter.py | 24 +-
Lib/test/test_logging.py | 2150 +++++++--
Lib/test/test_new.py | 2 +-
Lib/test/test_operator.py | 63 +-
Lib/test/test_os.py | 554 ++-
Lib/test/test_pkgimport.py | 8 +-
Lib/test/test_pprint.py | 285 +-
Lib/test/test_profilehooks.py | 43 +-
Lib/test/test_random.py | 96 +-
Lib/test/test_repr.py | 67 +-
Lib/test/test_robotparser.py | 51 +-
Lib/test/test_shutil.py | 702 +++-
Lib/test/test_str.py | 382 -
Lib/test/test_tempfile.py | 190 +-
Lib/test/test_time.py | 90 +-
Lib/test/test_trace.py | 47 +-
Lib/test/test_univnewlines.py | 50 +-
Lib/test/test_urllib2.py | 396 +-
Lib/test/test_weakref.py | 145 +-
Lib/test/test_xml_etree.py | 5 +-
Lib/test/test_xml_etree_c.py | 5 +-
Lib/test/test_zlib.py | 161 +-
Lib/timeit.py | 67 +-
Lib/types.py | 30 +-
Lib/warnings.py | 40 +-
Lib/weakref.py | 30 +-
Lib/zipfile.py | 907 +++-
src/org/python/core/ArgParser.java | 36 +-
src/org/python/core/PyString.java | 36 +-
src/org/python/modules/struct.java | 20 +
64 files changed, 8034 insertions(+), 3146 deletions(-)
diff --git a/.hgignore b/.hgignore
--- a/.hgignore
+++ b/.hgignore
@@ -6,11 +6,14 @@
*.orig
*.rej
*.swp
+\#*
*~
# IntelliJ files
*.ipr
*.iml
*.iws
+.idea/misc.xml
+.idea/workspace.xml
.AppleDouble
.DS_Store
diff --git a/.idea/misc.xml b/.idea/misc.xml
deleted file mode 100644
--- a/.idea/misc.xml
+++ /dev/null
@@ -1,150 +0,0 @@
-<?xml version="1.0" encoding="UTF-8"?>
-<project version="4">
- <component name="DaemonCodeAnalyzer">
- <disable_hints />
- </component>
- <component name="EntryPointsManager">
- <entry_points version="2.0" />
- </component>
- <component name="ProjectLevelVcsManager" settingsEditedManually="false">
- <OptionsSetting value="true" id="Add" />
- <OptionsSetting value="true" id="Remove" />
- <OptionsSetting value="true" id="Checkout" />
- <OptionsSetting value="true" id="Update" />
- <OptionsSetting value="true" id="Status" />
- <OptionsSetting value="true" id="Edit" />
- <ConfirmationsSetting value="0" id="Add" />
- <ConfirmationsSetting value="0" id="Remove" />
- </component>
- <component name="ProjectResources">
- <default-html-doctype>http://www.w3.org/1999/xhtml</default-html-doctype>
- </component>
- <component name="ProjectRootManager" version="2" languageLevel="JDK_1_6" assert-keyword="true" jdk-15="true" project-jdk-name="1.7" project-jdk-type="JavaSDK">
- <output url="file://$PROJECT_DIR$/out" />
- </component>
- <component name="RunManager">
- <configuration default="true" type="#org.jetbrains.idea.devkit.run.PluginConfigurationType" factoryName="Plugin">
- <module name="" />
- <option name="VM_PARAMETERS" value="-Xmx512m -Xms256m -XX:MaxPermSize=250m" />
- <option name="PROGRAM_PARAMETERS" />
- <method>
- <option name="AntTarget" enabled="false" />
- <option name="BuildArtifacts" enabled="false" />
- <option name="Make" enabled="true" />
- <option name="Maven.BeforeRunTask" enabled="false" />
- </method>
- </configuration>
- <configuration default="true" type="Remote" factoryName="Remote">
- <option name="USE_SOCKET_TRANSPORT" value="true" />
- <option name="SERVER_MODE" value="false" />
- <option name="SHMEM_ADDRESS" value="javadebug" />
- <option name="HOST" value="localhost" />
- <option name="PORT" value="5005" />
- <method>
- <option name="AntTarget" enabled="false" />
- <option name="BuildArtifacts" enabled="false" />
- <option name="Maven.BeforeRunTask" enabled="false" />
- </method>
- </configuration>
- <configuration default="true" type="Applet" factoryName="Applet">
- <module name="" />
- <option name="MAIN_CLASS_NAME" />
- <option name="HTML_FILE_NAME" />
- <option name="HTML_USED" value="false" />
- <option name="WIDTH" value="400" />
- <option name="HEIGHT" value="300" />
- <option name="POLICY_FILE" value="$APPLICATION_HOME_DIR$/bin/appletviewer.policy" />
- <option name="VM_PARAMETERS" />
- <option name="ALTERNATIVE_JRE_PATH_ENABLED" value="false" />
- <option name="ALTERNATIVE_JRE_PATH" />
- <method>
- <option name="AntTarget" enabled="false" />
- <option name="BuildArtifacts" enabled="false" />
- <option name="Make" enabled="true" />
- <option name="Maven.BeforeRunTask" enabled="false" />
- </method>
- </configuration>
- <configuration default="true" type="TestNG" factoryName="TestNG">
- <module name="" />
- <option name="ALTERNATIVE_JRE_PATH_ENABLED" value="false" />
- <option name="ALTERNATIVE_JRE_PATH" />
- <option name="SUITE_NAME" />
- <option name="PACKAGE_NAME" />
- <option name="MAIN_CLASS_NAME" />
- <option name="METHOD_NAME" />
- <option name="GROUP_NAME" />
- <option name="TEST_OBJECT" value="CLASS" />
- <option name="VM_PARAMETERS" value="-ea" />
- <option name="PARAMETERS" />
- <option name="WORKING_DIRECTORY" value="$PROJECT_DIR$" />
- <option name="OUTPUT_DIRECTORY" />
- <option name="ANNOTATION_TYPE" />
- <option name="ENV_VARIABLES" />
- <option name="PASS_PARENT_ENVS" value="true" />
- <option name="TEST_SEARCH_SCOPE">
- <value defaultName="moduleWithDependencies" />
- </option>
- <option name="USE_DEFAULT_REPORTERS" value="false" />
- <option name="PROPERTIES_FILE" />
- <envs />
- <properties />
- <listeners />
- <method>
- <option name="AntTarget" enabled="false" />
- <option name="BuildArtifacts" enabled="false" />
- <option name="Make" enabled="true" />
- <option name="Maven.BeforeRunTask" enabled="false" />
- </method>
- </configuration>
- <configuration default="true" type="Application" factoryName="Application">
- <option name="MAIN_CLASS_NAME" />
- <option name="VM_PARAMETERS" />
- <option name="PROGRAM_PARAMETERS" />
- <option name="WORKING_DIRECTORY" value="$PROJECT_DIR$" />
- <option name="ALTERNATIVE_JRE_PATH_ENABLED" value="false" />
- <option name="ALTERNATIVE_JRE_PATH" />
- <option name="ENABLE_SWING_INSPECTOR" value="false" />
- <option name="ENV_VARIABLES" />
- <option name="PASS_PARENT_ENVS" value="true" />
- <module name="" />
- <envs />
- <method>
- <option name="AntTarget" enabled="false" />
- <option name="BuildArtifacts" enabled="false" />
- <option name="Make" enabled="true" />
- <option name="Maven.BeforeRunTask" enabled="false" />
- </method>
- </configuration>
- <configuration default="true" type="JUnit" factoryName="JUnit">
- <module name="" />
- <option name="ALTERNATIVE_JRE_PATH_ENABLED" value="false" />
- <option name="ALTERNATIVE_JRE_PATH" />
- <option name="PACKAGE_NAME" />
- <option name="MAIN_CLASS_NAME" />
- <option name="METHOD_NAME" />
- <option name="TEST_OBJECT" value="class" />
- <option name="VM_PARAMETERS" value="-ea" />
- <option name="PARAMETERS" />
- <option name="WORKING_DIRECTORY" value="$PROJECT_DIR$" />
- <option name="ENV_VARIABLES" />
- <option name="PASS_PARENT_ENVS" value="true" />
- <option name="TEST_SEARCH_SCOPE">
- <value defaultName="moduleWithDependencies" />
- </option>
- <envs />
- <patterns />
- <method>
- <option name="AntTarget" enabled="false" />
- <option name="BuildArtifacts" enabled="false" />
- <option name="Make" enabled="true" />
- <option name="Maven.BeforeRunTask" enabled="false" />
- </method>
- </configuration>
- <list size="0" />
- <configuration name="<template>" type="WebApp" default="true" selected="false">
- <Host>localhost</Host>
- <Port>5050</Port>
- </configuration>
- </component>
-</project>
-
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
--- a/.idea/vcs.xml
+++ b/.idea/vcs.xml
@@ -1,7 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
- <mapping directory="" vcs="" />
+ <mapping directory="" vcs="hg4idea" />
</component>
</project>
diff --git a/.project b/.project
--- a/.project
+++ b/.project
@@ -6,12 +6,28 @@
</projects>
<buildSpec>
<buildCommand>
+ <name>org.python.pydev.PyDevBuilder</name>
+ <arguments>
+ </arguments>
+ </buildCommand>
+ <buildCommand>
<name>org.eclipse.jdt.core.javabuilder</name>
<arguments>
</arguments>
</buildCommand>
+ <buildCommand>
+ <name>org.eclipse.ui.externaltools.ExternalToolBuilder</name>
+ <triggers>full,incremental,</triggers>
+ <arguments>
+ <dictionary>
+ <key>LaunchConfigHandle</key>
+ <value><project>/.externalToolBuilders/New_Builder.launch</value>
+ </dictionary>
+ </arguments>
+ </buildCommand>
</buildSpec>
<natures>
<nature>org.eclipse.jdt.core.javanature</nature>
+ <nature>org.python.pydev.pythonNature</nature>
</natures>
</projectDescription>
diff --git a/Lib/decimal.py b/Lib/decimal.py
--- a/Lib/decimal.py
+++ b/Lib/decimal.py
@@ -35,26 +35,26 @@
useful for financial applications or for contexts where users have
expectations that are at odds with binary floating point (for instance,
in binary floating point, 1.00 % 0.1 gives 0.09999999999999995 instead
-of the expected Decimal("0.00") returned by decimal floating point).
+of the expected Decimal('0.00') returned by decimal floating point).
Here are some examples of using the decimal module:
>>> from decimal import *
>>> setcontext(ExtendedContext)
>>> Decimal(0)
-Decimal("0")
->>> Decimal("1")
-Decimal("1")
->>> Decimal("-.0123")
-Decimal("-0.0123")
+Decimal('0')
+>>> Decimal('1')
+Decimal('1')
+>>> Decimal('-.0123')
+Decimal('-0.0123')
>>> Decimal(123456)
-Decimal("123456")
->>> Decimal("123.45e12345678901234567890")
-Decimal("1.2345E+12345678901234567892")
->>> Decimal("1.33") + Decimal("1.27")
-Decimal("2.60")
->>> Decimal("12.34") + Decimal("3.87") - Decimal("18.41")
-Decimal("-2.20")
+Decimal('123456')
+>>> Decimal('123.45e12345678901234567890')
+Decimal('1.2345E+12345678901234567892')
+>>> Decimal('1.33') + Decimal('1.27')
+Decimal('2.60')
+>>> Decimal('12.34') + Decimal('3.87') - Decimal('18.41')
+Decimal('-2.20')
>>> dig = Decimal(1)
>>> print dig / Decimal(3)
0.333333333
@@ -91,7 +91,7 @@
>>> print c.flags[InvalidOperation]
0
>>> c.divide(Decimal(0), Decimal(0))
-Decimal("NaN")
+Decimal('NaN')
>>> c.traps[InvalidOperation] = 1
>>> print c.flags[InvalidOperation]
1
@@ -134,7 +134,17 @@
'setcontext', 'getcontext', 'localcontext'
]
+__version__ = '1.70' # Highest version of the spec this complies with
+
import copy as _copy
+import math as _math
+import numbers as _numbers
+
+try:
+ from collections import namedtuple as _namedtuple
+ DecimalTuple = _namedtuple('DecimalTuple', 'sign digits exponent')
+except ImportError:
+ DecimalTuple = lambda *args: args
# Rounding
ROUND_DOWN = 'ROUND_DOWN'
@@ -158,7 +168,7 @@
anything, though.
handle -- Called when context._raise_error is called and the
- trap_enabler is set. First argument is self, second is the
+ trap_enabler is not set. First argument is self, second is the
context. More arguments can be given, those being after
the explanation in _raise_error (For example,
context._raise_error(NewError, '(-x)!', self._sign) would
@@ -210,7 +220,7 @@
if args:
ans = _dec_from_triple(args[0]._sign, args[0]._int, 'n', True)
return ans._fix_nan(context)
- return NaN
+ return _NaN
class ConversionSyntax(InvalidOperation):
"""Trying to convert badly formed string.
@@ -220,7 +230,7 @@
syntax. The result is [0,qNaN].
"""
def handle(self, context, *args):
- return NaN
+ return _NaN
class DivisionByZero(DecimalException, ZeroDivisionError):
"""Division by 0.
@@ -236,7 +246,7 @@
"""
def handle(self, context, sign, *args):
- return Infsign[sign]
+ return _SignedInfinity[sign]
class DivisionImpossible(InvalidOperation):
"""Cannot perform the division adequately.
@@ -247,7 +257,7 @@
"""
def handle(self, context, *args):
- return NaN
+ return _NaN
class DivisionUndefined(InvalidOperation, ZeroDivisionError):
"""Undefined result of division.
@@ -258,7 +268,7 @@
"""
def handle(self, context, *args):
- return NaN
+ return _NaN
class Inexact(DecimalException):
"""Had to round, losing information.
@@ -284,7 +294,7 @@
"""
def handle(self, context, *args):
- return NaN
+ return _NaN
class Rounded(DecimalException):
"""Number got rounded (not necessarily changed during rounding).
@@ -334,15 +344,15 @@
def handle(self, context, sign, *args):
if context.rounding in (ROUND_HALF_UP, ROUND_HALF_EVEN,
ROUND_HALF_DOWN, ROUND_UP):
- return Infsign[sign]
+ return _SignedInfinity[sign]
if sign == 0:
if context.rounding == ROUND_CEILING:
- return Infsign[sign]
+ return _SignedInfinity[sign]
return _dec_from_triple(sign, '9'*context.prec,
context.Emax-context.prec+1)
if sign == 1:
if context.rounding == ROUND_FLOOR:
- return Infsign[sign]
+ return _SignedInfinity[sign]
return _dec_from_triple(sign, '9'*context.prec,
context.Emax-context.prec+1)
@@ -471,11 +481,7 @@
# General Decimal Arithmetic Specification
return +s # Convert result to normal context
- """
- # The string below can't be included in the docstring until Python 2.6
- # as the doctest module doesn't understand __future__ statements
- """
- >>> from __future__ import with_statement
+ >>> setcontext(DefaultContext)
>>> print getcontext().prec
28
>>> with localcontext():
@@ -510,13 +516,15 @@
"""Create a decimal point instance.
>>> Decimal('3.14') # string input
- Decimal("3.14")
+ Decimal('3.14')
>>> Decimal((0, (3, 1, 4), -2)) # tuple (sign, digit_tuple, exponent)
- Decimal("3.14")
+ Decimal('3.14')
>>> Decimal(314) # int or long
- Decimal("314")
+ Decimal('314')
>>> Decimal(Decimal(314)) # another decimal instance
- Decimal("314")
+ Decimal('314')
+ >>> Decimal(' 3.14 \\n') # leading and trailing whitespace okay
+ Decimal('3.14')
"""
# Note that the coefficient, self._int, is actually stored as
@@ -532,7 +540,7 @@
# From a string
# REs insist on real strings, so we can too.
if isinstance(value, basestring):
- m = _parser(value)
+ m = _parser(value.strip())
if m is None:
if context is None:
context = getcontext()
@@ -546,20 +554,16 @@
intpart = m.group('int')
if intpart is not None:
# finite number
- fracpart = m.group('frac')
+ fracpart = m.group('frac') or ''
exp = int(m.group('exp') or '0')
- if fracpart is not None:
- self._int = str((intpart+fracpart).lstrip('0') or '0')
- self._exp = exp - len(fracpart)
- else:
- self._int = str(intpart.lstrip('0') or '0')
- self._exp = exp
+ self._int = str(int(intpart+fracpart))
+ self._exp = exp - len(fracpart)
self._is_special = False
else:
diag = m.group('diag')
if diag is not None:
# NaN
- self._int = str(diag.lstrip('0'))
+ self._int = str(int(diag or '0')).lstrip('0')
if m.group('signal'):
self._exp = 'N'
else:
@@ -644,11 +648,55 @@
return self
if isinstance(value, float):
- raise TypeError("Cannot convert float to Decimal. " +
- "First convert the float to a string")
+ value = Decimal.from_float(value)
+ self._exp = value._exp
+ self._sign = value._sign
+ self._int = value._int
+ self._is_special = value._is_special
+ return self
raise TypeError("Cannot convert %r to Decimal" % value)
+ # @classmethod, but @decorator is not valid Python 2.3 syntax, so
+ # don't use it (see notes on Py2.3 compatibility at top of file)
+ def from_float(cls, f):
+ """Converts a float to a decimal number, exactly.
+
+ Note that Decimal.from_float(0.1) is not the same as Decimal('0.1').
+ Since 0.1 is not exactly representable in binary floating point, the
+ value is stored as the nearest representable value which is
+ 0x1.999999999999ap-4. The exact equivalent of the value in decimal
+ is 0.1000000000000000055511151231257827021181583404541015625.
+
+ >>> Decimal.from_float(0.1)
+ Decimal('0.1000000000000000055511151231257827021181583404541015625')
+ >>> Decimal.from_float(float('nan'))
+ Decimal('NaN')
+ >>> Decimal.from_float(float('inf'))
+ Decimal('Infinity')
+ >>> Decimal.from_float(-float('inf'))
+ Decimal('-Infinity')
+ >>> Decimal.from_float(-0.0)
+ Decimal('-0')
+
+ """
+ if isinstance(f, (int, long)): # handle integer inputs
+ return cls(f)
+ if _math.isinf(f) or _math.isnan(f): # raises TypeError if not a float
+ return cls(repr(f))
+ if _math.copysign(1.0, f) == 1.0:
+ sign = 0
+ else:
+ sign = 1
+ n, d = abs(f).as_integer_ratio()
+ k = d.bit_length() - 1
+ result = _dec_from_triple(sign, str(n*5**k), -k)
+ if cls is Decimal:
+ return result
+ else:
+ return cls(result)
+ from_float = classmethod(from_float)
+
def _isnan(self):
"""Returns whether the number is not actually one.
@@ -709,6 +757,39 @@
return other._fix_nan(context)
return 0
+ def _compare_check_nans(self, other, context):
+ """Version of _check_nans used for the signaling comparisons
+ compare_signal, __le__, __lt__, __ge__, __gt__.
+
+ Signal InvalidOperation if either self or other is a (quiet
+ or signaling) NaN. Signaling NaNs take precedence over quiet
+ NaNs.
+
+ Return 0 if neither operand is a NaN.
+
+ """
+ if context is None:
+ context = getcontext()
+
+ if self._is_special or other._is_special:
+ if self.is_snan():
+ return context._raise_error(InvalidOperation,
+ 'comparison involving sNaN',
+ self)
+ elif other.is_snan():
+ return context._raise_error(InvalidOperation,
+ 'comparison involving sNaN',
+ other)
+ elif self.is_qnan():
+ return context._raise_error(InvalidOperation,
+ 'comparison involving NaN',
+ self)
+ elif other.is_qnan():
+ return context._raise_error(InvalidOperation,
+ 'comparison involving NaN',
+ other)
+ return 0
+
def __nonzero__(self):
"""Return True if self is nonzero; otherwise return False.
@@ -716,21 +797,23 @@
"""
return self._is_special or self._int != '0'
- def __cmp__(self, other):
- other = _convert_other(other)
- if other is NotImplemented:
- # Never return NotImplemented
- return 1
+ def _cmp(self, other):
+ """Compare the two non-NaN decimal instances self and other.
+
+ Returns -1 if self < other, 0 if self == other and 1
+ if self > other. This routine is for internal use only."""
if self._is_special or other._is_special:
- # check for nans, without raising on a signaling nan
- if self._isnan() or other._isnan():
- return 1 # Comparison involving NaN's always reports self > other
-
- # INF = INF
- return cmp(self._isinfinity(), other._isinfinity())
-
- # check for zeros; note that cmp(0, -0) should return 0
+ self_inf = self._isinfinity()
+ other_inf = other._isinfinity()
+ if self_inf == other_inf:
+ return 0
+ elif self_inf < other_inf:
+ return -1
+ else:
+ return 1
+
+ # check for zeros; Decimal('0') == Decimal('-0')
if not self:
if not other:
return 0
@@ -750,21 +833,85 @@
if self_adjusted == other_adjusted:
self_padded = self._int + '0'*(self._exp - other._exp)
other_padded = other._int + '0'*(other._exp - self._exp)
- return cmp(self_padded, other_padded) * (-1)**self._sign
+ if self_padded == other_padded:
+ return 0
+ elif self_padded < other_padded:
+ return -(-1)**self._sign
+ else:
+ return (-1)**self._sign
elif self_adjusted > other_adjusted:
return (-1)**self._sign
else: # self_adjusted < other_adjusted
return -((-1)**self._sign)
- def __eq__(self, other):
- if not isinstance(other, (Decimal, int, long)):
- return NotImplemented
- return self.__cmp__(other) == 0
-
- def __ne__(self, other):
- if not isinstance(other, (Decimal, int, long)):
- return NotImplemented
- return self.__cmp__(other) != 0
+ # Note: The Decimal standard doesn't cover rich comparisons for
+ # Decimals. In particular, the specification is silent on the
+ # subject of what should happen for a comparison involving a NaN.
+ # We take the following approach:
+ #
+ # == comparisons involving a quiet NaN always return False
+ # != comparisons involving a quiet NaN always return True
+ # == or != comparisons involving a signaling NaN signal
+ # InvalidOperation, and return False or True as above if the
+ # InvalidOperation is not trapped.
+ # <, >, <= and >= comparisons involving a (quiet or signaling)
+ # NaN signal InvalidOperation, and return False if the
+ # InvalidOperation is not trapped.
+ #
+ # This behavior is designed to conform as closely as possible to
+ # that specified by IEEE 754.
+
+ def __eq__(self, other, context=None):
+ other = _convert_other(other, allow_float=True)
+ if other is NotImplemented:
+ return other
+ if self._check_nans(other, context):
+ return False
+ return self._cmp(other) == 0
+
+ def __ne__(self, other, context=None):
+ other = _convert_other(other, allow_float=True)
+ if other is NotImplemented:
+ return other
+ if self._check_nans(other, context):
+ return True
+ return self._cmp(other) != 0
+
+ def __lt__(self, other, context=None):
+ other = _convert_other(other, allow_float=True)
+ if other is NotImplemented:
+ return other
+ ans = self._compare_check_nans(other, context)
+ if ans:
+ return False
+ return self._cmp(other) < 0
+
+ def __le__(self, other, context=None):
+ other = _convert_other(other, allow_float=True)
+ if other is NotImplemented:
+ return other
+ ans = self._compare_check_nans(other, context)
+ if ans:
+ return False
+ return self._cmp(other) <= 0
+
+ def __gt__(self, other, context=None):
+ other = _convert_other(other, allow_float=True)
+ if other is NotImplemented:
+ return other
+ ans = self._compare_check_nans(other, context)
+ if ans:
+ return False
+ return self._cmp(other) > 0
+
+ def __ge__(self, other, context=None):
+ other = _convert_other(other, allow_float=True)
+ if other is NotImplemented:
+ return other
+ ans = self._compare_check_nans(other, context)
+ if ans:
+ return False
+ return self._cmp(other) >= 0
def compare(self, other, context=None):
"""Compares one to another.
@@ -783,7 +930,7 @@
if ans:
return ans
- return Decimal(self.__cmp__(other))
+ return Decimal(self._cmp(other))
def __hash__(self):
"""x.__hash__() <==> hash(x)"""
@@ -791,16 +938,44 @@
#
# The hash of a nonspecial noninteger Decimal must depend only
# on the value of that Decimal, and not on its representation.
- # For example: hash(Decimal("100E-1")) == hash(Decimal("10")).
+ # For example: hash(Decimal('100E-1')) == hash(Decimal('10')).
+
+ # Equality comparisons involving signaling nans can raise an
+ # exception; since equality checks are implicitly and
+ # unpredictably used when checking set and dict membership, we
+ # prevent signaling nans from being used as set elements or
+ # dict keys by making __hash__ raise an exception.
if self._is_special:
- if self._isnan():
- raise TypeError('Cannot hash a NaN value.')
- return hash(str(self))
- if not self:
- return 0
+ if self.is_snan():
+ raise TypeError('Cannot hash a signaling NaN value.')
+ elif self.is_nan():
+ # 0 to match hash(float('nan'))
+ return 0
+ else:
+ # values chosen to match hash(float('inf')) and
+ # hash(float('-inf')).
+ if self._sign:
+ return -271828
+ else:
+ return 314159
+
+ # In Python 2.7, we're allowing comparisons (but not
+ # arithmetic operations) between floats and Decimals; so if
+ # a Decimal instance is exactly representable as a float then
+ # its hash should match that of the float.
+ self_as_float = float(self)
+ if Decimal.from_float(self_as_float) == self:
+ return hash(self_as_float)
+
if self._isinteger():
op = _WorkRep(self.to_integral_value())
- return hash((-1)**op.sign*op.int*10**op.exp)
+ # to make computation feasible for Decimals with large
+ # exponent, we use the fact that hash(n) == hash(m) for
+ # any two nonzero integers n and m such that (i) n and m
+ # have the same sign, and (ii) n is congruent to m modulo
+ # 2**64-1. So we can replace hash((-1)**s*c*10**e) with
+ # hash((-1)**s*c*pow(10, e, 2**64-1).
+ return hash((-1)**op.sign*op.int*pow(10, op.exp, 2**64-1))
# The value of a nonzero nonspecial Decimal instance is
# faithfully represented by the triple consisting of its sign,
# its adjusted exponent, and its coefficient with trailing
@@ -814,12 +989,12 @@
To show the internals exactly as they are.
"""
- return (self._sign, tuple(map(int, self._int)), self._exp)
+ return DecimalTuple(self._sign, tuple(map(int, self._int)), self._exp)
def __repr__(self):
"""Represents the number as an instance of Decimal."""
# Invariant: eval(repr(d)) == d
- return 'Decimal("%s")' % str(self)
+ return "Decimal('%s')" % str(self)
def __str__(self, eng=False, context=None):
"""Return string representation of the number in scientific notation.
@@ -1077,12 +1252,12 @@
if self._isinfinity():
if not other:
return context._raise_error(InvalidOperation, '(+-)INF * 0')
- return Infsign[resultsign]
+ return _SignedInfinity[resultsign]
if other._isinfinity():
if not self:
return context._raise_error(InvalidOperation, '0 * (+-)INF')
- return Infsign[resultsign]
+ return _SignedInfinity[resultsign]
resultexp = self._exp + other._exp
@@ -1112,7 +1287,7 @@
return ans
__rmul__ = __mul__
- def __div__(self, other, context=None):
+ def __truediv__(self, other, context=None):
"""Return self / other."""
other = _convert_other(other)
if other is NotImplemented:
@@ -1132,7 +1307,7 @@
return context._raise_error(InvalidOperation, '(+-)INF/(+-)INF')
if self._isinfinity():
- return Infsign[sign]
+ return _SignedInfinity[sign]
if other._isinfinity():
context._raise_error(Clamped, 'Division by infinity')
@@ -1171,8 +1346,6 @@
ans = _dec_from_triple(sign, str(coeff), exp)
return ans._fix(context)
- __truediv__ = __div__
-
def _divide(self, other, context):
"""Return (self // other, self % other), to context.prec precision.
@@ -1206,13 +1379,15 @@
'quotient too large in //, % or divmod')
return ans, ans
- def __rdiv__(self, other, context=None):
- """Swaps self/other and returns __div__."""
+ def __rtruediv__(self, other, context=None):
+ """Swaps self/other and returns __truediv__."""
other = _convert_other(other)
if other is NotImplemented:
return other
- return other.__div__(self, context=context)
- __rtruediv__ = __rdiv__
+ return other.__truediv__(self, context=context)
+
+ __div__ = __truediv__
+ __rdiv__ = __rtruediv__
def __divmod__(self, other, context=None):
"""
@@ -1235,7 +1410,7 @@
ans = context._raise_error(InvalidOperation, 'divmod(INF, INF)')
return ans, ans
else:
- return (Infsign[sign],
+ return (_SignedInfinity[sign],
context._raise_error(InvalidOperation, 'INF % x'))
if not other:
@@ -1383,7 +1558,7 @@
if other._isinfinity():
return context._raise_error(InvalidOperation, 'INF // INF')
else:
- return Infsign[self._sign ^ other._sign]
+ return _SignedInfinity[self._sign ^ other._sign]
if not other:
if self:
@@ -1409,16 +1584,31 @@
"""Converts self to an int, truncating if necessary."""
if self._is_special:
if self._isnan():
- context = getcontext()
- return context._raise_error(InvalidContext)
+ raise ValueError("Cannot convert NaN to integer")
elif self._isinfinity():
- raise OverflowError("Cannot convert infinity to long")
+ raise OverflowError("Cannot convert infinity to integer")
s = (-1)**self._sign
if self._exp >= 0:
return s*int(self._int)*10**self._exp
else:
return s*int(self._int[:self._exp] or '0')
+ __trunc__ = __int__
+
+ def real(self):
+ return self
+ real = property(real)
+
+ def imag(self):
+ return Decimal(0)
+ imag = property(imag)
+
+ def conjugate(self):
+ return self
+
+ def __complex__(self):
+ return complex(float(self))
+
def __long__(self):
"""Converts to a long.
@@ -1474,47 +1664,53 @@
exp_min = len(self._int) + self._exp - context.prec
if exp_min > Etop:
# overflow: exp_min > Etop iff self.adjusted() > Emax
+ ans = context._raise_error(Overflow, 'above Emax', self._sign)
context._raise_error(Inexact)
context._raise_error(Rounded)
- return context._raise_error(Overflow, 'above Emax', self._sign)
+ return ans
+
self_is_subnormal = exp_min < Etiny
if self_is_subnormal:
- context._raise_error(Subnormal)
exp_min = Etiny
# round if self has too many digits
if self._exp < exp_min:
- context._raise_error(Rounded)
digits = len(self._int) + self._exp - exp_min
if digits < 0:
self = _dec_from_triple(self._sign, '1', exp_min-1)
digits = 0
- this_function = getattr(self, self._pick_rounding_function[context.rounding])
- changed = this_function(digits)
+ rounding_method = self._pick_rounding_function[context.rounding]
+ changed = getattr(self, rounding_method)(digits)
coeff = self._int[:digits] or '0'
- if changed == 1:
+ if changed > 0:
coeff = str(int(coeff)+1)
- ans = _dec_from_triple(self._sign, coeff, exp_min)
-
+ if len(coeff) > context.prec:
+ coeff = coeff[:-1]
+ exp_min += 1
+
+ # check whether the rounding pushed the exponent out of range
+ if exp_min > Etop:
+ ans = context._raise_error(Overflow, 'above Emax', self._sign)
+ else:
+ ans = _dec_from_triple(self._sign, coeff, exp_min)
+
+ # raise the appropriate signals, taking care to respect
+ # the precedence described in the specification
+ if changed and self_is_subnormal:
+ context._raise_error(Underflow)
+ if self_is_subnormal:
+ context._raise_error(Subnormal)
if changed:
context._raise_error(Inexact)
- if self_is_subnormal:
- context._raise_error(Underflow)
- if not ans:
- # raise Clamped on underflow to 0
- context._raise_error(Clamped)
- elif len(ans._int) == context.prec+1:
- # we get here only if rescaling rounds the
- # cofficient up to exactly 10**context.prec
- if ans._exp < Etop:
- ans = _dec_from_triple(ans._sign,
- ans._int[:-1], ans._exp+1)
- else:
- # Inexact and Rounded have already been raised
- ans = context._raise_error(Overflow, 'above Emax',
- self._sign)
+ context._raise_error(Rounded)
+ if not ans:
+ # raise Clamped on underflow to 0
+ context._raise_error(Clamped)
return ans
+ if self_is_subnormal:
+ context._raise_error(Subnormal)
+
# fold down if _clamp == 1 and self has too few digits
if context._clamp == 1 and self._exp > Etop:
context._raise_error(Clamped)
@@ -1622,12 +1818,12 @@
if not other:
return context._raise_error(InvalidOperation,
'INF * 0 in fma')
- product = Infsign[self._sign ^ other._sign]
+ product = _SignedInfinity[self._sign ^ other._sign]
elif other._exp == 'F':
if not self:
return context._raise_error(InvalidOperation,
'0 * INF in fma')
- product = Infsign[self._sign ^ other._sign]
+ product = _SignedInfinity[self._sign ^ other._sign]
else:
product = _dec_from_triple(self._sign ^ other._sign,
str(int(self._int) * int(other._int)),
@@ -1794,12 +1990,14 @@
# case where xc == 1: result is 10**(xe*y), with xe*y
# required to be an integer
if xc == 1:
- if ye >= 0:
- exponent = xe*yc*10**ye
- else:
- exponent, remainder = divmod(xe*yc, 10**-ye)
- if remainder:
- return None
+ xe *= yc
+ # result is now 10**(xe * 10**ye); xe * 10**ye must be integral
+ while xe % 10 == 0:
+ xe //= 10
+ ye += 1
+ if ye < 0:
+ return None
+ exponent = xe * 10**ye
if y.sign == 1:
exponent = -exponent
# if other is a nonnegative integer, use ideal exponent
@@ -1977,7 +2175,7 @@
if not self:
return context._raise_error(InvalidOperation, '0 ** 0')
else:
- return Dec_p1
+ return _One
# result has sign 1 iff self._sign is 1 and other is an odd integer
result_sign = 0
@@ -1999,19 +2197,19 @@
if other._sign == 0:
return _dec_from_triple(result_sign, '0', 0)
else:
- return Infsign[result_sign]
+ return _SignedInfinity[result_sign]
# Inf**(+ve or Inf) = Inf; Inf**(-ve or -Inf) = 0
if self._isinfinity():
if other._sign == 0:
- return Infsign[result_sign]
+ return _SignedInfinity[result_sign]
else:
return _dec_from_triple(result_sign, '0', 0)
# 1**other = 1, but the choice of exponent and the flags
# depend on the exponent of self, and on whether other is a
# positive integer, a negative integer, or neither
- if self == Dec_p1:
+ if self == _One:
if other._isinteger():
# exp = max(self._exp*max(int(other), 0),
# 1-context.prec) but evaluating int(other) directly
@@ -2044,11 +2242,12 @@
if (other._sign == 0) == (self_adj < 0):
return _dec_from_triple(result_sign, '0', 0)
else:
- return Infsign[result_sign]
+ return _SignedInfinity[result_sign]
# from here on, the result always goes through the call
# to _fix at the end of this function.
ans = None
+ exact = False
# crude test to catch cases of extreme overflow/underflow. If
# log10(self)*other >= 10**bound and bound >= len(str(Emax))
@@ -2071,8 +2270,10 @@
# try for an exact result with precision +1
if ans is None:
ans = self._power_exact(other, context.prec + 1)
- if ans is not None and result_sign == 1:
- ans = _dec_from_triple(1, ans._int, ans._exp)
+ if ans is not None:
+ if result_sign == 1:
+ ans = _dec_from_triple(1, ans._int, ans._exp)
+ exact = True
# usual case: inexact result, x**y computed directly as exp(y*log(x))
if ans is None:
@@ -2095,24 +2296,55 @@
ans = _dec_from_triple(result_sign, str(coeff), exp)
- # the specification says that for non-integer other we need to
- # raise Inexact, even when the result is actually exact. In
- # the same way, we need to raise Underflow here if the result
- # is subnormal. (The call to _fix will take care of raising
- # Rounded and Subnormal, as usual.)
- if not other._isinteger():
- context._raise_error(Inexact)
- # pad with zeros up to length context.prec+1 if necessary
+ # unlike exp, ln and log10, the power function respects the
+ # rounding mode; no need to switch to ROUND_HALF_EVEN here
+
+ # There's a difficulty here when 'other' is not an integer and
+ # the result is exact. In this case, the specification
+ # requires that the Inexact flag be raised (in spite of
+ # exactness), but since the result is exact _fix won't do this
+ # for us. (Correspondingly, the Underflow signal should also
+ # be raised for subnormal results.) We can't directly raise
+ # these signals either before or after calling _fix, since
+ # that would violate the precedence for signals. So we wrap
+ # the ._fix call in a temporary context, and reraise
+ # afterwards.
+ if exact and not other._isinteger():
+ # pad with zeros up to length context.prec+1 if necessary; this
+ # ensures that the Rounded signal will be raised.
if len(ans._int) <= context.prec:
- expdiff = context.prec+1 - len(ans._int)
+ expdiff = context.prec + 1 - len(ans._int)
ans = _dec_from_triple(ans._sign, ans._int+'0'*expdiff,
ans._exp-expdiff)
- if ans.adjusted() < context.Emin:
- context._raise_error(Underflow)
-
- # unlike exp, ln and log10, the power function respects the
- # rounding mode; no need to use ROUND_HALF_EVEN here
- ans = ans._fix(context)
+
+ # create a copy of the current context, with cleared flags/traps
+ newcontext = context.copy()
+ newcontext.clear_flags()
+ for exception in _signals:
+ newcontext.traps[exception] = 0
+
+ # round in the new context
+ ans = ans._fix(newcontext)
+
+ # raise Inexact, and if necessary, Underflow
+ newcontext._raise_error(Inexact)
+ if newcontext.flags[Subnormal]:
+ newcontext._raise_error(Underflow)
+
+ # propagate signals to the original context; _fix could
+ # have raised any of Overflow, Underflow, Subnormal,
+ # Inexact, Rounded, Clamped. Overflow needs the correct
+ # arguments. Note that the order of the exceptions is
+ # important here.
+ if newcontext.flags[Overflow]:
+ context._raise_error(Overflow, 'above Emax', ans._sign)
+ for exception in Underflow, Subnormal, Inexact, Rounded, Clamped:
+ if newcontext.flags[exception]:
+ context._raise_error(exception)
+
+ else:
+ ans = ans._fix(context)
+
return ans
def __rpow__(self, other, context=None):
@@ -2206,14 +2438,15 @@
'quantize result has too many digits for current context')
# raise appropriate flags
+ if ans and ans.adjusted() < context.Emin:
+ context._raise_error(Subnormal)
if ans._exp > self._exp:
- context._raise_error(Rounded)
if ans != self:
context._raise_error(Inexact)
- if ans and ans.adjusted() < context.Emin:
- context._raise_error(Subnormal)
-
- # call to fix takes care of any necessary folddown
+ context._raise_error(Rounded)
+
+ # call to fix takes care of any necessary folddown, and
+ # signals Clamped if necessary
ans = ans._fix(context)
return ans
@@ -2266,6 +2499,29 @@
coeff = str(int(coeff)+1)
return _dec_from_triple(self._sign, coeff, exp)
+ def _round(self, places, rounding):
+ """Round a nonzero, nonspecial Decimal to a fixed number of
+ significant figures, using the given rounding mode.
+
+ Infinities, NaNs and zeros are returned unaltered.
+
+ This operation is quiet: it raises no flags, and uses no
+ information from the context.
+
+ """
+ if places <= 0:
+ raise ValueError("argument should be at least 1 in _round")
+ if self._is_special or not self:
+ return Decimal(self)
+ ans = self._rescale(self.adjusted()+1-places, rounding)
+ # it can happen that the rescale alters the adjusted exponent;
+ # for example when rounding 99.97 to 3 significant figures.
+ # When this happens we end up with an extra 0 at the end of
+ # the number; a second rescale fixes this.
+ if ans.adjusted() != self.adjusted():
+ ans = ans._rescale(ans.adjusted()+1-places, rounding)
+ return ans
+
def to_integral_exact(self, rounding=None, context=None):
"""Rounds to a nearby integer.
@@ -2289,10 +2545,10 @@
context = getcontext()
if rounding is None:
rounding = context.rounding
- context._raise_error(Rounded)
ans = self._rescale(0, rounding)
if ans != self:
context._raise_error(Inexact)
+ context._raise_error(Rounded)
return ans
def to_integral_value(self, rounding=None, context=None):
@@ -2436,7 +2692,7 @@
return other._fix(context)
return self._check_nans(other, context)
- c = self.__cmp__(other)
+ c = self._cmp(other)
if c == 0:
# If both operands are finite and equal in numerical value
# then an ordering is applied:
@@ -2478,7 +2734,7 @@
return other._fix(context)
return self._check_nans(other, context)
- c = self.__cmp__(other)
+ c = self._cmp(other)
if c == 0:
c = self.compare_total(other)
@@ -2526,23 +2782,10 @@
It's pretty much like compare(), but all NaNs signal, with signaling
NaNs taking precedence over quiet NaNs.
"""
- if context is None:
- context = getcontext()
-
- self_is_nan = self._isnan()
- other_is_nan = other._isnan()
- if self_is_nan == 2:
- return context._raise_error(InvalidOperation, 'sNaN',
- self)
- if other_is_nan == 2:
- return context._raise_error(InvalidOperation, 'sNaN',
- other)
- if self_is_nan:
- return context._raise_error(InvalidOperation, 'NaN in compare_signal',
- self)
- if other_is_nan:
- return context._raise_error(InvalidOperation, 'NaN in compare_signal',
- other)
+ other = _convert_other(other, raiseit = True)
+ ans = self._compare_check_nans(other, context)
+ if ans:
+ return ans
return self.compare(other, context=context)
def compare_total(self, other):
@@ -2552,11 +2795,13 @@
value. Note that a total ordering is defined for all possible abstract
representations.
"""
+ other = _convert_other(other, raiseit=True)
+
# if one is negative and the other is positive, it's easy
if self._sign and not other._sign:
- return Dec_n1
+ return _NegativeOne
if not self._sign and other._sign:
- return Dec_p1
+ return _One
sign = self._sign
# let's handle both NaN types
@@ -2564,53 +2809,56 @@
other_nan = other._isnan()
if self_nan or other_nan:
if self_nan == other_nan:
- if self._int < other._int:
+ # compare payloads as though they're integers
+ self_key = len(self._int), self._int
+ other_key = len(other._int), other._int
+ if self_key < other_key:
if sign:
- return Dec_p1
+ return _One
else:
- return Dec_n1
- if self._int > other._int:
+ return _NegativeOne
+ if self_key > other_key:
if sign:
- return Dec_n1
+ return _NegativeOne
else:
- return Dec_p1
- return Dec_0
+ return _One
+ return _Zero
if sign:
if self_nan == 1:
- return Dec_n1
+ return _NegativeOne
if other_nan == 1:
- return Dec_p1
+ return _One
if self_nan == 2:
- return Dec_n1
+ return _NegativeOne
if other_nan == 2:
- return Dec_p1
+ return _One
else:
if self_nan == 1:
- return Dec_p1
+ return _One
if other_nan == 1:
- return Dec_n1
+ return _NegativeOne
if self_nan == 2:
- return Dec_p1
+ return _One
if other_nan == 2:
- return Dec_n1
+ return _NegativeOne
if self < other:
- return Dec_n1
+ return _NegativeOne
if self > other:
- return Dec_p1
+ return _One
if self._exp < other._exp:
if sign:
- return Dec_p1
+ return _One
else:
- return Dec_n1
+ return _NegativeOne
if self._exp > other._exp:
if sign:
- return Dec_n1
+ return _NegativeOne
else:
- return Dec_p1
- return Dec_0
+ return _One
+ return _Zero
def compare_total_mag(self, other):
@@ -2618,6 +2866,8 @@
Like compare_total, but with operand's sign ignored and assumed to be 0.
"""
+ other = _convert_other(other, raiseit=True)
+
s = self.copy_abs()
o = other.copy_abs()
return s.compare_total(o)
@@ -2635,6 +2885,7 @@
def copy_sign(self, other):
"""Returns self with the sign of other."""
+ other = _convert_other(other, raiseit=True)
return _dec_from_triple(other._sign, self._int,
self._exp, self._is_special)
@@ -2651,11 +2902,11 @@
# exp(-Infinity) = 0
if self._isinfinity() == -1:
- return Dec_0
+ return _Zero
# exp(0) = 1
if not self:
- return Dec_p1
+ return _One
# exp(Infinity) = Infinity
if self._isinfinity() == 1:
@@ -2743,7 +2994,7 @@
return False
if context is None:
context = getcontext()
- return context.Emin <= self.adjusted() <= context.Emax
+ return context.Emin <= self.adjusted()
def is_qnan(self):
"""Return True if self is a quiet NaN; otherwise return False."""
@@ -2807,15 +3058,15 @@
# ln(0.0) == -Infinity
if not self:
- return negInf
+ return _NegativeInfinity
# ln(Infinity) = Infinity
if self._isinfinity() == 1:
- return Inf
+ return _Infinity
# ln(1.0) == 0.0
- if self == Dec_p1:
- return Dec_0
+ if self == _One:
+ return _Zero
# ln(negative) raises InvalidOperation
if self._sign == 1:
@@ -2887,11 +3138,11 @@
# log10(0.0) == -Infinity
if not self:
- return negInf
+ return _NegativeInfinity
# log10(Infinity) = Infinity
if self._isinfinity() == 1:
- return Inf
+ return _Infinity
# log10(negative or -Infinity) raises InvalidOperation
if self._sign == 1:
@@ -2943,7 +3194,7 @@
# logb(+/-Inf) = +Inf
if self._isinfinity():
- return Inf
+ return _Infinity
# logb(0) = -Inf, DivisionByZero
if not self:
@@ -2952,12 +3203,13 @@
# otherwise, simply return the adjusted exponent of self, as a
# Decimal. Note that no attempt is made to fit the result
# into the current context.
- return Decimal(self.adjusted())
+ ans = Decimal(self.adjusted())
+ return ans._fix(context)
def _islogical(self):
"""Return True if self is a logical operand.
- For being logical, it must be a finite numbers with a sign of 0,
+ For being logical, it must be a finite number with a sign of 0,
an exponent of 0, and a coefficient whose digits must all be
either 0 or 1.
"""
@@ -2985,6 +3237,9 @@
"""Applies an 'and' operation between self and other's digits."""
if context is None:
context = getcontext()
+
+ other = _convert_other(other, raiseit=True)
+
if not self._islogical() or not other._islogical():
return context._raise_error(InvalidOperation)
@@ -3006,6 +3261,9 @@
"""Applies an 'or' operation between self and other's digits."""
if context is None:
context = getcontext()
+
+ other = _convert_other(other, raiseit=True)
+
if not self._islogical() or not other._islogical():
return context._raise_error(InvalidOperation)
@@ -3013,13 +3271,16 @@
(opa, opb) = self._fill_logical(context, self._int, other._int)
# make the operation, and clean starting zeroes
- result = "".join(str(int(a)|int(b)) for a,b in zip(opa,opb))
+ result = "".join([str(int(a)|int(b)) for a,b in zip(opa,opb)])
return _dec_from_triple(0, result.lstrip('0') or '0', 0)
def logical_xor(self, other, context=None):
"""Applies an 'xor' operation between self and other's digits."""
if context is None:
context = getcontext()
+
+ other = _convert_other(other, raiseit=True)
+
if not self._islogical() or not other._islogical():
return context._raise_error(InvalidOperation)
@@ -3027,7 +3288,7 @@
(opa, opb) = self._fill_logical(context, self._int, other._int)
# make the operation, and clean starting zeroes
- result = "".join(str(int(a)^int(b)) for a,b in zip(opa,opb))
+ result = "".join([str(int(a)^int(b)) for a,b in zip(opa,opb)])
return _dec_from_triple(0, result.lstrip('0') or '0', 0)
def max_mag(self, other, context=None):
@@ -3049,7 +3310,7 @@
return other._fix(context)
return self._check_nans(other, context)
- c = self.copy_abs().__cmp__(other.copy_abs())
+ c = self.copy_abs()._cmp(other.copy_abs())
if c == 0:
c = self.compare_total(other)
@@ -3079,7 +3340,7 @@
return other._fix(context)
return self._check_nans(other, context)
- c = self.copy_abs().__cmp__(other.copy_abs())
+ c = self.copy_abs()._cmp(other.copy_abs())
if c == 0:
c = self.compare_total(other)
@@ -3100,7 +3361,7 @@
return ans
if self._isinfinity() == -1:
- return negInf
+ return _NegativeInfinity
if self._isinfinity() == 1:
return _dec_from_triple(0, '9'*context.prec, context.Etop())
@@ -3123,7 +3384,7 @@
return ans
if self._isinfinity() == 1:
- return Inf
+ return _Infinity
if self._isinfinity() == -1:
return _dec_from_triple(1, '9'*context.prec, context.Etop())
@@ -3154,7 +3415,7 @@
if ans:
return ans
- comparison = self.__cmp__(other)
+ comparison = self._cmp(other)
if comparison == 0:
return self.copy_sign(other)
@@ -3168,13 +3429,13 @@
context._raise_error(Overflow,
'Infinite result from next_toward',
ans._sign)
+ context._raise_error(Inexact)
context._raise_error(Rounded)
- context._raise_error(Inexact)
elif ans.adjusted() < context.Emin:
context._raise_error(Underflow)
context._raise_error(Subnormal)
+ context._raise_error(Inexact)
context._raise_error(Rounded)
- context._raise_error(Inexact)
# if precision == 1 then we don't raise Clamped for a
# result 0E-Etiny.
if not ans:
@@ -3233,6 +3494,8 @@
if context is None:
context = getcontext()
+ other = _convert_other(other, raiseit=True)
+
ans = self._check_nans(other, context)
if ans:
return ans
@@ -3249,19 +3512,23 @@
torot = int(other)
rotdig = self._int
topad = context.prec - len(rotdig)
- if topad:
+ if topad > 0:
rotdig = '0'*topad + rotdig
+ elif topad < 0:
+ rotdig = rotdig[-topad:]
# let's rotate!
rotated = rotdig[torot:] + rotdig[:torot]
return _dec_from_triple(self._sign,
rotated.lstrip('0') or '0', self._exp)
- def scaleb (self, other, context=None):
+ def scaleb(self, other, context=None):
"""Returns self operand after adding the second value to its exp."""
if context is None:
context = getcontext()
+ other = _convert_other(other, raiseit=True)
+
ans = self._check_nans(other, context)
if ans:
return ans
@@ -3285,6 +3552,8 @@
if context is None:
context = getcontext()
+ other = _convert_other(other, raiseit=True)
+
ans = self._check_nans(other, context)
if ans:
return ans
@@ -3299,45 +3568,118 @@
# get values, pad if necessary
torot = int(other)
- if not torot:
- return Decimal(self)
rotdig = self._int
topad = context.prec - len(rotdig)
- if topad:
+ if topad > 0:
rotdig = '0'*topad + rotdig
+ elif topad < 0:
+ rotdig = rotdig[-topad:]
# let's shift!
if torot < 0:
- rotated = rotdig[:torot]
+ shifted = rotdig[:torot]
else:
- rotated = rotdig + '0'*torot
- rotated = rotated[-context.prec:]
+ shifted = rotdig + '0'*torot
+ shifted = shifted[-context.prec:]
return _dec_from_triple(self._sign,
- rotated.lstrip('0') or '0', self._exp)
+ shifted.lstrip('0') or '0', self._exp)
# Support for pickling, copy, and deepcopy
def __reduce__(self):
return (self.__class__, (str(self),))
def __copy__(self):
- if type(self) == Decimal:
+ if type(self) is Decimal:
return self # I'm immutable; therefore I am my own clone
return self.__class__(str(self))
def __deepcopy__(self, memo):
- if type(self) == Decimal:
+ if type(self) is Decimal:
return self # My components are also immutable
return self.__class__(str(self))
- # support for Jython __tojava__:
- def __tojava__(self, java_class):
- from java.lang import Object
- from java.math import BigDecimal
- from org.python.core import Py
- if java_class not in (BigDecimal, Object):
- return Py.NoConversion
- return BigDecimal(str(self))
+ # PEP 3101 support. the _localeconv keyword argument should be
+ # considered private: it's provided for ease of testing only.
+ def __format__(self, specifier, context=None, _localeconv=None):
+ """Format a Decimal instance according to the given specifier.
+
+ The specifier should be a standard format specifier, with the
+ form described in PEP 3101. Formatting types 'e', 'E', 'f',
+ 'F', 'g', 'G', 'n' and '%' are supported. If the formatting
+ type is omitted it defaults to 'g' or 'G', depending on the
+ value of context.capitals.
+ """
+
+ # Note: PEP 3101 says that if the type is not present then
+ # there should be at least one digit after the decimal point.
+ # We take the liberty of ignoring this requirement for
+ # Decimal---it's presumably there to make sure that
+ # format(float, '') behaves similarly to str(float).
+ if context is None:
+ context = getcontext()
+
+ spec = _parse_format_specifier(specifier, _localeconv=_localeconv)
+
+ # special values don't care about the type or precision
+ if self._is_special:
+ sign = _format_sign(self._sign, spec)
+ body = str(self.copy_abs())
+ return _format_align(sign, body, spec)
+
+ # a type of None defaults to 'g' or 'G', depending on context
+ if spec['type'] is None:
+ spec['type'] = ['g', 'G'][context.capitals]
+
+ # if type is '%', adjust exponent of self accordingly
+ if spec['type'] == '%':
+ self = _dec_from_triple(self._sign, self._int, self._exp+2)
+
+ # round if necessary, taking rounding mode from the context
+ rounding = context.rounding
+ precision = spec['precision']
+ if precision is not None:
+ if spec['type'] in 'eE':
+ self = self._round(precision+1, rounding)
+ elif spec['type'] in 'fF%':
+ self = self._rescale(-precision, rounding)
+ elif spec['type'] in 'gG' and len(self._int) > precision:
+ self = self._round(precision, rounding)
+ # special case: zeros with a positive exponent can't be
+ # represented in fixed point; rescale them to 0e0.
+ if not self and self._exp > 0 and spec['type'] in 'fF%':
+ self = self._rescale(0, rounding)
+
+ # figure out placement of the decimal point
+ leftdigits = self._exp + len(self._int)
+ if spec['type'] in 'eE':
+ if not self and precision is not None:
+ dotplace = 1 - precision
+ else:
+ dotplace = 1
+ elif spec['type'] in 'fF%':
+ dotplace = leftdigits
+ elif spec['type'] in 'gG':
+ if self._exp <= 0 and leftdigits > -6:
+ dotplace = leftdigits
+ else:
+ dotplace = 1
+
+ # find digits before and after decimal point, and get exponent
+ if dotplace < 0:
+ intpart = '0'
+ fracpart = '0'*(-dotplace) + self._int
+ elif dotplace > len(self._int):
+ intpart = self._int + '0'*(dotplace-len(self._int))
+ fracpart = ''
+ else:
+ intpart = self._int[:dotplace] or '0'
+ fracpart = self._int[dotplace:]
+ exp = leftdigits-dotplace
+
+ # done with the decimal-specific stuff; hand over the rest
+ # of the formatting to the _format_number function
+ return _format_number(self._sign, intpart, fracpart, exp, spec)
def _dec_from_triple(sign, coefficient, exponent, special=False):
"""Create a decimal instance directly, without any validation,
@@ -3355,6 +3697,12 @@
return self
+# Register Decimal as a kind of Number (an abstract base class).
+# However, do not register it as Real (because Decimals are not
+# interoperable with floats).
+_numbers.Number.register(Decimal)
+
+
##### Context class #######################################################
@@ -3393,7 +3741,7 @@
traps - If traps[exception] = 1, then the exception is
raised when it is caused. Otherwise, a value is
substituted in.
- flags - When an exception is caused, flags[exception] is incremented.
+ flags - When an exception is caused, flags[exception] is set.
(Whether or not the trap_enabler is set)
Should be reset by user of Decimal instance.
Emin - Minimum exponent
@@ -3408,22 +3756,38 @@
Emin=None, Emax=None,
capitals=None, _clamp=0,
_ignored_flags=None):
+ # Set defaults; for everything except flags and _ignored_flags,
+ # inherit from DefaultContext.
+ try:
+ dc = DefaultContext
+ except NameError:
+ pass
+
+ self.prec = prec if prec is not None else dc.prec
+ self.rounding = rounding if rounding is not None else dc.rounding
+ self.Emin = Emin if Emin is not None else dc.Emin
+ self.Emax = Emax if Emax is not None else dc.Emax
+ self.capitals = capitals if capitals is not None else dc.capitals
+ self._clamp = _clamp if _clamp is not None else dc._clamp
+
+ if _ignored_flags is None:
+ self._ignored_flags = []
+ else:
+ self._ignored_flags = _ignored_flags
+
+ if traps is None:
+ self.traps = dc.traps.copy()
+ elif not isinstance(traps, dict):
+ self.traps = dict((s, int(s in traps)) for s in _signals)
+ else:
+ self.traps = traps
+
if flags is None:
- flags = []
- if _ignored_flags is None:
- _ignored_flags = []
- if not isinstance(flags, dict):
- flags = dict([(s,s in flags) for s in _signals])
- del s
- if traps is not None and not isinstance(traps, dict):
- traps = dict([(s,s in traps) for s in _signals])
- del s
- for name, val in locals().items():
- if val is None:
- setattr(self, name, _copy.copy(getattr(DefaultContext, name)))
- else:
- setattr(self, name, val)
- del self.self
+ self.flags = dict.fromkeys(_signals, 0)
+ elif not isinstance(flags, dict):
+ self.flags = dict((s, int(s in flags)) for s in _signals)
+ else:
+ self.flags = flags
def __repr__(self):
"""Show the current context."""
@@ -3461,23 +3825,23 @@
"""Handles an error
If the flag is in _ignored_flags, returns the default response.
- Otherwise, it increments the flag, then, if the corresponding
- trap_enabler is set, it reaises the exception. Otherwise, it returns
- the default value after incrementing the flag.
+ Otherwise, it sets the flag, then, if the corresponding
+ trap_enabler is set, it reraises the exception. Otherwise, it returns
+ the default value after setting the flag.
"""
error = _condition_map.get(condition, condition)
if error in self._ignored_flags:
# Don't touch the flag
return error().handle(self, *args)
- self.flags[error] += 1
+ self.flags[error] = 1
if not self.traps[error]:
# The errors define how to handle themselves.
return condition().handle(self, *args)
# Errors should only be risked on copies of the context
# self._ignored_flags = []
- raise error, explanation
+ raise error(explanation)
def _ignore_all_flags(self):
"""Ignore all flags, if they are raised"""
@@ -3497,10 +3861,8 @@
for flag in flags:
self._ignored_flags.remove(flag)
- def __hash__(self):
- """A Context cannot be hashed."""
- # We inherit object.__hash__, so we must deny this explicitly
- raise TypeError("Cannot hash a Context.")
+ # We inherit object.__hash__, so we must deny this explicitly
+ __hash__ = None
def Etiny(self):
"""Returns Etiny (= Emin - prec + 1)"""
@@ -3530,13 +3892,39 @@
return rounding
def create_decimal(self, num='0'):
- """Creates a new Decimal instance but using self as context."""
+ """Creates a new Decimal instance but using self as context.
+
+ This method implements the to-number operation of the
+ IBM Decimal specification."""
+
+ if isinstance(num, basestring) and num != num.strip():
+ return self._raise_error(ConversionSyntax,
+ "no trailing or leading whitespace is "
+ "permitted.")
+
d = Decimal(num, context=self)
if d._isnan() and len(d._int) > self.prec - self._clamp:
return self._raise_error(ConversionSyntax,
"diagnostic info too long in NaN")
return d._fix(self)
+ def create_decimal_from_float(self, f):
+ """Creates a new Decimal instance from a float but rounding using self
+ as the context.
+
+ >>> context = Context(prec=5, rounding=ROUND_DOWN)
+ >>> context.create_decimal_from_float(3.1415926535897932)
+ Decimal('3.1415')
+ >>> context = Context(prec=5, traps=[Inexact])
+ >>> context.create_decimal_from_float(3.1415926535897932)
+ Traceback (most recent call last):
+ ...
+ Inexact: None
+
+ """
+ d = Decimal.from_float(f) # An exact conversion
+ return d._fix(self) # Apply the context rounding
+
# Methods
def abs(self, a):
"""Returns the absolute value of the operand.
@@ -3546,25 +3934,39 @@
the plus operation on the operand.
>>> ExtendedContext.abs(Decimal('2.1'))
- Decimal("2.1")
+ Decimal('2.1')
>>> ExtendedContext.abs(Decimal('-100'))
- Decimal("100")
+ Decimal('100')
>>> ExtendedContext.abs(Decimal('101.5'))
- Decimal("101.5")
+ Decimal('101.5')
>>> ExtendedContext.abs(Decimal('-101.5'))
- Decimal("101.5")
+ Decimal('101.5')
+ >>> ExtendedContext.abs(-1)
+ Decimal('1')
"""
+ a = _convert_other(a, raiseit=True)
return a.__abs__(context=self)
def add(self, a, b):
"""Return the sum of the two operands.
>>> ExtendedContext.add(Decimal('12'), Decimal('7.00'))
- Decimal("19.00")
+ Decimal('19.00')
>>> ExtendedContext.add(Decimal('1E+2'), Decimal('1.01E+4'))
- Decimal("1.02E+4")
+ Decimal('1.02E+4')
+ >>> ExtendedContext.add(1, Decimal(2))
+ Decimal('3')
+ >>> ExtendedContext.add(Decimal(8), 5)
+ Decimal('13')
+ >>> ExtendedContext.add(5, 5)
+ Decimal('10')
"""
- return a.__add__(b, context=self)
+ a = _convert_other(a, raiseit=True)
+ r = a.__add__(b, context=self)
+ if r is NotImplemented:
+ raise TypeError("Unable to convert %s to Decimal" % b)
+ else:
+ return r
def _apply(self, a):
return str(a._fix(self))
@@ -3576,7 +3978,7 @@
received object already is in its canonical form.
>>> ExtendedContext.canonical(Decimal('2.50'))
- Decimal("2.50")
+ Decimal('2.50')
"""
return a.canonical(context=self)
@@ -3595,18 +3997,25 @@
zero or negative zero, or '1' if the result is greater than zero.
>>> ExtendedContext.compare(Decimal('2.1'), Decimal('3'))
- Decimal("-1")
+ Decimal('-1')
>>> ExtendedContext.compare(Decimal('2.1'), Decimal('2.1'))
- Decimal("0")
+ Decimal('0')
>>> ExtendedContext.compare(Decimal('2.1'), Decimal('2.10'))
- Decimal("0")
+ Decimal('0')
>>> ExtendedContext.compare(Decimal('3'), Decimal('2.1'))
- Decimal("1")
+ Decimal('1')
>>> ExtendedContext.compare(Decimal('2.1'), Decimal('-3'))
- Decimal("1")
+ Decimal('1')
>>> ExtendedContext.compare(Decimal('-3'), Decimal('2.1'))
- Decimal("-1")
+ Decimal('-1')
+ >>> ExtendedContext.compare(1, 2)
+ Decimal('-1')
+ >>> ExtendedContext.compare(Decimal(1), 2)
+ Decimal('-1')
+ >>> ExtendedContext.compare(1, Decimal(2))
+ Decimal('-1')
"""
+ a = _convert_other(a, raiseit=True)
return a.compare(b, context=self)
def compare_signal(self, a, b):
@@ -3617,24 +4026,31 @@
>>> c = ExtendedContext
>>> c.compare_signal(Decimal('2.1'), Decimal('3'))
- Decimal("-1")
+ Decimal('-1')
>>> c.compare_signal(Decimal('2.1'), Decimal('2.1'))
- Decimal("0")
+ Decimal('0')
>>> c.flags[InvalidOperation] = 0
>>> print c.flags[InvalidOperation]
0
>>> c.compare_signal(Decimal('NaN'), Decimal('2.1'))
- Decimal("NaN")
+ Decimal('NaN')
>>> print c.flags[InvalidOperation]
1
>>> c.flags[InvalidOperation] = 0
>>> print c.flags[InvalidOperation]
0
>>> c.compare_signal(Decimal('sNaN'), Decimal('2.1'))
- Decimal("NaN")
+ Decimal('NaN')
>>> print c.flags[InvalidOperation]
1
+ >>> c.compare_signal(-1, 2)
+ Decimal('-1')
+ >>> c.compare_signal(Decimal(-1), 2)
+ Decimal('-1')
+ >>> c.compare_signal(-1, Decimal(2))
+ Decimal('-1')
"""
+ a = _convert_other(a, raiseit=True)
return a.compare_signal(b, context=self)
def compare_total(self, a, b):
@@ -3645,18 +4061,25 @@
representations.
>>> ExtendedContext.compare_total(Decimal('12.73'), Decimal('127.9'))
- Decimal("-1")
+ Decimal('-1')
>>> ExtendedContext.compare_total(Decimal('-127'), Decimal('12'))
- Decimal("-1")
+ Decimal('-1')
>>> ExtendedContext.compare_total(Decimal('12.30'), Decimal('12.3'))
- Decimal("-1")
+ Decimal('-1')
>>> ExtendedContext.compare_total(Decimal('12.30'), Decimal('12.30'))
- Decimal("0")
+ Decimal('0')
>>> ExtendedContext.compare_total(Decimal('12.3'), Decimal('12.300'))
- Decimal("1")
+ Decimal('1')
>>> ExtendedContext.compare_total(Decimal('12.3'), Decimal('NaN'))
- Decimal("-1")
+ Decimal('-1')
+ >>> ExtendedContext.compare_total(1, 2)
+ Decimal('-1')
+ >>> ExtendedContext.compare_total(Decimal(1), 2)
+ Decimal('-1')
+ >>> ExtendedContext.compare_total(1, Decimal(2))
+ Decimal('-1')
"""
+ a = _convert_other(a, raiseit=True)
return a.compare_total(b)
def compare_total_mag(self, a, b):
@@ -3664,36 +4087,46 @@
Like compare_total, but with operand's sign ignored and assumed to be 0.
"""
+ a = _convert_other(a, raiseit=True)
return a.compare_total_mag(b)
def copy_abs(self, a):
"""Returns a copy of the operand with the sign set to 0.
>>> ExtendedContext.copy_abs(Decimal('2.1'))
- Decimal("2.1")
+ Decimal('2.1')
>>> ExtendedContext.copy_abs(Decimal('-100'))
- Decimal("100")
+ Decimal('100')
+ >>> ExtendedContext.copy_abs(-1)
+ Decimal('1')
"""
+ a = _convert_other(a, raiseit=True)
return a.copy_abs()
def copy_decimal(self, a):
- """Returns a copy of the decimal objet.
+ """Returns a copy of the decimal object.
>>> ExtendedContext.copy_decimal(Decimal('2.1'))
- Decimal("2.1")
+ Decimal('2.1')
>>> ExtendedContext.copy_decimal(Decimal('-1.00'))
- Decimal("-1.00")
+ Decimal('-1.00')
+ >>> ExtendedContext.copy_decimal(1)
+ Decimal('1')
"""
+ a = _convert_other(a, raiseit=True)
return Decimal(a)
def copy_negate(self, a):
"""Returns a copy of the operand with the sign inverted.
>>> ExtendedContext.copy_negate(Decimal('101.5'))
- Decimal("-101.5")
+ Decimal('-101.5')
>>> ExtendedContext.copy_negate(Decimal('-101.5'))
- Decimal("101.5")
+ Decimal('101.5')
+ >>> ExtendedContext.copy_negate(1)
+ Decimal('-1')
"""
+ a = _convert_other(a, raiseit=True)
return a.copy_negate()
def copy_sign(self, a, b):
@@ -3703,56 +4136,103 @@
equal to the sign of the second operand.
>>> ExtendedContext.copy_sign(Decimal( '1.50'), Decimal('7.33'))
- Decimal("1.50")
+ Decimal('1.50')
>>> ExtendedContext.copy_sign(Decimal('-1.50'), Decimal('7.33'))
- Decimal("1.50")
+ Decimal('1.50')
>>> ExtendedContext.copy_sign(Decimal( '1.50'), Decimal('-7.33'))
- Decimal("-1.50")
+ Decimal('-1.50')
>>> ExtendedContext.copy_sign(Decimal('-1.50'), Decimal('-7.33'))
- Decimal("-1.50")
+ Decimal('-1.50')
+ >>> ExtendedContext.copy_sign(1, -2)
+ Decimal('-1')
+ >>> ExtendedContext.copy_sign(Decimal(1), -2)
+ Decimal('-1')
+ >>> ExtendedContext.copy_sign(1, Decimal(-2))
+ Decimal('-1')
"""
+ a = _convert_other(a, raiseit=True)
return a.copy_sign(b)
def divide(self, a, b):
"""Decimal division in a specified context.
>>> ExtendedContext.divide(Decimal('1'), Decimal('3'))
- Decimal("0.333333333")
+ Decimal('0.333333333')
>>> ExtendedContext.divide(Decimal('2'), Decimal('3'))
- Decimal("0.666666667")
+ Decimal('0.666666667')
>>> ExtendedContext.divide(Decimal('5'), Decimal('2'))
- Decimal("2.5")
+ Decimal('2.5')
>>> ExtendedContext.divide(Decimal('1'), Decimal('10'))
- Decimal("0.1")
+ Decimal('0.1')
>>> ExtendedContext.divide(Decimal('12'), Decimal('12'))
- Decimal("1")
+ Decimal('1')
>>> ExtendedContext.divide(Decimal('8.00'), Decimal('2'))
- Decimal("4.00")
+ Decimal('4.00')
>>> ExtendedContext.divide(Decimal('2.400'), Decimal('2.0'))
- Decimal("1.20")
+ Decimal('1.20')
>>> ExtendedContext.divide(Decimal('1000'), Decimal('100'))
- Decimal("10")
+ Decimal('10')
>>> ExtendedContext.divide(Decimal('1000'), Decimal('1'))
- Decimal("1000")
+ Decimal('1000')
>>> ExtendedContext.divide(Decimal('2.40E+6'), Decimal('2'))
- Decimal("1.20E+6")
+ Decimal('1.20E+6')
+ >>> ExtendedContext.divide(5, 5)
+ Decimal('1')
+ >>> ExtendedContext.divide(Decimal(5), 5)
+ Decimal('1')
+ >>> ExtendedContext.divide(5, Decimal(5))
+ Decimal('1')
"""
- return a.__div__(b, context=self)
+ a = _convert_other(a, raiseit=True)
+ r = a.__div__(b, context=self)
+ if r is NotImplemented:
+ raise TypeError("Unable to convert %s to Decimal" % b)
+ else:
+ return r
def divide_int(self, a, b):
"""Divides two numbers and returns the integer part of the result.
>>> ExtendedContext.divide_int(Decimal('2'), Decimal('3'))
- Decimal("0")
+ Decimal('0')
>>> ExtendedContext.divide_int(Decimal('10'), Decimal('3'))
- Decimal("3")
+ Decimal('3')
>>> ExtendedContext.divide_int(Decimal('1'), Decimal('0.3'))
- Decimal("3")
+ Decimal('3')
+ >>> ExtendedContext.divide_int(10, 3)
+ Decimal('3')
+ >>> ExtendedContext.divide_int(Decimal(10), 3)
+ Decimal('3')
+ >>> ExtendedContext.divide_int(10, Decimal(3))
+ Decimal('3')
"""
- return a.__floordiv__(b, context=self)
+ a = _convert_other(a, raiseit=True)
+ r = a.__floordiv__(b, context=self)
+ if r is NotImplemented:
+ raise TypeError("Unable to convert %s to Decimal" % b)
+ else:
+ return r
def divmod(self, a, b):
- return a.__divmod__(b, context=self)
+ """Return (a // b, a % b).
+
+ >>> ExtendedContext.divmod(Decimal(8), Decimal(3))
+ (Decimal('2'), Decimal('2'))
+ >>> ExtendedContext.divmod(Decimal(8), Decimal(4))
+ (Decimal('2'), Decimal('0'))
+ >>> ExtendedContext.divmod(8, 4)
+ (Decimal('2'), Decimal('0'))
+ >>> ExtendedContext.divmod(Decimal(8), 4)
+ (Decimal('2'), Decimal('0'))
+ >>> ExtendedContext.divmod(8, Decimal(4))
+ (Decimal('2'), Decimal('0'))
+ """
+ a = _convert_other(a, raiseit=True)
+ r = a.__divmod__(b, context=self)
+ if r is NotImplemented:
+ raise TypeError("Unable to convert %s to Decimal" % b)
+ else:
+ return r
def exp(self, a):
"""Returns e ** a.
@@ -3761,18 +4241,21 @@
>>> c.Emin = -999
>>> c.Emax = 999
>>> c.exp(Decimal('-Infinity'))
- Decimal("0")
+ Decimal('0')
>>> c.exp(Decimal('-1'))
- Decimal("0.367879441")
+ Decimal('0.367879441')
>>> c.exp(Decimal('0'))
- Decimal("1")
+ Decimal('1')
>>> c.exp(Decimal('1'))
- Decimal("2.71828183")
+ Decimal('2.71828183')
>>> c.exp(Decimal('0.693147181'))
- Decimal("2.00000000")
+ Decimal('2.00000000')
>>> c.exp(Decimal('+Infinity'))
- Decimal("Infinity")
+ Decimal('Infinity')
+ >>> c.exp(10)
+ Decimal('22026.4658')
"""
+ a =_convert_other(a, raiseit=True)
return a.exp(context=self)
def fma(self, a, b, c):
@@ -3783,12 +4266,19 @@
multiplication, using add, all with only one final rounding.
>>> ExtendedContext.fma(Decimal('3'), Decimal('5'), Decimal('7'))
- Decimal("22")
+ Decimal('22')
>>> ExtendedContext.fma(Decimal('3'), Decimal('-5'), Decimal('7'))
- Decimal("-8")
+ Decimal('-8')
>>> ExtendedContext.fma(Decimal('888565290'), Decimal('1557.96930'), Decimal('-86087.7578'))
- Decimal("1.38435736E+12")
+ Decimal('1.38435736E+12')
+ >>> ExtendedContext.fma(1, 3, 4)
+ Decimal('7')
+ >>> ExtendedContext.fma(1, Decimal(3), 4)
+ Decimal('7')
+ >>> ExtendedContext.fma(1, 3, Decimal(4))
+ Decimal('7')
"""
+ a = _convert_other(a, raiseit=True)
return a.fma(b, c, context=self)
def is_canonical(self, a):
@@ -3818,7 +4308,10 @@
False
>>> ExtendedContext.is_finite(Decimal('NaN'))
False
+ >>> ExtendedContext.is_finite(1)
+ True
"""
+ a = _convert_other(a, raiseit=True)
return a.is_finite()
def is_infinite(self, a):
@@ -3830,7 +4323,10 @@
True
>>> ExtendedContext.is_infinite(Decimal('NaN'))
False
+ >>> ExtendedContext.is_infinite(1)
+ False
"""
+ a = _convert_other(a, raiseit=True)
return a.is_infinite()
def is_nan(self, a):
@@ -3843,7 +4339,10 @@
True
>>> ExtendedContext.is_nan(Decimal('-sNaN'))
True
+ >>> ExtendedContext.is_nan(1)
+ False
"""
+ a = _convert_other(a, raiseit=True)
return a.is_nan()
def is_normal(self, a):
@@ -3863,7 +4362,10 @@
False
>>> c.is_normal(Decimal('NaN'))
False
+ >>> c.is_normal(1)
+ True
"""
+ a = _convert_other(a, raiseit=True)
return a.is_normal(context=self)
def is_qnan(self, a):
@@ -3875,7 +4377,10 @@
True
>>> ExtendedContext.is_qnan(Decimal('sNaN'))
False
+ >>> ExtendedContext.is_qnan(1)
+ False
"""
+ a = _convert_other(a, raiseit=True)
return a.is_qnan()
def is_signed(self, a):
@@ -3887,7 +4392,12 @@
True
>>> ExtendedContext.is_signed(Decimal('-0'))
True
+ >>> ExtendedContext.is_signed(8)
+ False
+ >>> ExtendedContext.is_signed(-8)
+ True
"""
+ a = _convert_other(a, raiseit=True)
return a.is_signed()
def is_snan(self, a):
@@ -3900,7 +4410,10 @@
False
>>> ExtendedContext.is_snan(Decimal('sNaN'))
True
+ >>> ExtendedContext.is_snan(1)
+ False
"""
+ a = _convert_other(a, raiseit=True)
return a.is_snan()
def is_subnormal(self, a):
@@ -3919,7 +4432,10 @@
False
>>> c.is_subnormal(Decimal('NaN'))
False
+ >>> c.is_subnormal(1)
+ False
"""
+ a = _convert_other(a, raiseit=True)
return a.is_subnormal(context=self)
def is_zero(self, a):
@@ -3931,7 +4447,12 @@
False
>>> ExtendedContext.is_zero(Decimal('-0E+2'))
True
+ >>> ExtendedContext.is_zero(1)
+ False
+ >>> ExtendedContext.is_zero(0)
+ True
"""
+ a = _convert_other(a, raiseit=True)
return a.is_zero()
def ln(self, a):
@@ -3941,16 +4462,19 @@
>>> c.Emin = -999
>>> c.Emax = 999
>>> c.ln(Decimal('0'))
- Decimal("-Infinity")
+ Decimal('-Infinity')
>>> c.ln(Decimal('1.000'))
- Decimal("0")
+ Decimal('0')
>>> c.ln(Decimal('2.71828183'))
- Decimal("1.00000000")
+ Decimal('1.00000000')
>>> c.ln(Decimal('10'))
- Decimal("2.30258509")
+ Decimal('2.30258509')
>>> c.ln(Decimal('+Infinity'))
- Decimal("Infinity")
+ Decimal('Infinity')
+ >>> c.ln(1)
+ Decimal('0')
"""
+ a = _convert_other(a, raiseit=True)
return a.ln(context=self)
def log10(self, a):
@@ -3960,20 +4484,25 @@
>>> c.Emin = -999
>>> c.Emax = 999
>>> c.log10(Decimal('0'))
- Decimal("-Infinity")
+ Decimal('-Infinity')
>>> c.log10(Decimal('0.001'))
- Decimal("-3")
+ Decimal('-3')
>>> c.log10(Decimal('1.000'))
- Decimal("0")
+ Decimal('0')
>>> c.log10(Decimal('2'))
- Decimal("0.301029996")
+ Decimal('0.301029996')
>>> c.log10(Decimal('10'))
- Decimal("1")
+ Decimal('1')
>>> c.log10(Decimal('70'))
- Decimal("1.84509804")
+ Decimal('1.84509804')
>>> c.log10(Decimal('+Infinity'))
- Decimal("Infinity")
+ Decimal('Infinity')
+ >>> c.log10(0)
+ Decimal('-Infinity')
+ >>> c.log10(1)
+ Decimal('0')
"""
+ a = _convert_other(a, raiseit=True)
return a.log10(context=self)
def logb(self, a):
@@ -3985,14 +4514,21 @@
value of that digit and without limiting the resulting exponent).
>>> ExtendedContext.logb(Decimal('250'))
- Decimal("2")
+ Decimal('2')
>>> ExtendedContext.logb(Decimal('2.50'))
- Decimal("0")
+ Decimal('0')
>>> ExtendedContext.logb(Decimal('0.03'))
- Decimal("-2")
+ Decimal('-2')
>>> ExtendedContext.logb(Decimal('0'))
- Decimal("-Infinity")
+ Decimal('-Infinity')
+ >>> ExtendedContext.logb(1)
+ Decimal('0')
+ >>> ExtendedContext.logb(10)
+ Decimal('1')
+ >>> ExtendedContext.logb(100)
+ Decimal('2')
"""
+ a = _convert_other(a, raiseit=True)
return a.logb(context=self)
def logical_and(self, a, b):
@@ -4001,18 +4537,25 @@
The operands must be both logical numbers.
>>> ExtendedContext.logical_and(Decimal('0'), Decimal('0'))
- Decimal("0")
+ Decimal('0')
>>> ExtendedContext.logical_and(Decimal('0'), Decimal('1'))
- Decimal("0")
+ Decimal('0')
>>> ExtendedContext.logical_and(Decimal('1'), Decimal('0'))
- Decimal("0")
+ Decimal('0')
>>> ExtendedContext.logical_and(Decimal('1'), Decimal('1'))
- Decimal("1")
+ Decimal('1')
>>> ExtendedContext.logical_and(Decimal('1100'), Decimal('1010'))
- Decimal("1000")
+ Decimal('1000')
>>> ExtendedContext.logical_and(Decimal('1111'), Decimal('10'))
- Decimal("10")
+ Decimal('10')
+ >>> ExtendedContext.logical_and(110, 1101)
+ Decimal('100')
+ >>> ExtendedContext.logical_and(Decimal(110), 1101)
+ Decimal('100')
+ >>> ExtendedContext.logical_and(110, Decimal(1101))
+ Decimal('100')
"""
+ a = _convert_other(a, raiseit=True)
return a.logical_and(b, context=self)
def logical_invert(self, a):
@@ -4021,14 +4564,17 @@
The operand must be a logical number.
>>> ExtendedContext.logical_invert(Decimal('0'))
- Decimal("111111111")
+ Decimal('111111111')
>>> ExtendedContext.logical_invert(Decimal('1'))
- Decimal("111111110")
+ Decimal('111111110')
>>> ExtendedContext.logical_invert(Decimal('111111111'))
- Decimal("0")
+ Decimal('0')
>>> ExtendedContext.logical_invert(Decimal('101010101'))
- Decimal("10101010")
+ Decimal('10101010')
+ >>> ExtendedContext.logical_invert(1101)
+ Decimal('111110010')
"""
+ a = _convert_other(a, raiseit=True)
return a.logical_invert(context=self)
def logical_or(self, a, b):
@@ -4037,18 +4583,25 @@
The operands must be both logical numbers.
>>> ExtendedContext.logical_or(Decimal('0'), Decimal('0'))
- Decimal("0")
+ Decimal('0')
>>> ExtendedContext.logical_or(Decimal('0'), Decimal('1'))
- Decimal("1")
+ Decimal('1')
>>> ExtendedContext.logical_or(Decimal('1'), Decimal('0'))
- Decimal("1")
+ Decimal('1')
>>> ExtendedContext.logical_or(Decimal('1'), Decimal('1'))
- Decimal("1")
+ Decimal('1')
>>> ExtendedContext.logical_or(Decimal('1100'), Decimal('1010'))
- Decimal("1110")
+ Decimal('1110')
>>> ExtendedContext.logical_or(Decimal('1110'), Decimal('10'))
- Decimal("1110")
+ Decimal('1110')
+ >>> ExtendedContext.logical_or(110, 1101)
+ Decimal('1111')
+ >>> ExtendedContext.logical_or(Decimal(110), 1101)
+ Decimal('1111')
+ >>> ExtendedContext.logical_or(110, Decimal(1101))
+ Decimal('1111')
"""
+ a = _convert_other(a, raiseit=True)
return a.logical_or(b, context=self)
def logical_xor(self, a, b):
@@ -4057,66 +4610,113 @@
The operands must be both logical numbers.
>>> ExtendedContext.logical_xor(Decimal('0'), Decimal('0'))
- Decimal("0")
+ Decimal('0')
>>> ExtendedContext.logical_xor(Decimal('0'), Decimal('1'))
- Decimal("1")
+ Decimal('1')
>>> ExtendedContext.logical_xor(Decimal('1'), Decimal('0'))
- Decimal("1")
+ Decimal('1')
>>> ExtendedContext.logical_xor(Decimal('1'), Decimal('1'))
- Decimal("0")
+ Decimal('0')
>>> ExtendedContext.logical_xor(Decimal('1100'), Decimal('1010'))
- Decimal("110")
+ Decimal('110')
>>> ExtendedContext.logical_xor(Decimal('1111'), Decimal('10'))
- Decimal("1101")
+ Decimal('1101')
+ >>> ExtendedContext.logical_xor(110, 1101)
+ Decimal('1011')
+ >>> ExtendedContext.logical_xor(Decimal(110), 1101)
+ Decimal('1011')
+ >>> ExtendedContext.logical_xor(110, Decimal(1101))
+ Decimal('1011')
"""
+ a = _convert_other(a, raiseit=True)
return a.logical_xor(b, context=self)
- def max(self, a,b):
+ def max(self, a, b):
"""max compares two values numerically and returns the maximum.
If either operand is a NaN then the general rules apply.
- Otherwise, the operands are compared as as though by the compare
+ Otherwise, the operands are compared as though by the compare
operation. If they are numerically equal then the left-hand operand
is chosen as the result. Otherwise the maximum (closer to positive
infinity) of the two operands is chosen as the result.
>>> ExtendedContext.max(Decimal('3'), Decimal('2'))
- Decimal("3")
+ Decimal('3')
>>> ExtendedContext.max(Decimal('-10'), Decimal('3'))
- Decimal("3")
+ Decimal('3')
>>> ExtendedContext.max(Decimal('1.0'), Decimal('1'))
- Decimal("1")
+ Decimal('1')
>>> ExtendedContext.max(Decimal('7'), Decimal('NaN'))
- Decimal("7")
+ Decimal('7')
+ >>> ExtendedContext.max(1, 2)
+ Decimal('2')
+ >>> ExtendedContext.max(Decimal(1), 2)
+ Decimal('2')
+ >>> ExtendedContext.max(1, Decimal(2))
+ Decimal('2')
"""
+ a = _convert_other(a, raiseit=True)
return a.max(b, context=self)
def max_mag(self, a, b):
- """Compares the values numerically with their sign ignored."""
+ """Compares the values numerically with their sign ignored.
+
+ >>> ExtendedContext.max_mag(Decimal('7'), Decimal('NaN'))
+ Decimal('7')
+ >>> ExtendedContext.max_mag(Decimal('7'), Decimal('-10'))
+ Decimal('-10')
+ >>> ExtendedContext.max_mag(1, -2)
+ Decimal('-2')
+ >>> ExtendedContext.max_mag(Decimal(1), -2)
+ Decimal('-2')
+ >>> ExtendedContext.max_mag(1, Decimal(-2))
+ Decimal('-2')
+ """
+ a = _convert_other(a, raiseit=True)
return a.max_mag(b, context=self)
- def min(self, a,b):
+ def min(self, a, b):
"""min compares two values numerically and returns the minimum.
If either operand is a NaN then the general rules apply.
- Otherwise, the operands are compared as as though by the compare
+ Otherwise, the operands are compared as though by the compare
operation. If they are numerically equal then the left-hand operand
is chosen as the result. Otherwise the minimum (closer to negative
infinity) of the two operands is chosen as the result.
>>> ExtendedContext.min(Decimal('3'), Decimal('2'))
- Decimal("2")
+ Decimal('2')
>>> ExtendedContext.min(Decimal('-10'), Decimal('3'))
- Decimal("-10")
+ Decimal('-10')
>>> ExtendedContext.min(Decimal('1.0'), Decimal('1'))
- Decimal("1.0")
+ Decimal('1.0')
>>> ExtendedContext.min(Decimal('7'), Decimal('NaN'))
- Decimal("7")
+ Decimal('7')
+ >>> ExtendedContext.min(1, 2)
+ Decimal('1')
+ >>> ExtendedContext.min(Decimal(1), 2)
+ Decimal('1')
+ >>> ExtendedContext.min(1, Decimal(29))
+ Decimal('1')
"""
+ a = _convert_other(a, raiseit=True)
return a.min(b, context=self)
def min_mag(self, a, b):
- """Compares the values numerically with their sign ignored."""
+ """Compares the values numerically with their sign ignored.
+
+ >>> ExtendedContext.min_mag(Decimal('3'), Decimal('-2'))
+ Decimal('-2')
+ >>> ExtendedContext.min_mag(Decimal('-3'), Decimal('NaN'))
+ Decimal('-3')
+ >>> ExtendedContext.min_mag(1, -2)
+ Decimal('1')
+ >>> ExtendedContext.min_mag(Decimal(1), -2)
+ Decimal('1')
+ >>> ExtendedContext.min_mag(1, Decimal(-2))
+ Decimal('1')
+ """
+ a = _convert_other(a, raiseit=True)
return a.min_mag(b, context=self)
def minus(self, a):
@@ -4127,32 +4727,46 @@
has the same exponent as the operand.
>>> ExtendedContext.minus(Decimal('1.3'))
- Decimal("-1.3")
+ Decimal('-1.3')
>>> ExtendedContext.minus(Decimal('-1.3'))
- Decimal("1.3")
+ Decimal('1.3')
+ >>> ExtendedContext.minus(1)
+ Decimal('-1')
"""
+ a = _convert_other(a, raiseit=True)
return a.__neg__(context=self)
def multiply(self, a, b):
"""multiply multiplies two operands.
If either operand is a special value then the general rules apply.
- Otherwise, the operands are multiplied together ('long multiplication'),
- resulting in a number which may be as long as the sum of the lengths
- of the two operands.
+ Otherwise, the operands are multiplied together
+ ('long multiplication'), resulting in a number which may be as long as
+ the sum of the lengths of the two operands.
>>> ExtendedContext.multiply(Decimal('1.20'), Decimal('3'))
- Decimal("3.60")
+ Decimal('3.60')
>>> ExtendedContext.multiply(Decimal('7'), Decimal('3'))
- Decimal("21")
+ Decimal('21')
>>> ExtendedContext.multiply(Decimal('0.9'), Decimal('0.8'))
- Decimal("0.72")
+ Decimal('0.72')
>>> ExtendedContext.multiply(Decimal('0.9'), Decimal('-0'))
- Decimal("-0.0")
+ Decimal('-0.0')
>>> ExtendedContext.multiply(Decimal('654321'), Decimal('654321'))
- Decimal("4.28135971E+11")
+ Decimal('4.28135971E+11')
+ >>> ExtendedContext.multiply(7, 7)
+ Decimal('49')
+ >>> ExtendedContext.multiply(Decimal(7), 7)
+ Decimal('49')
+ >>> ExtendedContext.multiply(7, Decimal(7))
+ Decimal('49')
"""
- return a.__mul__(b, context=self)
+ a = _convert_other(a, raiseit=True)
+ r = a.__mul__(b, context=self)
+ if r is NotImplemented:
+ raise TypeError("Unable to convert %s to Decimal" % b)
+ else:
+ return r
def next_minus(self, a):
"""Returns the largest representable number smaller than a.
@@ -4161,14 +4775,17 @@
>>> c.Emin = -999
>>> c.Emax = 999
>>> ExtendedContext.next_minus(Decimal('1'))
- Decimal("0.999999999")
+ Decimal('0.999999999')
>>> c.next_minus(Decimal('1E-1007'))
- Decimal("0E-1007")
+ Decimal('0E-1007')
>>> ExtendedContext.next_minus(Decimal('-1.00000003'))
- Decimal("-1.00000004")
+ Decimal('-1.00000004')
>>> c.next_minus(Decimal('Infinity'))
- Decimal("9.99999999E+999")
+ Decimal('9.99999999E+999')
+ >>> c.next_minus(1)
+ Decimal('0.999999999')
"""
+ a = _convert_other(a, raiseit=True)
return a.next_minus(context=self)
def next_plus(self, a):
@@ -4178,14 +4795,17 @@
>>> c.Emin = -999
>>> c.Emax = 999
>>> ExtendedContext.next_plus(Decimal('1'))
- Decimal("1.00000001")
+ Decimal('1.00000001')
>>> c.next_plus(Decimal('-1E-1007'))
- Decimal("-0E-1007")
+ Decimal('-0E-1007')
>>> ExtendedContext.next_plus(Decimal('-1.00000003'))
- Decimal("-1.00000002")
+ Decimal('-1.00000002')
>>> c.next_plus(Decimal('-Infinity'))
- Decimal("-9.99999999E+999")
+ Decimal('-9.99999999E+999')
+ >>> c.next_plus(1)
+ Decimal('1.00000001')
"""
+ a = _convert_other(a, raiseit=True)
return a.next_plus(context=self)
def next_toward(self, a, b):
@@ -4200,20 +4820,27 @@
>>> c.Emin = -999
>>> c.Emax = 999
>>> c.next_toward(Decimal('1'), Decimal('2'))
- Decimal("1.00000001")
+ Decimal('1.00000001')
>>> c.next_toward(Decimal('-1E-1007'), Decimal('1'))
- Decimal("-0E-1007")
+ Decimal('-0E-1007')
>>> c.next_toward(Decimal('-1.00000003'), Decimal('0'))
- Decimal("-1.00000002")
+ Decimal('-1.00000002')
>>> c.next_toward(Decimal('1'), Decimal('0'))
- Decimal("0.999999999")
+ Decimal('0.999999999')
>>> c.next_toward(Decimal('1E-1007'), Decimal('-100'))
- Decimal("0E-1007")
+ Decimal('0E-1007')
>>> c.next_toward(Decimal('-1.00000003'), Decimal('-10'))
- Decimal("-1.00000004")
+ Decimal('-1.00000004')
>>> c.next_toward(Decimal('0.00'), Decimal('-0.0000'))
- Decimal("-0.00")
+ Decimal('-0.00')
+ >>> c.next_toward(0, 1)
+ Decimal('1E-1007')
+ >>> c.next_toward(Decimal(0), 1)
+ Decimal('1E-1007')
+ >>> c.next_toward(0, Decimal(1))
+ Decimal('1E-1007')
"""
+ a = _convert_other(a, raiseit=True)
return a.next_toward(b, context=self)
def normalize(self, a):
@@ -4223,18 +4850,21 @@
result.
>>> ExtendedContext.normalize(Decimal('2.1'))
- Decimal("2.1")
+ Decimal('2.1')
>>> ExtendedContext.normalize(Decimal('-2.0'))
- Decimal("-2")
+ Decimal('-2')
>>> ExtendedContext.normalize(Decimal('1.200'))
- Decimal("1.2")
+ Decimal('1.2')
>>> ExtendedContext.normalize(Decimal('-120'))
- Decimal("-1.2E+2")
+ Decimal('-1.2E+2')
>>> ExtendedContext.normalize(Decimal('120.00'))
- Decimal("1.2E+2")
+ Decimal('1.2E+2')
>>> ExtendedContext.normalize(Decimal('0.00'))
- Decimal("0")
+ Decimal('0')
+ >>> ExtendedContext.normalize(6)
+ Decimal('6')
"""
+ a = _convert_other(a, raiseit=True)
return a.normalize(context=self)
def number_class(self, a):
@@ -4281,7 +4911,10 @@
'NaN'
>>> c.number_class(Decimal('sNaN'))
'sNaN'
+ >>> c.number_class(123)
+ '+Normal'
"""
+ a = _convert_other(a, raiseit=True)
return a.number_class(context=self)
def plus(self, a):
@@ -4292,10 +4925,13 @@
has the same exponent as the operand.
>>> ExtendedContext.plus(Decimal('1.3'))
- Decimal("1.3")
+ Decimal('1.3')
>>> ExtendedContext.plus(Decimal('-1.3'))
- Decimal("-1.3")
+ Decimal('-1.3')
+ >>> ExtendedContext.plus(-1)
+ Decimal('-1')
"""
+ a = _convert_other(a, raiseit=True)
return a.__pos__(context=self)
def power(self, a, b, modulo=None):
@@ -4324,48 +4960,59 @@
>>> c.Emin = -999
>>> c.Emax = 999
>>> c.power(Decimal('2'), Decimal('3'))
- Decimal("8")
+ Decimal('8')
>>> c.power(Decimal('-2'), Decimal('3'))
- Decimal("-8")
+ Decimal('-8')
>>> c.power(Decimal('2'), Decimal('-3'))
- Decimal("0.125")
+ Decimal('0.125')
>>> c.power(Decimal('1.7'), Decimal('8'))
- Decimal("69.7575744")
+ Decimal('69.7575744')
>>> c.power(Decimal('10'), Decimal('0.301029996'))
- Decimal("2.00000000")
+ Decimal('2.00000000')
>>> c.power(Decimal('Infinity'), Decimal('-1'))
- Decimal("0")
+ Decimal('0')
>>> c.power(Decimal('Infinity'), Decimal('0'))
- Decimal("1")
+ Decimal('1')
>>> c.power(Decimal('Infinity'), Decimal('1'))
- Decimal("Infinity")
+ Decimal('Infinity')
>>> c.power(Decimal('-Infinity'), Decimal('-1'))
- Decimal("-0")
+ Decimal('-0')
>>> c.power(Decimal('-Infinity'), Decimal('0'))
- Decimal("1")
+ Decimal('1')
>>> c.power(Decimal('-Infinity'), Decimal('1'))
- Decimal("-Infinity")
+ Decimal('-Infinity')
>>> c.power(Decimal('-Infinity'), Decimal('2'))
- Decimal("Infinity")
+ Decimal('Infinity')
>>> c.power(Decimal('0'), Decimal('0'))
- Decimal("NaN")
+ Decimal('NaN')
>>> c.power(Decimal('3'), Decimal('7'), Decimal('16'))
- Decimal("11")
+ Decimal('11')
>>> c.power(Decimal('-3'), Decimal('7'), Decimal('16'))
- Decimal("-11")
+ Decimal('-11')
>>> c.power(Decimal('-3'), Decimal('8'), Decimal('16'))
- Decimal("1")
+ Decimal('1')
>>> c.power(Decimal('3'), Decimal('7'), Decimal('-16'))
- Decimal("11")
+ Decimal('11')
>>> c.power(Decimal('23E12345'), Decimal('67E189'), Decimal('123456789'))
- Decimal("11729830")
+ Decimal('11729830')
>>> c.power(Decimal('-0'), Decimal('17'), Decimal('1729'))
- Decimal("-0")
+ Decimal('-0')
>>> c.power(Decimal('-23'), Decimal('0'), Decimal('65537'))
- Decimal("1")
+ Decimal('1')
+ >>> ExtendedContext.power(7, 7)
+ Decimal('823543')
+ >>> ExtendedContext.power(Decimal(7), 7)
+ Decimal('823543')
+ >>> ExtendedContext.power(7, Decimal(7), 2)
+ Decimal('1')
"""
- return a.__pow__(b, modulo, context=self)
+ a = _convert_other(a, raiseit=True)
+ r = a.__pow__(b, modulo, context=self)
+ if r is NotImplemented:
+ raise TypeError("Unable to convert %s to Decimal" % b)
+ else:
+ return r
def quantize(self, a, b):
"""Returns a value equal to 'a' (rounded), having the exponent of 'b'.
@@ -4386,43 +5033,50 @@
if the result is subnormal and inexact.
>>> ExtendedContext.quantize(Decimal('2.17'), Decimal('0.001'))
- Decimal("2.170")
+ Decimal('2.170')
>>> ExtendedContext.quantize(Decimal('2.17'), Decimal('0.01'))
- Decimal("2.17")
+ Decimal('2.17')
>>> ExtendedContext.quantize(Decimal('2.17'), Decimal('0.1'))
- Decimal("2.2")
+ Decimal('2.2')
>>> ExtendedContext.quantize(Decimal('2.17'), Decimal('1e+0'))
- Decimal("2")
+ Decimal('2')
>>> ExtendedContext.quantize(Decimal('2.17'), Decimal('1e+1'))
- Decimal("0E+1")
+ Decimal('0E+1')
>>> ExtendedContext.quantize(Decimal('-Inf'), Decimal('Infinity'))
- Decimal("-Infinity")
+ Decimal('-Infinity')
>>> ExtendedContext.quantize(Decimal('2'), Decimal('Infinity'))
- Decimal("NaN")
+ Decimal('NaN')
>>> ExtendedContext.quantize(Decimal('-0.1'), Decimal('1'))
- Decimal("-0")
+ Decimal('-0')
>>> ExtendedContext.quantize(Decimal('-0'), Decimal('1e+5'))
- Decimal("-0E+5")
+ Decimal('-0E+5')
>>> ExtendedContext.quantize(Decimal('+35236450.6'), Decimal('1e-2'))
- Decimal("NaN")
+ Decimal('NaN')
>>> ExtendedContext.quantize(Decimal('-35236450.6'), Decimal('1e-2'))
- Decimal("NaN")
+ Decimal('NaN')
>>> ExtendedContext.quantize(Decimal('217'), Decimal('1e-1'))
- Decimal("217.0")
+ Decimal('217.0')
>>> ExtendedContext.quantize(Decimal('217'), Decimal('1e-0'))
- Decimal("217")
+ Decimal('217')
>>> ExtendedContext.quantize(Decimal('217'), Decimal('1e+1'))
- Decimal("2.2E+2")
+ Decimal('2.2E+2')
>>> ExtendedContext.quantize(Decimal('217'), Decimal('1e+2'))
- Decimal("2E+2")
+ Decimal('2E+2')
+ >>> ExtendedContext.quantize(1, 2)
+ Decimal('1')
+ >>> ExtendedContext.quantize(Decimal(1), 2)
+ Decimal('1')
+ >>> ExtendedContext.quantize(1, Decimal(2))
+ Decimal('1')
"""
+ a = _convert_other(a, raiseit=True)
return a.quantize(b, context=self)
def radix(self):
"""Just returns 10, as this is Decimal, :)
>>> ExtendedContext.radix()
- Decimal("10")
+ Decimal('10')
"""
return Decimal(10)
@@ -4439,19 +5093,30 @@
remainder cannot be calculated).
>>> ExtendedContext.remainder(Decimal('2.1'), Decimal('3'))
- Decimal("2.1")
+ Decimal('2.1')
>>> ExtendedContext.remainder(Decimal('10'), Decimal('3'))
- Decimal("1")
+ Decimal('1')
>>> ExtendedContext.remainder(Decimal('-10'), Decimal('3'))
- Decimal("-1")
+ Decimal('-1')
>>> ExtendedContext.remainder(Decimal('10.2'), Decimal('1'))
- Decimal("0.2")
+ Decimal('0.2')
>>> ExtendedContext.remainder(Decimal('10'), Decimal('0.3'))
- Decimal("0.1")
+ Decimal('0.1')
>>> ExtendedContext.remainder(Decimal('3.6'), Decimal('1.3'))
- Decimal("1.0")
+ Decimal('1.0')
+ >>> ExtendedContext.remainder(22, 6)
+ Decimal('4')
+ >>> ExtendedContext.remainder(Decimal(22), 6)
+ Decimal('4')
+ >>> ExtendedContext.remainder(22, Decimal(6))
+ Decimal('4')
"""
- return a.__mod__(b, context=self)
+ a = _convert_other(a, raiseit=True)
+ r = a.__mod__(b, context=self)
+ if r is NotImplemented:
+ raise TypeError("Unable to convert %s to Decimal" % b)
+ else:
+ return r
def remainder_near(self, a, b):
"""Returns to be "a - b * n", where n is the integer nearest the exact
@@ -4464,20 +5129,27 @@
remainder cannot be calculated).
>>> ExtendedContext.remainder_near(Decimal('2.1'), Decimal('3'))
- Decimal("-0.9")
+ Decimal('-0.9')
>>> ExtendedContext.remainder_near(Decimal('10'), Decimal('6'))
- Decimal("-2")
+ Decimal('-2')
>>> ExtendedContext.remainder_near(Decimal('10'), Decimal('3'))
- Decimal("1")
+ Decimal('1')
>>> ExtendedContext.remainder_near(Decimal('-10'), Decimal('3'))
- Decimal("-1")
+ Decimal('-1')
>>> ExtendedContext.remainder_near(Decimal('10.2'), Decimal('1'))
- Decimal("0.2")
+ Decimal('0.2')
>>> ExtendedContext.remainder_near(Decimal('10'), Decimal('0.3'))
- Decimal("0.1")
+ Decimal('0.1')
>>> ExtendedContext.remainder_near(Decimal('3.6'), Decimal('1.3'))
- Decimal("-0.3")
+ Decimal('-0.3')
+ >>> ExtendedContext.remainder_near(3, 11)
+ Decimal('3')
+ >>> ExtendedContext.remainder_near(Decimal(3), 11)
+ Decimal('3')
+ >>> ExtendedContext.remainder_near(3, Decimal(11))
+ Decimal('3')
"""
+ a = _convert_other(a, raiseit=True)
return a.remainder_near(b, context=self)
def rotate(self, a, b):
@@ -4490,16 +5162,23 @@
positive or to the right otherwise.
>>> ExtendedContext.rotate(Decimal('34'), Decimal('8'))
- Decimal("400000003")
+ Decimal('400000003')
>>> ExtendedContext.rotate(Decimal('12'), Decimal('9'))
- Decimal("12")
+ Decimal('12')
>>> ExtendedContext.rotate(Decimal('123456789'), Decimal('-2'))
- Decimal("891234567")
+ Decimal('891234567')
>>> ExtendedContext.rotate(Decimal('123456789'), Decimal('0'))
- Decimal("123456789")
+ Decimal('123456789')
>>> ExtendedContext.rotate(Decimal('123456789'), Decimal('+2'))
- Decimal("345678912")
+ Decimal('345678912')
+ >>> ExtendedContext.rotate(1333333, 1)
+ Decimal('13333330')
+ >>> ExtendedContext.rotate(Decimal(1333333), 1)
+ Decimal('13333330')
+ >>> ExtendedContext.rotate(1333333, Decimal(1))
+ Decimal('13333330')
"""
+ a = _convert_other(a, raiseit=True)
return a.rotate(b, context=self)
def same_quantum(self, a, b):
@@ -4516,20 +5195,34 @@
False
>>> ExtendedContext.same_quantum(Decimal('Inf'), Decimal('-Inf'))
True
+ >>> ExtendedContext.same_quantum(10000, -1)
+ True
+ >>> ExtendedContext.same_quantum(Decimal(10000), -1)
+ True
+ >>> ExtendedContext.same_quantum(10000, Decimal(-1))
+ True
"""
+ a = _convert_other(a, raiseit=True)
return a.same_quantum(b)
def scaleb (self, a, b):
"""Returns the first operand after adding the second value its exp.
>>> ExtendedContext.scaleb(Decimal('7.50'), Decimal('-2'))
- Decimal("0.0750")
+ Decimal('0.0750')
>>> ExtendedContext.scaleb(Decimal('7.50'), Decimal('0'))
- Decimal("7.50")
+ Decimal('7.50')
>>> ExtendedContext.scaleb(Decimal('7.50'), Decimal('3'))
- Decimal("7.50E+3")
+ Decimal('7.50E+3')
+ >>> ExtendedContext.scaleb(1, 4)
+ Decimal('1E+4')
+ >>> ExtendedContext.scaleb(Decimal(1), 4)
+ Decimal('1E+4')
+ >>> ExtendedContext.scaleb(1, Decimal(4))
+ Decimal('1E+4')
"""
- return a.scaleb (b, context=self)
+ a = _convert_other(a, raiseit=True)
+ return a.scaleb(b, context=self)
def shift(self, a, b):
"""Returns a shifted copy of a, b times.
@@ -4542,16 +5235,23 @@
coefficient are zeros.
>>> ExtendedContext.shift(Decimal('34'), Decimal('8'))
- Decimal("400000000")
+ Decimal('400000000')
>>> ExtendedContext.shift(Decimal('12'), Decimal('9'))
- Decimal("0")
+ Decimal('0')
>>> ExtendedContext.shift(Decimal('123456789'), Decimal('-2'))
- Decimal("1234567")
+ Decimal('1234567')
>>> ExtendedContext.shift(Decimal('123456789'), Decimal('0'))
- Decimal("123456789")
+ Decimal('123456789')
>>> ExtendedContext.shift(Decimal('123456789'), Decimal('+2'))
- Decimal("345678900")
+ Decimal('345678900')
+ >>> ExtendedContext.shift(88888888, 2)
+ Decimal('888888800')
+ >>> ExtendedContext.shift(Decimal(88888888), 2)
+ Decimal('888888800')
+ >>> ExtendedContext.shift(88888888, Decimal(2))
+ Decimal('888888800')
"""
+ a = _convert_other(a, raiseit=True)
return a.shift(b, context=self)
def sqrt(self, a):
@@ -4561,45 +5261,60 @@
algorithm.
>>> ExtendedContext.sqrt(Decimal('0'))
- Decimal("0")
+ Decimal('0')
>>> ExtendedContext.sqrt(Decimal('-0'))
- Decimal("-0")
+ Decimal('-0')
>>> ExtendedContext.sqrt(Decimal('0.39'))
- Decimal("0.624499800")
+ Decimal('0.624499800')
>>> ExtendedContext.sqrt(Decimal('100'))
- Decimal("10")
+ Decimal('10')
>>> ExtendedContext.sqrt(Decimal('1'))
- Decimal("1")
+ Decimal('1')
>>> ExtendedContext.sqrt(Decimal('1.0'))
- Decimal("1.0")
+ Decimal('1.0')
>>> ExtendedContext.sqrt(Decimal('1.00'))
- Decimal("1.0")
+ Decimal('1.0')
>>> ExtendedContext.sqrt(Decimal('7'))
- Decimal("2.64575131")
+ Decimal('2.64575131')
>>> ExtendedContext.sqrt(Decimal('10'))
- Decimal("3.16227766")
+ Decimal('3.16227766')
+ >>> ExtendedContext.sqrt(2)
+ Decimal('1.41421356')
>>> ExtendedContext.prec
9
"""
+ a = _convert_other(a, raiseit=True)
return a.sqrt(context=self)
def subtract(self, a, b):
"""Return the difference between the two operands.
>>> ExtendedContext.subtract(Decimal('1.3'), Decimal('1.07'))
- Decimal("0.23")
+ Decimal('0.23')
>>> ExtendedContext.subtract(Decimal('1.3'), Decimal('1.30'))
- Decimal("0.00")
+ Decimal('0.00')
>>> ExtendedContext.subtract(Decimal('1.3'), Decimal('2.07'))
- Decimal("-0.77")
+ Decimal('-0.77')
+ >>> ExtendedContext.subtract(8, 5)
+ Decimal('3')
+ >>> ExtendedContext.subtract(Decimal(8), 5)
+ Decimal('3')
+ >>> ExtendedContext.subtract(8, Decimal(5))
+ Decimal('3')
"""
- return a.__sub__(b, context=self)
+ a = _convert_other(a, raiseit=True)
+ r = a.__sub__(b, context=self)
+ if r is NotImplemented:
+ raise TypeError("Unable to convert %s to Decimal" % b)
+ else:
+ return r
def to_eng_string(self, a):
"""Converts a number to a string, using scientific notation.
The operation is not affected by the context.
"""
+ a = _convert_other(a, raiseit=True)
return a.to_eng_string(context=self)
def to_sci_string(self, a):
@@ -4607,6 +5322,7 @@
The operation is not affected by the context.
"""
+ a = _convert_other(a, raiseit=True)
return a.__str__(context=self)
def to_integral_exact(self, a):
@@ -4620,22 +5336,23 @@
context.
>>> ExtendedContext.to_integral_exact(Decimal('2.1'))
- Decimal("2")
+ Decimal('2')
>>> ExtendedContext.to_integral_exact(Decimal('100'))
- Decimal("100")
+ Decimal('100')
>>> ExtendedContext.to_integral_exact(Decimal('100.0'))
- Decimal("100")
+ Decimal('100')
>>> ExtendedContext.to_integral_exact(Decimal('101.5'))
- Decimal("102")
+ Decimal('102')
>>> ExtendedContext.to_integral_exact(Decimal('-101.5'))
- Decimal("-102")
+ Decimal('-102')
>>> ExtendedContext.to_integral_exact(Decimal('10E+5'))
- Decimal("1.0E+6")
+ Decimal('1.0E+6')
>>> ExtendedContext.to_integral_exact(Decimal('7.89E+77'))
- Decimal("7.89E+77")
+ Decimal('7.89E+77')
>>> ExtendedContext.to_integral_exact(Decimal('-Inf'))
- Decimal("-Infinity")
+ Decimal('-Infinity')
"""
+ a = _convert_other(a, raiseit=True)
return a.to_integral_exact(context=self)
def to_integral_value(self, a):
@@ -4648,22 +5365,23 @@
be set. The rounding mode is taken from the context.
>>> ExtendedContext.to_integral_value(Decimal('2.1'))
- Decimal("2")
+ Decimal('2')
>>> ExtendedContext.to_integral_value(Decimal('100'))
- Decimal("100")
+ Decimal('100')
>>> ExtendedContext.to_integral_value(Decimal('100.0'))
- Decimal("100")
+ Decimal('100')
>>> ExtendedContext.to_integral_value(Decimal('101.5'))
- Decimal("102")
+ Decimal('102')
>>> ExtendedContext.to_integral_value(Decimal('-101.5'))
- Decimal("-102")
+ Decimal('-102')
>>> ExtendedContext.to_integral_value(Decimal('10E+5'))
- Decimal("1.0E+6")
+ Decimal('1.0E+6')
>>> ExtendedContext.to_integral_value(Decimal('7.89E+77'))
- Decimal("7.89E+77")
+ Decimal('7.89E+77')
>>> ExtendedContext.to_integral_value(Decimal('-Inf'))
- Decimal("-Infinity")
+ Decimal('-Infinity')
"""
+ a = _convert_other(a, raiseit=True)
return a.to_integral_value(context=self)
# the method name changed, but we provide also the old one, for compatibility
@@ -4854,7 +5572,7 @@
log_tenpower = f*M # exact
else:
log_d = 0 # error < 2.31
- log_tenpower = div_nearest(f, 10**-p) # error < 0.5
+ log_tenpower = _div_nearest(f, 10**-p) # error < 0.5
return _div_nearest(log_tenpower+log_d, 100)
@@ -5065,15 +5783,21 @@
##### Helper Functions ####################################################
-def _convert_other(other, raiseit=False):
+def _convert_other(other, raiseit=False, allow_float=False):
"""Convert other to Decimal.
Verifies that it's ok to use in an implicit construction.
+ If allow_float is true, allow conversion from float; this
+ is used in the comparison methods (__eq__ and friends).
+
"""
if isinstance(other, Decimal):
return other
if isinstance(other, (int, long)):
return Decimal(other)
+ if allow_float and isinstance(other, float):
+ return Decimal.from_float(other)
+
if raiseit:
raise TypeError("Unable to convert %s to Decimal" % other)
return NotImplemented
@@ -5111,8 +5835,7 @@
##### crud for parsing strings #############################################
-import re
-
+#
# Regular expression used for parsing numeric strings. Additional
# comments:
#
@@ -5124,47 +5847,300 @@
# number between the optional sign and the optional exponent must have
# at least one decimal digit, possibly after the decimal point. The
# lookahead expression '(?=\d|\.\d)' checks this.
-#
-# As the flag UNICODE is not enabled here, we're explicitly avoiding any
-# other meaning for \d than the numbers [0-9].
import re
-_parser = re.compile(r""" # A numeric string consists of:
+_parser = re.compile(r""" # A numeric string consists of:
# \s*
- (?P<sign>[-+])? # an optional sign, followed by either...
+ (?P<sign>[-+])? # an optional sign, followed by either...
(
- (?=\d|\.\d) # ...a number (with at least one digit)
- (?P<int>\d*) # consisting of a (possibly empty) integer part
- (\.(?P<frac>\d*))? # followed by an optional fractional part
- (E(?P<exp>[-+]?\d+))? # followed by an optional exponent, or...
+ (?=\d|\.\d) # ...a number (with at least one digit)
+ (?P<int>\d*) # having a (possibly empty) integer part
+ (\.(?P<frac>\d*))? # followed by an optional fractional part
+ (E(?P<exp>[-+]?\d+))? # followed by an optional exponent, or...
|
- Inf(inity)? # ...an infinity, or...
+ Inf(inity)? # ...an infinity, or...
|
- (?P<signal>s)? # ...an (optionally signaling)
- NaN # NaN
- (?P<diag>\d*) # with (possibly empty) diagnostic information.
+ (?P<signal>s)? # ...an (optionally signaling)
+ NaN # NaN
+ (?P<diag>\d*) # with (possibly empty) diagnostic info.
)
# \s*
- $
-""", re.VERBOSE | re.IGNORECASE).match
+ \Z
+""", re.VERBOSE | re.IGNORECASE | re.UNICODE).match
_all_zeros = re.compile('0*$').match
_exact_half = re.compile('50*$').match
+
+##### PEP3101 support functions ##############################################
+# The functions in this section have little to do with the Decimal
+# class, and could potentially be reused or adapted for other pure
+# Python numeric classes that want to implement __format__
+#
+# A format specifier for Decimal looks like:
+#
+# [[fill]align][sign][0][minimumwidth][,][.precision][type]
+
+_parse_format_specifier_regex = re.compile(r"""\A
+(?:
+ (?P<fill>.)?
+ (?P<align>[<>=^])
+)?
+(?P<sign>[-+ ])?
+(?P<zeropad>0)?
+(?P<minimumwidth>(?!0)\d+)?
+(?P<thousands_sep>,)?
+(?:\.(?P<precision>0|(?!0)\d+))?
+(?P<type>[eEfFgGn%])?
+\Z
+""", re.VERBOSE)
+
del re
+# The locale module is only needed for the 'n' format specifier. The
+# rest of the PEP 3101 code functions quite happily without it, so we
+# don't care too much if locale isn't present.
+try:
+ import locale as _locale
+except ImportError:
+ pass
+
+def _parse_format_specifier(format_spec, _localeconv=None):
+ """Parse and validate a format specifier.
+
+ Turns a standard numeric format specifier into a dict, with the
+ following entries:
+
+ fill: fill character to pad field to minimum width
+ align: alignment type, either '<', '>', '=' or '^'
+ sign: either '+', '-' or ' '
+ minimumwidth: nonnegative integer giving minimum width
+ zeropad: boolean, indicating whether to pad with zeros
+ thousands_sep: string to use as thousands separator, or ''
+ grouping: grouping for thousands separators, in format
+ used by localeconv
+ decimal_point: string to use for decimal point
+ precision: nonnegative integer giving precision, or None
+ type: one of the characters 'eEfFgG%', or None
+ unicode: boolean (always True for Python 3.x)
+
+ """
+ m = _parse_format_specifier_regex.match(format_spec)
+ if m is None:
+ raise ValueError("Invalid format specifier: " + format_spec)
+
+ # get the dictionary
+ format_dict = m.groupdict()
+
+ # zeropad; defaults for fill and alignment. If zero padding
+ # is requested, the fill and align fields should be absent.
+ fill = format_dict['fill']
+ align = format_dict['align']
+ format_dict['zeropad'] = (format_dict['zeropad'] is not None)
+ if format_dict['zeropad']:
+ if fill is not None:
+ raise ValueError("Fill character conflicts with '0'"
+ " in format specifier: " + format_spec)
+ if align is not None:
+ raise ValueError("Alignment conflicts with '0' in "
+ "format specifier: " + format_spec)
+ format_dict['fill'] = fill or ' '
+ # PEP 3101 originally specified that the default alignment should
+ # be left; it was later agreed that right-aligned makes more sense
+ # for numeric types. See http://bugs.python.org/issue6857.
+ format_dict['align'] = align or '>'
+
+ # default sign handling: '-' for negative, '' for positive
+ if format_dict['sign'] is None:
+ format_dict['sign'] = '-'
+
+ # minimumwidth defaults to 0; precision remains None if not given
+ format_dict['minimumwidth'] = int(format_dict['minimumwidth'] or '0')
+ if format_dict['precision'] is not None:
+ format_dict['precision'] = int(format_dict['precision'])
+
+ # if format type is 'g' or 'G' then a precision of 0 makes little
+ # sense; convert it to 1. Same if format type is unspecified.
+ if format_dict['precision'] == 0:
+ if format_dict['type'] is None or format_dict['type'] in 'gG':
+ format_dict['precision'] = 1
+
+ # determine thousands separator, grouping, and decimal separator, and
+ # add appropriate entries to format_dict
+ if format_dict['type'] == 'n':
+ # apart from separators, 'n' behaves just like 'g'
+ format_dict['type'] = 'g'
+ if _localeconv is None:
+ _localeconv = _locale.localeconv()
+ if format_dict['thousands_sep'] is not None:
+ raise ValueError("Explicit thousands separator conflicts with "
+ "'n' type in format specifier: " + format_spec)
+ format_dict['thousands_sep'] = _localeconv['thousands_sep']
+ format_dict['grouping'] = _localeconv['grouping']
+ format_dict['decimal_point'] = _localeconv['decimal_point']
+ else:
+ if format_dict['thousands_sep'] is None:
+ format_dict['thousands_sep'] = ''
+ format_dict['grouping'] = [3, 0]
+ format_dict['decimal_point'] = '.'
+
+ # record whether return type should be str or unicode
+ format_dict['unicode'] = isinstance(format_spec, unicode)
+
+ return format_dict
+
+def _format_align(sign, body, spec):
+ """Given an unpadded, non-aligned numeric string 'body' and sign
+ string 'sign', add padding and aligment conforming to the given
+ format specifier dictionary 'spec' (as produced by
+ parse_format_specifier).
+
+ Also converts result to unicode if necessary.
+
+ """
+ # how much extra space do we have to play with?
+ minimumwidth = spec['minimumwidth']
+ fill = spec['fill']
+ padding = fill*(minimumwidth - len(sign) - len(body))
+
+ align = spec['align']
+ if align == '<':
+ result = sign + body + padding
+ elif align == '>':
+ result = padding + sign + body
+ elif align == '=':
+ result = sign + padding + body
+ elif align == '^':
+ half = len(padding)//2
+ result = padding[:half] + sign + body + padding[half:]
+ else:
+ raise ValueError('Unrecognised alignment field')
+
+ # make sure that result is unicode if necessary
+ if spec['unicode']:
+ result = unicode(result)
+
+ return result
+
+def _group_lengths(grouping):
+ """Convert a localeconv-style grouping into a (possibly infinite)
+ iterable of integers representing group lengths.
+
+ """
+ # The result from localeconv()['grouping'], and the input to this
+ # function, should be a list of integers in one of the
+ # following three forms:
+ #
+ # (1) an empty list, or
+ # (2) nonempty list of positive integers + [0]
+ # (3) list of positive integers + [locale.CHAR_MAX], or
+
+ from itertools import chain, repeat
+ if not grouping:
+ return []
+ elif grouping[-1] == 0 and len(grouping) >= 2:
+ return chain(grouping[:-1], repeat(grouping[-2]))
+ elif grouping[-1] == _locale.CHAR_MAX:
+ return grouping[:-1]
+ else:
+ raise ValueError('unrecognised format for grouping')
+
+def _insert_thousands_sep(digits, spec, min_width=1):
+ """Insert thousands separators into a digit string.
+
+ spec is a dictionary whose keys should include 'thousands_sep' and
+ 'grouping'; typically it's the result of parsing the format
+ specifier using _parse_format_specifier.
+
+ The min_width keyword argument gives the minimum length of the
+ result, which will be padded on the left with zeros if necessary.
+
+ If necessary, the zero padding adds an extra '0' on the left to
+ avoid a leading thousands separator. For example, inserting
+ commas every three digits in '123456', with min_width=8, gives
+ '0,123,456', even though that has length 9.
+
+ """
+
+ sep = spec['thousands_sep']
+ grouping = spec['grouping']
+
+ groups = []
+ for l in _group_lengths(grouping):
+ if l <= 0:
+ raise ValueError("group length should be positive")
+ # max(..., 1) forces at least 1 digit to the left of a separator
+ l = min(max(len(digits), min_width, 1), l)
+ groups.append('0'*(l - len(digits)) + digits[-l:])
+ digits = digits[:-l]
+ min_width -= l
+ if not digits and min_width <= 0:
+ break
+ min_width -= len(sep)
+ else:
+ l = max(len(digits), min_width, 1)
+ groups.append('0'*(l - len(digits)) + digits[-l:])
+ return sep.join(reversed(groups))
+
+def _format_sign(is_negative, spec):
+ """Determine sign character."""
+
+ if is_negative:
+ return '-'
+ elif spec['sign'] in ' +':
+ return spec['sign']
+ else:
+ return ''
+
+def _format_number(is_negative, intpart, fracpart, exp, spec):
+ """Format a number, given the following data:
+
+ is_negative: true if the number is negative, else false
+ intpart: string of digits that must appear before the decimal point
+ fracpart: string of digits that must come after the point
+ exp: exponent, as an integer
+ spec: dictionary resulting from parsing the format specifier
+
+ This function uses the information in spec to:
+ insert separators (decimal separator and thousands separators)
+ format the sign
+ format the exponent
+ add trailing '%' for the '%' type
+ zero-pad if necessary
+ fill and align if necessary
+ """
+
+ sign = _format_sign(is_negative, spec)
+
+ if fracpart:
+ fracpart = spec['decimal_point'] + fracpart
+
+ if exp != 0 or spec['type'] in 'eE':
+ echar = {'E': 'E', 'e': 'e', 'G': 'E', 'g': 'e'}[spec['type']]
+ fracpart += "{0}{1:+}".format(echar, exp)
+ if spec['type'] == '%':
+ fracpart += '%'
+
+ if spec['zeropad']:
+ min_width = spec['minimumwidth'] - len(fracpart) - len(sign)
+ else:
+ min_width = 0
+ intpart = _insert_thousands_sep(intpart, spec, min_width)
+
+ return _format_align(sign, intpart+fracpart, spec)
+
##### Useful Constants (internal use only) ################################
# Reusable defaults
-Inf = Decimal('Inf')
-negInf = Decimal('-Inf')
-NaN = Decimal('NaN')
-Dec_0 = Decimal(0)
-Dec_p1 = Decimal(1)
-Dec_n1 = Decimal(-1)
-
-# Infsign[sign] is infinity w/ that sign
-Infsign = (Inf, negInf)
+_Infinity = Decimal('Inf')
+_NegativeInfinity = Decimal('-Inf')
+_NaN = Decimal('NaN')
+_Zero = Decimal(0)
+_One = Decimal(1)
+_NegativeOne = Decimal(-1)
+
+# _SignedInfinity[sign] is infinity w/ that sign
+_SignedInfinity = (_Infinity, _NegativeInfinity)
diff --git a/Lib/distutils/ccompiler.py b/Lib/distutils/ccompiler.py
--- a/Lib/distutils/ccompiler.py
+++ b/Lib/distutils/ccompiler.py
@@ -3,21 +3,73 @@
Contains CCompiler, an abstract base class that defines the interface
for the Distutils compiler abstraction model."""
-# This module should be kept compatible with Python 2.1.
+__revision__ = "$Id: ccompiler.py 86238 2010-11-06 04:06:18Z eric.araujo $"
-__revision__ = "$Id: ccompiler.py 46331 2006-05-26 14:07:23Z bob.ippolito $"
+import sys
+import os
+import re
-import sys, os, re
-from types import *
-from copy import copy
-from distutils.errors import *
+from distutils.errors import (CompileError, LinkError, UnknownFileError,
+ DistutilsPlatformError, DistutilsModuleError)
from distutils.spawn import spawn
from distutils.file_util import move_file
from distutils.dir_util import mkpath
-from distutils.dep_util import newer_pairwise, newer_group
+from distutils.dep_util import newer_group
from distutils.util import split_quoted, execute
from distutils import log
+_sysconfig = __import__('sysconfig')
+
+def customize_compiler(compiler):
+ """Do any platform-specific customization of a CCompiler instance.
+
+ Mainly needed on Unix, so we can plug in the information that
+ varies across Unices and is stored in Python's Makefile.
+ """
+ if compiler.compiler_type == "unix":
+ (cc, cxx, opt, cflags, ccshared, ldshared, so_ext, ar, ar_flags) = \
+ _sysconfig.get_config_vars('CC', 'CXX', 'OPT', 'CFLAGS',
+ 'CCSHARED', 'LDSHARED', 'SO', 'AR',
+ 'ARFLAGS')
+
+ if 'CC' in os.environ:
+ cc = os.environ['CC']
+ if 'CXX' in os.environ:
+ cxx = os.environ['CXX']
+ if 'LDSHARED' in os.environ:
+ ldshared = os.environ['LDSHARED']
+ if 'CPP' in os.environ:
+ cpp = os.environ['CPP']
+ else:
+ cpp = cc + " -E" # not always
+ if 'LDFLAGS' in os.environ:
+ ldshared = ldshared + ' ' + os.environ['LDFLAGS']
+ if 'CFLAGS' in os.environ:
+ cflags = opt + ' ' + os.environ['CFLAGS']
+ ldshared = ldshared + ' ' + os.environ['CFLAGS']
+ if 'CPPFLAGS' in os.environ:
+ cpp = cpp + ' ' + os.environ['CPPFLAGS']
+ cflags = cflags + ' ' + os.environ['CPPFLAGS']
+ ldshared = ldshared + ' ' + os.environ['CPPFLAGS']
+ if 'AR' in os.environ:
+ ar = os.environ['AR']
+ if 'ARFLAGS' in os.environ:
+ archiver = ar + ' ' + os.environ['ARFLAGS']
+ else:
+ archiver = ar + ' ' + ar_flags
+
+ cc_cmd = cc + ' ' + cflags
+ compiler.set_executables(
+ preprocessor=cpp,
+ compiler=cc_cmd,
+ compiler_so=cc_cmd + ' ' + ccshared,
+ compiler_cxx=cxx,
+ linker_so=ldshared,
+ linker_exe=cc,
+ archiver=archiver)
+
+ compiler.shared_lib_extension = so_ext
+
class CCompiler:
"""Abstract base class to define the interface that must be implemented
by real compiler classes. Also has some utility methods used by
@@ -88,11 +140,7 @@
}
language_order = ["c++", "objc", "c"]
- def __init__ (self,
- verbose=0,
- dry_run=0,
- force=0):
-
+ def __init__ (self, verbose=0, dry_run=0, force=0):
self.dry_run = dry_run
self.force = force
self.verbose = verbose
@@ -128,11 +176,7 @@
for key in self.executables.keys():
self.set_executable(key, self.executables[key])
- # __init__ ()
-
-
- def set_executables (self, **args):
-
+ def set_executables(self, **args):
"""Define the executables (and options for them) that will be run
to perform the various stages of compilation. The exact set of
executables that may be specified here depends on the compiler
@@ -159,42 +203,37 @@
# basically the same things with Unix C compilers.
for key in args.keys():
- if not self.executables.has_key(key):
+ if key not in self.executables:
raise ValueError, \
"unknown executable '%s' for class %s" % \
(key, self.__class__.__name__)
self.set_executable(key, args[key])
- # set_executables ()
-
def set_executable(self, key, value):
- if type(value) is StringType:
+ if isinstance(value, str):
setattr(self, key, split_quoted(value))
else:
setattr(self, key, value)
-
- def _find_macro (self, name):
+ def _find_macro(self, name):
i = 0
for defn in self.macros:
if defn[0] == name:
return i
i = i + 1
-
return None
-
- def _check_macro_definitions (self, definitions):
+ def _check_macro_definitions(self, definitions):
"""Ensures that every element of 'definitions' is a valid macro
definition, ie. either (name,value) 2-tuple or a (name,) tuple. Do
nothing if all definitions are OK, raise TypeError otherwise.
"""
for defn in definitions:
- if not (type (defn) is TupleType and
+ if not (isinstance(defn, tuple) and
(len (defn) == 1 or
(len (defn) == 2 and
- (type (defn[1]) is StringType or defn[1] is None))) and
- type (defn[0]) is StringType):
+ (isinstance(defn[1], str) or defn[1] is None))) and
+ isinstance(defn[0], str)):
raise TypeError, \
("invalid macro definition '%s': " % defn) + \
"must be tuple (string,), (string, string), or " + \
@@ -203,7 +242,7 @@
# -- Bookkeeping methods -------------------------------------------
- def define_macro (self, name, value=None):
+ def define_macro(self, name, value=None):
"""Define a preprocessor macro for all compilations driven by this
compiler object. The optional parameter 'value' should be a
string; if it is not supplied, then the macro will be defined
@@ -219,8 +258,7 @@
defn = (name, value)
self.macros.append (defn)
-
- def undefine_macro (self, name):
+ def undefine_macro(self, name):
"""Undefine a preprocessor macro for all compilations driven by
this compiler object. If the same macro is defined by
'define_macro()' and undefined by 'undefine_macro()' the last call
@@ -238,8 +276,7 @@
undefn = (name,)
self.macros.append (undefn)
-
- def add_include_dir (self, dir):
+ def add_include_dir(self, dir):
"""Add 'dir' to the list of directories that will be searched for
header files. The compiler is instructed to search directories in
the order in which they are supplied by successive calls to
@@ -247,7 +284,7 @@
"""
self.include_dirs.append (dir)
- def set_include_dirs (self, dirs):
+ def set_include_dirs(self, dirs):
"""Set the list of directories that will be searched to 'dirs' (a
list of strings). Overrides any preceding calls to
'add_include_dir()'; subsequence calls to 'add_include_dir()' add
@@ -255,10 +292,9 @@
any list of standard include directories that the compiler may
search by default.
"""
- self.include_dirs = copy (dirs)
+ self.include_dirs = dirs[:]
-
- def add_library (self, libname):
+ def add_library(self, libname):
"""Add 'libname' to the list of libraries that will be included in
all links driven by this compiler object. Note that 'libname'
should *not* be the name of a file containing a library, but the
@@ -274,61 +310,59 @@
"""
self.libraries.append (libname)
- def set_libraries (self, libnames):
+ def set_libraries(self, libnames):
"""Set the list of libraries to be included in all links driven by
this compiler object to 'libnames' (a list of strings). This does
not affect any standard system libraries that the linker may
include by default.
"""
- self.libraries = copy (libnames)
+ self.libraries = libnames[:]
- def add_library_dir (self, dir):
+ def add_library_dir(self, dir):
"""Add 'dir' to the list of directories that will be searched for
libraries specified to 'add_library()' and 'set_libraries()'. The
linker will be instructed to search for libraries in the order they
are supplied to 'add_library_dir()' and/or 'set_library_dirs()'.
"""
- self.library_dirs.append (dir)
+ self.library_dirs.append(dir)
- def set_library_dirs (self, dirs):
+ def set_library_dirs(self, dirs):
"""Set the list of library search directories to 'dirs' (a list of
strings). This does not affect any standard library search path
that the linker may search by default.
"""
- self.library_dirs = copy (dirs)
+ self.library_dirs = dirs[:]
-
- def add_runtime_library_dir (self, dir):
+ def add_runtime_library_dir(self, dir):
"""Add 'dir' to the list of directories that will be searched for
shared libraries at runtime.
"""
- self.runtime_library_dirs.append (dir)
+ self.runtime_library_dirs.append(dir)
- def set_runtime_library_dirs (self, dirs):
+ def set_runtime_library_dirs(self, dirs):
"""Set the list of directories to search for shared libraries at
runtime to 'dirs' (a list of strings). This does not affect any
standard search path that the runtime linker may search by
default.
"""
- self.runtime_library_dirs = copy (dirs)
+ self.runtime_library_dirs = dirs[:]
-
- def add_link_object (self, object):
+ def add_link_object(self, object):
"""Add 'object' to the list of object files (or analogues, such as
explicitly named library files or the output of "resource
compilers") to be included in every link driven by this compiler
object.
"""
- self.objects.append (object)
+ self.objects.append(object)
- def set_link_objects (self, objects):
+ def set_link_objects(self, objects):
"""Set the list of object files (or analogues) to be included in
every link to 'objects'. This does not affect any standard object
files that the linker may include by default (such as system
libraries).
"""
- self.objects = copy (objects)
+ self.objects = objects[:]
# -- Private utility methods --------------------------------------
@@ -338,25 +372,22 @@
def _setup_compile(self, outdir, macros, incdirs, sources, depends,
extra):
- """Process arguments and decide which source files to compile.
-
- Merges _fix_compile_args() and _prep_compile().
- """
+ """Process arguments and decide which source files to compile."""
if outdir is None:
outdir = self.output_dir
- elif type(outdir) is not StringType:
+ elif not isinstance(outdir, str):
raise TypeError, "'output_dir' must be a string or None"
if macros is None:
macros = self.macros
- elif type(macros) is ListType:
+ elif isinstance(macros, list):
macros = macros + (self.macros or [])
else:
raise TypeError, "'macros' (if supplied) must be a list of tuples"
if incdirs is None:
incdirs = self.include_dirs
- elif type(incdirs) in (ListType, TupleType):
+ elif isinstance(incdirs, (list, tuple)):
incdirs = list(incdirs) + (self.include_dirs or [])
else:
raise TypeError, \
@@ -371,41 +402,6 @@
output_dir=outdir)
assert len(objects) == len(sources)
- # XXX should redo this code to eliminate skip_source entirely.
- # XXX instead create build and issue skip messages inline
-
- if self.force:
- skip_source = {} # rebuild everything
- for source in sources:
- skip_source[source] = 0
- elif depends is None:
- # If depends is None, figure out which source files we
- # have to recompile according to a simplistic check. We
- # just compare the source and object file, no deep
- # dependency checking involving header files.
- skip_source = {} # rebuild everything
- for source in sources: # no wait, rebuild nothing
- skip_source[source] = 1
-
- n_sources, n_objects = newer_pairwise(sources, objects)
- for source in n_sources: # no really, only rebuild what's
- skip_source[source] = 0 # out-of-date
- else:
- # If depends is a list of files, then do a different
- # simplistic check. Assume that each object depends on
- # its source and all files in the depends list.
- skip_source = {}
- # L contains all the depends plus a spot at the end for a
- # particular source file
- L = depends[:] + [None]
- for i in range(len(objects)):
- source = sources[i]
- L[-1] = source
- if newer_group(L, objects[i]):
- skip_source[source] = 0
- else:
- skip_source[source] = 1
-
pp_opts = gen_preprocess_options(macros, incdirs)
build = {}
@@ -414,10 +410,7 @@
obj = objects[i]
ext = os.path.splitext(src)[1]
self.mkpath(os.path.dirname(obj))
- if skip_source[src]:
- log.debug("skipping %s (%s up-to-date)", src, obj)
- else:
- build[obj] = src, ext
+ build[obj] = (src, ext)
return macros, objects, extra, pp_opts, build
@@ -430,7 +423,7 @@
cc_args[:0] = before
return cc_args
- def _fix_compile_args (self, output_dir, macros, include_dirs):
+ def _fix_compile_args(self, output_dir, macros, include_dirs):
"""Typecheck and fix-up some of the arguments to the 'compile()'
method, and return fixed-up values. Specifically: if 'output_dir'
is None, replaces it with 'self.output_dir'; ensures that 'macros'
@@ -442,19 +435,19 @@
"""
if output_dir is None:
output_dir = self.output_dir
- elif type (output_dir) is not StringType:
+ elif not isinstance(output_dir, str):
raise TypeError, "'output_dir' must be a string or None"
if macros is None:
macros = self.macros
- elif type (macros) is ListType:
+ elif isinstance(macros, list):
macros = macros + (self.macros or [])
else:
raise TypeError, "'macros' (if supplied) must be a list of tuples"
if include_dirs is None:
include_dirs = self.include_dirs
- elif type (include_dirs) in (ListType, TupleType):
+ elif isinstance(include_dirs, (list, tuple)):
include_dirs = list (include_dirs) + (self.include_dirs or [])
else:
raise TypeError, \
@@ -462,78 +455,25 @@
return output_dir, macros, include_dirs
- # _fix_compile_args ()
-
-
- def _prep_compile(self, sources, output_dir, depends=None):
- """Decide which souce files must be recompiled.
-
- Determine the list of object files corresponding to 'sources',
- and figure out which ones really need to be recompiled.
- Return a list of all object files and a dictionary telling
- which source files can be skipped.
- """
- # Get the list of expected output (object) files
- objects = self.object_filenames(sources, output_dir=output_dir)
- assert len(objects) == len(sources)
-
- if self.force:
- skip_source = {} # rebuild everything
- for source in sources:
- skip_source[source] = 0
- elif depends is None:
- # If depends is None, figure out which source files we
- # have to recompile according to a simplistic check. We
- # just compare the source and object file, no deep
- # dependency checking involving header files.
- skip_source = {} # rebuild everything
- for source in sources: # no wait, rebuild nothing
- skip_source[source] = 1
-
- n_sources, n_objects = newer_pairwise(sources, objects)
- for source in n_sources: # no really, only rebuild what's
- skip_source[source] = 0 # out-of-date
- else:
- # If depends is a list of files, then do a different
- # simplistic check. Assume that each object depends on
- # its source and all files in the depends list.
- skip_source = {}
- # L contains all the depends plus a spot at the end for a
- # particular source file
- L = depends[:] + [None]
- for i in range(len(objects)):
- source = sources[i]
- L[-1] = source
- if newer_group(L, objects[i]):
- skip_source[source] = 0
- else:
- skip_source[source] = 1
-
- return objects, skip_source
-
- # _prep_compile ()
-
-
- def _fix_object_args (self, objects, output_dir):
+ def _fix_object_args(self, objects, output_dir):
"""Typecheck and fix up some arguments supplied to various methods.
Specifically: ensure that 'objects' is a list; if output_dir is
None, replace with self.output_dir. Return fixed versions of
'objects' and 'output_dir'.
"""
- if type (objects) not in (ListType, TupleType):
+ if not isinstance(objects, (list, tuple)):
raise TypeError, \
"'objects' must be a list or tuple of strings"
objects = list (objects)
if output_dir is None:
output_dir = self.output_dir
- elif type (output_dir) is not StringType:
+ elif not isinstance(output_dir, str):
raise TypeError, "'output_dir' must be a string or None"
return (objects, output_dir)
-
- def _fix_lib_args (self, libraries, library_dirs, runtime_library_dirs):
+ def _fix_lib_args(self, libraries, library_dirs, runtime_library_dirs):
"""Typecheck and fix up some of the arguments supplied to the
'link_*' methods. Specifically: ensure that all arguments are
lists, and augment them with their permanent versions
@@ -542,7 +482,7 @@
"""
if libraries is None:
libraries = self.libraries
- elif type (libraries) in (ListType, TupleType):
+ elif isinstance(libraries, (list, tuple)):
libraries = list (libraries) + (self.libraries or [])
else:
raise TypeError, \
@@ -550,7 +490,7 @@
if library_dirs is None:
library_dirs = self.library_dirs
- elif type (library_dirs) in (ListType, TupleType):
+ elif isinstance(library_dirs, (list, tuple)):
library_dirs = list (library_dirs) + (self.library_dirs or [])
else:
raise TypeError, \
@@ -558,7 +498,7 @@
if runtime_library_dirs is None:
runtime_library_dirs = self.runtime_library_dirs
- elif type (runtime_library_dirs) in (ListType, TupleType):
+ elif isinstance(runtime_library_dirs, (list, tuple)):
runtime_library_dirs = (list (runtime_library_dirs) +
(self.runtime_library_dirs or []))
else:
@@ -568,10 +508,7 @@
return (libraries, library_dirs, runtime_library_dirs)
- # _fix_lib_args ()
-
-
- def _need_link (self, objects, output_file):
+ def _need_link(self, objects, output_file):
"""Return true if we need to relink the files listed in 'objects'
to recreate 'output_file'.
"""
@@ -584,13 +521,11 @@
newer = newer_group (objects, output_file)
return newer
- # _need_link ()
-
- def detect_language (self, sources):
+ def detect_language(self, sources):
"""Detect the language of a given file, or list of files. Uses
language_map, and language_order to do the job.
"""
- if type(sources) is not ListType:
+ if not isinstance(sources, list):
sources = [sources]
lang = None
index = len(self.language_order)
@@ -606,18 +541,11 @@
pass
return lang
- # detect_language ()
-
# -- Worker methods ------------------------------------------------
# (must be implemented by subclasses)
- def preprocess (self,
- source,
- output_file=None,
- macros=None,
- include_dirs=None,
- extra_preargs=None,
- extra_postargs=None):
+ def preprocess(self, source, output_file=None, macros=None,
+ include_dirs=None, extra_preargs=None, extra_postargs=None):
"""Preprocess a single C/C++ source file, named in 'source'.
Output will be written to file named 'output_file', or stdout if
'output_file' not supplied. 'macros' is a list of macro
@@ -680,7 +608,6 @@
Raises CompileError on failure.
"""
-
# A concrete compiler class can either override this method
# entirely or implement _compile().
@@ -706,12 +633,8 @@
# should implement _compile().
pass
- def create_static_lib (self,
- objects,
- output_libname,
- output_dir=None,
- debug=0,
- target_lang=None):
+ def create_static_lib(self, objects, output_libname, output_dir=None,
+ debug=0, target_lang=None):
"""Link a bunch of stuff together to create a static library file.
The "bunch of stuff" consists of the list of object files supplied
as 'objects', the extra object files supplied to
@@ -736,26 +659,15 @@
"""
pass
-
# values for target_desc parameter in link()
SHARED_OBJECT = "shared_object"
SHARED_LIBRARY = "shared_library"
EXECUTABLE = "executable"
- def link (self,
- target_desc,
- objects,
- output_filename,
- output_dir=None,
- libraries=None,
- library_dirs=None,
- runtime_library_dirs=None,
- export_symbols=None,
- debug=0,
- extra_preargs=None,
- extra_postargs=None,
- build_temp=None,
- target_lang=None):
+ def link(self, target_desc, objects, output_filename, output_dir=None,
+ libraries=None, library_dirs=None, runtime_library_dirs=None,
+ export_symbols=None, debug=0, extra_preargs=None,
+ extra_postargs=None, build_temp=None, target_lang=None):
"""Link a bunch of stuff together to create an executable or
shared library file.
@@ -804,19 +716,11 @@
# Old 'link_*()' methods, rewritten to use the new 'link()' method.
- def link_shared_lib (self,
- objects,
- output_libname,
- output_dir=None,
- libraries=None,
- library_dirs=None,
- runtime_library_dirs=None,
- export_symbols=None,
- debug=0,
- extra_preargs=None,
- extra_postargs=None,
- build_temp=None,
- target_lang=None):
+ def link_shared_lib(self, objects, output_libname, output_dir=None,
+ libraries=None, library_dirs=None,
+ runtime_library_dirs=None, export_symbols=None,
+ debug=0, extra_preargs=None, extra_postargs=None,
+ build_temp=None, target_lang=None):
self.link(CCompiler.SHARED_LIBRARY, objects,
self.library_filename(output_libname, lib_type='shared'),
output_dir,
@@ -825,37 +729,21 @@
extra_preargs, extra_postargs, build_temp, target_lang)
- def link_shared_object (self,
- objects,
- output_filename,
- output_dir=None,
- libraries=None,
- library_dirs=None,
- runtime_library_dirs=None,
- export_symbols=None,
- debug=0,
- extra_preargs=None,
- extra_postargs=None,
- build_temp=None,
- target_lang=None):
+ def link_shared_object(self, objects, output_filename, output_dir=None,
+ libraries=None, library_dirs=None,
+ runtime_library_dirs=None, export_symbols=None,
+ debug=0, extra_preargs=None, extra_postargs=None,
+ build_temp=None, target_lang=None):
self.link(CCompiler.SHARED_OBJECT, objects,
output_filename, output_dir,
libraries, library_dirs, runtime_library_dirs,
export_symbols, debug,
extra_preargs, extra_postargs, build_temp, target_lang)
-
- def link_executable (self,
- objects,
- output_progname,
- output_dir=None,
- libraries=None,
- library_dirs=None,
- runtime_library_dirs=None,
- debug=0,
- extra_preargs=None,
- extra_postargs=None,
- target_lang=None):
+ def link_executable(self, objects, output_progname, output_dir=None,
+ libraries=None, library_dirs=None,
+ runtime_library_dirs=None, debug=0, extra_preargs=None,
+ extra_postargs=None, target_lang=None):
self.link(CCompiler.EXECUTABLE, objects,
self.executable_filename(output_progname), output_dir,
libraries, library_dirs, runtime_library_dirs, None,
@@ -867,29 +755,26 @@
# no appropriate default implementation so subclasses should
# implement all of these.
- def library_dir_option (self, dir):
+ def library_dir_option(self, dir):
"""Return the compiler option to add 'dir' to the list of
directories searched for libraries.
"""
raise NotImplementedError
- def runtime_library_dir_option (self, dir):
+ def runtime_library_dir_option(self, dir):
"""Return the compiler option to add 'dir' to the list of
directories searched for runtime libraries.
"""
raise NotImplementedError
- def library_option (self, lib):
+ def library_option(self, lib):
"""Return the compiler option to add 'dir' to the list of libraries
linked into the shared library or executable.
"""
raise NotImplementedError
- def has_function(self, funcname,
- includes=None,
- include_dirs=None,
- libraries=None,
- library_dirs=None):
+ def has_function(self, funcname, includes=None, include_dirs=None,
+ libraries=None, library_dirs=None):
"""Return a boolean indicating whether funcname is supported on
the current platform. The optional arguments can be used to
augment the compilation environment.
@@ -909,14 +794,16 @@
library_dirs = []
fd, fname = tempfile.mkstemp(".c", funcname, text=True)
f = os.fdopen(fd, "w")
- for incl in includes:
- f.write("""#include "%s"\n""" % incl)
- f.write("""\
+ try:
+ for incl in includes:
+ f.write("""#include "%s"\n""" % incl)
+ f.write("""\
main (int argc, char **argv) {
%s();
}
""" % funcname)
- f.close()
+ finally:
+ f.close()
try:
objects = self.compile([fname], include_dirs=include_dirs)
except CompileError:
@@ -1020,28 +907,28 @@
# -- Utility methods -----------------------------------------------
- def announce (self, msg, level=1):
+ def announce(self, msg, level=1):
log.debug(msg)
- def debug_print (self, msg):
+ def debug_print(self, msg):
from distutils.debug import DEBUG
if DEBUG:
print msg
- def warn (self, msg):
- sys.stderr.write ("warning: %s\n" % msg)
+ def warn(self, msg):
+ sys.stderr.write("warning: %s\n" % msg)
- def execute (self, func, args, msg=None, level=1):
+ def execute(self, func, args, msg=None, level=1):
execute(func, args, msg, self.dry_run)
- def spawn (self, cmd):
- spawn (cmd, dry_run=self.dry_run)
+ def spawn(self, cmd):
+ spawn(cmd, dry_run=self.dry_run)
- def move_file (self, src, dst):
- return move_file (src, dst, dry_run=self.dry_run)
+ def move_file(self, src, dst):
+ return move_file(src, dst, dry_run=self.dry_run)
- def mkpath (self, name, mode=0777):
- mkpath (name, mode, self.dry_run)
+ def mkpath(self, name, mode=0777):
+ mkpath(name, mode, dry_run=self.dry_run)
# class CCompiler
@@ -1064,12 +951,10 @@
# OS name mappings
('posix', 'unix'),
('nt', 'msvc'),
- ('mac', 'mwerks'),
)
def get_default_compiler(osname=None, platform=None):
-
""" Determine the default compiler to use for the given platform.
osname should be one of the standard Python OS names (i.e. the
@@ -1104,8 +989,6 @@
"Mingw32 port of GNU C Compiler for Win32"),
'bcpp': ('bcppcompiler', 'BCPPCompiler',
"Borland C++ Compiler"),
- 'mwerks': ('mwerkscompiler', 'MWerksCompiler',
- "MetroWerks CodeWarrior"),
'emx': ('emxccompiler', 'EMXCCompiler',
"EMX port of GNU C Compiler for OS/2"),
'jython': ('jythoncompiler', 'JythonCompiler',
@@ -1129,11 +1012,7 @@
pretty_printer.print_help("List of available compilers:")
-def new_compiler (plat=None,
- compiler=None,
- verbose=0,
- dry_run=0,
- force=0):
+def new_compiler(plat=None, compiler=None, verbose=0, dry_run=0, force=0):
"""Generate an instance of some CCompiler subclass for the supplied
platform/compiler combination. 'plat' defaults to 'os.name'
(eg. 'posix', 'nt'), and 'compiler' defaults to the default compiler
@@ -1175,10 +1054,10 @@
# XXX The None is necessary to preserve backwards compatibility
# with classes that expect verbose to be the first positional
# argument.
- return klass (None, dry_run, force)
+ return klass(None, dry_run, force)
-def gen_preprocess_options (macros, include_dirs):
+def gen_preprocess_options(macros, include_dirs):
"""Generate C pre-processor options (-D, -U, -I) as used by at least
two types of compilers: the typical Unix compiler and Visual C++.
'macros' is the usual thing, a list of 1- or 2-tuples, where (name,)
@@ -1203,7 +1082,7 @@
pp_opts = []
for macro in macros:
- if not (type (macro) is TupleType and
+ if not (isinstance(macro, tuple) and
1 <= len (macro) <= 2):
raise TypeError, \
("bad macro definition '%s': " +
@@ -1226,27 +1105,27 @@
return pp_opts
-# gen_preprocess_options ()
+def gen_lib_options(compiler, library_dirs, runtime_library_dirs, libraries):
+ """Generate linker options for searching library directories and
+ linking with specific libraries.
-def gen_lib_options (compiler, library_dirs, runtime_library_dirs, libraries):
- """Generate linker options for searching library directories and
- linking with specific libraries. 'libraries' and 'library_dirs' are,
- respectively, lists of library names (not filenames!) and search
- directories. Returns a list of command-line options suitable for use
- with some compiler (depending on the two format strings passed in).
+ 'libraries' and 'library_dirs' are, respectively, lists of library names
+ (not filenames!) and search directories. Returns a list of command-line
+ options suitable for use with some compiler (depending on the two format
+ strings passed in).
"""
lib_opts = []
for dir in library_dirs:
- lib_opts.append (compiler.library_dir_option (dir))
+ lib_opts.append(compiler.library_dir_option(dir))
for dir in runtime_library_dirs:
- opt = compiler.runtime_library_dir_option (dir)
- if type(opt) is ListType:
- lib_opts = lib_opts + opt
+ opt = compiler.runtime_library_dir_option(dir)
+ if isinstance(opt, list):
+ lib_opts.extend(opt)
else:
- lib_opts.append (opt)
+ lib_opts.append(opt)
# XXX it's important that we *not* remove redundant library mentions!
# sometimes you really do have to say "-lfoo -lbar -lfoo" in order to
@@ -1255,17 +1134,15 @@
# pretty nasty way to arrange your C code.
for lib in libraries:
- (lib_dir, lib_name) = os.path.split (lib)
- if lib_dir:
- lib_file = compiler.find_library_file ([lib_dir], lib_name)
- if lib_file:
- lib_opts.append (lib_file)
+ lib_dir, lib_name = os.path.split(lib)
+ if lib_dir != '':
+ lib_file = compiler.find_library_file([lib_dir], lib_name)
+ if lib_file is not None:
+ lib_opts.append(lib_file)
else:
- compiler.warn ("no library file corresponding to "
- "'%s' found (skipping)" % lib)
+ compiler.warn("no library file corresponding to "
+ "'%s' found (skipping)" % lib)
else:
- lib_opts.append (compiler.library_option (lib))
+ lib_opts.append(compiler.library_option(lib))
return lib_opts
-
-# gen_lib_options ()
diff --git a/Lib/distutils/command/bdist.py b/Lib/distutils/command/bdist.py
--- a/Lib/distutils/command/bdist.py
+++ b/Lib/distutils/command/bdist.py
@@ -5,9 +5,9 @@
# This module should be kept compatible with Python 2.1.
-__revision__ = "$Id: bdist.py 37828 2004-11-10 22:23:15Z loewis $"
+__revision__ = "$Id: bdist.py 62197 2008-04-07 01:53:39Z mark.hammond $"
-import os, string
+import os
from types import *
from distutils.core import Command
from distutils.errors import *
@@ -98,7 +98,10 @@
def finalize_options (self):
# have to finalize 'plat_name' before 'bdist_base'
if self.plat_name is None:
- self.plat_name = get_platform()
+ if self.skip_build:
+ self.plat_name = get_platform()
+ else:
+ self.plat_name = self.get_finalized_command('build').plat_name
# 'bdist_base' -- parent of per-built-distribution-format
# temporary directories (eg. we'll probably have
@@ -122,7 +125,6 @@
# finalize_options()
-
def run (self):
# Figure out which sub-commands we need to run.
diff --git a/Lib/distutils/command/bdist_dumb.py b/Lib/distutils/command/bdist_dumb.py
--- a/Lib/distutils/command/bdist_dumb.py
+++ b/Lib/distutils/command/bdist_dumb.py
@@ -4,21 +4,21 @@
distribution -- i.e., just an archive to be unpacked under $prefix or
$exec_prefix)."""
-# This module should be kept compatible with Python 2.1.
-
-__revision__ = "$Id: bdist_dumb.py 38697 2005-03-23 18:54:36Z loewis $"
+__revision__ = "$Id: bdist_dumb.py 77761 2010-01-26 22:46:15Z tarek.ziade $"
import os
+
+from sysconfig import get_python_version
+
+from distutils.util import get_platform
from distutils.core import Command
-from distutils.util import get_platform
-from distutils.dir_util import create_tree, remove_tree, ensure_relative
-from distutils.errors import *
-from distutils.sysconfig import get_python_version
+from distutils.dir_util import remove_tree, ensure_relative
+from distutils.errors import DistutilsPlatformError
from distutils import log
class bdist_dumb (Command):
- description = "create a \"dumb\" built distribution"
+ description = 'create a "dumb" built distribution'
user_options = [('bdist-dir=', 'd',
"temporary directory for creating the distribution"),
@@ -37,6 +37,12 @@
('relative', None,
"build the archive using relative paths"
"(default: false)"),
+ ('owner=', 'u',
+ "Owner name used when creating a tar file"
+ " [default: current user]"),
+ ('group=', 'g',
+ "Group name used when creating a tar file"
+ " [default: current group]"),
]
boolean_options = ['keep-temp', 'skip-build', 'relative']
@@ -55,12 +61,10 @@
self.dist_dir = None
self.skip_build = 0
self.relative = 0
+ self.owner = None
+ self.group = None
- # initialize_options()
-
-
- def finalize_options (self):
-
+ def finalize_options(self):
if self.bdist_dir is None:
bdist_base = self.get_finalized_command('bdist').bdist_base
self.bdist_dir = os.path.join(bdist_base, 'dumb')
@@ -77,11 +81,7 @@
('dist_dir', 'dist_dir'),
('plat_name', 'plat_name'))
- # finalize_options()
-
-
- def run (self):
-
+ def run(self):
if not self.skip_build:
self.run_command('build')
@@ -120,7 +120,8 @@
# Make the archive
filename = self.make_archive(pseudoinstall_root,
- self.format, root_dir=archive_root)
+ self.format, root_dir=archive_root,
+ owner=self.owner, group=self.group)
if self.distribution.has_ext_modules():
pyversion = get_python_version()
else:
@@ -130,7 +131,3 @@
if not self.keep_temp:
remove_tree(self.bdist_dir, dry_run=self.dry_run)
-
- # run()
-
-# class bdist_dumb
diff --git a/Lib/distutils/command/install_scripts.py b/Lib/distutils/command/install_scripts.py
--- a/Lib/distutils/command/install_scripts.py
+++ b/Lib/distutils/command/install_scripts.py
@@ -5,9 +5,7 @@
# contributed by Bastian Kleineidam
-# This module should be kept compatible with Python 2.1.
-
-__revision__ = "$Id: install_scripts.py 37828 2004-11-10 22:23:15Z loewis $"
+__revision__ = "$Id: install_scripts.py 68943 2009-01-25 22:09:10Z tarek.ziade $"
import os
from distutils.core import Command
diff --git a/Lib/distutils/file_util.py b/Lib/distutils/file_util.py
--- a/Lib/distutils/file_util.py
+++ b/Lib/distutils/file_util.py
@@ -3,58 +3,55 @@
Utility functions for operating on single files.
"""
-# This module should be kept compatible with Python 2.1.
-
-__revision__ = "$Id: file_util.py 37828 2004-11-10 22:23:15Z loewis $"
+__revision__ = "$Id: file_util.py 86238 2010-11-06 04:06:18Z eric.araujo $"
import os
from distutils.errors import DistutilsFileError
from distutils import log
# for generating verbose output in 'copy_file()'
-_copy_action = { None: 'copying',
- 'hard': 'hard linking',
- 'sym': 'symbolically linking' }
+_copy_action = {None: 'copying',
+ 'hard': 'hard linking',
+ 'sym': 'symbolically linking'}
-def _copy_file_contents (src, dst, buffer_size=16*1024):
- """Copy the file 'src' to 'dst'; both must be filenames. Any error
- opening either file, reading from 'src', or writing to 'dst', raises
- DistutilsFileError. Data is read/written in chunks of 'buffer_size'
- bytes (default 16k). No attempt is made to handle anything apart from
- regular files.
+def _copy_file_contents(src, dst, buffer_size=16*1024):
+ """Copy the file 'src' to 'dst'.
+
+ Both must be filenames. Any error opening either file, reading from
+ 'src', or writing to 'dst', raises DistutilsFileError. Data is
+ read/written in chunks of 'buffer_size' bytes (default 16k). No attempt
+ is made to handle anything apart from regular files.
"""
# Stolen from shutil module in the standard library, but with
# custom error-handling added.
-
fsrc = None
fdst = None
try:
try:
fsrc = open(src, 'rb')
except os.error, (errno, errstr):
- raise DistutilsFileError, \
- "could not open '%s': %s" % (src, errstr)
+ raise DistutilsFileError("could not open '%s': %s" % (src, errstr))
if os.path.exists(dst):
try:
os.unlink(dst)
except os.error, (errno, errstr):
- raise DistutilsFileError, \
- "could not delete '%s': %s" % (dst, errstr)
+ raise DistutilsFileError(
+ "could not delete '%s': %s" % (dst, errstr))
try:
fdst = open(dst, 'wb')
except os.error, (errno, errstr):
- raise DistutilsFileError, \
- "could not create '%s': %s" % (dst, errstr)
+ raise DistutilsFileError(
+ "could not create '%s': %s" % (dst, errstr))
while 1:
try:
buf = fsrc.read(buffer_size)
except os.error, (errno, errstr):
- raise DistutilsFileError, \
- "could not read from '%s': %s" % (src, errstr)
+ raise DistutilsFileError(
+ "could not read from '%s': %s" % (src, errstr))
if not buf:
break
@@ -62,8 +59,8 @@
try:
fdst.write(buf)
except os.error, (errno, errstr):
- raise DistutilsFileError, \
- "could not write to '%s': %s" % (dst, errstr)
+ raise DistutilsFileError(
+ "could not write to '%s': %s" % (dst, errstr))
finally:
if fdst:
@@ -71,25 +68,18 @@
if fsrc:
fsrc.close()
-# _copy_file_contents()
+def copy_file(src, dst, preserve_mode=1, preserve_times=1, update=0,
+ link=None, verbose=1, dry_run=0):
+ """Copy a file 'src' to 'dst'.
-def copy_file (src, dst,
- preserve_mode=1,
- preserve_times=1,
- update=0,
- link=None,
- verbose=0,
- dry_run=0):
-
- """Copy a file 'src' to 'dst'. If 'dst' is a directory, then 'src' is
- copied there with the same name; otherwise, it must be a filename. (If
- the file exists, it will be ruthlessly clobbered.) If 'preserve_mode'
- is true (the default), the file's mode (type and permission bits, or
- whatever is analogous on the current platform) is copied. If
- 'preserve_times' is true (the default), the last-modified and
- last-access times are copied as well. If 'update' is true, 'src' will
- only be copied if 'dst' does not exist, or if 'dst' does exist but is
- older than 'src'.
+ If 'dst' is a directory, then 'src' is copied there with the same name;
+ otherwise, it must be a filename. (If the file exists, it will be
+ ruthlessly clobbered.) If 'preserve_mode' is true (the default),
+ the file's mode (type and permission bits, or whatever is analogous on
+ the current platform) is copied. If 'preserve_times' is true (the
+ default), the last-modified and last-access times are copied as well.
+ If 'update' is true, 'src' will only be copied if 'dst' does not exist,
+ or if 'dst' does exist but is older than 'src'.
'link' allows you to make hard links (os.link) or symbolic links
(os.symlink) instead of copying: set it to "hard" or "sym"; if it is
@@ -115,8 +105,8 @@
from stat import ST_ATIME, ST_MTIME, ST_MODE, S_IMODE
if not os.path.isfile(src):
- raise DistutilsFileError, \
- "can't copy '%s': doesn't exist or not a regular file" % src
+ raise DistutilsFileError(
+ "can't copy '%s': doesn't exist or not a regular file" % src)
if os.path.isdir(dst):
dir = dst
@@ -125,34 +115,27 @@
dir = os.path.dirname(dst)
if update and not newer(src, dst):
- log.debug("not copying %s (output up-to-date)", src)
+ if verbose >= 1:
+ log.debug("not copying %s (output up-to-date)", src)
return dst, 0
try:
action = _copy_action[link]
except KeyError:
- raise ValueError, \
- "invalid value '%s' for 'link' argument" % link
- if os.path.basename(dst) == os.path.basename(src):
- log.info("%s %s -> %s", action, src, dir)
- else:
- log.info("%s %s -> %s", action, src, dst)
+ raise ValueError("invalid value '%s' for 'link' argument" % link)
+
+ if verbose >= 1:
+ if os.path.basename(dst) == os.path.basename(src):
+ log.info("%s %s -> %s", action, src, dir)
+ else:
+ log.info("%s %s -> %s", action, src, dst)
if dry_run:
return (dst, 1)
- # On Mac OS, use the native file copy routine
- if os.name == 'mac':
- import macostools
- try:
- macostools.copy(src, dst, 0, preserve_times)
- except os.error, exc:
- raise DistutilsFileError, \
- "could not copy '%s' to '%s': %s" % (src, dst, exc[-1])
-
# If linking (hard or symbolic), use the appropriate system call
# (Unix only, of course, but that's the caller's responsibility)
- elif link == 'hard':
+ if link == 'hard':
if not (os.path.exists(dst) and os.path.samefile(src, dst)):
os.link(src, dst)
elif link == 'sym':
@@ -175,17 +158,13 @@
return (dst, 1)
-# copy_file ()
+# XXX I suspect this is Unix-specific -- need porting help!
+def move_file (src, dst, verbose=1, dry_run=0):
+ """Move a file 'src' to 'dst'.
-
-# XXX I suspect this is Unix-specific -- need porting help!
-def move_file (src, dst,
- verbose=0,
- dry_run=0):
-
- """Move a file 'src' to 'dst'. If 'dst' is a directory, the file will
- be moved into it with the same name; otherwise, 'src' is just renamed
- to 'dst'. Return the new full name of the file.
+ If 'dst' is a directory, the file will be moved into it with the same
+ name; otherwise, 'src' is just renamed to 'dst'. Return the new
+ full name of the file.
Handles cross-device moves on Unix using 'copy_file()'. What about
other systems???
@@ -193,26 +172,26 @@
from os.path import exists, isfile, isdir, basename, dirname
import errno
- log.info("moving %s -> %s", src, dst)
+ if verbose >= 1:
+ log.info("moving %s -> %s", src, dst)
if dry_run:
return dst
if not isfile(src):
- raise DistutilsFileError, \
- "can't move '%s': not a regular file" % src
+ raise DistutilsFileError("can't move '%s': not a regular file" % src)
if isdir(dst):
dst = os.path.join(dst, basename(src))
elif exists(dst):
- raise DistutilsFileError, \
- "can't move '%s': destination '%s' already exists" % \
- (src, dst)
+ raise DistutilsFileError(
+ "can't move '%s': destination '%s' already exists" %
+ (src, dst))
if not isdir(dirname(dst)):
- raise DistutilsFileError, \
+ raise DistutilsFileError(
"can't move '%s': destination '%s' not a valid path" % \
- (src, dst)
+ (src, dst))
copy_it = 0
try:
@@ -221,11 +200,11 @@
if num == errno.EXDEV:
copy_it = 1
else:
- raise DistutilsFileError, \
- "couldn't move '%s' to '%s': %s" % (src, dst, msg)
+ raise DistutilsFileError(
+ "couldn't move '%s' to '%s': %s" % (src, dst, msg))
if copy_it:
- copy_file(src, dst)
+ copy_file(src, dst, verbose=verbose)
try:
os.unlink(src)
except os.error, (num, msg):
@@ -233,21 +212,20 @@
os.unlink(dst)
except os.error:
pass
- raise DistutilsFileError, \
+ raise DistutilsFileError(
("couldn't move '%s' to '%s' by copy/delete: " +
- "delete '%s' failed: %s") % \
- (src, dst, src, msg)
-
+ "delete '%s' failed: %s") %
+ (src, dst, src, msg))
return dst
-# move_file ()
-
def write_file (filename, contents):
"""Create a file with the specified name and write 'contents' (a
sequence of strings without line terminators) to it.
"""
f = open(filename, "w")
- for line in contents:
- f.write(line + "\n")
- f.close()
+ try:
+ for line in contents:
+ f.write(line + "\n")
+ finally:
+ f.close()
diff --git a/Lib/distutils/spawn.py b/Lib/distutils/spawn.py
--- a/Lib/distutils/spawn.py
+++ b/Lib/distutils/spawn.py
@@ -6,21 +6,18 @@
executable name.
"""
-# This module should be kept compatible with Python 2.1.
+__revision__ = "$Id: spawn.py 73147 2009-06-02 15:58:43Z tarek.ziade $"
-__revision__ = "$Id: spawn.py 37828 2004-11-10 22:23:15Z loewis $"
+import sys
+import os
-import sys, os, string
-from distutils.errors import *
+from distutils.errors import DistutilsPlatformError, DistutilsExecError
from distutils import log
-def spawn (cmd,
- search_path=1,
- verbose=0,
- dry_run=0):
+def spawn(cmd, search_path=1, verbose=0, dry_run=0):
+ """Run another program, specified as a command list 'cmd', in a new process.
- """Run another program, specified as a command list 'cmd', in a new
- process. 'cmd' is just the argument list for the new process, ie.
+ 'cmd' is just the argument list for the new process, ie.
cmd[0] is the program to run and cmd[1:] are the rest of its arguments.
There is no way to run a program with a name different from that of its
executable.
@@ -45,37 +42,29 @@
raise DistutilsPlatformError, \
"don't know how to spawn programs on platform '%s'" % os.name
-# spawn ()
+def _nt_quote_args(args):
+ """Quote command-line arguments for DOS/Windows conventions.
-
-def _nt_quote_args (args):
- """Quote command-line arguments for DOS/Windows conventions: just
- wraps every argument which contains blanks in double quotes, and
+ Just wraps every argument which contains blanks in double quotes, and
returns a new argument list.
"""
-
# XXX this doesn't seem very robust to me -- but if the Windows guys
# say it'll work, I guess I'll have to accept it. (What if an arg
# contains quotes? What other magic characters, other than spaces,
# have to be escaped? Is there an escaping mechanism other than
# quoting?)
-
- for i in range(len(args)):
- if string.find(args[i], ' ') != -1:
- args[i] = '"%s"' % args[i]
+ for i, arg in enumerate(args):
+ if ' ' in arg:
+ args[i] = '"%s"' % arg
return args
-def _spawn_nt (cmd,
- search_path=1,
- verbose=0,
- dry_run=0):
-
+def _spawn_nt(cmd, search_path=1, verbose=0, dry_run=0):
executable = cmd[0]
cmd = _nt_quote_args(cmd)
if search_path:
# either we find one or it stays the same
executable = find_executable(executable) or executable
- log.info(string.join([executable] + cmd[1:], ' '))
+ log.info(' '.join([executable] + cmd[1:]))
if not dry_run:
# spawn for NT requires a full path to the .exe
try:
@@ -89,18 +78,12 @@
raise DistutilsExecError, \
"command '%s' failed with exit status %d" % (cmd[0], rc)
-
-def _spawn_os2 (cmd,
- search_path=1,
- verbose=0,
- dry_run=0):
-
+def _spawn_os2(cmd, search_path=1, verbose=0, dry_run=0):
executable = cmd[0]
- #cmd = _nt_quote_args(cmd)
if search_path:
# either we find one or it stays the same
executable = find_executable(executable) or executable
- log.info(string.join([executable] + cmd[1:], ' '))
+ log.info(' '.join([executable] + cmd[1:]))
if not dry_run:
# spawnv for OS/2 EMX requires a full path to the .exe
try:
@@ -111,27 +94,20 @@
"command '%s' failed: %s" % (cmd[0], exc[-1])
if rc != 0:
# and this reflects the command running but failing
- print "command '%s' failed with exit status %d" % (cmd[0], rc)
+ log.debug("command '%s' failed with exit status %d" % (cmd[0], rc))
raise DistutilsExecError, \
"command '%s' failed with exit status %d" % (cmd[0], rc)
-def _spawn_posix (cmd,
- search_path=1,
- verbose=0,
- dry_run=0):
-
- log.info(string.join(cmd, ' '))
+def _spawn_posix(cmd, search_path=1, verbose=0, dry_run=0):
+ log.info(' '.join(cmd))
if dry_run:
return
exec_fn = search_path and os.execvp or os.execv
-
pid = os.fork()
- if pid == 0: # in the child
+ if pid == 0: # in the child
try:
- #print "cmd[0] =", cmd[0]
- #print "cmd =", cmd
exec_fn(cmd[0], cmd)
except OSError, e:
sys.stderr.write("unable to execute %s: %s\n" %
@@ -140,14 +116,12 @@
sys.stderr.write("unable to execute %s for unknown reasons" % cmd[0])
os._exit(1)
-
-
- else: # in the parent
+ else: # in the parent
# Loop until the child either exits or is terminated by a signal
# (ie. keep waiting if it's merely stopped)
while 1:
try:
- (pid, status) = os.waitpid(pid, 0)
+ pid, status = os.waitpid(pid, 0)
except OSError, exc:
import errno
if exc.errno == errno.EINTR:
@@ -162,7 +136,7 @@
elif os.WIFEXITED(status):
exit_status = os.WEXITSTATUS(status)
if exit_status == 0:
- return # hey, it succeeded!
+ return # hey, it succeeded!
else:
raise DistutilsExecError, \
"command '%s' failed with exit status %d" % \
@@ -175,8 +149,6 @@
raise DistutilsExecError, \
"unknown error executing '%s': termination status %d" % \
(cmd[0], status)
-# _spawn_posix ()
-
def _spawn_java(cmd,
search_path=1,
@@ -200,17 +172,19 @@
def find_executable(executable, path=None):
- """Try to find 'executable' in the directories listed in 'path' (a
- string listing directories separated by 'os.pathsep'; defaults to
- os.environ['PATH']). Returns the complete filename or None if not
- found.
+ """Tries to find 'executable' in the directories listed in 'path'.
+
+ A string listing directories separated by 'os.pathsep'; defaults to
+ os.environ['PATH']. Returns the complete filename or None if not found.
"""
if path is None:
path = os.environ['PATH']
- paths = string.split(path, os.pathsep)
- (base, ext) = os.path.splitext(executable)
+ paths = path.split(os.pathsep)
+ base, ext = os.path.splitext(executable)
+
if (sys.platform == 'win32' or os.name == 'os2') and (ext != '.exe'):
executable = executable + '.exe'
+
if not os.path.isfile(executable):
for p in paths:
f = os.path.join(p, executable)
@@ -220,5 +194,3 @@
return None
else:
return executable
-
-# find_executable()
diff --git a/Lib/distutils/sysconfig.py b/Lib/distutils/sysconfig.py
--- a/Lib/distutils/sysconfig.py
+++ b/Lib/distutils/sysconfig.py
@@ -9,7 +9,7 @@
Email: <fdrake at acm.org>
"""
-__revision__ = "$Id: sysconfig.py 52234 2006-10-08 17:50:26Z ronald.oussoren $"
+__revision__ = "$Id: sysconfig.py 83688 2010-08-03 21:18:06Z mark.dickinson $"
import os
import re
@@ -22,16 +22,32 @@
PREFIX = os.path.normpath(sys.prefix)
EXEC_PREFIX = os.path.normpath(sys.exec_prefix)
+# Path to the base directory of the project. On Windows the binary may
+# live in project/PCBuild9. If we're dealing with an x64 Windows build,
+# it'll live in project/PCbuild/amd64.
+project_base = os.path.dirname(os.path.realpath(sys.executable))
+if os.name == "nt" and "pcbuild" in project_base[-8:].lower():
+ project_base = os.path.abspath(os.path.join(project_base, os.path.pardir))
+# PC/VS7.1
+if os.name == "nt" and "\\pc\\v" in project_base[-10:].lower():
+ project_base = os.path.abspath(os.path.join(project_base, os.path.pardir,
+ os.path.pardir))
+# PC/AMD64
+if os.name == "nt" and "\\pcbuild\\amd64" in project_base[-14:].lower():
+ project_base = os.path.abspath(os.path.join(project_base, os.path.pardir,
+ os.path.pardir))
+
# python_build: (Boolean) if true, we're either building Python or
# building an extension with an un-installed Python, so we use
# different (hard-wired) directories.
-
-argv0_path = os.path.dirname(os.path.abspath(sys.executable))
-landmark = os.path.join(argv0_path, "Modules", "Setup")
-
-python_build = os.path.isfile(landmark)
-
-del landmark
+# Setup.local is available for Makefile builds including VPATH builds,
+# Setup.dist is available on Windows
+def _python_build():
+ for fn in ("Setup.dist", "Setup.local"):
+ if os.path.isfile(os.path.join(project_base, "Modules", fn)):
+ return True
+ return False
+python_build = _python_build()
def get_python_version():
@@ -55,15 +71,19 @@
"""
if prefix is None:
prefix = plat_specific and EXEC_PREFIX or PREFIX
+
if os.name == "posix":
if python_build:
- base = os.path.dirname(os.path.abspath(sys.executable))
+ buildir = os.path.dirname(os.path.realpath(sys.executable))
if plat_specific:
- inc_dir = base
+ # python.h is located in the buildir
+ inc_dir = buildir
else:
- inc_dir = os.path.join(base, "Include")
- if not os.path.exists(inc_dir):
- inc_dir = os.path.join(os.path.dirname(base), "Include")
+ # the source dir is relative to the buildir
+ srcdir = os.path.abspath(os.path.join(buildir,
+ get_config_var('srcdir')))
+ # Include is located in the srcdir
+ inc_dir = os.path.join(srcdir, "Include")
return inc_dir
return os.path.join(prefix, "include", "python" + get_python_version())
elif os.name == "nt":
@@ -113,7 +133,7 @@
if get_python_version() < "2.2":
return prefix
else:
- return os.path.join(PREFIX, "Lib", "site-packages")
+ return os.path.join(prefix, "Lib", "site-packages")
elif os.name == "mac":
if plat_specific:
@@ -129,9 +149,9 @@
elif os.name == "os2" or os.name == "java":
if standard_lib:
- return os.path.join(PREFIX, "Lib")
+ return os.path.join(prefix, "Lib")
else:
- return os.path.join(PREFIX, "Lib", "site-packages")
+ return os.path.join(prefix, "Lib", "site-packages")
else:
raise DistutilsPlatformError(
@@ -150,22 +170,22 @@
get_config_vars('CC', 'CXX', 'OPT', 'CFLAGS',
'CCSHARED', 'LDSHARED', 'SO')
- if os.environ.has_key('CC'):
+ if 'CC' in os.environ:
cc = os.environ['CC']
- if os.environ.has_key('CXX'):
+ if 'CXX' in os.environ:
cxx = os.environ['CXX']
- if os.environ.has_key('LDSHARED'):
+ if 'LDSHARED' in os.environ:
ldshared = os.environ['LDSHARED']
- if os.environ.has_key('CPP'):
+ if 'CPP' in os.environ:
cpp = os.environ['CPP']
else:
cpp = cc + " -E" # not always
- if os.environ.has_key('LDFLAGS'):
+ if 'LDFLAGS' in os.environ:
ldshared = ldshared + ' ' + os.environ['LDFLAGS']
- if os.environ.has_key('CFLAGS'):
+ if 'CFLAGS' in os.environ:
cflags = opt + ' ' + os.environ['CFLAGS']
ldshared = ldshared + ' ' + os.environ['CFLAGS']
- if os.environ.has_key('CPPFLAGS'):
+ if 'CPPFLAGS' in os.environ:
cpp = cpp + ' ' + os.environ['CPPFLAGS']
cflags = cflags + ' ' + os.environ['CPPFLAGS']
ldshared = ldshared + ' ' + os.environ['CPPFLAGS']
@@ -185,7 +205,10 @@
def get_config_h_filename():
"""Return full pathname of installed pyconfig.h file."""
if python_build:
- inc_dir = argv0_path
+ if os.name == "nt":
+ inc_dir = os.path.join(project_base, "PC")
+ else:
+ inc_dir = project_base
else:
inc_dir = get_python_inc(plat_specific=1)
if get_python_version() < '2.2':
@@ -199,7 +222,8 @@
def get_makefile_filename():
"""Return full pathname of installed Makefile from the Python build."""
if python_build:
- return os.path.join(os.path.dirname(sys.executable), "Makefile")
+ return os.path.join(os.path.dirname(os.path.realpath(sys.executable)),
+ "Makefile")
lib_dir = get_python_lib(plat_specific=1, standard_lib=1)
return os.path.join(lib_dir, "config", "Makefile")
@@ -256,18 +280,25 @@
while 1:
line = fp.readline()
- if line is None: # eof
+ if line is None: # eof
break
m = _variable_rx.match(line)
if m:
n, v = m.group(1, 2)
- v = string.strip(v)
- if "$" in v:
+ v = v.strip()
+ # `$$' is a literal `$' in make
+ tmpv = v.replace('$$', '')
+
+ if "$" in tmpv:
notdone[n] = v
else:
- try: v = int(v)
- except ValueError: pass
- done[n] = v
+ try:
+ v = int(v)
+ except ValueError:
+ # insert literal `$'
+ done[n] = v.replace('$$', '$')
+ else:
+ done[n] = v
# do variable interpolation here
while notdone:
@@ -277,12 +308,12 @@
if m:
n = m.group(1)
found = True
- if done.has_key(n):
+ if n in done:
item = str(done[n])
- elif notdone.has_key(n):
+ elif n in notdone:
# get it on a subsequent round
found = False
- elif os.environ.has_key(n):
+ elif n in os.environ:
# do it like make: fall back to environment
item = os.environ[n]
else:
@@ -295,7 +326,7 @@
else:
try: value = int(value)
except ValueError:
- done[name] = string.strip(value)
+ done[name] = value.strip()
else:
done[name] = value
del notdone[name]
@@ -366,7 +397,7 @@
# MACOSX_DEPLOYMENT_TARGET: configure bases some choices on it so
# it needs to be compatible.
# If it isn't set we set it to the configure-time value
- if sys.platform == 'darwin' and g.has_key('MACOSX_DEPLOYMENT_TARGET'):
+ if sys.platform == 'darwin' and 'MACOSX_DEPLOYMENT_TARGET' in g:
cfg_target = g['MACOSX_DEPLOYMENT_TARGET']
cur_target = os.getenv('MACOSX_DEPLOYMENT_TARGET', '')
if cur_target == '':
@@ -428,6 +459,8 @@
g['SO'] = '.pyd'
g['EXE'] = ".exe"
+ g['VERSION'] = get_python_version().replace(".", "")
+ g['BINDIR'] = os.path.dirname(os.path.realpath(sys.executable))
global _config_vars
_config_vars = g
@@ -521,15 +554,57 @@
# are in CFLAGS or LDFLAGS and remove them if they are.
# This is needed when building extensions on a 10.3 system
# using a universal build of python.
- for key in ('LDFLAGS', 'BASECFLAGS',
+ for key in ('LDFLAGS', 'BASECFLAGS', 'LDSHARED',
+ # a number of derived variables. These need to be
+ # patched up as well.
+ 'CFLAGS', 'PY_CFLAGS', 'BLDSHARED'):
+ flags = _config_vars[key]
+ flags = re.sub('-arch\s+\w+\s', ' ', flags)
+ flags = re.sub('-isysroot [^ \t]*', ' ', flags)
+ _config_vars[key] = flags
+
+ else:
+
+ # Allow the user to override the architecture flags using
+ # an environment variable.
+ # NOTE: This name was introduced by Apple in OSX 10.5 and
+ # is used by several scripting languages distributed with
+ # that OS release.
+
+ if 'ARCHFLAGS' in os.environ:
+ arch = os.environ['ARCHFLAGS']
+ for key in ('LDFLAGS', 'BASECFLAGS', 'LDSHARED',
# a number of derived variables. These need to be
# patched up as well.
'CFLAGS', 'PY_CFLAGS', 'BLDSHARED'):
- flags = _config_vars[key]
- flags = re.sub('-arch\s+\w+\s', ' ', flags)
- flags = re.sub('-isysroot [^ \t]*', ' ', flags)
- _config_vars[key] = flags
+ flags = _config_vars[key]
+ flags = re.sub('-arch\s+\w+\s', ' ', flags)
+ flags = flags + ' ' + arch
+ _config_vars[key] = flags
+
+ # If we're on OSX 10.5 or later and the user tries to
+ # compiles an extension using an SDK that is not present
+ # on the current machine it is better to not use an SDK
+ # than to fail.
+ #
+ # The major usecase for this is users using a Python.org
+ # binary installer on OSX 10.6: that installer uses
+ # the 10.4u SDK, but that SDK is not installed by default
+ # when you install Xcode.
+ #
+ m = re.search('-isysroot\s+(\S+)', _config_vars['CFLAGS'])
+ if m is not None:
+ sdk = m.group(1)
+ if not os.path.exists(sdk):
+ for key in ('LDFLAGS', 'BASECFLAGS', 'LDSHARED',
+ # a number of derived variables. These need to be
+ # patched up as well.
+ 'CFLAGS', 'PY_CFLAGS', 'BLDSHARED'):
+
+ flags = _config_vars[key]
+ flags = re.sub('-isysroot\s+\S+(\s|$)', ' ', flags)
+ _config_vars[key] = flags
if args:
vals = []
diff --git a/Lib/distutils/tests/test_build_py.py b/Lib/distutils/tests/test_build_py.py
--- a/Lib/distutils/tests/test_build_py.py
+++ b/Lib/distutils/tests/test_build_py.py
@@ -72,6 +72,7 @@
open(os.path.join(testdir, "testfile"), "w").close()
os.chdir(sources)
+ old_stdout = sys.stdout
sys.stdout = StringIO.StringIO()
try:
@@ -90,7 +91,23 @@
finally:
# Restore state.
os.chdir(cwd)
- sys.stdout = sys.__stdout__
+ sys.stdout = old_stdout
+
+ def test_dont_write_bytecode(self):
+ # makes sure byte_compile is not used
+ pkg_dir, dist = self.create_dist()
+ cmd = build_py(dist)
+ cmd.compile = 1
+ cmd.optimize = 1
+
+ old_dont_write_bytecode = sys.dont_write_bytecode
+ sys.dont_write_bytecode = True
+ try:
+ cmd.byte_compile([])
+ finally:
+ sys.dont_write_bytecode = old_dont_write_bytecode
+
+ self.assertTrue('byte-compiling is disabled' in self.logs[0][1])
def test_suite():
return unittest.makeSuite(BuildPyTestCase)
diff --git a/Lib/distutils/util.py b/Lib/distutils/util.py
--- a/Lib/distutils/util.py
+++ b/Lib/distutils/util.py
@@ -4,13 +4,14 @@
one of the other *util.py modules.
"""
-__revision__ = "$Id: util.py 59116 2007-11-22 10:14:26Z ronald.oussoren $"
+__revision__ = "$Id: util.py 83588 2010-08-02 21:35:06Z ezio.melotti $"
import sys, os, string, re
from distutils.errors import DistutilsPlatformError
from distutils.dep_util import newer
from distutils.spawn import spawn
from distutils import log
+from distutils.errors import DistutilsByteCompileError
def get_platform ():
"""Return a string that identifies the current platform. This is used
@@ -29,8 +30,27 @@
irix-5.3
irix64-6.2
- For non-POSIX platforms, currently just returns 'sys.platform'.
+ Windows will return one of:
+ win-amd64 (64bit Windows on AMD64 (aka x86_64, Intel64, EM64T, etc)
+ win-ia64 (64bit Windows on Itanium)
+ win32 (all others - specifically, sys.platform is returned)
+
+ For other non-POSIX platforms, currently just returns 'sys.platform'.
"""
+ if os.name == 'nt':
+ # sniff sys.version for architecture.
+ prefix = " bit ("
+ i = string.find(sys.version, prefix)
+ if i == -1:
+ return sys.platform
+ j = string.find(sys.version, ")", i)
+ look = sys.version[i+len(prefix):j].lower()
+ if look=='amd64':
+ return 'win-amd64'
+ if look=='itanium':
+ return 'win-ia64'
+ return sys.platform
+
if os.name != "posix" or not hasattr(os, 'uname'):
# XXX what about the architecture? NT is Intel or Alpha,
# Mac OS is M68k or PPC, etc.
@@ -81,7 +101,11 @@
if not macver:
macver = cfgvars.get('MACOSX_DEPLOYMENT_TARGET')
- if not macver:
+ if 1:
+ # Always calculate the release of the running machine,
+ # needed to determine if we can build fat binaries or not.
+
+ macrelease = macver
# Get the system version. Reading this plist is a documented
# way to get the system version (see the documentation for
# the Gestalt Manager)
@@ -97,25 +121,62 @@
r'<string>(.*?)</string>', f.read())
f.close()
if m is not None:
- macver = '.'.join(m.group(1).split('.')[:2])
+ macrelease = '.'.join(m.group(1).split('.')[:2])
# else: fall back to the default behaviour
+ if not macver:
+ macver = macrelease
+
if macver:
from distutils.sysconfig import get_config_vars
release = macver
osname = "macosx"
-
- if (release + '.') >= '10.4.' and \
- get_config_vars().get('UNIVERSALSDK', '').strip():
+ if (macrelease + '.') >= '10.4.' and \
+ '-arch' in get_config_vars().get('CFLAGS', '').strip():
# The universal build will build fat binaries, but not on
# systems before 10.4
+ #
+ # Try to detect 4-way universal builds, those have machine-type
+ # 'universal' instead of 'fat'.
+
machine = 'fat'
+ cflags = get_config_vars().get('CFLAGS')
+
+ archs = re.findall('-arch\s+(\S+)', cflags)
+ archs = tuple(sorted(set(archs)))
+
+ if len(archs) == 1:
+ machine = archs[0]
+ elif archs == ('i386', 'ppc'):
+ machine = 'fat'
+ elif archs == ('i386', 'x86_64'):
+ machine = 'intel'
+ elif archs == ('i386', 'ppc', 'x86_64'):
+ machine = 'fat3'
+ elif archs == ('ppc64', 'x86_64'):
+ machine = 'fat64'
+ elif archs == ('i386', 'ppc', 'ppc64', 'x86_64'):
+ machine = 'universal'
+ else:
+ raise ValueError(
+ "Don't know machine value for archs=%r"%(archs,))
+
+ elif machine == 'i386':
+ # On OSX the machine type returned by uname is always the
+ # 32-bit variant, even if the executable architecture is
+ # the 64-bit variant
+ if sys.maxint >= 2**32:
+ machine = 'x86_64'
elif machine in ('PowerPC', 'Power_Macintosh'):
# Pick a sane name for the PPC architecture.
machine = 'ppc'
+ # See 'i386' case
+ if sys.maxint >= 2**32:
+ machine = 'ppc64'
+
return "%s-%s-%s" % (osname, release, machine)
# get_platform ()
@@ -144,7 +205,7 @@
paths.remove('.')
if not paths:
return os.curdir
- return apply(os.path.join, paths)
+ return os.path.join(*paths)
# convert_path ()
@@ -201,11 +262,11 @@
if _environ_checked:
return
- if os.name == 'posix' and not os.environ.has_key('HOME'):
+ if os.name == 'posix' and 'HOME' not in os.environ:
import pwd
os.environ['HOME'] = pwd.getpwuid(os.getuid())[5]
- if not os.environ.has_key('PLAT'):
+ if 'PLAT' not in os.environ:
os.environ['PLAT'] = get_platform()
_environ_checked = 1
@@ -223,7 +284,7 @@
check_environ()
def _subst (match, local_vars=local_vars):
var_name = match.group(1)
- if local_vars.has_key(var_name):
+ if var_name in local_vars:
return str(local_vars[var_name])
else:
return os.environ[var_name]
@@ -345,7 +406,7 @@
log.info(msg)
if not dry_run:
- apply(func, args)
+ func(*args)
def strtobool (val):
@@ -397,6 +458,9 @@
generated in indirect mode; unless you know what you're doing, leave
it set to None.
"""
+ # nothing is done if sys.dont_write_bytecode is True
+ if sys.dont_write_bytecode:
+ raise DistutilsByteCompileError('byte-compiling is disabled.')
# First, if the caller didn't force us into direct or indirect mode,
# figure out which mode we should be in. We take a conservative
@@ -512,6 +576,5 @@
RFC-822 header, by ensuring there are 8 spaces space after each newline.
"""
lines = string.split(header, '\n')
- lines = map(string.strip, lines)
header = string.join(lines, '\n' + 8*' ')
return header
diff --git a/Lib/filecmp.py b/Lib/filecmp.py
--- a/Lib/filecmp.py
+++ b/Lib/filecmp.py
@@ -11,7 +11,6 @@
import os
import stat
-import warnings
from itertools import ifilter, ifilterfalse, imap, izip
__all__ = ["cmp","dircmp","cmpfiles"]
@@ -136,9 +135,9 @@
def phase1(self): # Compute common names
a = dict(izip(imap(os.path.normcase, self.left_list), self.left_list))
b = dict(izip(imap(os.path.normcase, self.right_list), self.right_list))
- self.common = map(a.__getitem__, ifilter(b.has_key, a))
- self.left_only = map(a.__getitem__, ifilterfalse(b.has_key, a))
- self.right_only = map(b.__getitem__, ifilterfalse(a.has_key, b))
+ self.common = map(a.__getitem__, ifilter(b.__contains__, a))
+ self.left_only = map(a.__getitem__, ifilterfalse(b.__contains__, a))
+ self.right_only = map(b.__getitem__, ifilterfalse(a.__contains__, b))
def phase2(self): # Distinguish files, directories, funnies
self.common_dirs = []
diff --git a/Lib/fileinput.py b/Lib/fileinput.py
--- a/Lib/fileinput.py
+++ b/Lib/fileinput.py
@@ -226,7 +226,7 @@
self._mode = mode
if inplace and openhook:
raise ValueError("FileInput cannot use an opening hook in inplace mode")
- elif openhook and not callable(openhook):
+ elif openhook and not hasattr(openhook, '__call__'):
raise ValueError("FileInput openhook must be callable")
self._openhook = openhook
diff --git a/Lib/gettext.py b/Lib/gettext.py
--- a/Lib/gettext.py
+++ b/Lib/gettext.py
@@ -472,7 +472,7 @@
# once.
result = None
for mofile in mofiles:
- key = os.path.abspath(mofile)
+ key = (class_, os.path.abspath(mofile))
t = _translations.get(key)
if t is None:
with open(mofile, 'rb') as fp:
diff --git a/Lib/mailbox.py b/Lib/mailbox.py
--- a/Lib/mailbox.py
+++ b/Lib/mailbox.py
@@ -16,9 +16,8 @@
import errno
import copy
import email
-import email.Message
-import email.Generator
-import rfc822
+import email.message
+import email.generator
import StringIO
try:
if sys.platform == 'os2emx':
@@ -28,6 +27,13 @@
except ImportError:
fcntl = None
+import warnings
+with warnings.catch_warnings():
+ if sys.py3kwarning:
+ warnings.filterwarnings("ignore", ".*rfc822 has been removed",
+ DeprecationWarning)
+ import rfc822
+
__all__ = [ 'Mailbox', 'Maildir', 'mbox', 'MH', 'Babyl', 'MMDF',
'Message', 'MaildirMessage', 'mboxMessage', 'MHMessage',
'BabylMessage', 'MMDFMessage', 'UnixMailbox',
@@ -196,9 +202,9 @@
# To get native line endings on disk, the user-friendly \n line endings
# used in strings and by email.Message are translated here.
"""Dump message contents to target file."""
- if isinstance(message, email.Message.Message):
+ if isinstance(message, email.message.Message):
buffer = StringIO.StringIO()
- gen = email.Generator.Generator(buffer, mangle_from_, 0)
+ gen = email.generator.Generator(buffer, mangle_from_, 0)
gen.flatten(message)
buffer.seek(0)
target.write(buffer.read().replace('\n', os.linesep))
@@ -237,14 +243,21 @@
else:
raise NoSuchMailboxError(self._path)
self._toc = {}
+ self._last_read = None # Records last time we read cur/new
+ # NOTE: we manually invalidate _last_read each time we do any
+ # modifications ourselves, otherwise we might get tripped up by
+ # bogus mtime behaviour on some systems (see issue #6896).
def add(self, message):
"""Add message and return assigned key."""
tmp_file = self._create_tmp()
try:
self._dump_message(message, tmp_file)
- finally:
- _sync_close(tmp_file)
+ except BaseException:
+ tmp_file.close()
+ os.remove(tmp_file.name)
+ raise
+ _sync_close(tmp_file)
if isinstance(message, MaildirMessage):
subdir = message.get_subdir()
suffix = self.colon + message.get_info()
@@ -270,11 +283,15 @@
raise
if isinstance(message, MaildirMessage):
os.utime(dest, (os.path.getatime(dest), message.get_date()))
+ # Invalidate cached toc
+ self._last_read = None
return uniq
def remove(self, key):
"""Remove the keyed message; raise KeyError if it doesn't exist."""
os.remove(os.path.join(self._path, self._lookup(key)))
+ # Invalidate cached toc (only on success)
+ self._last_read = None
def discard(self, key):
"""If the keyed message exists, remove it."""
@@ -309,6 +326,8 @@
if isinstance(message, MaildirMessage):
os.utime(new_path, (os.path.getatime(new_path),
message.get_date()))
+ # Invalidate cached toc
+ self._last_read = None
def get_message(self, key):
"""Return a Message representation or raise a KeyError."""
@@ -363,7 +382,9 @@
def flush(self):
"""Write any pending changes to disk."""
- return # Maildir changes are always written immediately.
+ # Maildir changes are always written immediately, so there's nothing
+ # to do except invalidate our cached toc.
+ self._last_read = None
def lock(self):
"""Lock the mailbox."""
@@ -398,7 +419,8 @@
result = Maildir(path, factory=self._factory)
maildirfolder_path = os.path.join(path, 'maildirfolder')
if not os.path.exists(maildirfolder_path):
- os.close(os.open(maildirfolder_path, os.O_CREAT | os.O_WRONLY))
+ os.close(os.open(maildirfolder_path, os.O_CREAT | os.O_WRONLY,
+ 0666))
return result
def remove_folder(self, folder):
@@ -460,16 +482,37 @@
def _refresh(self):
"""Update table of contents mapping."""
+ if self._last_read is not None:
+ for subdir in ('new', 'cur'):
+ mtime = os.path.getmtime(os.path.join(self._path, subdir))
+ if mtime > self._last_read:
+ break
+ else:
+ return
+
+ # We record the current time - 1sec so that, if _refresh() is called
+ # again in the same second, we will always re-read the mailbox
+ # just in case it's been modified. (os.path.mtime() only has
+ # 1sec resolution.) This results in a few unnecessary re-reads
+ # when _refresh() is called multiple times in the same second,
+ # but once the clock ticks over, we will only re-read as needed.
+ now = time.time() - 1
+
self._toc = {}
- for subdir in ('new', 'cur'):
- subdir_path = os.path.join(self._path, subdir)
- for entry in os.listdir(subdir_path):
- p = os.path.join(subdir_path, entry)
+ def update_dir (subdir):
+ path = os.path.join(self._path, subdir)
+ for entry in os.listdir(path):
+ p = os.path.join(path, entry)
if os.path.isdir(p):
continue
uniq = entry.split(self.colon)[0]
self._toc[uniq] = os.path.join(subdir, entry)
+ update_dir('new')
+ update_dir('cur')
+
+ self._last_read = now
+
def _lookup(self, key):
"""Use TOC to return subpath for given key, or raise a KeyError."""
try:
@@ -511,7 +554,7 @@
f = open(self._path, 'wb+')
else:
raise NoSuchMailboxError(self._path)
- elif e.errno == errno.EACCES:
+ elif e.errno in (errno.EACCES, errno.EROFS):
f = open(self._path, 'rb')
else:
raise
@@ -520,6 +563,7 @@
self._next_key = 0
self._pending = False # No changes require rewriting the file.
self._locked = False
+ self._file_length = None # Used to record mailbox size
def add(self, message):
"""Add message and return assigned key."""
@@ -573,7 +617,21 @@
"""Write any pending changes to disk."""
if not self._pending:
return
- self._lookup()
+
+ # In order to be writing anything out at all, self._toc must
+ # already have been generated (and presumably has been modified
+ # by adding or deleting an item).
+ assert self._toc is not None
+
+ # Check length of self._file; if it's changed, some other process
+ # has modified the mailbox since we scanned it.
+ self._file.seek(0, 2)
+ cur_len = self._file.tell()
+ if cur_len != self._file_length:
+ raise ExternalClashError('Size of mailbox file changed '
+ '(expected %i, found %i)' %
+ (self._file_length, cur_len))
+
new_file = _create_temporary(self._path)
try:
new_toc = {}
@@ -645,10 +703,16 @@
def _append_message(self, message):
"""Append message to mailbox and return (start, stop) offsets."""
self._file.seek(0, 2)
- self._pre_message_hook(self._file)
- offsets = self._install_message(message)
- self._post_message_hook(self._file)
+ before = self._file.tell()
+ try:
+ self._pre_message_hook(self._file)
+ offsets = self._install_message(message)
+ self._post_message_hook(self._file)
+ except BaseException:
+ self._file.truncate(before)
+ raise
self._file.flush()
+ self._file_length = self._file.tell() # Record current length of mailbox
return offsets
@@ -698,7 +762,7 @@
message = ''
elif isinstance(message, _mboxMMDFMessage):
from_line = 'From ' + message.get_from()
- elif isinstance(message, email.Message.Message):
+ elif isinstance(message, email.message.Message):
from_line = message.get_unixfrom() # May be None.
if from_line is None:
from_line = 'From MAILER-DAEMON %s' % time.asctime(time.gmtime())
@@ -740,6 +804,7 @@
break
self._toc = dict(enumerate(zip(starts, stops)))
self._next_key = len(self._toc)
+ self._file_length = self._file.tell()
class MMDF(_mboxMMDF):
@@ -783,6 +848,8 @@
break
self._toc = dict(enumerate(zip(starts, stops)))
self._next_key = len(self._toc)
+ self._file.seek(0, 2)
+ self._file_length = self._file.tell()
class MH(Mailbox):
@@ -809,18 +876,29 @@
new_key = max(keys) + 1
new_path = os.path.join(self._path, str(new_key))
f = _create_carefully(new_path)
+ closed = False
try:
if self._locked:
_lock_file(f)
try:
- self._dump_message(message, f)
+ try:
+ self._dump_message(message, f)
+ except BaseException:
+ # Unlock and close so it can be deleted on Windows
+ if self._locked:
+ _unlock_file(f)
+ _sync_close(f)
+ closed = True
+ os.remove(new_path)
+ raise
if isinstance(message, MHMessage):
self._dump_sequences(message, new_key)
finally:
if self._locked:
_unlock_file(f)
finally:
- _sync_close(f)
+ if not closed:
+ _sync_close(f)
return new_key
def remove(self, key):
@@ -833,17 +911,9 @@
raise KeyError('No message with key: %s' % key)
else:
raise
- try:
- if self._locked:
- _lock_file(f)
- try:
- f.close()
- os.remove(os.path.join(self._path, str(key)))
- finally:
- if self._locked:
- _unlock_file(f)
- finally:
+ else:
f.close()
+ os.remove(path)
def __setitem__(self, key, message):
"""Replace the keyed message; raise KeyError if it doesn't exist."""
@@ -891,7 +961,7 @@
_unlock_file(f)
finally:
f.close()
- for name, key_list in self.get_sequences():
+ for name, key_list in self.get_sequences().iteritems():
if key in key_list:
msg.add_sequence(name)
return msg
@@ -1209,6 +1279,8 @@
self._toc = dict(enumerate(zip(starts, stops)))
self._labels = dict(enumerate(label_lists))
self._next_key = len(self._toc)
+ self._file.seek(0, 2)
+ self._file_length = self._file.tell()
def _pre_mailbox_hook(self, f):
"""Called before writing the mailbox to file f."""
@@ -1244,9 +1316,9 @@
self._file.write(os.linesep)
else:
self._file.write('1,,' + os.linesep)
- if isinstance(message, email.Message.Message):
+ if isinstance(message, email.message.Message):
orig_buffer = StringIO.StringIO()
- orig_generator = email.Generator.Generator(orig_buffer, False, 0)
+ orig_generator = email.generator.Generator(orig_buffer, False, 0)
orig_generator.flatten(message)
orig_buffer.seek(0)
while True:
@@ -1257,7 +1329,7 @@
self._file.write('*** EOOH ***' + os.linesep)
if isinstance(message, BabylMessage):
vis_buffer = StringIO.StringIO()
- vis_generator = email.Generator.Generator(vis_buffer, False, 0)
+ vis_generator = email.generator.Generator(vis_buffer, False, 0)
vis_generator.flatten(message.get_visible())
while True:
line = vis_buffer.readline()
@@ -1313,12 +1385,12 @@
return (start, stop)
-class Message(email.Message.Message):
+class Message(email.message.Message):
"""Message with mailbox-format-specific properties."""
def __init__(self, message=None):
"""Initialize a Message instance."""
- if isinstance(message, email.Message.Message):
+ if isinstance(message, email.message.Message):
self._become_message(copy.deepcopy(message))
if isinstance(message, Message):
message._explain_to(self)
@@ -1327,7 +1399,7 @@
elif hasattr(message, "read"):
self._become_message(email.message_from_file(message))
elif message is None:
- email.Message.Message.__init__(self)
+ email.message.Message.__init__(self)
else:
raise TypeError('Invalid message type: %s' % type(message))
@@ -1458,7 +1530,7 @@
def __init__(self, message=None):
"""Initialize an mboxMMDFMessage instance."""
self.set_from('MAILER-DAEMON', True)
- if isinstance(message, email.Message.Message):
+ if isinstance(message, email.message.Message):
unixfrom = message.get_unixfrom()
if unixfrom is not None and unixfrom.startswith('From '):
self.set_from(unixfrom[5:])
@@ -1835,7 +1907,7 @@
try:
fcntl.lockf(f, fcntl.LOCK_EX | fcntl.LOCK_NB)
except IOError, e:
- if e.errno in (errno.EAGAIN, errno.EACCES):
+ if e.errno in (errno.EAGAIN, errno.EACCES, errno.EROFS):
raise ExternalClashError('lockf: lock unavailable: %s' %
f.name)
else:
@@ -1845,7 +1917,7 @@
pre_lock = _create_temporary(f.name + '.lock')
pre_lock.close()
except IOError, e:
- if e.errno == errno.EACCES:
+ if e.errno in (errno.EACCES, errno.EROFS):
return # Without write access, just skip dotlocking.
else:
raise
@@ -1881,7 +1953,7 @@
def _create_carefully(path):
"""Create a file if it doesn't exist and open for reading and writing."""
- fd = os.open(path, os.O_CREAT | os.O_EXCL | os.O_RDWR)
+ fd = os.open(path, os.O_CREAT | os.O_EXCL | os.O_RDWR, 0666)
try:
return open(path, 'rb+')
finally:
diff --git a/Lib/netrc.py b/Lib/netrc.py
--- a/Lib/netrc.py
+++ b/Lib/netrc.py
@@ -35,11 +35,15 @@
def _parse(self, file, fp):
lexer = shlex.shlex(fp)
lexer.wordchars += r"""!"#$%&'()*+,-./:;<=>?@[\]^_`{|}~"""
+ lexer.commenters = lexer.commenters.replace('#', '')
while 1:
# Look for a machine, default, or macdef top-level keyword
toplevel = tt = lexer.get_token()
if not tt:
break
+ elif tt[0] == '#':
+ fp.readline();
+ continue;
elif tt == 'machine':
entryname = lexer.get_token()
elif tt == 'default':
diff --git a/Lib/new.py b/Lib/new.py
--- a/Lib/new.py
+++ b/Lib/new.py
@@ -3,6 +3,10 @@
This module is no longer required except for backward compatibility.
Objects of most types can now be created by calling the type object.
"""
+from warnings import warnpy3k
+warnpy3k("The 'new' module has been removed in Python 3.0; use the 'types' "
+ "module instead.", stacklevel=2)
+del warnpy3k
from types import ClassType as classobj
from types import FunctionType as function
diff --git a/Lib/py_compile.py b/Lib/py_compile.py
--- a/Lib/py_compile.py
+++ b/Lib/py_compile.py
@@ -114,11 +114,15 @@
"""
if args is None:
args = sys.argv[1:]
+ rv = 0
for filename in args:
try:
compile(filename, doraise=True)
- except PyCompileError,err:
+ except PyCompileError, err:
+ # return value to indicate at least one failure
+ rv = 1
sys.stderr.write(err.msg)
+ return rv
if __name__ == "__main__":
- main()
+ sys.exit(main())
diff --git a/Lib/robotparser.py b/Lib/robotparser.py
--- a/Lib/robotparser.py
+++ b/Lib/robotparser.py
@@ -133,7 +133,12 @@
return True
# search for given user agent matches
# the first match counts
- url = urllib.quote(urlparse.urlparse(urllib.unquote(url))[2]) or "/"
+ parsed_url = urlparse.urlparse(urllib.unquote(url))
+ url = urlparse.urlunparse(('', '', parsed_url.path,
+ parsed_url.params, parsed_url.query, parsed_url.fragment))
+ url = urllib.quote(url)
+ if not url:
+ url = "/"
for entry in self.entries:
if entry.applies_to(useragent):
return entry.allowance(url)
diff --git a/Lib/test/list_tests.py b/Lib/test/list_tests.py
--- a/Lib/test/list_tests.py
+++ b/Lib/test/list_tests.py
@@ -5,7 +5,6 @@
import sys
import os
-import unittest
from test import test_support, seq_tests
class CommonTest(seq_tests.CommonTest):
@@ -37,7 +36,7 @@
self.assertEqual(str(a0), str(l0))
self.assertEqual(repr(a0), repr(l0))
- self.assertEqual(`a2`, `l2`)
+ self.assertEqual(repr(a2), repr(l2))
self.assertEqual(str(a2), "[0, 1, 2]")
self.assertEqual(repr(a2), "[0, 1, 2]")
@@ -46,6 +45,11 @@
self.assertEqual(str(a2), "[0, 1, 2, [...], 3]")
self.assertEqual(repr(a2), "[0, 1, 2, [...], 3]")
+ l0 = []
+ for i in xrange(sys.getrecursionlimit() + 100):
+ l0 = [l0]
+ self.assertRaises(RuntimeError, repr, l0)
+
def test_print(self):
d = self.type2test(xrange(200))
d.append(d)
@@ -53,13 +57,11 @@
d.append(d)
d.append(400)
try:
- fo = open(test_support.TESTFN, "wb")
- print >> fo, d,
- fo.close()
- fo = open(test_support.TESTFN, "rb")
- self.assertEqual(fo.read(), repr(d))
+ with open(test_support.TESTFN, "wb") as fo:
+ print >> fo, d,
+ with open(test_support.TESTFN, "rb") as fo:
+ self.assertEqual(fo.read(), repr(d))
finally:
- fo.close()
os.remove(test_support.TESTFN)
def test_set_subscript(self):
@@ -80,6 +82,8 @@
self.assertRaises(StopIteration, r.next)
self.assertEqual(list(reversed(self.type2test())),
self.type2test())
+ # Bug 3689: make sure list-reversed-iterator doesn't have __len__
+ self.assertRaises(TypeError, len, reversed([1,2,3]))
def test_setitem(self):
a = self.type2test([0, 1])
@@ -179,8 +183,10 @@
self.assertEqual(a, self.type2test(range(10)))
self.assertRaises(TypeError, a.__setslice__, 0, 1, 5)
+ self.assertRaises(TypeError, a.__setitem__, slice(0, 1, 5))
self.assertRaises(TypeError, a.__setslice__)
+ self.assertRaises(TypeError, a.__setitem__)
def test_delslice(self):
a = self.type2test([0, 1])
@@ -324,7 +330,7 @@
self.assertRaises(BadExc, d.remove, 'c')
for x, y in zip(d, e):
# verify that original order and values are retained.
- self.assert_(x is y)
+ self.assertIs(x, y)
def test_count(self):
a = self.type2test([0, 1, 2])*3
@@ -413,6 +419,11 @@
self.assertRaises(TypeError, u.reverse, 42)
def test_sort(self):
+ with test_support.check_py3k_warnings(
+ ("the cmp argument is not supported", DeprecationWarning)):
+ self._test_sort()
+
+ def _test_sort(self):
u = self.type2test([1, 0])
u.sort()
self.assertEqual(u, [0, 1])
@@ -455,7 +466,7 @@
u = self.type2test([0, 1])
u2 = u
u += [2, 3]
- self.assert_(u is u2)
+ self.assertIs(u, u2)
u = self.type2test("spam")
u += "eggs"
@@ -515,13 +526,14 @@
a = self.type2test(range(10))
a[::2] = tuple(range(5))
self.assertEqual(a, self.type2test([0, 1, 1, 3, 2, 5, 3, 7, 4, 9]))
+ # test issue7788
+ a = self.type2test(range(10))
+ del a[9::1<<333]
# XXX: CPython specific, PyList doesn't len() during init
def _test_constructor_exception_handling(self):
# Bug #1242657
class F(object):
def __iter__(self):
- yield 23
- def __len__(self):
raise KeyboardInterrupt
self.assertRaises(KeyboardInterrupt, list, F())
diff --git a/Lib/test/pickletester.py b/Lib/test/pickletester.py
--- a/Lib/test/pickletester.py
+++ b/Lib/test/pickletester.py
@@ -1,6 +1,8 @@
import unittest
import pickle
import cPickle
+import StringIO
+import cStringIO
import pickletools
import copy_reg
@@ -12,6 +14,42 @@
assert pickle.HIGHEST_PROTOCOL == cPickle.HIGHEST_PROTOCOL == 2
protocols = range(pickle.HIGHEST_PROTOCOL + 1)
+# Copy of test.test_support.run_with_locale. This is needed to support Python
+# 2.4, which didn't include it. This is all to support test_xpickle, which
+# bounces pickled objects through older Python versions to test backwards
+# compatibility.
+def run_with_locale(catstr, *locales):
+ def decorator(func):
+ def inner(*args, **kwds):
+ try:
+ import locale
+ category = getattr(locale, catstr)
+ orig_locale = locale.setlocale(category)
+ except AttributeError:
+ # if the test author gives us an invalid category string
+ raise
+ except:
+ # cannot retrieve original locale, so do nothing
+ locale = orig_locale = None
+ else:
+ for loc in locales:
+ try:
+ locale.setlocale(category, loc)
+ break
+ except:
+ pass
+
+ # now run the function, resetting the locale on exceptions
+ try:
+ return func(*args, **kwds)
+ finally:
+ if locale and orig_locale:
+ locale.setlocale(category, orig_locale)
+ inner.func_name = func.func_name
+ inner.__doc__ = func.__doc__
+ return inner
+ return decorator
+
# Return True if opcode code appears in the pickle, else False.
def opcode_in_pickle(code, pickle):
@@ -408,12 +446,11 @@
# is a mystery. cPickle also suppresses PUT for objects with a refcount
# of 1.
def dont_test_disassembly(self):
- from cStringIO import StringIO
from pickletools import dis
for proto, expected in (0, DATA0_DIS), (1, DATA1_DIS):
s = self.dumps(self._testdata, proto)
- filelike = StringIO()
+ filelike = cStringIO.StringIO()
dis(s, out=filelike)
got = filelike.getvalue()
self.assertEqual(expected, got)
@@ -424,9 +461,18 @@
for proto in protocols:
s = self.dumps(l, proto)
x = self.loads(s)
- self.assertEqual(x, l)
- self.assertEqual(x, x[0])
- self.assertEqual(id(x), id(x[0]))
+ self.assertEqual(len(x), 1)
+ self.assertTrue(x is x[0])
+
+ def test_recursive_tuple(self):
+ t = ([],)
+ t[0].append(t)
+ for proto in protocols:
+ s = self.dumps(t, proto)
+ x = self.loads(s)
+ self.assertEqual(len(x), 1)
+ self.assertEqual(len(x[0]), 1)
+ self.assertTrue(x is x[0][0])
def test_recursive_dict(self):
d = {}
@@ -434,9 +480,8 @@
for proto in protocols:
s = self.dumps(d, proto)
x = self.loads(s)
- self.assertEqual(x, d)
- self.assertEqual(x[1], x)
- self.assertEqual(id(x[1]), id(x))
+ self.assertEqual(x.keys(), [1])
+ self.assertTrue(x[1] is x)
def test_recursive_inst(self):
i = C()
@@ -444,9 +489,8 @@
for proto in protocols:
s = self.dumps(i, 2)
x = self.loads(s)
- self.assertEqual(x, i)
- self.assertEqual(x.attr, x)
- self.assertEqual(id(x.attr), id(x))
+ self.assertEqual(dir(x), dir(i))
+ self.assertTrue(x.attr is x)
def test_recursive_multi(self):
l = []
@@ -457,12 +501,10 @@
for proto in protocols:
s = self.dumps(l, proto)
x = self.loads(s)
- self.assertEqual(x, l)
- self.assertEqual(x[0], i)
- self.assertEqual(x[0].attr, d)
- self.assertEqual(x[0].attr[1], x)
- self.assertEqual(x[0].attr[1][0], i)
- self.assertEqual(x[0].attr[1][0].attr, d)
+ self.assertEqual(len(x), 1)
+ self.assertEqual(dir(x[0]), dir(i))
+ self.assertEqual(x[0].attr.keys(), [1])
+ self.assertTrue(x[0].attr[1] is x)
def test_garyp(self):
self.assertRaises(self.error, self.loads, 'garyp')
@@ -484,14 +526,21 @@
if have_unicode:
def test_unicode(self):
- endcases = [unicode(''), unicode('<\\u>'), unicode('<\\\u1234>'),
- unicode('<\n>'), unicode('<\\>')]
+ endcases = [u'', u'<\\u>', u'<\\\u1234>', u'<\n>',
+ u'<\\>', u'<\\\U00012345>']
for proto in protocols:
for u in endcases:
p = self.dumps(u, proto)
u2 = self.loads(p)
self.assertEqual(u2, u)
+ def test_unicode_high_plane(self):
+ t = u'\U00012345'
+ for proto in protocols:
+ p = self.dumps(t, proto)
+ t2 = self.loads(p)
+ self.assertEqual(t2, t)
+
def test_ints(self):
import sys
for proto in protocols:
@@ -534,6 +583,21 @@
got = self.loads(p)
self.assertEqual(n, got)
+ def test_float(self):
+ test_values = [0.0, 4.94e-324, 1e-310, 7e-308, 6.626e-34, 0.1, 0.5,
+ 3.14, 263.44582062374053, 6.022e23, 1e30]
+ test_values = test_values + [-x for x in test_values]
+ for proto in protocols:
+ for value in test_values:
+ pickle = self.dumps(value, proto)
+ got = self.loads(pickle)
+ self.assertEqual(value, got)
+
+ @run_with_locale('LC_ALL', 'de_DE', 'fr_FR')
+ def test_float_format(self):
+ # make sure that floats are formatted locale independent
+ self.assertEqual(self.dumps(1.2)[0:3], 'F1.')
+
def test_reduce(self):
pass
@@ -583,7 +647,7 @@
try:
self.loads(badpickle)
except ValueError, detail:
- self.failUnless(str(detail).startswith(
+ self.assertTrue(str(detail).startswith(
"unsupported pickle protocol"))
else:
self.fail("expected bad protocol number to raise ValueError")
@@ -655,7 +719,7 @@
for x in None, False, True:
s = self.dumps(x, proto)
y = self.loads(s)
- self.assert_(x is y, (proto, x, s, y))
+ self.assertTrue(x is y, (proto, x, s, y))
expected = expected_opcode[proto, x]
self.assertEqual(opcode_in_pickle(expected, s), True)
@@ -705,8 +769,8 @@
# Dump using protocol 1 for comparison.
s1 = self.dumps(x, 1)
- self.assert_(__name__ in s1)
- self.assert_("MyList" in s1)
+ self.assertIn(__name__, s1)
+ self.assertIn("MyList", s1)
self.assertEqual(opcode_in_pickle(opcode, s1), False)
y = self.loads(s1)
@@ -715,8 +779,8 @@
# Dump using protocol 2 for test.
s2 = self.dumps(x, 2)
- self.assert_(__name__ not in s2)
- self.assert_("MyList" not in s2)
+ self.assertNotIn(__name__, s2)
+ self.assertNotIn("MyList", s2)
self.assertEqual(opcode_in_pickle(opcode, s2), True)
y = self.loads(s2)
@@ -760,7 +824,7 @@
if proto == 0:
self.assertEqual(num_appends, 0)
else:
- self.failUnless(num_appends >= 2)
+ self.assertTrue(num_appends >= 2)
def test_dict_chunking(self):
n = 10 # too small to chunk
@@ -782,7 +846,7 @@
if proto == 0:
self.assertEqual(num_setitems, 0)
else:
- self.failUnless(num_setitems >= 2)
+ self.assertTrue(num_setitems >= 2)
def test_simple_newobj(self):
x = object.__new__(SimpleNewObj) # avoid __init__
@@ -806,7 +870,7 @@
self.assertEqual(x.bar, y.bar)
def test_reduce_overrides_default_reduce_ex(self):
- for proto in 0, 1, 2:
+ for proto in protocols:
x = REX_one()
self.assertEqual(x._reduce_called, 0)
s = self.dumps(x, proto)
@@ -815,7 +879,7 @@
self.assertEqual(y._reduce_called, 0)
def test_reduce_ex_called(self):
- for proto in 0, 1, 2:
+ for proto in protocols:
x = REX_two()
self.assertEqual(x._proto, None)
s = self.dumps(x, proto)
@@ -824,7 +888,7 @@
self.assertEqual(y._proto, None)
def test_reduce_ex_overrides_reduce(self):
- for proto in 0, 1, 2:
+ for proto in protocols:
x = REX_three()
self.assertEqual(x._proto, None)
s = self.dumps(x, proto)
@@ -832,6 +896,76 @@
y = self.loads(s)
self.assertEqual(y._proto, None)
+ def test_reduce_ex_calls_base(self):
+ for proto in protocols:
+ x = REX_four()
+ self.assertEqual(x._proto, None)
+ s = self.dumps(x, proto)
+ self.assertEqual(x._proto, proto)
+ y = self.loads(s)
+ self.assertEqual(y._proto, proto)
+
+ def test_reduce_calls_base(self):
+ for proto in protocols:
+ x = REX_five()
+ self.assertEqual(x._reduce_called, 0)
+ s = self.dumps(x, proto)
+ self.assertEqual(x._reduce_called, 1)
+ y = self.loads(s)
+ self.assertEqual(y._reduce_called, 1)
+
+ def test_reduce_bad_iterator(self):
+ # Issue4176: crash when 4th and 5th items of __reduce__()
+ # are not iterators
+ class C(object):
+ def __reduce__(self):
+ # 4th item is not an iterator
+ return list, (), None, [], None
+ class D(object):
+ def __reduce__(self):
+ # 5th item is not an iterator
+ return dict, (), None, None, []
+
+ # Protocol 0 is less strict and also accept iterables.
+ for proto in protocols:
+ try:
+ self.dumps(C(), proto)
+ except (AttributeError, pickle.PickleError, cPickle.PickleError):
+ pass
+ try:
+ self.dumps(D(), proto)
+ except (AttributeError, pickle.PickleError, cPickle.PickleError):
+ pass
+
+ def test_many_puts_and_gets(self):
+ # Test that internal data structures correctly deal with lots of
+ # puts/gets.
+ keys = ("aaa" + str(i) for i in xrange(100))
+ large_dict = dict((k, [4, 5, 6]) for k in keys)
+ obj = [dict(large_dict), dict(large_dict), dict(large_dict)]
+
+ for proto in protocols:
+ dumped = self.dumps(obj, proto)
+ loaded = self.loads(dumped)
+ self.assertEqual(loaded, obj,
+ "Failed protocol %d: %r != %r"
+ % (proto, obj, loaded))
+
+ def test_attribute_name_interning(self):
+ # Test that attribute names of pickled objects are interned when
+ # unpickling.
+ for proto in protocols:
+ x = C()
+ x.foo = 42
+ x.bar = "hello"
+ s = self.dumps(x, proto)
+ y = self.loads(s)
+ x_keys = sorted(x.__dict__)
+ y_keys = sorted(y.__dict__)
+ for x_key, y_key in zip(x_keys, y_keys):
+ self.assertIs(x_key, y_key)
+
+
# Test classes for reduce_ex
class REX_one(object):
@@ -856,6 +990,20 @@
def __reduce__(self):
raise TestFailed, "This __reduce__ shouldn't be called"
+class REX_four(object):
+ _proto = None
+ def __reduce_ex__(self, proto):
+ self._proto = proto
+ return object.__reduce_ex__(self, proto)
+ # Calling base class method should succeed
+
+class REX_five(object):
+ _reduce_called = 0
+ def __reduce__(self):
+ self._reduce_called = 1
+ return object.__reduce__(self)
+ # This one used to fail with infinite recursion
+
# Test classes for newobj
class MyInt(int):
@@ -919,10 +1067,50 @@
finally:
os.remove(TESTFN)
+ def test_load_from_and_dump_to_file(self):
+ stream = cStringIO.StringIO()
+ data = [123, {}, 124]
+ self.module.dump(data, stream)
+ stream.seek(0)
+ unpickled = self.module.load(stream)
+ self.assertEqual(unpickled, data)
+
def test_highest_protocol(self):
# Of course this needs to be changed when HIGHEST_PROTOCOL changes.
self.assertEqual(self.module.HIGHEST_PROTOCOL, 2)
+ def test_callapi(self):
+ f = cStringIO.StringIO()
+ # With and without keyword arguments
+ self.module.dump(123, f, -1)
+ self.module.dump(123, file=f, protocol=-1)
+ self.module.dumps(123, -1)
+ self.module.dumps(123, protocol=-1)
+ self.module.Pickler(f, -1)
+ self.module.Pickler(f, protocol=-1)
+
+ def test_incomplete_input(self):
+ s = StringIO.StringIO("X''.")
+ self.assertRaises(EOFError, self.module.load, s)
+
+ def test_restricted(self):
+ # issue7128: cPickle failed in restricted mode
+ builtins = {self.module.__name__: self.module,
+ '__import__': __import__}
+ d = {}
+ teststr = "def f(): {0}.dumps(0)".format(self.module.__name__)
+ exec teststr in {'__builtins__': builtins}, d
+ d['f']()
+
+ def test_bad_input(self):
+ # Test issue4298
+ s = '\x58\0\0\0\x54'
+ self.assertRaises(EOFError, self.module.loads, s)
+ # Test issue7455
+ s = '0'
+ # XXX Why doesn't pickle raise UnpicklingError?
+ self.assertRaises((IndexError, cPickle.UnpicklingError),
+ self.module.loads, s)
class AbstractPersistentPicklerTests(unittest.TestCase):
@@ -958,3 +1146,116 @@
self.assertEqual(self.loads(self.dumps(L, 1)), L)
self.assertEqual(self.id_count, 5)
self.assertEqual(self.load_count, 5)
+
+class AbstractPicklerUnpicklerObjectTests(unittest.TestCase):
+
+ pickler_class = None
+ unpickler_class = None
+
+ def setUp(self):
+ assert self.pickler_class
+ assert self.unpickler_class
+
+ def test_clear_pickler_memo(self):
+ # To test whether clear_memo() has any effect, we pickle an object,
+ # then pickle it again without clearing the memo; the two serialized
+ # forms should be different. If we clear_memo() and then pickle the
+ # object again, the third serialized form should be identical to the
+ # first one we obtained.
+ data = ["abcdefg", "abcdefg", 44]
+ f = cStringIO.StringIO()
+ pickler = self.pickler_class(f)
+
+ pickler.dump(data)
+ first_pickled = f.getvalue()
+
+ # Reset StringIO object.
+ f.seek(0)
+ f.truncate()
+
+ pickler.dump(data)
+ second_pickled = f.getvalue()
+
+ # Reset the Pickler and StringIO objects.
+ pickler.clear_memo()
+ f.seek(0)
+ f.truncate()
+
+ pickler.dump(data)
+ third_pickled = f.getvalue()
+
+ self.assertNotEqual(first_pickled, second_pickled)
+ self.assertEqual(first_pickled, third_pickled)
+
+ def test_priming_pickler_memo(self):
+ # Verify that we can set the Pickler's memo attribute.
+ data = ["abcdefg", "abcdefg", 44]
+ f = cStringIO.StringIO()
+ pickler = self.pickler_class(f)
+
+ pickler.dump(data)
+ first_pickled = f.getvalue()
+
+ f = cStringIO.StringIO()
+ primed = self.pickler_class(f)
+ primed.memo = pickler.memo
+
+ primed.dump(data)
+ primed_pickled = f.getvalue()
+
+ self.assertNotEqual(first_pickled, primed_pickled)
+
+ def test_priming_unpickler_memo(self):
+ # Verify that we can set the Unpickler's memo attribute.
+ data = ["abcdefg", "abcdefg", 44]
+ f = cStringIO.StringIO()
+ pickler = self.pickler_class(f)
+
+ pickler.dump(data)
+ first_pickled = f.getvalue()
+
+ f = cStringIO.StringIO()
+ primed = self.pickler_class(f)
+ primed.memo = pickler.memo
+
+ primed.dump(data)
+ primed_pickled = f.getvalue()
+
+ unpickler = self.unpickler_class(cStringIO.StringIO(first_pickled))
+ unpickled_data1 = unpickler.load()
+
+ self.assertEqual(unpickled_data1, data)
+
+ primed = self.unpickler_class(cStringIO.StringIO(primed_pickled))
+ primed.memo = unpickler.memo
+ unpickled_data2 = primed.load()
+
+ primed.memo.clear()
+
+ self.assertEqual(unpickled_data2, data)
+ self.assertTrue(unpickled_data2 is unpickled_data1)
+
+ def test_reusing_unpickler_objects(self):
+ data1 = ["abcdefg", "abcdefg", 44]
+ f = cStringIO.StringIO()
+ pickler = self.pickler_class(f)
+ pickler.dump(data1)
+ pickled1 = f.getvalue()
+
+ data2 = ["abcdefg", 44, 44]
+ f = cStringIO.StringIO()
+ pickler = self.pickler_class(f)
+ pickler.dump(data2)
+ pickled2 = f.getvalue()
+
+ f = cStringIO.StringIO()
+ f.write(pickled1)
+ f.seek(0)
+ unpickler = self.unpickler_class(f)
+ self.assertEqual(unpickler.load(), data1)
+
+ f.seek(0)
+ f.truncate()
+ f.write(pickled2)
+ f.seek(0)
+ self.assertEqual(unpickler.load(), data2)
diff --git a/Lib/test/test_array.py b/Lib/test/test_array.py
--- a/Lib/test/test_array.py
+++ b/Lib/test/test_array.py
@@ -6,8 +6,8 @@
import unittest
from test import test_support
from weakref import proxy
-import array, cStringIO, math
-from cPickle import loads, dumps
+import array, cStringIO
+from cPickle import loads, dumps, HIGHEST_PROTOCOL
if test_support.is_jython:
import operator
@@ -18,7 +18,7 @@
class ArraySubclassWithKwargs(array.array):
def __init__(self, typecode, newarg=None):
- array.array.__init__(typecode)
+ array.array.__init__(self, typecode)
tests = [] # list to accumulate all tests
typecodes = "cubBhHiIlLfd"
@@ -52,7 +52,7 @@
def test_constructor(self):
a = array.array(self.typecode)
self.assertEqual(a.typecode, self.typecode)
- self.assert_(a.itemsize>=self.minitemsize)
+ self.assertTrue(a.itemsize>=self.minitemsize)
self.assertRaises(TypeError, array.array, self.typecode, None)
def test_len(self):
@@ -67,10 +67,10 @@
a = array.array(self.typecode, self.example)
self.assertRaises(TypeError, a.buffer_info, 42)
bi = a.buffer_info()
- self.assert_(isinstance(bi, tuple))
+ self.assertIsInstance(bi, tuple)
self.assertEqual(len(bi), 2)
- self.assert_(isinstance(bi[0], (int, long)))
- self.assert_(isinstance(bi[1], int))
+ self.assertIsInstance(bi[0], (int, long))
+ self.assertIsInstance(bi[1], int)
self.assertEqual(bi[1], len(a))
def test_byteswap(self):
@@ -105,7 +105,7 @@
self.assertEqual(a, b)
def test_pickle(self):
- for protocol in (0, 1, 2):
+ for protocol in range(HIGHEST_PROTOCOL + 1):
a = array.array(self.typecode, self.example)
b = loads(dumps(a, protocol))
self.assertNotEqual(id(a), id(b))
@@ -120,7 +120,7 @@
self.assertEqual(type(a), type(b))
def test_pickle_for_empty_array(self):
- for protocol in (0, 1, 2):
+ for protocol in range(HIGHEST_PROTOCOL + 1):
a = array.array(self.typecode)
b = loads(dumps(a, protocol))
self.assertNotEqual(id(a), id(b))
@@ -171,6 +171,7 @@
a = array.array(self.typecode, 2*self.example)
self.assertRaises(TypeError, a.tofile)
self.assertRaises(TypeError, a.tofile, cStringIO.StringIO())
+ test_support.unlink(test_support.TESTFN)
f = open(test_support.TESTFN, 'wb')
try:
a.tofile(f)
@@ -195,6 +196,36 @@
f.close()
test_support.unlink(test_support.TESTFN)
+ def test_fromfile_ioerror(self):
+ # Issue #5395: Check if fromfile raises a proper IOError
+ # instead of EOFError.
+ a = array.array(self.typecode)
+ f = open(test_support.TESTFN, 'wb')
+ try:
+ self.assertRaises(IOError, a.fromfile, f, len(self.example))
+ finally:
+ f.close()
+ test_support.unlink(test_support.TESTFN)
+
+ def test_filewrite(self):
+ a = array.array(self.typecode, 2*self.example)
+ f = open(test_support.TESTFN, 'wb')
+ try:
+ f.write(a)
+ f.close()
+ b = array.array(self.typecode)
+ f = open(test_support.TESTFN, 'rb')
+ b.fromfile(f, len(self.example))
+ self.assertEqual(b, array.array(self.typecode, self.example))
+ self.assertNotEqual(a, b)
+ b.fromfile(f, len(self.example))
+ self.assertEqual(a, b)
+ f.close()
+ finally:
+ if not f.closed:
+ f.close()
+ test_support.unlink(test_support.TESTFN)
+
def test_tofromlist(self):
a = array.array(self.typecode, 2*self.example)
b = array.array(self.typecode)
@@ -248,39 +279,39 @@
def test_cmp(self):
a = array.array(self.typecode, self.example)
- self.assert_((a == 42) is False)
- self.assert_((a != 42) is True)
+ self.assertTrue((a == 42) is False)
+ self.assertTrue((a != 42) is True)
- self.assert_((a == a) is True)
- self.assert_((a != a) is False)
- self.assert_((a < a) is False)
- self.assert_((a <= a) is True)
- self.assert_((a > a) is False)
- self.assert_((a >= a) is True)
+ self.assertTrue((a == a) is True)
+ self.assertTrue((a != a) is False)
+ self.assertTrue((a < a) is False)
+ self.assertTrue((a <= a) is True)
+ self.assertTrue((a > a) is False)
+ self.assertTrue((a >= a) is True)
al = array.array(self.typecode, self.smallerexample)
ab = array.array(self.typecode, self.biggerexample)
- self.assert_((a == 2*a) is False)
- self.assert_((a != 2*a) is True)
- self.assert_((a < 2*a) is True)
- self.assert_((a <= 2*a) is True)
- self.assert_((a > 2*a) is False)
- self.assert_((a >= 2*a) is False)
+ self.assertTrue((a == 2*a) is False)
+ self.assertTrue((a != 2*a) is True)
+ self.assertTrue((a < 2*a) is True)
+ self.assertTrue((a <= 2*a) is True)
+ self.assertTrue((a > 2*a) is False)
+ self.assertTrue((a >= 2*a) is False)
- self.assert_((a == al) is False)
- self.assert_((a != al) is True)
- self.assert_((a < al) is False)
- self.assert_((a <= al) is False)
- self.assert_((a > al) is True)
- self.assert_((a >= al) is True)
+ self.assertTrue((a == al) is False)
+ self.assertTrue((a != al) is True)
+ self.assertTrue((a < al) is False)
+ self.assertTrue((a <= al) is False)
+ self.assertTrue((a > al) is True)
+ self.assertTrue((a >= al) is True)
- self.assert_((a == ab) is False)
- self.assert_((a != ab) is True)
- self.assert_((a < ab) is True)
- self.assert_((a <= ab) is True)
- self.assert_((a > ab) is False)
- self.assert_((a >= ab) is False)
+ self.assertTrue((a == ab) is False)
+ self.assertTrue((a != ab) is True)
+ self.assertTrue((a < ab) is True)
+ self.assertTrue((a <= ab) is True)
+ self.assertTrue((a > ab) is False)
+ self.assertTrue((a >= ab) is False)
def test_add(self):
a = array.array(self.typecode, self.example) \
@@ -302,11 +333,17 @@
a = array.array(self.typecode, self.example[::-1])
b = a
a += array.array(self.typecode, 2*self.example)
- self.assert_(a is b)
+ self.assertTrue(a is b)
self.assertEqual(
a,
array.array(self.typecode, self.example[::-1]+2*self.example)
)
+ a = array.array(self.typecode, self.example)
+ a += a
+ self.assertEqual(
+ a,
+ array.array(self.typecode, self.example + self.example)
+ )
b = array.array(self.badtypecode())
if test_support.is_jython:
@@ -351,22 +388,22 @@
b = a
a *= 5
- self.assert_(a is b)
+ self.assertTrue(a is b)
self.assertEqual(
a,
array.array(self.typecode, 5*self.example)
)
a *= 0
- self.assert_(a is b)
+ self.assertTrue(a is b)
self.assertEqual(a, array.array(self.typecode))
a *= 1000
- self.assert_(a is b)
+ self.assertTrue(a is b)
self.assertEqual(a, array.array(self.typecode))
a *= -1
- self.assert_(a is b)
+ self.assertTrue(a is b)
self.assertEqual(a, array.array(self.typecode))
a = array.array(self.typecode, self.example)
@@ -513,6 +550,18 @@
array.array(self.typecode)
)
+ def test_extended_getslice(self):
+ # Test extended slicing by comparing with list slicing
+ # (Assumes list conversion works correctly, too)
+ a = array.array(self.typecode, self.example)
+ indices = (0, None, 1, 3, 19, 100, -1, -2, -31, -100)
+ for start in indices:
+ for stop in indices:
+ # Everything except the initial 0 (invalid step)
+ for step in indices[1:]:
+ self.assertEqual(list(a[start:stop:step]),
+ list(a)[start:stop:step])
+
def test_setslice(self):
a = array.array(self.typecode, self.example)
a[:1] = a
@@ -596,12 +645,34 @@
a = array.array(self.typecode, self.example)
self.assertRaises(TypeError, a.__setslice__, 0, 0, None)
+ self.assertRaises(TypeError, a.__setitem__, slice(0, 0), None)
self.assertRaises(TypeError, a.__setitem__, slice(0, 1), None)
b = array.array(self.badtypecode())
self.assertRaises(TypeError, a.__setslice__, 0, 0, b)
+ self.assertRaises(TypeError, a.__setitem__, slice(0, 0), b)
self.assertRaises(TypeError, a.__setitem__, slice(0, 1), b)
+ def test_extended_set_del_slice(self):
+ indices = (0, None, 1, 3, 19, 100, -1, -2, -31, -100)
+ for start in indices:
+ for stop in indices:
+ # Everything except the initial 0 (invalid step)
+ for step in indices[1:]:
+ a = array.array(self.typecode, self.example)
+ L = list(a)
+ # Make sure we have a slice of exactly the right length,
+ # but with (hopefully) different data.
+ data = L[start:stop:step]
+ data.reverse()
+ L[start:stop:step] = data
+ a[start:stop:step] = array.array(self.typecode, data)
+ self.assertEqual(a, array.array(self.typecode, L))
+
+ del L[start:stop:step]
+ del a[start:stop:step]
+ self.assertEqual(a, array.array(self.typecode, L))
+
def test_index(self):
example = 2*self.example
a = array.array(self.typecode, example)
@@ -679,6 +750,13 @@
array.array(self.typecode, self.example+self.example[::-1])
)
+ a = array.array(self.typecode, self.example)
+ a.extend(a)
+ self.assertEqual(
+ a,
+ array.array(self.typecode, self.example+self.example)
+ )
+
b = array.array(self.badtypecode())
self.assertRaises(TypeError, a.extend, b)
@@ -721,7 +799,8 @@
def test_buffer(self):
a = array.array(self.typecode, self.example)
- b = buffer(a)
+ with test_support.check_py3k_warnings():
+ b = buffer(a)
self.assertEqual(b[0], a.tostring()[0])
def test_weakref(self):
@@ -769,7 +848,6 @@
return array.array.__new__(cls, 'c', s)
def __init__(self, s, color='blue'):
- array.array.__init__(self, 'c', s)
self.color = color
def strip(self):
@@ -857,6 +935,9 @@
a = array.array(self.typecode, range(10))
del a[::1000]
self.assertEqual(a, array.array(self.typecode, [1,2,3,4,5,6,7,8,9]))
+ # test issue7788
+ a = array.array(self.typecode, range(10))
+ del a[9::1<<333]
def test_assignment(self):
a = array.array(self.typecode, range(10))
@@ -1023,6 +1104,24 @@
class DoubleTest(FPTest):
typecode = 'd'
minitemsize = 8
+
+ def test_alloc_overflow(self):
+ from sys import maxsize
+ a = array.array('d', [-1]*65536)
+ try:
+ a *= maxsize//65536 + 1
+ except MemoryError:
+ pass
+ else:
+ self.fail("Array of size > maxsize created - MemoryError expected")
+ b = array.array('d', [ 2.71828183, 3.14159265, -1])
+ try:
+ b * (maxsize//3 + 1)
+ except MemoryError:
+ pass
+ else:
+ self.fail("Array of size > maxsize created - MemoryError expected")
+
tests.append(DoubleTest)
def test_main(verbose=None):
diff --git a/Lib/test/test_code.py b/Lib/test/test_code.py
--- a/Lib/test/test_code.py
+++ b/Lib/test/test_code.py
@@ -70,6 +70,11 @@
"""
+import unittest
+import weakref
+import _testcapi
+
+
def consts(t):
"""Yield a doctest-safe sequence of object reprs."""
for elt in t:
@@ -85,7 +90,47 @@
"freevars", "nlocals"]:
print "%s: %s" % (attr, getattr(co, "co_" + attr))
+
+class CodeTest(unittest.TestCase):
+
+ def test_newempty(self):
+ co = _testcapi.code_newempty("filename", "funcname", 15)
+ self.assertEqual(co.co_filename, "filename")
+ self.assertEqual(co.co_name, "funcname")
+ self.assertEqual(co.co_firstlineno, 15)
+
+
+class CodeWeakRefTest(unittest.TestCase):
+
+ def test_basic(self):
+ # Create a code object in a clean environment so that we know we have
+ # the only reference to it left.
+ namespace = {}
+ exec "def f(): pass" in globals(), namespace
+ f = namespace["f"]
+ del namespace
+
+ self.called = False
+ def callback(code):
+ self.called = True
+
+ # f is now the last reference to the function, and through it, the code
+ # object. While we hold it, check that we can create a weakref and
+ # deref it. Then delete it, and check that the callback gets called and
+ # the reference dies.
+ coderef = weakref.ref(f.__code__, callback)
+ self.assertTrue(bool(coderef()))
+ del f
+ self.assertFalse(bool(coderef()))
+ self.assertTrue(self.called)
+
+
def test_main(verbose=None):
- from test.test_support import run_doctest
+ from test.test_support import run_doctest, run_unittest
from test import test_code
run_doctest(test_code, verbose)
+ run_unittest(CodeTest, CodeWeakRefTest)
+
+
+if __name__ == "__main__":
+ test_main()
diff --git a/Lib/test/test_codeccallbacks.py b/Lib/test/test_codeccallbacks.py
--- a/Lib/test/test_codeccallbacks.py
+++ b/Lib/test/test_codeccallbacks.py
@@ -109,7 +109,7 @@
# useful that the error handler is not called for every single
# unencodable character, but for a complete sequence of
# unencodable characters, otherwise we would output many
- # unneccessary escape sequences.
+ # unnecessary escape sequences.
def uninamereplace(exc):
if not isinstance(exc, UnicodeEncodeError):
@@ -153,28 +153,30 @@
sout += "\\U%08x" % sys.maxunicode
self.assertEqual(sin.encode("iso-8859-15", "backslashreplace"), sout)
- def test_decoderelaxedutf8(self):
- # This is the test for a decoding callback handler,
- # that relaxes the UTF-8 minimal encoding restriction.
- # A null byte that is encoded as "\xc0\x80" will be
- # decoded as a null byte. All other illegal sequences
- # will be handled strictly.
+ def test_decoding_callbacks(self):
+ # This is a test for a decoding callback handler
+ # that allows the decoding of the invalid sequence
+ # "\xc0\x80" and returns "\x00" instead of raising an error.
+ # All other illegal sequences will be handled strictly.
def relaxedutf8(exc):
if not isinstance(exc, UnicodeDecodeError):
raise TypeError("don't know how to handle %r" % exc)
- if exc.object[exc.start:exc.end].startswith("\xc0\x80"):
+ if exc.object[exc.start:exc.start+2] == "\xc0\x80":
return (u"\x00", exc.start+2) # retry after two bytes
else:
raise exc
- codecs.register_error(
- "test.relaxedutf8", relaxedutf8)
+ codecs.register_error("test.relaxedutf8", relaxedutf8)
+ # all the "\xc0\x80" will be decoded to "\x00"
sin = "a\x00b\xc0\x80c\xc3\xbc\xc0\x80\xc0\x80"
sout = u"a\x00b\x00c\xfc\x00\x00"
self.assertEqual(sin.decode("utf-8", "test.relaxedutf8"), sout)
+
+ # "\xc0\x81" is not valid and a UnicodeDecodeError will be raised
sin = "\xc0\x80\xc0\x81"
- self.assertRaises(UnicodeError, sin.decode, "utf-8", "test.relaxedutf8")
+ self.assertRaises(UnicodeDecodeError, sin.decode,
+ "utf-8", "test.relaxedutf8")
def test_charmapencode(self):
# For charmap encodings the replacement string will be
@@ -184,7 +186,7 @@
charmap = dict([ (ord(c), 2*c.upper()) for c in "abcdefgh"])
sin = u"abc"
sout = "AABBCC"
- self.assertEquals(codecs.charmap_encode(sin, "strict", charmap)[0], sout)
+ self.assertEqual(codecs.charmap_encode(sin, "strict", charmap)[0], sout)
sin = u"abcA"
self.assertRaises(UnicodeError, codecs.charmap_encode, sin, "strict", charmap)
@@ -192,7 +194,7 @@
charmap[ord("?")] = "XYZ"
sin = u"abcDEF"
sout = "AABBCCXYZXYZXYZ"
- self.assertEquals(codecs.charmap_encode(sin, "replace", charmap)[0], sout)
+ self.assertEqual(codecs.charmap_encode(sin, "replace", charmap)[0], sout)
charmap[ord("?")] = u"XYZ"
self.assertRaises(TypeError, codecs.charmap_encode, sin, "replace", charmap)
@@ -285,7 +287,8 @@
def test_longstrings(self):
# test long strings to check for memory overflow problems
- errors = [ "strict", "ignore", "replace", "xmlcharrefreplace", "backslashreplace"]
+ errors = [ "strict", "ignore", "replace", "xmlcharrefreplace",
+ "backslashreplace"]
# register the handlers under different names,
# to prevent the codec from recognizing the name
for err in errors:
@@ -293,7 +296,8 @@
l = 1000
errors += [ "test." + err for err in errors ]
for uni in [ s*l for s in (u"x", u"\u3042", u"a\xe4") ]:
- for enc in ("ascii", "latin-1", "iso-8859-1", "iso-8859-15", "utf-8", "utf-7", "utf-16"):
+ for enc in ("ascii", "latin-1", "iso-8859-1", "iso-8859-15",
+ "utf-8", "utf-7", "utf-16", "utf-32"):
for err in errors:
try:
uni.encode(enc, err)
@@ -323,7 +327,7 @@
# check with the correct number and type of arguments
exc = exctype(*args)
- self.assertEquals(str(exc), msg)
+ self.assertEqual(str(exc), msg)
def test_unicodeencodeerror(self):
self.check_exceptionobjectargs(
@@ -433,15 +437,15 @@
UnicodeError("ouch")
)
# If the correct exception is passed in, "ignore" returns an empty replacement
- self.assertEquals(
+ self.assertEqual(
codecs.ignore_errors(UnicodeEncodeError("ascii", u"\u3042", 0, 1, "ouch")),
(u"", 1)
)
- self.assertEquals(
+ self.assertEqual(
codecs.ignore_errors(UnicodeDecodeError("ascii", "\xff", 0, 1, "ouch")),
(u"", 1)
)
- self.assertEquals(
+ self.assertEqual(
codecs.ignore_errors(UnicodeTranslateError(u"\u3042", 0, 1, "ouch")),
(u"", 1)
)
@@ -470,15 +474,15 @@
BadObjectUnicodeDecodeError()
)
# With the correct exception, "replace" returns an "?" or u"\ufffd" replacement
- self.assertEquals(
+ self.assertEqual(
codecs.replace_errors(UnicodeEncodeError("ascii", u"\u3042", 0, 1, "ouch")),
(u"?", 1)
)
- self.assertEquals(
+ self.assertEqual(
codecs.replace_errors(UnicodeDecodeError("ascii", "\xff", 0, 1, "ouch")),
(u"\ufffd", 1)
)
- self.assertEquals(
+ self.assertEqual(
codecs.replace_errors(UnicodeTranslateError(u"\u3042", 0, 1, "ouch")),
(u"\ufffd", 1)
)
@@ -510,7 +514,7 @@
# Use the correct exception
cs = (0, 1, 9, 10, 99, 100, 999, 1000, 9999, 10000, 0x3042)
s = "".join(unichr(c) for c in cs)
- self.assertEquals(
+ self.assertEqual(
codecs.xmlcharrefreplace_errors(
UnicodeEncodeError("ascii", s, 0, len(s), "ouch")
),
@@ -542,32 +546,32 @@
UnicodeTranslateError(u"\u3042", 0, 1, "ouch")
)
# Use the correct exception
- self.assertEquals(
+ self.assertEqual(
codecs.backslashreplace_errors(UnicodeEncodeError("ascii", u"\u3042", 0, 1, "ouch")),
(u"\\u3042", 1)
)
- self.assertEquals(
+ self.assertEqual(
codecs.backslashreplace_errors(UnicodeEncodeError("ascii", u"\x00", 0, 1, "ouch")),
(u"\\x00", 1)
)
- self.assertEquals(
+ self.assertEqual(
codecs.backslashreplace_errors(UnicodeEncodeError("ascii", u"\xff", 0, 1, "ouch")),
(u"\\xff", 1)
)
- self.assertEquals(
+ self.assertEqual(
codecs.backslashreplace_errors(UnicodeEncodeError("ascii", u"\u0100", 0, 1, "ouch")),
(u"\\u0100", 1)
)
- self.assertEquals(
+ self.assertEqual(
codecs.backslashreplace_errors(UnicodeEncodeError("ascii", u"\uffff", 0, 1, "ouch")),
(u"\\uffff", 1)
)
if sys.maxunicode>0xffff:
- self.assertEquals(
+ self.assertEqual(
codecs.backslashreplace_errors(UnicodeEncodeError("ascii", u"\U00010000", 0, 1, "ouch")),
(u"\\U00010000", 1)
)
- self.assertEquals(
+ self.assertEqual(
codecs.backslashreplace_errors(UnicodeEncodeError("ascii", u"\U0010ffff", 0, 1, "ouch")),
(u"\\U0010ffff", 1)
)
@@ -577,7 +581,7 @@
encs = ("ascii", "latin-1", "iso-8859-1", "iso-8859-15")
for res in results:
- codecs.register_error("test.badhandler", lambda: res)
+ codecs.register_error("test.badhandler", lambda x: res)
for enc in encs:
self.assertRaises(
TypeError,
@@ -599,14 +603,14 @@
)
def test_lookup(self):
- self.assertEquals(codecs.strict_errors, codecs.lookup_error("strict"))
- self.assertEquals(codecs.ignore_errors, codecs.lookup_error("ignore"))
- self.assertEquals(codecs.strict_errors, codecs.lookup_error("strict"))
- self.assertEquals(
+ self.assertEqual(codecs.strict_errors, codecs.lookup_error("strict"))
+ self.assertEqual(codecs.ignore_errors, codecs.lookup_error("ignore"))
+ self.assertEqual(codecs.strict_errors, codecs.lookup_error("strict"))
+ self.assertEqual(
codecs.xmlcharrefreplace_errors,
codecs.lookup_error("xmlcharrefreplace")
)
- self.assertEquals(
+ self.assertEqual(
codecs.backslashreplace_errors,
codecs.lookup_error("backslashreplace")
)
@@ -682,11 +686,11 @@
# Valid negative position
handler.pos = -1
- self.assertEquals("\xff0".decode("ascii", "test.posreturn"), u"<?>0")
+ self.assertEqual("\xff0".decode("ascii", "test.posreturn"), u"<?>0")
# Valid negative position
handler.pos = -2
- self.assertEquals("\xff0".decode("ascii", "test.posreturn"), u"<?><?>")
+ self.assertEqual("\xff0".decode("ascii", "test.posreturn"), u"<?><?>")
# Negative position out of bounds
handler.pos = -3
@@ -694,11 +698,11 @@
# Valid positive position
handler.pos = 1
- self.assertEquals("\xff0".decode("ascii", "test.posreturn"), u"<?>0")
+ self.assertEqual("\xff0".decode("ascii", "test.posreturn"), u"<?>0")
# Largest valid positive position (one beyond end of input)
handler.pos = 2
- self.assertEquals("\xff0".decode("ascii", "test.posreturn"), u"<?>")
+ self.assertEqual("\xff0".decode("ascii", "test.posreturn"), u"<?>")
# Invalid positive position
handler.pos = 3
@@ -706,7 +710,7 @@
# Restart at the "0"
handler.pos = 6
- self.assertEquals("\\uyyyy0".decode("raw-unicode-escape", "test.posreturn"), u"<?>0")
+ self.assertEqual("\\uyyyy0".decode("raw-unicode-escape", "test.posreturn"), u"<?>0")
class D(dict):
def __getitem__(self, key):
@@ -736,11 +740,11 @@
# Valid negative position
handler.pos = -1
- self.assertEquals(u"\xff0".encode("ascii", "test.posreturn"), "<?>0")
+ self.assertEqual(u"\xff0".encode("ascii", "test.posreturn"), "<?>0")
# Valid negative position
handler.pos = -2
- self.assertEquals(u"\xff0".encode("ascii", "test.posreturn"), "<?><?>")
+ self.assertEqual(u"\xff0".encode("ascii", "test.posreturn"), "<?><?>")
# Negative position out of bounds
handler.pos = -3
@@ -748,11 +752,11 @@
# Valid positive position
handler.pos = 1
- self.assertEquals(u"\xff0".encode("ascii", "test.posreturn"), "<?>0")
+ self.assertEqual(u"\xff0".encode("ascii", "test.posreturn"), "<?>0")
# Largest valid positive position (one beyond end of input
handler.pos = 2
- self.assertEquals(u"\xff0".encode("ascii", "test.posreturn"), "<?>")
+ self.assertEqual(u"\xff0".encode("ascii", "test.posreturn"), "<?>")
# Invalid positive position
handler.pos = 3
diff --git a/Lib/test/test_compile.py b/Lib/test/test_compile.py
--- a/Lib/test/test_compile.py
+++ b/Lib/test/test_compile.py
@@ -1,7 +1,8 @@
import unittest
-import warnings
import sys
+import _ast
from test import test_support
+import textwrap
class TestSpecifics(unittest.TestCase):
@@ -137,6 +138,9 @@
def test_complex_args(self):
+ with test_support._check_py3k_warnings(
+ ("tuple parameter unpacking has been removed", SyntaxWarning)):
+ exec textwrap.dedent('''
def comp_args((a, b)):
return a,b
self.assertEqual(comp_args((1, 2)), (1, 2))
@@ -154,6 +158,7 @@
return a, b, c
self.assertEqual(comp_args(1, (2, 3)), (1, 2, 3))
self.assertEqual(comp_args(), (2, 3, 4))
+ ''')
def test_argument_order(self):
try:
@@ -190,7 +195,9 @@
def test_literals_with_leading_zeroes(self):
for arg in ["077787", "0xj", "0x.", "0e", "090000000000000",
- "080000000000000", "000000000000009", "000000000000008"]:
+ "080000000000000", "000000000000009", "000000000000008",
+ "0b42", "0BADCAFE", "0o123456789", "0b1.1", "0o4.2",
+ "0b101j2", "0o153j2", "0b100e1", "0o777e1", "0o8", "0o78"]:
self.assertRaises(SyntaxError, eval, arg)
self.assertEqual(eval("0777"), 511)
@@ -218,6 +225,10 @@
self.assertEqual(eval("000000000000007"), 7)
self.assertEqual(eval("000000000000008."), 8.)
self.assertEqual(eval("000000000000009."), 9.)
+ self.assertEqual(eval("0b101010"), 42)
+ self.assertEqual(eval("-0b000000000010"), -2)
+ self.assertEqual(eval("0o777"), 511)
+ self.assertEqual(eval("-0o0000010"), -8)
self.assertEqual(eval("020000000000.0"), 20000000000.0)
self.assertEqual(eval("037777777777e0"), 37777777777.0)
self.assertEqual(eval("01000000000000000000000.0"),
@@ -417,9 +428,58 @@
del d[..., ...]
self.assertEqual((Ellipsis, Ellipsis) in d, False)
- def test_nested_classes(self):
- # Verify that it does not leak
- compile("class A:\n class B: pass", 'tmp', 'exec')
+ def test_mangling(self):
+ class A:
+ def f():
+ __mangled = 1
+ __not_mangled__ = 2
+ import __mangled_mod
+ import __package__.module
+
+ self.assert_("_A__mangled" in A.f.func_code.co_varnames)
+ self.assert_("__not_mangled__" in A.f.func_code.co_varnames)
+ self.assert_("_A__mangled_mod" in A.f.func_code.co_varnames)
+ self.assert_("__package__" in A.f.func_code.co_varnames)
+
+ def test_compile_ast(self):
+ fname = __file__
+ if fname.lower().endswith(('pyc', 'pyo')):
+ fname = fname[:-1]
+ with open(fname, 'r') as f:
+ fcontents = f.read()
+ sample_code = [
+ ['<assign>', 'x = 5'],
+ ['<print1>', 'print 1'],
+ ['<printv>', 'print v'],
+ ['<printTrue>', 'print True'],
+ ['<printList>', 'print []'],
+ ['<ifblock>', """if True:\n pass\n"""],
+ ['<forblock>', """for n in [1, 2, 3]:\n print n\n"""],
+ ['<deffunc>', """def foo():\n pass\nfoo()\n"""],
+ [fname, fcontents],
+ ]
+
+ for fname, code in sample_code:
+ co1 = compile(code, '%s1' % fname, 'exec')
+ ast = compile(code, '%s2' % fname, 'exec', _ast.PyCF_ONLY_AST)
+ self.assert_(type(ast) == _ast.Module)
+ co2 = compile(ast, '%s3' % fname, 'exec')
+ self.assertEqual(co1, co2)
+ # the code object's filename comes from the second compilation step
+ self.assertEqual(co2.co_filename, '%s3' % fname)
+
+ # raise exception when node type doesn't match with compile mode
+ co1 = compile('print 1', '<string>', 'exec', _ast.PyCF_ONLY_AST)
+ self.assertRaises(TypeError, compile, co1, '<ast>', 'eval')
+
+ # raise exception when node type is no start node
+ self.assertRaises(TypeError, compile, _ast.If(), '<ast>', 'exec')
+
+ # raise exception when node has invalid children
+ ast = _ast.Module()
+ ast.body = [_ast.BoolOp()]
+ self.assertRaises(TypeError, compile, ast, '<ast>', 'exec')
+
def test_main():
test_support.run_unittest(TestSpecifics)
diff --git a/Lib/test/test_copy.py b/Lib/test/test_copy.py
--- a/Lib/test/test_copy.py
+++ b/Lib/test/test_copy.py
@@ -1,6 +1,5 @@
"""Unit tests for the copy module."""
-import sys
import copy
import copy_reg
@@ -439,6 +438,7 @@
return (C, (), self.__dict__)
def __cmp__(self, other):
return cmp(self.__dict__, other.__dict__)
+ __hash__ = None # Silence Py3k warning
x = C()
x.foo = [42]
y = copy.copy(x)
@@ -455,6 +455,7 @@
self.__dict__.update(state)
def __cmp__(self, other):
return cmp(self.__dict__, other.__dict__)
+ __hash__ = None # Silence Py3k warning
x = C()
x.foo = [42]
y = copy.copy(x)
@@ -481,6 +482,7 @@
def __cmp__(self, other):
return (cmp(list(self), list(other)) or
cmp(self.__dict__, other.__dict__))
+ __hash__ = None # Silence Py3k warning
x = C([[1, 2], 3])
y = copy.copy(x)
self.assertEqual(x, y)
@@ -498,6 +500,7 @@
def __cmp__(self, other):
return (cmp(dict(self), list(dict)) or
cmp(self.__dict__, other.__dict__))
+ __hash__ = None # Silence Py3k warning
x = C([("foo", [1, 2]), ("bar", 3)])
y = copy.copy(x)
self.assertEqual(x, y)
diff --git a/Lib/test/test_descrtut.py b/Lib/test/test_descrtut.py
--- a/Lib/test/test_descrtut.py
+++ b/Lib/test/test_descrtut.py
@@ -54,7 +54,7 @@
{1: 3.25}
>>> print a[1] # show the new item
3.25
- >>> print a[0] # a non-existant item
+ >>> print a[0] # a non-existent item
0.0
>>> a.merge({1:100, 2:200}) # use a dict method
>>> print sortdict(a) # show the result
@@ -66,7 +66,7 @@
statement or the built-in function eval():
>>> def sorted(seq):
- ... seq.sort()
+ ... seq.sort(key=str)
... return seq
>>> print sorted(a.keys())
[1, 2]
@@ -183,6 +183,7 @@
'__delslice__',
'__doc__',
'__eq__',
+ '__format__',
'__ge__',
'__getattribute__',
'__getitem__',
@@ -207,7 +208,9 @@
'__setattr__',
'__setitem__',
'__setslice__',
+ '__sizeof__',
'__str__',
+ '__subclasshook__',
'append',
'count',
'extend',
diff --git a/Lib/test/test_dumbdbm.py b/Lib/test/test_dumbdbm.py
--- a/Lib/test/test_dumbdbm.py
+++ b/Lib/test/test_dumbdbm.py
@@ -38,6 +38,30 @@
self.read_helper(f)
f.close()
+ def test_dumbdbm_creation_mode(self):
+ # On platforms without chmod, don't do anything.
+ if not (hasattr(os, 'chmod') and hasattr(os, 'umask')):
+ return
+
+ try:
+ old_umask = os.umask(0002)
+ f = dumbdbm.open(_fname, 'c', 0637)
+ f.close()
+ finally:
+ os.umask(old_umask)
+
+ expected_mode = 0635
+ if os.name != 'posix':
+ # Windows only supports setting the read-only attribute.
+ # This shouldn't fail, but doesn't work like Unix either.
+ expected_mode = 0666
+
+ import stat
+ st = os.stat(_fname + '.dat')
+ self.assertEqual(stat.S_IMODE(st.st_mode), expected_mode)
+ st = os.stat(_fname + '.dir')
+ self.assertEqual(stat.S_IMODE(st.st_mode), expected_mode)
+
def test_close_twice(self):
f = dumbdbm.open(_fname)
f['a'] = 'b'
diff --git a/Lib/test/test_genexps.py b/Lib/test/test_genexps.py
--- a/Lib/test/test_genexps.py
+++ b/Lib/test/test_genexps.py
@@ -98,7 +98,7 @@
Verify that parenthesis are required when used as a keyword argument value
>>> dict(a = (i for i in xrange(10))) #doctest: +ELLIPSIS
- {'a': <generator object at ...>}
+ {'a': <generator object <genexpr> at ...>}
Verify early binding for the outermost for-expression
diff --git a/Lib/test/test_hashlib.py b/Lib/test/test_hashlib.py
--- a/Lib/test/test_hashlib.py
+++ b/Lib/test/test_hashlib.py
@@ -1,14 +1,26 @@
# Test hashlib module
#
-# $Id: test_hashlib.py 39316 2005-08-21 18:45:59Z greg $
+# $Id: test_hashlib.py 80564 2010-04-27 22:59:35Z victor.stinner $
#
-# Copyright (C) 2005 Gregory P. Smith (greg at electricrain.com)
+# Copyright (C) 2005-2010 Gregory P. Smith (greg at krypto.org)
# Licensed to PSF under a Contributor Agreement.
#
+import array
import hashlib
+import itertools
+import sys
+try:
+ import threading
+except ImportError:
+ threading = None
import unittest
+import warnings
from test import test_support
+from test.test_support import _4G, precisionbigmemtest
+
+# Were we compiled --with-pydebug or with #define Py_DEBUG?
+COMPILED_WITH_PYDEBUG = hasattr(sys, 'gettotalrefcount')
def hexstr(s):
@@ -26,24 +38,93 @@
'sha224', 'SHA224', 'sha256', 'SHA256',
'sha384', 'SHA384', 'sha512', 'SHA512' )
+ _warn_on_extension_import = COMPILED_WITH_PYDEBUG
+
+ def _conditional_import_module(self, module_name):
+ """Import a module and return a reference to it or None on failure."""
+ try:
+ exec('import '+module_name)
+ except ImportError, error:
+ if self._warn_on_extension_import:
+ warnings.warn('Did a C extension fail to compile? %s' % error)
+ return locals().get(module_name)
+
+ def __init__(self, *args, **kwargs):
+ algorithms = set()
+ for algorithm in self.supported_hash_names:
+ algorithms.add(algorithm.lower())
+ self.constructors_to_test = {}
+ for algorithm in algorithms:
+ self.constructors_to_test[algorithm] = set()
+
+ # For each algorithm, test the direct constructor and the use
+ # of hashlib.new given the algorithm name.
+ for algorithm, constructors in self.constructors_to_test.items():
+ constructors.add(getattr(hashlib, algorithm))
+ def _test_algorithm_via_hashlib_new(data=None, _alg=algorithm):
+ if data is None:
+ return hashlib.new(_alg)
+ return hashlib.new(_alg, data)
+ constructors.add(_test_algorithm_via_hashlib_new)
+
+ _hashlib = self._conditional_import_module('_hashlib')
+ if _hashlib:
+ # These two algorithms should always be present when this module
+ # is compiled. If not, something was compiled wrong.
+ assert hasattr(_hashlib, 'openssl_md5')
+ assert hasattr(_hashlib, 'openssl_sha1')
+ for algorithm, constructors in self.constructors_to_test.items():
+ constructor = getattr(_hashlib, 'openssl_'+algorithm, None)
+ if constructor:
+ constructors.add(constructor)
+
+ _md5 = self._conditional_import_module('_md5')
+ if _md5:
+ self.constructors_to_test['md5'].add(_md5.new)
+ _sha = self._conditional_import_module('_sha')
+ if _sha:
+ self.constructors_to_test['sha1'].add(_sha.new)
+ _sha256 = self._conditional_import_module('_sha256')
+ if _sha256:
+ self.constructors_to_test['sha224'].add(_sha256.sha224)
+ self.constructors_to_test['sha256'].add(_sha256.sha256)
+ _sha512 = self._conditional_import_module('_sha512')
+ if _sha512:
+ self.constructors_to_test['sha384'].add(_sha512.sha384)
+ self.constructors_to_test['sha512'].add(_sha512.sha512)
+
+ super(HashLibTestCase, self).__init__(*args, **kwargs)
+
+ def test_hash_array(self):
+ a = array.array("b", range(10))
+ constructors = self.constructors_to_test.itervalues()
+ for cons in itertools.chain.from_iterable(constructors):
+ c = cons(a)
+ c.hexdigest()
+
+ def test_algorithms_attribute(self):
+ self.assertEqual(hashlib.algorithms,
+ tuple([_algo for _algo in self.supported_hash_names if
+ _algo.islower()]))
+
def test_unknown_hash(self):
try:
hashlib.new('spam spam spam spam spam')
except ValueError:
pass
else:
- self.assert_(0 == "hashlib didn't reject bogus hash name")
+ self.assertTrue(0 == "hashlib didn't reject bogus hash name")
def test_hexdigest(self):
for name in self.supported_hash_names:
h = hashlib.new(name)
- self.assert_(hexstr(h.digest()) == h.hexdigest())
-
+ self.assertTrue(hexstr(h.digest()) == h.hexdigest())
def test_large_update(self):
aas = 'a' * 128
bees = 'b' * 127
cees = 'c' * 126
+ abcs = aas + bees + cees
for name in self.supported_hash_names:
m1 = hashlib.new(name)
@@ -52,18 +133,39 @@
m1.update(cees)
m2 = hashlib.new(name)
- m2.update(aas + bees + cees)
- self.assertEqual(m1.digest(), m2.digest())
+ m2.update(abcs)
+ self.assertEqual(m1.digest(), m2.digest(), name+' update problem.')
+ m3 = hashlib.new(name, abcs)
+ self.assertEqual(m1.digest(), m3.digest(), name+' new problem.')
def check(self, name, data, digest):
- # test the direct constructors
- computed = getattr(hashlib, name)(data).hexdigest()
- self.assert_(computed == digest)
- # test the general new() interface
- computed = hashlib.new(name, data).hexdigest()
- self.assert_(computed == digest)
+ constructors = self.constructors_to_test[name]
+ # 2 is for hashlib.name(...) and hashlib.new(name, ...)
+ self.assertGreaterEqual(len(constructors), 2)
+ for hash_object_constructor in constructors:
+ computed = hash_object_constructor(data).hexdigest()
+ self.assertEqual(
+ computed, digest,
+ "Hash algorithm %s constructed using %s returned hexdigest"
+ " %r for %d byte input data that should have hashed to %r."
+ % (name, hash_object_constructor,
+ computed, len(data), digest))
+ def check_unicode(self, algorithm_name):
+ # Unicode objects are not allowed as input.
+ expected = hashlib.new(algorithm_name, str(u'spam')).hexdigest()
+ self.check(algorithm_name, u'spam', expected)
+
+ def test_unicode(self):
+ # In python 2.x unicode is auto-encoded to the system default encoding
+ # when passed to hashlib functions.
+ self.check_unicode('md5')
+ self.check_unicode('sha1')
+ self.check_unicode('sha224')
+ self.check_unicode('sha256')
+ self.check_unicode('sha384')
+ self.check_unicode('sha512')
def test_case_md5_0(self):
self.check('md5', '', 'd41d8cd98f00b204e9800998ecf8427e')
@@ -75,6 +177,21 @@
self.check('md5', 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789',
'd174ab98d277d9f5a5611c2c9f419d9f')
+ @precisionbigmemtest(size=_4G + 5, memuse=1)
+ def test_case_md5_huge(self, size):
+ if size == _4G + 5:
+ try:
+ self.check('md5', 'A'*size, 'c9af2dff37468ce5dfee8f2cfc0a9c6d')
+ except OverflowError:
+ pass # 32-bit arch
+
+ @precisionbigmemtest(size=_4G - 1, memuse=1)
+ def test_case_md5_uintmax(self, size):
+ if size == _4G - 1:
+ try:
+ self.check('md5', 'A'*size, '28138d306ff1b8281f1a9067e1a1a2b3')
+ except OverflowError:
+ pass # 32-bit arch
# use the three examples from Federal Information Processing Standards
# Publication 180-1, Secure Hash Standard, 1995 April 17
@@ -182,6 +299,42 @@
"e718483d0ce769644e2e42c7bc15b4638e1f98b13b2044285632a803afa973eb"+
"de0ff244877ea60a4cb0432ce577c31beb009c5c2c49aa2e4eadb217ad8cc09b")
+ @unittest.skipUnless(threading, 'Threading required for this test.')
+ @test_support.reap_threads
+ def test_threaded_hashing(self):
+ # Updating the same hash object from several threads at once
+ # using data chunk sizes containing the same byte sequences.
+ #
+ # If the internal locks are working to prevent multiple
+ # updates on the same object from running at once, the resulting
+ # hash will be the same as doing it single threaded upfront.
+ hasher = hashlib.sha1()
+ num_threads = 5
+ smallest_data = 'swineflu'
+ data = smallest_data*200000
+ expected_hash = hashlib.sha1(data*num_threads).hexdigest()
+
+ def hash_in_chunks(chunk_size, event):
+ index = 0
+ while index < len(data):
+ hasher.update(data[index:index+chunk_size])
+ index += chunk_size
+ event.set()
+
+ events = []
+ for threadnum in xrange(num_threads):
+ chunk_size = len(data) // (10**threadnum)
+ assert chunk_size > 0
+ assert chunk_size % len(smallest_data) == 0
+ event = threading.Event()
+ events.append(event)
+ threading.Thread(target=hash_in_chunks,
+ args=(chunk_size, event)).start()
+
+ for event in events:
+ event.wait()
+
+ self.assertEqual(expected_hash, hasher.hexdigest())
def test_main():
if test_support.is_jython:
@@ -196,6 +349,5 @@
test_support.run_unittest(HashLibTestCase)
-
if __name__ == "__main__":
test_main()
diff --git a/Lib/test/test_hmac.py b/Lib/test/test_hmac.py
--- a/Lib/test/test_hmac.py
+++ b/Lib/test/test_hmac.py
@@ -1,7 +1,7 @@
import hmac
-import sha
import hashlib
import unittest
+import warnings
from test import test_support
class TestVectorsTestCase(unittest.TestCase):
@@ -44,7 +44,7 @@
def test_sha_vectors(self):
def shatest(key, data, digest):
- h = hmac.HMAC(key, data, digestmod=sha)
+ h = hmac.HMAC(key, data, digestmod=hashlib.sha1)
self.assertEqual(h.hexdigest().upper(), digest.upper())
shatest(chr(0x0b) * 20,
@@ -200,6 +200,29 @@
def test_sha512_rfc4231(self):
self._rfc4231_test_cases(hashlib.sha512)
+ def test_legacy_block_size_warnings(self):
+ class MockCrazyHash(object):
+ """Ain't no block_size attribute here."""
+ def __init__(self, *args):
+ self._x = hashlib.sha1(*args)
+ self.digest_size = self._x.digest_size
+ def update(self, v):
+ self._x.update(v)
+ def digest(self):
+ return self._x.digest()
+
+ with warnings.catch_warnings():
+ warnings.simplefilter('error', RuntimeWarning)
+ with self.assertRaises(RuntimeWarning):
+ hmac.HMAC('a', 'b', digestmod=MockCrazyHash)
+ self.fail('Expected warning about missing block_size')
+
+ MockCrazyHash.block_size = 1
+ with self.assertRaises(RuntimeWarning):
+ hmac.HMAC('a', 'b', digestmod=MockCrazyHash)
+ self.fail('Expected warning about small block_size')
+
+
class ConstructorTestCase(unittest.TestCase):
@@ -220,20 +243,18 @@
def test_withmodule(self):
# Constructor call with text and digest module.
- import sha
try:
- h = hmac.HMAC("key", "", sha)
+ h = hmac.HMAC("key", "", hashlib.sha1)
except:
- self.fail("Constructor call with sha module raised exception.")
+ self.fail("Constructor call with hashlib.sha1 raised exception.")
class SanityTestCase(unittest.TestCase):
def test_default_is_md5(self):
# Testing if HMAC defaults to MD5 algorithm.
# NOTE: this whitebox test depends on the hmac class internals
- import hashlib
h = hmac.HMAC("key")
- self.failUnless(h.digest_cons == hashlib.md5)
+ self.assertTrue(h.digest_cons == hashlib.md5)
def test_exercise_all_methods(self):
# Exercising all methods once.
@@ -253,11 +274,11 @@
# Testing if attributes are of same type.
h1 = hmac.HMAC("key")
h2 = h1.copy()
- self.failUnless(h1.digest_cons == h2.digest_cons,
+ self.assertTrue(h1.digest_cons == h2.digest_cons,
"digest constructors don't match.")
- self.failUnless(type(h1.inner) == type(h2.inner),
+ self.assertTrue(type(h1.inner) == type(h2.inner),
"Types of inner don't match.")
- self.failUnless(type(h1.outer) == type(h2.outer),
+ self.assertTrue(type(h1.outer) == type(h2.outer),
"Types of outer don't match.")
def test_realcopy(self):
@@ -265,10 +286,10 @@
h1 = hmac.HMAC("key")
h2 = h1.copy()
# Using id() in case somebody has overridden __cmp__.
- self.failUnless(id(h1) != id(h2), "No real copy of the HMAC instance.")
- self.failUnless(id(h1.inner) != id(h2.inner),
+ self.assertTrue(id(h1) != id(h2), "No real copy of the HMAC instance.")
+ self.assertTrue(id(h1.inner) != id(h2.inner),
"No real copy of the attribute 'inner'.")
- self.failUnless(id(h1.outer) != id(h2.outer),
+ self.assertTrue(id(h1.outer) != id(h2.outer),
"No real copy of the attribute 'outer'.")
def test_equality(self):
@@ -276,9 +297,9 @@
h1 = hmac.HMAC("key")
h1.update("some random text")
h2 = h1.copy()
- self.failUnless(h1.digest() == h2.digest(),
+ self.assertTrue(h1.digest() == h2.digest(),
"Digest of copy doesn't match original digest.")
- self.failUnless(h1.hexdigest() == h2.hexdigest(),
+ self.assertTrue(h1.hexdigest() == h2.hexdigest(),
"Hexdigest of copy doesn't match original hexdigest.")
def test_main():
diff --git a/Lib/test/test_iter.py b/Lib/test/test_iter.py
--- a/Lib/test/test_iter.py
+++ b/Lib/test/test_iter.py
@@ -1,7 +1,8 @@
# Test iterators.
import unittest
-from test.test_support import run_unittest, TESTFN, unlink, have_unicode
+from test.test_support import run_unittest, TESTFN, unlink, have_unicode, \
+ _check_py3k_warnings
# Test result of triple loop (too big to inline)
TRIPLETS = [(0, 0, 0), (0, 0, 1), (0, 0, 2),
@@ -389,21 +390,24 @@
# Test map()'s use of iterators.
def test_builtin_map(self):
- self.assertEqual(map(None, SequenceClass(5)), range(5))
self.assertEqual(map(lambda x: x+1, SequenceClass(5)), range(1, 6))
d = {"one": 1, "two": 2, "three": 3}
- self.assertEqual(map(None, d), d.keys())
self.assertEqual(map(lambda k, d=d: (k, d[k]), d), d.items())
dkeys = d.keys()
expected = [(i < len(d) and dkeys[i] or None,
i,
i < len(d) and dkeys[i] or None)
for i in range(5)]
- self.assertEqual(map(None, d,
- SequenceClass(5),
- iter(d.iterkeys())),
- expected)
+
+ # Deprecated map(None, ...)
+ with _check_py3k_warnings():
+ self.assertEqual(map(None, SequenceClass(5)), range(5))
+ self.assertEqual(map(None, d), d.keys())
+ self.assertEqual(map(None, d,
+ SequenceClass(5),
+ iter(d.iterkeys())),
+ expected)
f = open(TESTFN, "w")
try:
@@ -499,7 +503,11 @@
self.assertEqual(zip(x, y), expected)
# Test reduces()'s use of iterators.
- def test_builtin_reduce(self):
+ def test_deprecated_builtin_reduce(self):
+ with _check_py3k_warnings():
+ self._test_builtin_reduce()
+
+ def _test_builtin_reduce(self):
from operator import add
self.assertEqual(reduce(add, SequenceClass(5)), 10)
self.assertEqual(reduce(add, SequenceClass(5), 42), 52)
diff --git a/Lib/test/test_logging.py b/Lib/test/test_logging.py
--- a/Lib/test/test_logging.py
+++ b/Lib/test/test_logging.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python
#
-# Copyright 2001-2004 by Vinay Sajip. All Rights Reserved.
+# Copyright 2001-2010 by Vinay Sajip. All Rights Reserved.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose and without fee is hereby granted,
@@ -15,201 +15,292 @@
# ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER
# IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
-#
-# This file is part of the Python logging distribution. See
-# http://www.red-dove.com/python_logging.html
-#
+
"""Test harness for the logging module. Run all tests.
-Copyright (C) 2001-2002 Vinay Sajip. All Rights Reserved.
+Copyright (C) 2001-2010 Vinay Sajip. All Rights Reserved.
"""
+import logging
+import logging.handlers
+import logging.config
+
+import codecs
+import cPickle
+import cStringIO
+import gc
+import json
+import os
+import re
import select
-import os, sys, string, struct, types, cPickle, cStringIO
-import socket, tempfile, threading, time
-import logging, logging.handlers, logging.config
-from test.test_support import run_with_locale
+import socket
+from SocketServer import ThreadingTCPServer, StreamRequestHandler
+import struct
+import sys
+import tempfile
+from test.test_support import captured_stdout, run_with_locale, run_unittest
+import textwrap
+import unittest
+import warnings
+import weakref
+try:
+ import threading
+except ImportError:
+ threading = None
-BANNER = "-- %-10s %-6s ---------------------------------------------------\n"
+class BaseTest(unittest.TestCase):
-FINISH_UP = "Finish up, it's closing time. Messages should bear numbers 0 through 24."
+ """Base class for logging tests."""
-#----------------------------------------------------------------------------
-# Log receiver
-#----------------------------------------------------------------------------
+ log_format = "%(name)s -> %(levelname)s: %(message)s"
+ expected_log_pat = r"^([\w.]+) -> ([\w]+): ([\d]+)$"
+ message_num = 0
-TIMEOUT = 10
+ def setUp(self):
+ """Setup the default logging stream to an internal StringIO instance,
+ so that we can examine log output as we want."""
+ logger_dict = logging.getLogger().manager.loggerDict
+ logging._acquireLock()
+ try:
+ self.saved_handlers = logging._handlers.copy()
+ self.saved_handler_list = logging._handlerList[:]
+ self.saved_loggers = logger_dict.copy()
+ self.saved_level_names = logging._levelNames.copy()
+ finally:
+ logging._releaseLock()
-from SocketServer import ThreadingTCPServer, StreamRequestHandler
+ # Set two unused loggers: one non-ASCII and one Unicode.
+ # This is to test correct operation when sorting existing
+ # loggers in the configuration code. See issue 8201.
+ logging.getLogger("\xab\xd7\xbb")
+ logging.getLogger(u"\u013f\u00d6\u0047")
-class LogRecordStreamHandler(StreamRequestHandler):
- """
- Handler for a streaming logging request. It basically logs the record
- using whatever logging policy is configured locally.
- """
+ self.root_logger = logging.getLogger("")
+ self.original_logging_level = self.root_logger.getEffectiveLevel()
- def handle(self):
- """
- Handle multiple requests - each expected to be a 4-byte length,
- followed by the LogRecord in pickle format. Logs the record
- according to whatever policy is configured locally.
- """
- while 1:
- try:
- chunk = self.connection.recv(4)
- if len(chunk) < 4:
- break
- slen = struct.unpack(">L", chunk)[0]
- chunk = self.connection.recv(slen)
- while len(chunk) < slen:
- chunk = chunk + self.connection.recv(slen - len(chunk))
- obj = self.unPickle(chunk)
- record = logging.makeLogRecord(obj)
- self.handleLogRecord(record)
- except:
- raise
+ self.stream = cStringIO.StringIO()
+ self.root_logger.setLevel(logging.DEBUG)
+ self.root_hdlr = logging.StreamHandler(self.stream)
+ self.root_formatter = logging.Formatter(self.log_format)
+ self.root_hdlr.setFormatter(self.root_formatter)
+ self.root_logger.addHandler(self.root_hdlr)
- def unPickle(self, data):
- return cPickle.loads(data)
+ def tearDown(self):
+ """Remove our logging stream, and restore the original logging
+ level."""
+ self.stream.close()
+ self.root_logger.removeHandler(self.root_hdlr)
+ while self.root_logger.handlers:
+ h = self.root_logger.handlers[0]
+ self.root_logger.removeHandler(h)
+ h.close()
+ self.root_logger.setLevel(self.original_logging_level)
+ logging._acquireLock()
+ try:
+ logging._levelNames.clear()
+ logging._levelNames.update(self.saved_level_names)
+ logging._handlers.clear()
+ logging._handlers.update(self.saved_handlers)
+ logging._handlerList[:] = self.saved_handler_list
+ loggerDict = logging.getLogger().manager.loggerDict
+ loggerDict.clear()
+ loggerDict.update(self.saved_loggers)
+ finally:
+ logging._releaseLock()
- def handleLogRecord(self, record):
- logname = "logrecv.tcp." + record.name
- #If the end-of-messages sentinel is seen, tell the server to terminate
- if record.msg == FINISH_UP:
- self.server.abort = 1
- record.msg = record.msg + " (via " + logname + ")"
- logger = logging.getLogger(logname)
- logger.handle(record)
+ def assert_log_lines(self, expected_values, stream=None):
+ """Match the collected log lines against the regular expression
+ self.expected_log_pat, and compare the extracted group values to
+ the expected_values list of tuples."""
+ stream = stream or self.stream
+ pat = re.compile(self.expected_log_pat)
+ try:
+ stream.reset()
+ actual_lines = stream.readlines()
+ except AttributeError:
+ # StringIO.StringIO lacks a reset() method.
+ actual_lines = stream.getvalue().splitlines()
+ self.assertEqual(len(actual_lines), len(expected_values))
+ for actual, expected in zip(actual_lines, expected_values):
+ match = pat.search(actual)
+ if not match:
+ self.fail("Log line does not match expected pattern:\n" +
+ actual)
+ self.assertEqual(tuple(match.groups()), expected)
+ s = stream.read()
+ if s:
+ self.fail("Remaining output at end of log stream:\n" + s)
-# The server sets socketDataProcessed when it's done.
-socketDataProcessed = threading.Event()
+ def next_message(self):
+ """Generate a message consisting solely of an auto-incrementing
+ integer."""
+ self.message_num += 1
+ return "%d" % self.message_num
-class LogRecordSocketReceiver(ThreadingTCPServer):
- """
- A simple-minded TCP socket-based logging receiver suitable for test
- purposes.
- """
- allow_reuse_address = 1
+class BuiltinLevelsTest(BaseTest):
+ """Test builtin levels and their inheritance."""
- def __init__(self, host='localhost',
- port=logging.handlers.DEFAULT_TCP_LOGGING_PORT,
- handler=LogRecordStreamHandler):
- ThreadingTCPServer.__init__(self, (host, port), handler)
- self.abort = 0
- self.timeout = 1
+ def test_flat(self):
+ #Logging levels in a flat logger namespace.
+ m = self.next_message
- def serve_until_stopped(self):
- if sys.platform.startswith('java'):
- # XXX: There's a problem using cpython_compatibile_select
- # here: it seems to be due to the fact that
- # cpython_compatible_select switches blocking mode on while
- # a separate thread is reading from the same socket, causing
- # a read of 0 in LogRecordStreamHandler.handle (which
- # deadlocks this test)
- self.socket.setblocking(0)
- while not self.abort:
- rd, wr, ex = select.select([self.socket.fileno()], [], [],
- self.timeout)
- if rd:
- self.handle_request()
- #notify the main thread that we're about to exit
- socketDataProcessed.set()
- # close the listen socket
- self.server_close()
+ ERR = logging.getLogger("ERR")
+ ERR.setLevel(logging.ERROR)
+ INF = logging.getLogger("INF")
+ INF.setLevel(logging.INFO)
+ DEB = logging.getLogger("DEB")
+ DEB.setLevel(logging.DEBUG)
- def process_request(self, request, client_address):
- #import threading
- t = threading.Thread(target = self.finish_request,
- args = (request, client_address))
- t.start()
+ # These should log.
+ ERR.log(logging.CRITICAL, m())
+ ERR.error(m())
-def runTCP(tcpserver):
- tcpserver.serve_until_stopped()
+ INF.log(logging.CRITICAL, m())
+ INF.error(m())
+ INF.warn(m())
+ INF.info(m())
-#----------------------------------------------------------------------------
-# Test 0
-#----------------------------------------------------------------------------
+ DEB.log(logging.CRITICAL, m())
+ DEB.error(m())
+ DEB.warn (m())
+ DEB.info (m())
+ DEB.debug(m())
-msgcount = 0
+ # These should not log.
+ ERR.warn(m())
+ ERR.info(m())
+ ERR.debug(m())
-def nextmessage():
- global msgcount
- rv = "Message %d" % msgcount
- msgcount = msgcount + 1
- return rv
+ INF.debug(m())
-def test0():
- ERR = logging.getLogger("ERR")
- ERR.setLevel(logging.ERROR)
- INF = logging.getLogger("INF")
- INF.setLevel(logging.INFO)
- INF_ERR = logging.getLogger("INF.ERR")
- INF_ERR.setLevel(logging.ERROR)
- DEB = logging.getLogger("DEB")
- DEB.setLevel(logging.DEBUG)
+ self.assert_log_lines([
+ ('ERR', 'CRITICAL', '1'),
+ ('ERR', 'ERROR', '2'),
+ ('INF', 'CRITICAL', '3'),
+ ('INF', 'ERROR', '4'),
+ ('INF', 'WARNING', '5'),
+ ('INF', 'INFO', '6'),
+ ('DEB', 'CRITICAL', '7'),
+ ('DEB', 'ERROR', '8'),
+ ('DEB', 'WARNING', '9'),
+ ('DEB', 'INFO', '10'),
+ ('DEB', 'DEBUG', '11'),
+ ])
- INF_UNDEF = logging.getLogger("INF.UNDEF")
- INF_ERR_UNDEF = logging.getLogger("INF.ERR.UNDEF")
- UNDEF = logging.getLogger("UNDEF")
+ def test_nested_explicit(self):
+ # Logging levels in a nested namespace, all explicitly set.
+ m = self.next_message
- GRANDCHILD = logging.getLogger("INF.BADPARENT.UNDEF")
- CHILD = logging.getLogger("INF.BADPARENT")
+ INF = logging.getLogger("INF")
+ INF.setLevel(logging.INFO)
+ INF_ERR = logging.getLogger("INF.ERR")
+ INF_ERR.setLevel(logging.ERROR)
- #These should log
- ERR.log(logging.FATAL, nextmessage())
- ERR.error(nextmessage())
+ # These should log.
+ INF_ERR.log(logging.CRITICAL, m())
+ INF_ERR.error(m())
- INF.log(logging.FATAL, nextmessage())
- INF.error(nextmessage())
- INF.warn(nextmessage())
- INF.info(nextmessage())
+ # These should not log.
+ INF_ERR.warn(m())
+ INF_ERR.info(m())
+ INF_ERR.debug(m())
- INF_UNDEF.log(logging.FATAL, nextmessage())
- INF_UNDEF.error(nextmessage())
- INF_UNDEF.warn (nextmessage())
- INF_UNDEF.info (nextmessage())
+ self.assert_log_lines([
+ ('INF.ERR', 'CRITICAL', '1'),
+ ('INF.ERR', 'ERROR', '2'),
+ ])
- INF_ERR.log(logging.FATAL, nextmessage())
- INF_ERR.error(nextmessage())
+ def test_nested_inherited(self):
+ #Logging levels in a nested namespace, inherited from parent loggers.
+ m = self.next_message
- INF_ERR_UNDEF.log(logging.FATAL, nextmessage())
- INF_ERR_UNDEF.error(nextmessage())
+ INF = logging.getLogger("INF")
+ INF.setLevel(logging.INFO)
+ INF_ERR = logging.getLogger("INF.ERR")
+ INF_ERR.setLevel(logging.ERROR)
+ INF_UNDEF = logging.getLogger("INF.UNDEF")
+ INF_ERR_UNDEF = logging.getLogger("INF.ERR.UNDEF")
+ UNDEF = logging.getLogger("UNDEF")
- DEB.log(logging.FATAL, nextmessage())
- DEB.error(nextmessage())
- DEB.warn (nextmessage())
- DEB.info (nextmessage())
- DEB.debug(nextmessage())
+ # These should log.
+ INF_UNDEF.log(logging.CRITICAL, m())
+ INF_UNDEF.error(m())
+ INF_UNDEF.warn(m())
+ INF_UNDEF.info(m())
+ INF_ERR_UNDEF.log(logging.CRITICAL, m())
+ INF_ERR_UNDEF.error(m())
- UNDEF.log(logging.FATAL, nextmessage())
- UNDEF.error(nextmessage())
- UNDEF.warn (nextmessage())
- UNDEF.info (nextmessage())
+ # These should not log.
+ INF_UNDEF.debug(m())
+ INF_ERR_UNDEF.warn(m())
+ INF_ERR_UNDEF.info(m())
+ INF_ERR_UNDEF.debug(m())
- GRANDCHILD.log(logging.FATAL, nextmessage())
- CHILD.log(logging.FATAL, nextmessage())
+ self.assert_log_lines([
+ ('INF.UNDEF', 'CRITICAL', '1'),
+ ('INF.UNDEF', 'ERROR', '2'),
+ ('INF.UNDEF', 'WARNING', '3'),
+ ('INF.UNDEF', 'INFO', '4'),
+ ('INF.ERR.UNDEF', 'CRITICAL', '5'),
+ ('INF.ERR.UNDEF', 'ERROR', '6'),
+ ])
- #These should not log
- ERR.warn(nextmessage())
- ERR.info(nextmessage())
- ERR.debug(nextmessage())
+ def test_nested_with_virtual_parent(self):
+ # Logging levels when some parent does not exist yet.
+ m = self.next_message
- INF.debug(nextmessage())
- INF_UNDEF.debug(nextmessage())
+ INF = logging.getLogger("INF")
+ GRANDCHILD = logging.getLogger("INF.BADPARENT.UNDEF")
+ CHILD = logging.getLogger("INF.BADPARENT")
+ INF.setLevel(logging.INFO)
- INF_ERR.warn(nextmessage())
- INF_ERR.info(nextmessage())
- INF_ERR.debug(nextmessage())
- INF_ERR_UNDEF.warn(nextmessage())
- INF_ERR_UNDEF.info(nextmessage())
- INF_ERR_UNDEF.debug(nextmessage())
+ # These should log.
+ GRANDCHILD.log(logging.FATAL, m())
+ GRANDCHILD.info(m())
+ CHILD.log(logging.FATAL, m())
+ CHILD.info(m())
- INF.info(FINISH_UP)
+ # These should not log.
+ GRANDCHILD.debug(m())
+ CHILD.debug(m())
-#----------------------------------------------------------------------------
-# Test 1
-#----------------------------------------------------------------------------
+ self.assert_log_lines([
+ ('INF.BADPARENT.UNDEF', 'CRITICAL', '1'),
+ ('INF.BADPARENT.UNDEF', 'INFO', '2'),
+ ('INF.BADPARENT', 'CRITICAL', '3'),
+ ('INF.BADPARENT', 'INFO', '4'),
+ ])
+
+
+class BasicFilterTest(BaseTest):
+
+ """Test the bundled Filter class."""
+
+ def test_filter(self):
+ # Only messages satisfying the specified criteria pass through the
+ # filter.
+ filter_ = logging.Filter("spam.eggs")
+ handler = self.root_logger.handlers[0]
+ try:
+ handler.addFilter(filter_)
+ spam = logging.getLogger("spam")
+ spam_eggs = logging.getLogger("spam.eggs")
+ spam_eggs_fish = logging.getLogger("spam.eggs.fish")
+ spam_bakedbeans = logging.getLogger("spam.bakedbeans")
+
+ spam.info(self.next_message())
+ spam_eggs.info(self.next_message()) # Good.
+ spam_eggs_fish.info(self.next_message()) # Good.
+ spam_bakedbeans.info(self.next_message())
+
+ self.assert_log_lines([
+ ('spam.eggs', 'INFO', '2'),
+ ('spam.eggs.fish', 'INFO', '3'),
+ ])
+ finally:
+ handler.removeFilter(filter_)
+
#
# First, we define our levels. There can be as many as you want - the only
@@ -219,16 +310,16 @@
# mapping dictionary to convert between your application levels and the
# logging system.
#
-SILENT = 10
-TACITURN = 9
-TERSE = 8
-EFFUSIVE = 7
-SOCIABLE = 6
-VERBOSE = 5
-TALKATIVE = 4
-GARRULOUS = 3
-CHATTERBOX = 2
-BORING = 1
+SILENT = 120
+TACITURN = 119
+TERSE = 118
+EFFUSIVE = 117
+SOCIABLE = 116
+VERBOSE = 115
+TALKATIVE = 114
+GARRULOUS = 113
+CHATTERBOX = 112
+BORING = 111
LEVEL_RANGE = range(BORING, SILENT + 1)
@@ -249,445 +340,1442 @@
BORING : 'Boring',
}
-#
-# Now, to demonstrate filtering: suppose for some perverse reason we only
-# want to print out all except GARRULOUS messages. Let's create a filter for
-# this purpose...
-#
-class SpecificLevelFilter(logging.Filter):
- def __init__(self, lvl):
- self.level = lvl
+class GarrulousFilter(logging.Filter):
+
+ """A filter which blocks garrulous messages."""
def filter(self, record):
- return self.level != record.levelno
+ return record.levelno != GARRULOUS
-class GarrulousFilter(SpecificLevelFilter):
- def __init__(self):
- SpecificLevelFilter.__init__(self, GARRULOUS)
+class VerySpecificFilter(logging.Filter):
-#
-# Now, let's demonstrate filtering at the logger. This time, use a filter
-# which excludes SOCIABLE and TACITURN messages. Note that GARRULOUS events
-# are still excluded.
-#
-class VerySpecificFilter(logging.Filter):
+ """A filter which blocks sociable and taciturn messages."""
+
def filter(self, record):
return record.levelno not in [SOCIABLE, TACITURN]
-def message(s):
- sys.stdout.write("%s\n" % s)
-SHOULD1 = "This should only be seen at the '%s' logging level (or lower)"
+class CustomLevelsAndFiltersTest(BaseTest):
-def test1():
-#
-# Now, tell the logging system to associate names with our levels.
-#
- for lvl in my_logging_levels.keys():
- logging.addLevelName(lvl, my_logging_levels[lvl])
+ """Test various filtering possibilities with custom logging levels."""
-#
-# Now, define a test function which logs an event at each of our levels.
-#
+ # Skip the logger name group.
+ expected_log_pat = r"^[\w.]+ -> ([\w]+): ([\d]+)$"
- def doLog(log):
+ def setUp(self):
+ BaseTest.setUp(self)
+ for k, v in my_logging_levels.items():
+ logging.addLevelName(k, v)
+
+ def log_at_all_levels(self, logger):
for lvl in LEVEL_RANGE:
- log.log(lvl, SHOULD1, logging.getLevelName(lvl))
+ logger.log(lvl, self.next_message())
- log = logging.getLogger("")
- hdlr = log.handlers[0]
-#
-# Set the logging level to each different value and call the utility
-# function to log events.
-# In the output, you should see that each time round the loop, the number of
-# logging events which are actually output decreases.
-#
- for lvl in LEVEL_RANGE:
- message("-- setting logging level to '%s' -----" %
- logging.getLevelName(lvl))
- log.setLevel(lvl)
- doLog(log)
- #
- # Now, we demonstrate level filtering at the handler level. Tell the
- # handler defined above to filter at level 'SOCIABLE', and repeat the
- # above loop. Compare the output from the two runs.
- #
- hdlr.setLevel(SOCIABLE)
- message("-- Filtering at handler level to SOCIABLE --")
- for lvl in LEVEL_RANGE:
- message("-- setting logging level to '%s' -----" %
- logging.getLevelName(lvl))
- log.setLevel(lvl)
- doLog(log)
+ def test_logger_filter(self):
+ # Filter at logger level.
+ self.root_logger.setLevel(VERBOSE)
+ # Levels >= 'Verbose' are good.
+ self.log_at_all_levels(self.root_logger)
+ self.assert_log_lines([
+ ('Verbose', '5'),
+ ('Sociable', '6'),
+ ('Effusive', '7'),
+ ('Terse', '8'),
+ ('Taciturn', '9'),
+ ('Silent', '10'),
+ ])
- hdlr.setLevel(0) #turn off level filtering at the handler
+ def test_handler_filter(self):
+ # Filter at handler level.
+ self.root_logger.handlers[0].setLevel(SOCIABLE)
+ try:
+ # Levels >= 'Sociable' are good.
+ self.log_at_all_levels(self.root_logger)
+ self.assert_log_lines([
+ ('Sociable', '6'),
+ ('Effusive', '7'),
+ ('Terse', '8'),
+ ('Taciturn', '9'),
+ ('Silent', '10'),
+ ])
+ finally:
+ self.root_logger.handlers[0].setLevel(logging.NOTSET)
- garr = GarrulousFilter()
- hdlr.addFilter(garr)
- message("-- Filtering using GARRULOUS filter --")
- for lvl in LEVEL_RANGE:
- message("-- setting logging level to '%s' -----" %
- logging.getLevelName(lvl))
- log.setLevel(lvl)
- doLog(log)
- spec = VerySpecificFilter()
- log.addFilter(spec)
- message("-- Filtering using specific filter for SOCIABLE, TACITURN --")
- for lvl in LEVEL_RANGE:
- message("-- setting logging level to '%s' -----" %
- logging.getLevelName(lvl))
- log.setLevel(lvl)
- doLog(log)
+ def test_specific_filters(self):
+ # Set a specific filter object on the handler, and then add another
+ # filter object on the logger itself.
+ handler = self.root_logger.handlers[0]
+ specific_filter = None
+ garr = GarrulousFilter()
+ handler.addFilter(garr)
+ try:
+ self.log_at_all_levels(self.root_logger)
+ first_lines = [
+ # Notice how 'Garrulous' is missing
+ ('Boring', '1'),
+ ('Chatterbox', '2'),
+ ('Talkative', '4'),
+ ('Verbose', '5'),
+ ('Sociable', '6'),
+ ('Effusive', '7'),
+ ('Terse', '8'),
+ ('Taciturn', '9'),
+ ('Silent', '10'),
+ ]
+ self.assert_log_lines(first_lines)
- log.removeFilter(spec)
- hdlr.removeFilter(garr)
- #Undo the one level which clashes...for regression tests
- logging.addLevelName(logging.DEBUG, "DEBUG")
+ specific_filter = VerySpecificFilter()
+ self.root_logger.addFilter(specific_filter)
+ self.log_at_all_levels(self.root_logger)
+ self.assert_log_lines(first_lines + [
+ # Not only 'Garrulous' is still missing, but also 'Sociable'
+ # and 'Taciturn'
+ ('Boring', '11'),
+ ('Chatterbox', '12'),
+ ('Talkative', '14'),
+ ('Verbose', '15'),
+ ('Effusive', '17'),
+ ('Terse', '18'),
+ ('Silent', '20'),
+ ])
+ finally:
+ if specific_filter:
+ self.root_logger.removeFilter(specific_filter)
+ handler.removeFilter(garr)
-#----------------------------------------------------------------------------
-# Test 2
-#----------------------------------------------------------------------------
-MSG = "-- logging %d at INFO, messages should be seen every 10 events --"
-def test2():
- logger = logging.getLogger("")
- sh = logger.handlers[0]
- sh.close()
- logger.removeHandler(sh)
- mh = logging.handlers.MemoryHandler(10,logging.WARNING, sh)
- logger.setLevel(logging.DEBUG)
- logger.addHandler(mh)
- message("-- logging at DEBUG, nothing should be seen yet --")
- logger.debug("Debug message")
- message("-- logging at INFO, nothing should be seen yet --")
- logger.info("Info message")
- message("-- logging at WARNING, 3 messages should be seen --")
- logger.warn("Warn message")
- for i in xrange(102):
- message(MSG % i)
- logger.info("Info index = %d", i)
- mh.close()
- logger.removeHandler(mh)
- logger.addHandler(sh)
+class MemoryHandlerTest(BaseTest):
-#----------------------------------------------------------------------------
-# Test 3
-#----------------------------------------------------------------------------
+ """Tests for the MemoryHandler."""
-FILTER = "a.b"
+ # Do not bother with a logger name group.
+ expected_log_pat = r"^[\w.]+ -> ([\w]+): ([\d]+)$"
-def doLog3():
- logging.getLogger("a").info("Info 1")
- logging.getLogger("a.b").info("Info 2")
- logging.getLogger("a.c").info("Info 3")
- logging.getLogger("a.b.c").info("Info 4")
- logging.getLogger("a.b.c.d").info("Info 5")
- logging.getLogger("a.bb.c").info("Info 6")
- logging.getLogger("b").info("Info 7")
- logging.getLogger("b.a").info("Info 8")
- logging.getLogger("c.a.b").info("Info 9")
- logging.getLogger("a.bb").info("Info 10")
+ def setUp(self):
+ BaseTest.setUp(self)
+ self.mem_hdlr = logging.handlers.MemoryHandler(10, logging.WARNING,
+ self.root_hdlr)
+ self.mem_logger = logging.getLogger('mem')
+ self.mem_logger.propagate = 0
+ self.mem_logger.addHandler(self.mem_hdlr)
-def test3():
- root = logging.getLogger()
- root.setLevel(logging.DEBUG)
- hand = root.handlers[0]
- message("Unfiltered...")
- doLog3()
- message("Filtered with '%s'..." % FILTER)
- filt = logging.Filter(FILTER)
- hand.addFilter(filt)
- doLog3()
- hand.removeFilter(filt)
+ def tearDown(self):
+ self.mem_hdlr.close()
+ BaseTest.tearDown(self)
-#----------------------------------------------------------------------------
-# Test 4
-#----------------------------------------------------------------------------
+ def test_flush(self):
+ # The memory handler flushes to its target handler based on specific
+ # criteria (message count and message level).
+ self.mem_logger.debug(self.next_message())
+ self.assert_log_lines([])
+ self.mem_logger.info(self.next_message())
+ self.assert_log_lines([])
+ # This will flush because the level is >= logging.WARNING
+ self.mem_logger.warn(self.next_message())
+ lines = [
+ ('DEBUG', '1'),
+ ('INFO', '2'),
+ ('WARNING', '3'),
+ ]
+ self.assert_log_lines(lines)
+ for n in (4, 14):
+ for i in range(9):
+ self.mem_logger.debug(self.next_message())
+ self.assert_log_lines(lines)
+ # This will flush because it's the 10th message since the last
+ # flush.
+ self.mem_logger.debug(self.next_message())
+ lines = lines + [('DEBUG', str(i)) for i in range(n, n + 10)]
+ self.assert_log_lines(lines)
-# config0 is a standard configuration.
-config0 = """
-[loggers]
-keys=root
+ self.mem_logger.debug(self.next_message())
+ self.assert_log_lines(lines)
-[handlers]
-keys=hand1
-[formatters]
-keys=form1
+class ExceptionFormatter(logging.Formatter):
+ """A special exception formatter."""
+ def formatException(self, ei):
+ return "Got a [%s]" % ei[0].__name__
-[logger_root]
-level=NOTSET
-handlers=hand1
-[handler_hand1]
-class=StreamHandler
-level=NOTSET
-formatter=form1
-args=(sys.stdout,)
+class ConfigFileTest(BaseTest):
-[formatter_form1]
-format=%(levelname)s:%(name)s:%(message)s
-datefmt=
-"""
+ """Reading logging config from a .ini-style config file."""
-# config1 adds a little to the standard configuration.
-config1 = """
-[loggers]
-keys=root,parser
+ expected_log_pat = r"^([\w]+) \+\+ ([\w]+)$"
-[handlers]
-keys=hand1
+ # config0 is a standard configuration.
+ config0 = """
+ [loggers]
+ keys=root
-[formatters]
-keys=form1
+ [handlers]
+ keys=hand1
-[logger_root]
-level=NOTSET
-handlers=hand1
+ [formatters]
+ keys=form1
-[logger_parser]
-level=DEBUG
-handlers=hand1
-propagate=1
-qualname=compiler.parser
+ [logger_root]
+ level=WARNING
+ handlers=hand1
-[handler_hand1]
-class=StreamHandler
-level=NOTSET
-formatter=form1
-args=(sys.stdout,)
+ [handler_hand1]
+ class=StreamHandler
+ level=NOTSET
+ formatter=form1
+ args=(sys.stdout,)
-[formatter_form1]
-format=%(levelname)s:%(name)s:%(message)s
-datefmt=
-"""
+ [formatter_form1]
+ format=%(levelname)s ++ %(message)s
+ datefmt=
+ """
-# config2 has a subtle configuration error that should be reported
-config2 = string.replace(config1, "sys.stdout", "sys.stbout")
+ # config1 adds a little to the standard configuration.
+ config1 = """
+ [loggers]
+ keys=root,parser
-# config3 has a less subtle configuration error
-config3 = string.replace(
- config1, "formatter=form1", "formatter=misspelled_name")
+ [handlers]
+ keys=hand1
-def test4():
- for i in range(4):
- conf = globals()['config%d' % i]
- sys.stdout.write('config%d: ' % i)
- loggerDict = logging.getLogger().manager.loggerDict
- logging._acquireLock()
+ [formatters]
+ keys=form1
+
+ [logger_root]
+ level=WARNING
+ handlers=
+
+ [logger_parser]
+ level=DEBUG
+ handlers=hand1
+ propagate=1
+ qualname=compiler.parser
+
+ [handler_hand1]
+ class=StreamHandler
+ level=NOTSET
+ formatter=form1
+ args=(sys.stdout,)
+
+ [formatter_form1]
+ format=%(levelname)s ++ %(message)s
+ datefmt=
+ """
+
+ # config2 has a subtle configuration error that should be reported
+ config2 = config1.replace("sys.stdout", "sys.stbout")
+
+ # config3 has a less subtle configuration error
+ config3 = config1.replace("formatter=form1", "formatter=misspelled_name")
+
+ # config4 specifies a custom formatter class to be loaded
+ config4 = """
+ [loggers]
+ keys=root
+
+ [handlers]
+ keys=hand1
+
+ [formatters]
+ keys=form1
+
+ [logger_root]
+ level=NOTSET
+ handlers=hand1
+
+ [handler_hand1]
+ class=StreamHandler
+ level=NOTSET
+ formatter=form1
+ args=(sys.stdout,)
+
+ [formatter_form1]
+ class=""" + __name__ + """.ExceptionFormatter
+ format=%(levelname)s:%(name)s:%(message)s
+ datefmt=
+ """
+
+ # config5 specifies a custom handler class to be loaded
+ config5 = config1.replace('class=StreamHandler', 'class=logging.StreamHandler')
+
+ # config6 uses ', ' delimiters in the handlers and formatters sections
+ config6 = """
+ [loggers]
+ keys=root,parser
+
+ [handlers]
+ keys=hand1, hand2
+
+ [formatters]
+ keys=form1, form2
+
+ [logger_root]
+ level=WARNING
+ handlers=
+
+ [logger_parser]
+ level=DEBUG
+ handlers=hand1
+ propagate=1
+ qualname=compiler.parser
+
+ [handler_hand1]
+ class=StreamHandler
+ level=NOTSET
+ formatter=form1
+ args=(sys.stdout,)
+
+ [handler_hand2]
+ class=StreamHandler
+ level=NOTSET
+ formatter=form1
+ args=(sys.stderr,)
+
+ [formatter_form1]
+ format=%(levelname)s ++ %(message)s
+ datefmt=
+
+ [formatter_form2]
+ format=%(message)s
+ datefmt=
+ """
+
+ def apply_config(self, conf):
+ file = cStringIO.StringIO(textwrap.dedent(conf))
+ logging.config.fileConfig(file)
+
+ def test_config0_ok(self):
+ # A simple config file which overrides the default settings.
+ with captured_stdout() as output:
+ self.apply_config(self.config0)
+ logger = logging.getLogger()
+ # Won't output anything
+ logger.info(self.next_message())
+ # Outputs a message
+ logger.error(self.next_message())
+ self.assert_log_lines([
+ ('ERROR', '2'),
+ ], stream=output)
+ # Original logger output is empty.
+ self.assert_log_lines([])
+
+ def test_config1_ok(self, config=config1):
+ # A config file defining a sub-parser as well.
+ with captured_stdout() as output:
+ self.apply_config(config)
+ logger = logging.getLogger("compiler.parser")
+ # Both will output a message
+ logger.info(self.next_message())
+ logger.error(self.next_message())
+ self.assert_log_lines([
+ ('INFO', '1'),
+ ('ERROR', '2'),
+ ], stream=output)
+ # Original logger output is empty.
+ self.assert_log_lines([])
+
+ def test_config2_failure(self):
+ # A simple config file which overrides the default settings.
+ self.assertRaises(StandardError, self.apply_config, self.config2)
+
+ def test_config3_failure(self):
+ # A simple config file which overrides the default settings.
+ self.assertRaises(StandardError, self.apply_config, self.config3)
+
+ def test_config4_ok(self):
+ # A config file specifying a custom formatter class.
+ with captured_stdout() as output:
+ self.apply_config(self.config4)
+ logger = logging.getLogger()
+ try:
+ raise RuntimeError()
+ except RuntimeError:
+ logging.exception("just testing")
+ sys.stdout.seek(0)
+ self.assertEqual(output.getvalue(),
+ "ERROR:root:just testing\nGot a [RuntimeError]\n")
+ # Original logger output is empty
+ self.assert_log_lines([])
+
+ def test_config5_ok(self):
+ self.test_config1_ok(config=self.config5)
+
+ def test_config6_ok(self):
+ self.test_config1_ok(config=self.config6)
+
+class LogRecordStreamHandler(StreamRequestHandler):
+
+ """Handler for a streaming logging request. It saves the log message in the
+ TCP server's 'log_output' attribute."""
+
+ TCP_LOG_END = "!!!END!!!"
+
+ def handle(self):
+ """Handle multiple requests - each expected to be of 4-byte length,
+ followed by the LogRecord in pickle format. Logs the record
+ according to whatever policy is configured locally."""
+ while True:
+ chunk = self.connection.recv(4)
+ if len(chunk) < 4:
+ break
+ slen = struct.unpack(">L", chunk)[0]
+ chunk = self.connection.recv(slen)
+ while len(chunk) < slen:
+ chunk = chunk + self.connection.recv(slen - len(chunk))
+ obj = self.unpickle(chunk)
+ record = logging.makeLogRecord(obj)
+ self.handle_log_record(record)
+
+ def unpickle(self, data):
+ return cPickle.loads(data)
+
+ def handle_log_record(self, record):
+ # If the end-of-messages sentinel is seen, tell the server to
+ # terminate.
+ if self.TCP_LOG_END in record.msg:
+ self.server.abort = 1
+ return
+ self.server.log_output += record.msg + "\n"
+
+
+class LogRecordSocketReceiver(ThreadingTCPServer):
+
+ """A simple-minded TCP socket-based logging receiver suitable for test
+ purposes."""
+
+ allow_reuse_address = 1
+ log_output = ""
+
+ def __init__(self, host='localhost',
+ port=logging.handlers.DEFAULT_TCP_LOGGING_PORT,
+ handler=LogRecordStreamHandler):
+ ThreadingTCPServer.__init__(self, (host, port), handler)
+ self.abort = False
+ self.timeout = 0.1
+ self.finished = threading.Event()
+
+ def serve_until_stopped(self):
+ if sys.platform.startswith('java'):
+ # XXX: There's a problem using cpython_compatibile_select
+ # here: it seems to be due to the fact that
+ # cpython_compatible_select switches blocking mode on while
+ # a separate thread is reading from the same socket, causing
+ # a read of 0 in LogRecordStreamHandler.handle (which
+ # deadlocks this test)
+ self.socket.setblocking(0)
+ while not self.abort:
+ rd, wr, ex = select.select([self.socket.fileno()], [], [],
+ self.timeout)
+ if rd:
+ self.handle_request()
+ # Notify the main thread that we're about to exit
+ self.finished.set()
+ # close the listen socket
+ self.server_close()
+
+
+ at unittest.skipUnless(threading, 'Threading required for this test.')
+class SocketHandlerTest(BaseTest):
+
+ """Test for SocketHandler objects."""
+
+ def setUp(self):
+ """Set up a TCP server to receive log messages, and a SocketHandler
+ pointing to that server's address and port."""
+ BaseTest.setUp(self)
+ self.tcpserver = LogRecordSocketReceiver(port=0)
+ self.port = self.tcpserver.socket.getsockname()[1]
+ self.threads = [
+ threading.Thread(target=self.tcpserver.serve_until_stopped)]
+ for thread in self.threads:
+ thread.start()
+
+ self.sock_hdlr = logging.handlers.SocketHandler('localhost', self.port)
+ self.sock_hdlr.setFormatter(self.root_formatter)
+ self.root_logger.removeHandler(self.root_logger.handlers[0])
+ self.root_logger.addHandler(self.sock_hdlr)
+
+ def tearDown(self):
+ """Shutdown the TCP server."""
try:
- saved_handlers = logging._handlers.copy()
- saved_handler_list = logging._handlerList[:]
- saved_loggers = loggerDict.copy()
+ self.tcpserver.abort = True
+ del self.tcpserver
+ self.root_logger.removeHandler(self.sock_hdlr)
+ self.sock_hdlr.close()
+ for thread in self.threads:
+ thread.join(2.0)
finally:
- logging._releaseLock()
+ BaseTest.tearDown(self)
+
+ def get_output(self):
+ """Get the log output as received by the TCP server."""
+ # Signal the TCP receiver and wait for it to terminate.
+ self.root_logger.critical(LogRecordStreamHandler.TCP_LOG_END)
+ self.tcpserver.finished.wait(2.0)
+ return self.tcpserver.log_output
+
+ def test_output(self):
+ # The log message sent to the SocketHandler is properly received.
+ logger = logging.getLogger("tcp")
+ logger.error("spam")
+ logger.debug("eggs")
+ self.assertEqual(self.get_output(), "spam\neggs\n")
+
+
+class MemoryTest(BaseTest):
+
+ """Test memory persistence of logger objects."""
+
+ def setUp(self):
+ """Create a dict to remember potentially destroyed objects."""
+ BaseTest.setUp(self)
+ self._survivors = {}
+
+ def _watch_for_survival(self, *args):
+ """Watch the given objects for survival, by creating weakrefs to
+ them."""
+ for obj in args:
+ key = id(obj), repr(obj)
+ self._survivors[key] = weakref.ref(obj)
+
+ def _assertTruesurvival(self):
+ """Assert that all objects watched for survival have survived."""
+ # Trigger cycle breaking.
+ gc.collect()
+ dead = []
+ for (id_, repr_), ref in self._survivors.items():
+ if ref() is None:
+ dead.append(repr_)
+ if dead:
+ self.fail("%d objects should have survived "
+ "but have been destroyed: %s" % (len(dead), ", ".join(dead)))
+
+ def test_persistent_loggers(self):
+ # Logger objects are persistent and retain their configuration, even
+ # if visible references are destroyed.
+ self.root_logger.setLevel(logging.INFO)
+ foo = logging.getLogger("foo")
+ self._watch_for_survival(foo)
+ foo.setLevel(logging.DEBUG)
+ self.root_logger.debug(self.next_message())
+ foo.debug(self.next_message())
+ self.assert_log_lines([
+ ('foo', 'DEBUG', '2'),
+ ])
+ del foo
+ # foo has survived.
+ self._assertTruesurvival()
+ # foo has retained its settings.
+ bar = logging.getLogger("foo")
+ bar.debug(self.next_message())
+ self.assert_log_lines([
+ ('foo', 'DEBUG', '2'),
+ ('foo', 'DEBUG', '3'),
+ ])
+
+
+class EncodingTest(BaseTest):
+ def test_encoding_plain_file(self):
+ # In Python 2.x, a plain file object is treated as having no encoding.
+ log = logging.getLogger("test")
+ fn = tempfile.mktemp(".log")
+ # the non-ascii data we write to the log.
+ data = "foo\x80"
try:
- fn = tempfile.mktemp(".ini")
- f = open(fn, "w")
- f.write(conf)
- f.close()
+ handler = logging.FileHandler(fn)
+ log.addHandler(handler)
try:
- logging.config.fileConfig(fn)
- #call again to make sure cleanup is correct
- logging.config.fileConfig(fn)
- except:
- t = sys.exc_info()[0]
- message(str(t))
- else:
- message('ok.')
- os.remove(fn)
+ # write non-ascii data to the log.
+ log.warning(data)
+ finally:
+ log.removeHandler(handler)
+ handler.close()
+ # check we wrote exactly those bytes, ignoring trailing \n etc
+ f = open(fn)
+ try:
+ self.assertEqual(f.read().rstrip(), data)
+ finally:
+ f.close()
finally:
- logging._acquireLock()
+ if os.path.isfile(fn):
+ os.remove(fn)
+
+ def test_encoding_cyrillic_unicode(self):
+ log = logging.getLogger("test")
+ #Get a message in Unicode: Do svidanya in Cyrillic (meaning goodbye)
+ message = u'\u0434\u043e \u0441\u0432\u0438\u0434\u0430\u043d\u0438\u044f'
+ #Ensure it's written in a Cyrillic encoding
+ writer_class = codecs.getwriter('cp1251')
+ writer_class.encoding = 'cp1251'
+ stream = cStringIO.StringIO()
+ writer = writer_class(stream, 'strict')
+ handler = logging.StreamHandler(writer)
+ log.addHandler(handler)
+ try:
+ log.warning(message)
+ finally:
+ log.removeHandler(handler)
+ handler.close()
+ # check we wrote exactly those bytes, ignoring trailing \n etc
+ s = stream.getvalue()
+ #Compare against what the data should be when encoded in CP-1251
+ self.assertEqual(s, '\xe4\xee \xf1\xe2\xe8\xe4\xe0\xed\xe8\xff\n')
+
+
+class WarningsTest(BaseTest):
+
+ def test_warnings(self):
+ with warnings.catch_warnings():
+ logging.captureWarnings(True)
try:
- logging._handlers.clear()
- logging._handlers.update(saved_handlers)
- logging._handlerList[:] = saved_handler_list
- loggerDict = logging.getLogger().manager.loggerDict
- loggerDict.clear()
- loggerDict.update(saved_loggers)
+ warnings.filterwarnings("always", category=UserWarning)
+ file = cStringIO.StringIO()
+ h = logging.StreamHandler(file)
+ logger = logging.getLogger("py.warnings")
+ logger.addHandler(h)
+ warnings.warn("I'm warning you...")
+ logger.removeHandler(h)
+ s = file.getvalue()
+ h.close()
+ self.assertTrue(s.find("UserWarning: I'm warning you...\n") > 0)
+
+ #See if an explicit file uses the original implementation
+ file = cStringIO.StringIO()
+ warnings.showwarning("Explicit", UserWarning, "dummy.py", 42,
+ file, "Dummy line")
+ s = file.getvalue()
+ file.close()
+ self.assertEqual(s,
+ "dummy.py:42: UserWarning: Explicit\n Dummy line\n")
finally:
- logging._releaseLock()
+ logging.captureWarnings(False)
-#----------------------------------------------------------------------------
-# Test 5
-#----------------------------------------------------------------------------
-test5_config = """
-[loggers]
-keys=root
+def formatFunc(format, datefmt=None):
+ return logging.Formatter(format, datefmt)
-[handlers]
-keys=hand1
+def handlerFunc():
+ return logging.StreamHandler()
-[formatters]
-keys=form1
+class CustomHandler(logging.StreamHandler):
+ pass
-[logger_root]
-level=NOTSET
-handlers=hand1
+class ConfigDictTest(BaseTest):
-[handler_hand1]
-class=StreamHandler
-level=NOTSET
-formatter=form1
-args=(sys.stdout,)
+ """Reading logging config from a dictionary."""
-[formatter_form1]
-class=test.test_logging.FriendlyFormatter
-format=%(levelname)s:%(name)s:%(message)s
-datefmt=
-"""
+ expected_log_pat = r"^([\w]+) \+\+ ([\w]+)$"
-class FriendlyFormatter (logging.Formatter):
- def formatException(self, ei):
- return "%s... Don't panic!" % str(ei[0])
+ # config0 is a standard configuration.
+ config0 = {
+ 'version': 1,
+ 'formatters': {
+ 'form1' : {
+ 'format' : '%(levelname)s ++ %(message)s',
+ },
+ },
+ 'handlers' : {
+ 'hand1' : {
+ 'class' : 'logging.StreamHandler',
+ 'formatter' : 'form1',
+ 'level' : 'NOTSET',
+ 'stream' : 'ext://sys.stdout',
+ },
+ },
+ 'root' : {
+ 'level' : 'WARNING',
+ 'handlers' : ['hand1'],
+ },
+ }
+ # config1 adds a little to the standard configuration.
+ config1 = {
+ 'version': 1,
+ 'formatters': {
+ 'form1' : {
+ 'format' : '%(levelname)s ++ %(message)s',
+ },
+ },
+ 'handlers' : {
+ 'hand1' : {
+ 'class' : 'logging.StreamHandler',
+ 'formatter' : 'form1',
+ 'level' : 'NOTSET',
+ 'stream' : 'ext://sys.stdout',
+ },
+ },
+ 'loggers' : {
+ 'compiler.parser' : {
+ 'level' : 'DEBUG',
+ 'handlers' : ['hand1'],
+ },
+ },
+ 'root' : {
+ 'level' : 'WARNING',
+ },
+ }
-def test5():
- loggerDict = logging.getLogger().manager.loggerDict
- logging._acquireLock()
- try:
- saved_handlers = logging._handlers.copy()
- saved_handler_list = logging._handlerList[:]
- saved_loggers = loggerDict.copy()
- finally:
- logging._releaseLock()
- try:
- fn = tempfile.mktemp(".ini")
- f = open(fn, "w")
- f.write(test5_config)
- f.close()
- logging.config.fileConfig(fn)
+ # config2 has a subtle configuration error that should be reported
+ config2 = {
+ 'version': 1,
+ 'formatters': {
+ 'form1' : {
+ 'format' : '%(levelname)s ++ %(message)s',
+ },
+ },
+ 'handlers' : {
+ 'hand1' : {
+ 'class' : 'logging.StreamHandler',
+ 'formatter' : 'form1',
+ 'level' : 'NOTSET',
+ 'stream' : 'ext://sys.stdbout',
+ },
+ },
+ 'loggers' : {
+ 'compiler.parser' : {
+ 'level' : 'DEBUG',
+ 'handlers' : ['hand1'],
+ },
+ },
+ 'root' : {
+ 'level' : 'WARNING',
+ },
+ }
+
+ #As config1 but with a misspelt level on a handler
+ config2a = {
+ 'version': 1,
+ 'formatters': {
+ 'form1' : {
+ 'format' : '%(levelname)s ++ %(message)s',
+ },
+ },
+ 'handlers' : {
+ 'hand1' : {
+ 'class' : 'logging.StreamHandler',
+ 'formatter' : 'form1',
+ 'level' : 'NTOSET',
+ 'stream' : 'ext://sys.stdout',
+ },
+ },
+ 'loggers' : {
+ 'compiler.parser' : {
+ 'level' : 'DEBUG',
+ 'handlers' : ['hand1'],
+ },
+ },
+ 'root' : {
+ 'level' : 'WARNING',
+ },
+ }
+
+
+ #As config1 but with a misspelt level on a logger
+ config2b = {
+ 'version': 1,
+ 'formatters': {
+ 'form1' : {
+ 'format' : '%(levelname)s ++ %(message)s',
+ },
+ },
+ 'handlers' : {
+ 'hand1' : {
+ 'class' : 'logging.StreamHandler',
+ 'formatter' : 'form1',
+ 'level' : 'NOTSET',
+ 'stream' : 'ext://sys.stdout',
+ },
+ },
+ 'loggers' : {
+ 'compiler.parser' : {
+ 'level' : 'DEBUG',
+ 'handlers' : ['hand1'],
+ },
+ },
+ 'root' : {
+ 'level' : 'WRANING',
+ },
+ }
+
+ # config3 has a less subtle configuration error
+ config3 = {
+ 'version': 1,
+ 'formatters': {
+ 'form1' : {
+ 'format' : '%(levelname)s ++ %(message)s',
+ },
+ },
+ 'handlers' : {
+ 'hand1' : {
+ 'class' : 'logging.StreamHandler',
+ 'formatter' : 'misspelled_name',
+ 'level' : 'NOTSET',
+ 'stream' : 'ext://sys.stdout',
+ },
+ },
+ 'loggers' : {
+ 'compiler.parser' : {
+ 'level' : 'DEBUG',
+ 'handlers' : ['hand1'],
+ },
+ },
+ 'root' : {
+ 'level' : 'WARNING',
+ },
+ }
+
+ # config4 specifies a custom formatter class to be loaded
+ config4 = {
+ 'version': 1,
+ 'formatters': {
+ 'form1' : {
+ '()' : __name__ + '.ExceptionFormatter',
+ 'format' : '%(levelname)s:%(name)s:%(message)s',
+ },
+ },
+ 'handlers' : {
+ 'hand1' : {
+ 'class' : 'logging.StreamHandler',
+ 'formatter' : 'form1',
+ 'level' : 'NOTSET',
+ 'stream' : 'ext://sys.stdout',
+ },
+ },
+ 'root' : {
+ 'level' : 'NOTSET',
+ 'handlers' : ['hand1'],
+ },
+ }
+
+ # As config4 but using an actual callable rather than a string
+ config4a = {
+ 'version': 1,
+ 'formatters': {
+ 'form1' : {
+ '()' : ExceptionFormatter,
+ 'format' : '%(levelname)s:%(name)s:%(message)s',
+ },
+ 'form2' : {
+ '()' : __name__ + '.formatFunc',
+ 'format' : '%(levelname)s:%(name)s:%(message)s',
+ },
+ 'form3' : {
+ '()' : formatFunc,
+ 'format' : '%(levelname)s:%(name)s:%(message)s',
+ },
+ },
+ 'handlers' : {
+ 'hand1' : {
+ 'class' : 'logging.StreamHandler',
+ 'formatter' : 'form1',
+ 'level' : 'NOTSET',
+ 'stream' : 'ext://sys.stdout',
+ },
+ 'hand2' : {
+ '()' : handlerFunc,
+ },
+ },
+ 'root' : {
+ 'level' : 'NOTSET',
+ 'handlers' : ['hand1'],
+ },
+ }
+
+ # config5 specifies a custom handler class to be loaded
+ config5 = {
+ 'version': 1,
+ 'formatters': {
+ 'form1' : {
+ 'format' : '%(levelname)s ++ %(message)s',
+ },
+ },
+ 'handlers' : {
+ 'hand1' : {
+ 'class' : __name__ + '.CustomHandler',
+ 'formatter' : 'form1',
+ 'level' : 'NOTSET',
+ 'stream' : 'ext://sys.stdout',
+ },
+ },
+ 'loggers' : {
+ 'compiler.parser' : {
+ 'level' : 'DEBUG',
+ 'handlers' : ['hand1'],
+ },
+ },
+ 'root' : {
+ 'level' : 'WARNING',
+ },
+ }
+
+ # config6 specifies a custom handler class to be loaded
+ # but has bad arguments
+ config6 = {
+ 'version': 1,
+ 'formatters': {
+ 'form1' : {
+ 'format' : '%(levelname)s ++ %(message)s',
+ },
+ },
+ 'handlers' : {
+ 'hand1' : {
+ 'class' : __name__ + '.CustomHandler',
+ 'formatter' : 'form1',
+ 'level' : 'NOTSET',
+ 'stream' : 'ext://sys.stdout',
+ '9' : 'invalid parameter name',
+ },
+ },
+ 'loggers' : {
+ 'compiler.parser' : {
+ 'level' : 'DEBUG',
+ 'handlers' : ['hand1'],
+ },
+ },
+ 'root' : {
+ 'level' : 'WARNING',
+ },
+ }
+
+ #config 7 does not define compiler.parser but defines compiler.lexer
+ #so compiler.parser should be disabled after applying it
+ config7 = {
+ 'version': 1,
+ 'formatters': {
+ 'form1' : {
+ 'format' : '%(levelname)s ++ %(message)s',
+ },
+ },
+ 'handlers' : {
+ 'hand1' : {
+ 'class' : 'logging.StreamHandler',
+ 'formatter' : 'form1',
+ 'level' : 'NOTSET',
+ 'stream' : 'ext://sys.stdout',
+ },
+ },
+ 'loggers' : {
+ 'compiler.lexer' : {
+ 'level' : 'DEBUG',
+ 'handlers' : ['hand1'],
+ },
+ },
+ 'root' : {
+ 'level' : 'WARNING',
+ },
+ }
+
+ config8 = {
+ 'version': 1,
+ 'disable_existing_loggers' : False,
+ 'formatters': {
+ 'form1' : {
+ 'format' : '%(levelname)s ++ %(message)s',
+ },
+ },
+ 'handlers' : {
+ 'hand1' : {
+ 'class' : 'logging.StreamHandler',
+ 'formatter' : 'form1',
+ 'level' : 'NOTSET',
+ 'stream' : 'ext://sys.stdout',
+ },
+ },
+ 'loggers' : {
+ 'compiler' : {
+ 'level' : 'DEBUG',
+ 'handlers' : ['hand1'],
+ },
+ 'compiler.lexer' : {
+ },
+ },
+ 'root' : {
+ 'level' : 'WARNING',
+ },
+ }
+
+ config9 = {
+ 'version': 1,
+ 'formatters': {
+ 'form1' : {
+ 'format' : '%(levelname)s ++ %(message)s',
+ },
+ },
+ 'handlers' : {
+ 'hand1' : {
+ 'class' : 'logging.StreamHandler',
+ 'formatter' : 'form1',
+ 'level' : 'WARNING',
+ 'stream' : 'ext://sys.stdout',
+ },
+ },
+ 'loggers' : {
+ 'compiler.parser' : {
+ 'level' : 'WARNING',
+ 'handlers' : ['hand1'],
+ },
+ },
+ 'root' : {
+ 'level' : 'NOTSET',
+ },
+ }
+
+ config9a = {
+ 'version': 1,
+ 'incremental' : True,
+ 'handlers' : {
+ 'hand1' : {
+ 'level' : 'WARNING',
+ },
+ },
+ 'loggers' : {
+ 'compiler.parser' : {
+ 'level' : 'INFO',
+ },
+ },
+ }
+
+ config9b = {
+ 'version': 1,
+ 'incremental' : True,
+ 'handlers' : {
+ 'hand1' : {
+ 'level' : 'INFO',
+ },
+ },
+ 'loggers' : {
+ 'compiler.parser' : {
+ 'level' : 'INFO',
+ },
+ },
+ }
+
+ #As config1 but with a filter added
+ config10 = {
+ 'version': 1,
+ 'formatters': {
+ 'form1' : {
+ 'format' : '%(levelname)s ++ %(message)s',
+ },
+ },
+ 'filters' : {
+ 'filt1' : {
+ 'name' : 'compiler.parser',
+ },
+ },
+ 'handlers' : {
+ 'hand1' : {
+ 'class' : 'logging.StreamHandler',
+ 'formatter' : 'form1',
+ 'level' : 'NOTSET',
+ 'stream' : 'ext://sys.stdout',
+ 'filters' : ['filt1'],
+ },
+ },
+ 'loggers' : {
+ 'compiler.parser' : {
+ 'level' : 'DEBUG',
+ 'filters' : ['filt1'],
+ },
+ },
+ 'root' : {
+ 'level' : 'WARNING',
+ 'handlers' : ['hand1'],
+ },
+ }
+
+ #As config1 but using cfg:// references
+ config11 = {
+ 'version': 1,
+ 'true_formatters': {
+ 'form1' : {
+ 'format' : '%(levelname)s ++ %(message)s',
+ },
+ },
+ 'handler_configs': {
+ 'hand1' : {
+ 'class' : 'logging.StreamHandler',
+ 'formatter' : 'form1',
+ 'level' : 'NOTSET',
+ 'stream' : 'ext://sys.stdout',
+ },
+ },
+ 'formatters' : 'cfg://true_formatters',
+ 'handlers' : {
+ 'hand1' : 'cfg://handler_configs[hand1]',
+ },
+ 'loggers' : {
+ 'compiler.parser' : {
+ 'level' : 'DEBUG',
+ 'handlers' : ['hand1'],
+ },
+ },
+ 'root' : {
+ 'level' : 'WARNING',
+ },
+ }
+
+ #As config11 but missing the version key
+ config12 = {
+ 'true_formatters': {
+ 'form1' : {
+ 'format' : '%(levelname)s ++ %(message)s',
+ },
+ },
+ 'handler_configs': {
+ 'hand1' : {
+ 'class' : 'logging.StreamHandler',
+ 'formatter' : 'form1',
+ 'level' : 'NOTSET',
+ 'stream' : 'ext://sys.stdout',
+ },
+ },
+ 'formatters' : 'cfg://true_formatters',
+ 'handlers' : {
+ 'hand1' : 'cfg://handler_configs[hand1]',
+ },
+ 'loggers' : {
+ 'compiler.parser' : {
+ 'level' : 'DEBUG',
+ 'handlers' : ['hand1'],
+ },
+ },
+ 'root' : {
+ 'level' : 'WARNING',
+ },
+ }
+
+ #As config11 but using an unsupported version
+ config13 = {
+ 'version': 2,
+ 'true_formatters': {
+ 'form1' : {
+ 'format' : '%(levelname)s ++ %(message)s',
+ },
+ },
+ 'handler_configs': {
+ 'hand1' : {
+ 'class' : 'logging.StreamHandler',
+ 'formatter' : 'form1',
+ 'level' : 'NOTSET',
+ 'stream' : 'ext://sys.stdout',
+ },
+ },
+ 'formatters' : 'cfg://true_formatters',
+ 'handlers' : {
+ 'hand1' : 'cfg://handler_configs[hand1]',
+ },
+ 'loggers' : {
+ 'compiler.parser' : {
+ 'level' : 'DEBUG',
+ 'handlers' : ['hand1'],
+ },
+ },
+ 'root' : {
+ 'level' : 'WARNING',
+ },
+ }
+
+ def apply_config(self, conf):
+ logging.config.dictConfig(conf)
+
+ def test_config0_ok(self):
+ # A simple config which overrides the default settings.
+ with captured_stdout() as output:
+ self.apply_config(self.config0)
+ logger = logging.getLogger()
+ # Won't output anything
+ logger.info(self.next_message())
+ # Outputs a message
+ logger.error(self.next_message())
+ self.assert_log_lines([
+ ('ERROR', '2'),
+ ], stream=output)
+ # Original logger output is empty.
+ self.assert_log_lines([])
+
+ def test_config1_ok(self, config=config1):
+ # A config defining a sub-parser as well.
+ with captured_stdout() as output:
+ self.apply_config(config)
+ logger = logging.getLogger("compiler.parser")
+ # Both will output a message
+ logger.info(self.next_message())
+ logger.error(self.next_message())
+ self.assert_log_lines([
+ ('INFO', '1'),
+ ('ERROR', '2'),
+ ], stream=output)
+ # Original logger output is empty.
+ self.assert_log_lines([])
+
+ def test_config2_failure(self):
+ # A simple config which overrides the default settings.
+ self.assertRaises(StandardError, self.apply_config, self.config2)
+
+ def test_config2a_failure(self):
+ # A simple config which overrides the default settings.
+ self.assertRaises(StandardError, self.apply_config, self.config2a)
+
+ def test_config2b_failure(self):
+ # A simple config which overrides the default settings.
+ self.assertRaises(StandardError, self.apply_config, self.config2b)
+
+ def test_config3_failure(self):
+ # A simple config which overrides the default settings.
+ self.assertRaises(StandardError, self.apply_config, self.config3)
+
+ def test_config4_ok(self):
+ # A config specifying a custom formatter class.
+ with captured_stdout() as output:
+ self.apply_config(self.config4)
+ #logger = logging.getLogger()
+ try:
+ raise RuntimeError()
+ except RuntimeError:
+ logging.exception("just testing")
+ sys.stdout.seek(0)
+ self.assertEqual(output.getvalue(),
+ "ERROR:root:just testing\nGot a [RuntimeError]\n")
+ # Original logger output is empty
+ self.assert_log_lines([])
+
+ def test_config4a_ok(self):
+ # A config specifying a custom formatter class.
+ with captured_stdout() as output:
+ self.apply_config(self.config4a)
+ #logger = logging.getLogger()
+ try:
+ raise RuntimeError()
+ except RuntimeError:
+ logging.exception("just testing")
+ sys.stdout.seek(0)
+ self.assertEqual(output.getvalue(),
+ "ERROR:root:just testing\nGot a [RuntimeError]\n")
+ # Original logger output is empty
+ self.assert_log_lines([])
+
+ def test_config5_ok(self):
+ self.test_config1_ok(config=self.config5)
+
+ def test_config6_failure(self):
+ self.assertRaises(StandardError, self.apply_config, self.config6)
+
+ def test_config7_ok(self):
+ with captured_stdout() as output:
+ self.apply_config(self.config1)
+ logger = logging.getLogger("compiler.parser")
+ # Both will output a message
+ logger.info(self.next_message())
+ logger.error(self.next_message())
+ self.assert_log_lines([
+ ('INFO', '1'),
+ ('ERROR', '2'),
+ ], stream=output)
+ # Original logger output is empty.
+ self.assert_log_lines([])
+ with captured_stdout() as output:
+ self.apply_config(self.config7)
+ logger = logging.getLogger("compiler.parser")
+ self.assertTrue(logger.disabled)
+ logger = logging.getLogger("compiler.lexer")
+ # Both will output a message
+ logger.info(self.next_message())
+ logger.error(self.next_message())
+ self.assert_log_lines([
+ ('INFO', '3'),
+ ('ERROR', '4'),
+ ], stream=output)
+ # Original logger output is empty.
+ self.assert_log_lines([])
+
+ #Same as test_config_7_ok but don't disable old loggers.
+ def test_config_8_ok(self):
+ with captured_stdout() as output:
+ self.apply_config(self.config1)
+ logger = logging.getLogger("compiler.parser")
+ # Both will output a message
+ logger.info(self.next_message())
+ logger.error(self.next_message())
+ self.assert_log_lines([
+ ('INFO', '1'),
+ ('ERROR', '2'),
+ ], stream=output)
+ # Original logger output is empty.
+ self.assert_log_lines([])
+ with captured_stdout() as output:
+ self.apply_config(self.config8)
+ logger = logging.getLogger("compiler.parser")
+ self.assertFalse(logger.disabled)
+ # Both will output a message
+ logger.info(self.next_message())
+ logger.error(self.next_message())
+ logger = logging.getLogger("compiler.lexer")
+ # Both will output a message
+ logger.info(self.next_message())
+ logger.error(self.next_message())
+ self.assert_log_lines([
+ ('INFO', '3'),
+ ('ERROR', '4'),
+ ('INFO', '5'),
+ ('ERROR', '6'),
+ ], stream=output)
+ # Original logger output is empty.
+ self.assert_log_lines([])
+
+ def test_config_9_ok(self):
+ with captured_stdout() as output:
+ self.apply_config(self.config9)
+ logger = logging.getLogger("compiler.parser")
+ #Nothing will be output since both handler and logger are set to WARNING
+ logger.info(self.next_message())
+ self.assert_log_lines([], stream=output)
+ self.apply_config(self.config9a)
+ #Nothing will be output since both handler is still set to WARNING
+ logger.info(self.next_message())
+ self.assert_log_lines([], stream=output)
+ self.apply_config(self.config9b)
+ #Message should now be output
+ logger.info(self.next_message())
+ self.assert_log_lines([
+ ('INFO', '3'),
+ ], stream=output)
+
+ def test_config_10_ok(self):
+ with captured_stdout() as output:
+ self.apply_config(self.config10)
+ logger = logging.getLogger("compiler.parser")
+ logger.warning(self.next_message())
+ logger = logging.getLogger('compiler')
+ #Not output, because filtered
+ logger.warning(self.next_message())
+ logger = logging.getLogger('compiler.lexer')
+ #Not output, because filtered
+ logger.warning(self.next_message())
+ logger = logging.getLogger("compiler.parser.codegen")
+ #Output, as not filtered
+ logger.error(self.next_message())
+ self.assert_log_lines([
+ ('WARNING', '1'),
+ ('ERROR', '4'),
+ ], stream=output)
+
+ def test_config11_ok(self):
+ self.test_config1_ok(self.config11)
+
+ def test_config12_failure(self):
+ self.assertRaises(StandardError, self.apply_config, self.config12)
+
+ def test_config13_failure(self):
+ self.assertRaises(StandardError, self.apply_config, self.config13)
+
+ @unittest.skipUnless(threading, 'listen() needs threading to work')
+ def setup_via_listener(self, text):
+ # Ask for a randomly assigned port (by using port 0)
+ t = logging.config.listen(0)
+ t.start()
+ t.ready.wait()
+ # Now get the port allocated
+ port = t.port
+ t.ready.clear()
try:
- raise KeyError
- except KeyError:
- logging.exception("just testing")
- os.remove(fn)
- hdlr = logging.getLogger().handlers[0]
- logging.getLogger().handlers.remove(hdlr)
- finally:
- logging._acquireLock()
- try:
- logging._handlers.clear()
- logging._handlers.update(saved_handlers)
- logging._handlerList[:] = saved_handler_list
- loggerDict = logging.getLogger().manager.loggerDict
- loggerDict.clear()
- loggerDict.update(saved_loggers)
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ sock.settimeout(2.0)
+ sock.connect(('localhost', port))
+
+ slen = struct.pack('>L', len(text))
+ s = slen + text
+ sentsofar = 0
+ left = len(s)
+ while left > 0:
+ sent = sock.send(s[sentsofar:])
+ sentsofar += sent
+ left -= sent
+ sock.close()
finally:
- logging._releaseLock()
+ t.ready.wait(2.0)
+ logging.config.stopListening()
+ t.join(2.0)
+ def test_listen_config_10_ok(self):
+ with captured_stdout() as output:
+ self.setup_via_listener(json.dumps(self.config10))
+ logger = logging.getLogger("compiler.parser")
+ logger.warning(self.next_message())
+ logger = logging.getLogger('compiler')
+ #Not output, because filtered
+ logger.warning(self.next_message())
+ logger = logging.getLogger('compiler.lexer')
+ #Not output, because filtered
+ logger.warning(self.next_message())
+ logger = logging.getLogger("compiler.parser.codegen")
+ #Output, as not filtered
+ logger.error(self.next_message())
+ self.assert_log_lines([
+ ('WARNING', '1'),
+ ('ERROR', '4'),
+ ], stream=output)
-#----------------------------------------------------------------------------
-# Test Harness
-#----------------------------------------------------------------------------
-def banner(nm, typ):
- sep = BANNER % (nm, typ)
- sys.stdout.write(sep)
- sys.stdout.flush()
+ def test_listen_config_1_ok(self):
+ with captured_stdout() as output:
+ self.setup_via_listener(textwrap.dedent(ConfigFileTest.config1))
+ logger = logging.getLogger("compiler.parser")
+ # Both will output a message
+ logger.info(self.next_message())
+ logger.error(self.next_message())
+ self.assert_log_lines([
+ ('INFO', '1'),
+ ('ERROR', '2'),
+ ], stream=output)
+ # Original logger output is empty.
+ self.assert_log_lines([])
-def test_main_inner():
- rootLogger = logging.getLogger("")
- rootLogger.setLevel(logging.DEBUG)
- hdlr = logging.StreamHandler(sys.stdout)
- fmt = logging.Formatter(logging.BASIC_FORMAT)
- hdlr.setFormatter(fmt)
- rootLogger.addHandler(hdlr)
- # Find an unused port number
- port = logging.handlers.DEFAULT_TCP_LOGGING_PORT
- while port < logging.handlers.DEFAULT_TCP_LOGGING_PORT+100:
- try:
- tcpserver = LogRecordSocketReceiver(port=port)
- except socket.error:
- port += 1
- else:
- break
- else:
- raise ImportError, "Could not find unused port"
+class ManagerTest(BaseTest):
+ def test_manager_loggerclass(self):
+ logged = []
+ class MyLogger(logging.Logger):
+ def _log(self, level, msg, args, exc_info=None, extra=None):
+ logged.append(msg)
- #Set up a handler such that all events are sent via a socket to the log
- #receiver (logrecv).
- #The handler will only be added to the rootLogger for some of the tests
- shdlr = logging.handlers.SocketHandler('localhost', port)
+ man = logging.Manager(None)
+ self.assertRaises(TypeError, man.setLoggerClass, int)
+ man.setLoggerClass(MyLogger)
+ logger = man.getLogger('test')
+ logger.warning('should appear in logged')
+ logging.warning('should not appear in logged')
- #Configure the logger for logrecv so events do not propagate beyond it.
- #The sockLogger output is buffered in memory until the end of the test,
- #and printed at the end.
- sockOut = cStringIO.StringIO()
- sockLogger = logging.getLogger("logrecv")
- sockLogger.setLevel(logging.DEBUG)
- sockhdlr = logging.StreamHandler(sockOut)
- sockhdlr.setFormatter(logging.Formatter(
- "%(name)s -> %(levelname)s: %(message)s"))
- sockLogger.addHandler(sockhdlr)
- sockLogger.propagate = 0
+ self.assertEqual(logged, ['should appear in logged'])
- #Set up servers
- threads = []
- #sys.stdout.write("About to start TCP server...\n")
- threads.append(threading.Thread(target=runTCP, args=(tcpserver,)))
- for thread in threads:
- thread.start()
- try:
- banner("log_test0", "begin")
+class ChildLoggerTest(BaseTest):
+ def test_child_loggers(self):
+ r = logging.getLogger()
+ l1 = logging.getLogger('abc')
+ l2 = logging.getLogger('def.ghi')
+ c1 = r.getChild('xyz')
+ c2 = r.getChild('uvw.xyz')
+ self.assertTrue(c1 is logging.getLogger('xyz'))
+ self.assertTrue(c2 is logging.getLogger('uvw.xyz'))
+ c1 = l1.getChild('def')
+ c2 = c1.getChild('ghi')
+ c3 = l1.getChild('def.ghi')
+ self.assertTrue(c1 is logging.getLogger('abc.def'))
+ self.assertTrue(c2 is logging.getLogger('abc.def.ghi'))
+ self.assertTrue(c2 is c3)
- rootLogger.addHandler(shdlr)
- test0()
- # XXX(nnorwitz): Try to fix timing related test failures.
- # This sleep gives us some extra time to read messages.
- # The test generally only fails on Solaris without this sleep.
- time.sleep(2.0)
- shdlr.close()
- rootLogger.removeHandler(shdlr)
-
- banner("log_test0", "end")
-
- for t in range(1,6):
- banner("log_test%d" % t, "begin")
- globals()['test%d' % t]()
- banner("log_test%d" % t, "end")
-
- finally:
- #wait for TCP receiver to terminate
- socketDataProcessed.wait()
- # ensure the server dies
- tcpserver.abort = 1
- for thread in threads:
- thread.join(2.0)
- banner("logrecv output", "begin")
- sys.stdout.write(sockOut.getvalue())
- sockOut.close()
- sockLogger.removeHandler(sockhdlr)
- sockhdlr.close()
- banner("logrecv output", "end")
- sys.stdout.flush()
- try:
- hdlr.close()
- except:
- pass
- rootLogger.removeHandler(hdlr)
# Set the locale to the platform-dependent default. I have no idea
# why the test does this, but in any case we save the current locale
# first and restore it at the end.
@run_with_locale('LC_ALL', '')
def test_main():
- # Save and restore the original root logger level across the tests.
- # Otherwise, e.g., if any test using cookielib runs after test_logging,
- # cookielib's debug-level logger tries to log messages, leading to
- # confusing:
- # No handlers could be found for logger "cookielib"
- # output while the tests are running.
- root_logger = logging.getLogger("")
- original_logging_level = root_logger.getEffectiveLevel()
- try:
- test_main_inner()
- finally:
- root_logger.setLevel(original_logging_level)
+ run_unittest(BuiltinLevelsTest, BasicFilterTest,
+ CustomLevelsAndFiltersTest, MemoryHandlerTest,
+ ConfigFileTest, SocketHandlerTest, MemoryTest,
+ EncodingTest, WarningsTest, ConfigDictTest, ManagerTest,
+ ChildLoggerTest)
if __name__ == "__main__":
- sys.stdout.write("test_logging\n")
test_main()
diff --git a/Lib/test/test_new.py b/Lib/test/test_new.py
--- a/Lib/test/test_new.py
+++ b/Lib/test/test_new.py
@@ -153,7 +153,7 @@
d = new.code(argcount, nlocals, stacksize, flags, codestring,
constants, t, varnames, filename, name,
firstlineno, lnotab)
- self.assert_(type(t[0]) is S, "eek, tuple changed under us!")
+ self.assertTrue(type(t[0]) is S, "eek, tuple changed under us!")
def test_main():
test_support.run_unittest(NewTest)
diff --git a/Lib/test/test_operator.py b/Lib/test/test_operator.py
--- a/Lib/test/test_operator.py
+++ b/Lib/test/test_operator.py
@@ -57,6 +57,7 @@
class C(object):
def __eq__(self, other):
raise SyntaxError
+ __hash__ = None # Silence Py3k warning
self.failUnlessRaises(TypeError, operator.eq)
self.failUnlessRaises(SyntaxError, operator.eq, C(), C())
self.failIf(operator.eq(1, 0))
@@ -193,7 +194,9 @@
class C:
pass
def check(self, o, v):
- self.assert_(operator.isCallable(o) == callable(o) == v)
+ self.assertEqual(operator.isCallable(o), v)
+ with test_support._check_py3k_warnings():
+ self.assertEqual(callable(o), v)
check(self, 4, 0)
check(self, operator.isCallable, 1)
check(self, C, 1)
@@ -305,12 +308,12 @@
self.assertRaises(ValueError, operator.rshift, 2, -1)
def test_contains(self):
- self.failUnlessRaises(TypeError, operator.contains)
- self.failUnlessRaises(TypeError, operator.contains, None, None)
- self.failUnless(operator.contains(range(4), 2))
- self.failIf(operator.contains(range(4), 5))
- self.failUnless(operator.sequenceIncludes(range(4), 2))
- self.failIf(operator.sequenceIncludes(range(4), 5))
+ self.assertRaises(TypeError, operator.contains)
+ self.assertRaises(TypeError, operator.contains, None, None)
+ self.assertTrue(operator.contains(range(4), 2))
+ self.assertFalse(operator.contains(range(4), 5))
+ self.assertTrue(operator.sequenceIncludes(range(4), 2))
+ self.assertFalse(operator.sequenceIncludes(range(4), 5))
def test_setitem(self):
a = range(3)
@@ -386,9 +389,29 @@
self.assertRaises(TypeError, operator.attrgetter('x', (), 'y'), record)
class C(object):
- def __getattr(self, name):
+ def __getattr__(self, name):
raise SyntaxError
- self.failUnlessRaises(AttributeError, operator.attrgetter('foo'), C())
+ self.failUnlessRaises(SyntaxError, operator.attrgetter('foo'), C())
+
+ # recursive gets
+ a = A()
+ a.name = 'arthur'
+ a.child = A()
+ a.child.name = 'thomas'
+ f = operator.attrgetter('child.name')
+ self.assertEqual(f(a), 'thomas')
+ self.assertRaises(AttributeError, f, a.child)
+ f = operator.attrgetter('name', 'child.name')
+ self.assertEqual(f(a), ('arthur', 'thomas'))
+ f = operator.attrgetter('name', 'child.name', 'child.child.name')
+ self.assertRaises(AttributeError, f, a)
+
+ a.child.child = A()
+ a.child.child.name = 'johnson'
+ f = operator.attrgetter('child.child.name')
+ self.assertEqual(f(a), 'johnson')
+ f = operator.attrgetter('name', 'child.name', 'child.child.name')
+ self.assertEqual(f(a), ('arthur', 'thomas', 'johnson'))
def test_itemgetter(self):
a = 'ABCDE'
@@ -398,9 +421,9 @@
self.assertRaises(IndexError, f, a)
class C(object):
- def __getitem(self, name):
+ def __getitem__(self, name):
raise SyntaxError
- self.failUnlessRaises(TypeError, operator.itemgetter(42), C())
+ self.failUnlessRaises(SyntaxError, operator.itemgetter(42), C())
f = operator.itemgetter('name')
self.assertRaises(TypeError, f, a)
@@ -424,6 +447,24 @@
self.assertEqual(operator.itemgetter(2,10,5)(data), ('2', '10', '5'))
self.assertRaises(TypeError, operator.itemgetter(2, 'x', 5), data)
+ def test_methodcaller(self):
+ self.assertRaises(TypeError, operator.methodcaller)
+ class A:
+ def foo(self, *args, **kwds):
+ return args[0] + args[1]
+ def bar(self, f=42):
+ return f
+ a = A()
+ f = operator.methodcaller('foo')
+ self.assertRaises(IndexError, f, a)
+ f = operator.methodcaller('foo', 1, 2)
+ self.assertEquals(f(a), 3)
+ f = operator.methodcaller('bar')
+ self.assertEquals(f(a), 42)
+ self.assertRaises(TypeError, f, a, a)
+ f = operator.methodcaller('bar', f=5)
+ self.assertEquals(f(a), 5)
+
def test_inplace(self):
class C(object):
def __iadd__ (self, other): return "iadd"
diff --git a/Lib/test/test_os.py b/Lib/test/test_os.py
--- a/Lib/test/test_os.py
+++ b/Lib/test/test_os.py
@@ -3,10 +3,16 @@
# portable than they had been thought to be.
import os
+import errno
import unittest
import warnings
import sys
+import signal
+import subprocess
+import time
from test import test_support
+import mmap
+import uuid
warnings.filterwarnings("ignore", "tempnam", RuntimeWarning, __name__)
warnings.filterwarnings("ignore", "tmpnam", RuntimeWarning, __name__)
@@ -21,8 +27,30 @@
def test_access(self):
f = os.open(test_support.TESTFN, os.O_CREAT|os.O_RDWR)
os.close(f)
- self.assert_(os.access(test_support.TESTFN, os.W_OK))
+ self.assertTrue(os.access(test_support.TESTFN, os.W_OK))
+ def test_closerange(self):
+ first = os.open(test_support.TESTFN, os.O_CREAT|os.O_RDWR)
+ # We must allocate two consecutive file descriptors, otherwise
+ # it will mess up other file descriptors (perhaps even the three
+ # standard ones).
+ second = os.dup(first)
+ try:
+ retries = 0
+ while second != first + 1:
+ os.close(first)
+ retries += 1
+ if retries > 10:
+ # XXX test skipped
+ self.skipTest("couldn't allocate two consecutive fds")
+ first, second = second, os.dup(second)
+ finally:
+ os.close(second)
+ # close a fd that is open, and one that isn't
+ os.closerange(first, first + 2)
+ self.assertRaises(OSError, os.write, first, "a")
+
+ @test_support.cpython_only
def test_rename(self):
path = unicode(test_support.TESTFN)
if not test_support.is_jython:
@@ -45,7 +73,7 @@
def check_tempfile(self, name):
# make sure it doesn't already exist:
- self.failIf(os.path.exists(name),
+ self.assertFalse(os.path.exists(name),
"file already exists for temporary file")
# make sure we can create the file
open(name, "w")
@@ -54,16 +82,18 @@
def test_tempnam(self):
if not hasattr(os, "tempnam"):
return
- warnings.filterwarnings("ignore", "tempnam", RuntimeWarning,
- r"test_os$")
- self.check_tempfile(os.tempnam())
+ with warnings.catch_warnings():
+ warnings.filterwarnings("ignore", "tempnam", RuntimeWarning,
+ r"test_os$")
+ warnings.filterwarnings("ignore", "tempnam", DeprecationWarning)
+ self.check_tempfile(os.tempnam())
- name = os.tempnam(test_support.TESTFN)
- self.check_tempfile(name)
+ name = os.tempnam(test_support.TESTFN)
+ self.check_tempfile(name)
- name = os.tempnam(test_support.TESTFN, "pfx")
- self.assert_(os.path.basename(name)[:3] == "pfx")
- self.check_tempfile(name)
+ name = os.tempnam(test_support.TESTFN, "pfx")
+ self.assertTrue(os.path.basename(name)[:3] == "pfx")
+ self.check_tempfile(name)
def test_tmpfile(self):
if not hasattr(os, "tmpfile"):
@@ -82,64 +112,69 @@
# test that a subsequent call to os.tmpfile() raises the same error. If
# it doesn't, assume we're on XP or below and the user running the test
# has administrative privileges, and proceed with the test as normal.
- if sys.platform == 'win32':
- name = '\\python_test_os_test_tmpfile.txt'
- if os.path.exists(name):
- os.remove(name)
- try:
- fp = open(name, 'w')
- except IOError, first:
- # open() failed, assert tmpfile() fails in the same way.
- # Although open() raises an IOError and os.tmpfile() raises an
- # OSError(), 'args' will be (13, 'Permission denied') in both
- # cases.
+ with warnings.catch_warnings():
+ warnings.filterwarnings("ignore", "tmpfile", DeprecationWarning)
+
+ if sys.platform == 'win32':
+ name = '\\python_test_os_test_tmpfile.txt'
+ if os.path.exists(name):
+ os.remove(name)
try:
- fp = os.tmpfile()
- except OSError, second:
- self.assertEqual(first.args, second.args)
+ fp = open(name, 'w')
+ except IOError, first:
+ # open() failed, assert tmpfile() fails in the same way.
+ # Although open() raises an IOError and os.tmpfile() raises an
+ # OSError(), 'args' will be (13, 'Permission denied') in both
+ # cases.
+ try:
+ fp = os.tmpfile()
+ except OSError, second:
+ self.assertEqual(first.args, second.args)
+ else:
+ self.fail("expected os.tmpfile() to raise OSError")
+ return
else:
- self.fail("expected os.tmpfile() to raise OSError")
- return
- else:
- # open() worked, therefore, tmpfile() should work. Close our
- # dummy file and proceed with the test as normal.
- fp.close()
- os.remove(name)
+ # open() worked, therefore, tmpfile() should work. Close our
+ # dummy file and proceed with the test as normal.
+ fp.close()
+ os.remove(name)
- fp = os.tmpfile()
- fp.write("foobar")
- fp.seek(0,0)
- s = fp.read()
- fp.close()
- self.assert_(s == "foobar")
+ fp = os.tmpfile()
+ fp.write("foobar")
+ fp.seek(0,0)
+ s = fp.read()
+ fp.close()
+ self.assertTrue(s == "foobar")
def test_tmpnam(self):
- import sys
if not hasattr(os, "tmpnam"):
return
- warnings.filterwarnings("ignore", "tmpnam", RuntimeWarning,
- r"test_os$")
- name = os.tmpnam()
- if sys.platform in ("win32",):
- # The Windows tmpnam() seems useless. From the MS docs:
- #
- # The character string that tmpnam creates consists of
- # the path prefix, defined by the entry P_tmpdir in the
- # file STDIO.H, followed by a sequence consisting of the
- # digit characters '0' through '9'; the numerical value
- # of this string is in the range 1 - 65,535. Changing the
- # definitions of L_tmpnam or P_tmpdir in STDIO.H does not
- # change the operation of tmpnam.
- #
- # The really bizarre part is that, at least under MSVC6,
- # P_tmpdir is "\\". That is, the path returned refers to
- # the root of the current drive. That's a terrible place to
- # put temp files, and, depending on privileges, the user
- # may not even be able to open a file in the root directory.
- self.failIf(os.path.exists(name),
- "file already exists for temporary file")
- else:
- self.check_tempfile(name)
+ with warnings.catch_warnings():
+ warnings.filterwarnings("ignore", "tmpnam", RuntimeWarning,
+ r"test_os$")
+ warnings.filterwarnings("ignore", "tmpnam", DeprecationWarning)
+
+ name = os.tmpnam()
+ if sys.platform in ("win32",):
+ # The Windows tmpnam() seems useless. From the MS docs:
+ #
+ # The character string that tmpnam creates consists of
+ # the path prefix, defined by the entry P_tmpdir in the
+ # file STDIO.H, followed by a sequence consisting of the
+ # digit characters '0' through '9'; the numerical value
+ # of this string is in the range 1 - 65,535. Changing the
+ # definitions of L_tmpnam or P_tmpdir in STDIO.H does not
+ # change the operation of tmpnam.
+ #
+ # The really bizarre part is that, at least under MSVC6,
+ # P_tmpdir is "\\". That is, the path returned refers to
+ # the root of the current drive. That's a terrible place to
+ # put temp files, and, depending on privileges, the user
+ # may not even be able to open a file in the root directory.
+ self.assertFalse(os.path.exists(name),
+ "file already exists for temporary file")
+ else:
+ self.check_tempfile(name)
# Test attributes on return values from os.*stat* family.
class StatAttributeTests(unittest.TestCase):
@@ -162,10 +197,8 @@
result = os.stat(self.fname)
# Make sure direct access works
- self.assertEquals(result[stat.ST_SIZE], 3)
- self.assertEquals(result.st_size, 3)
-
- import sys
+ self.assertEqual(result[stat.ST_SIZE], 3)
+ self.assertEqual(result.st_size, 3)
# Make sure all the attributes are there
members = dir(result)
@@ -176,9 +209,9 @@
def trunc(x): return int(x)
else:
def trunc(x): return x
- self.assertEquals(trunc(getattr(result, attr)),
- result[getattr(stat, name)])
- self.assert_(attr in members)
+ self.assertEqual(trunc(getattr(result, attr)),
+ result[getattr(stat, name)])
+ self.assertIn(attr, members)
try:
result[200]
@@ -190,7 +223,7 @@
try:
result.st_mode = 1
self.fail("No exception thrown")
- except TypeError:
+ except (AttributeError, TypeError):
pass
try:
@@ -223,26 +256,21 @@
if not hasattr(os, "statvfs"):
return
- import statvfs
try:
result = os.statvfs(self.fname)
except OSError, e:
# On AtheOS, glibc always returns ENOSYS
- import errno
if e.errno == errno.ENOSYS:
return
# Make sure direct access works
- self.assertEquals(result.f_bfree, result[statvfs.F_BFREE])
+ self.assertEqual(result.f_bfree, result[3])
- # Make sure all the attributes are there
- members = dir(result)
- for name in dir(statvfs):
- if name[:2] == 'F_':
- attr = name.lower()
- self.assertEquals(getattr(result, attr),
- result[getattr(statvfs, name)])
- self.assert_(attr in members)
+ # Make sure all the attributes are there.
+ members = ('bsize', 'frsize', 'blocks', 'bfree', 'bavail', 'files',
+ 'ffree', 'favail', 'flag', 'namemax')
+ for value, member in enumerate(members):
+ self.assertEqual(getattr(result, 'f_' + member), result[value])
# Make sure that assignment really fails
try:
@@ -270,6 +298,15 @@
except TypeError:
pass
+ def test_utime_dir(self):
+ delta = 1000000
+ st = os.stat(test_support.TESTFN)
+ # round to int, because some systems may support sub-second
+ # time stamps in stat, but not in utime.
+ os.utime(test_support.TESTFN, (st.st_atime, int(st.st_mtime-delta)))
+ st2 = os.stat(test_support.TESTFN)
+ self.assertEqual(st2.st_mtime, int(st.st_mtime-delta))
+
# Restrict test to Win32, since there is no guarantee other
# systems support centiseconds
if sys.platform == 'win32':
@@ -285,14 +322,19 @@
def test_1565150(self):
t1 = 1159195039.25
os.utime(self.fname, (t1, t1))
- self.assertEquals(os.stat(self.fname).st_mtime, t1)
+ self.assertEqual(os.stat(self.fname).st_mtime, t1)
+
+ def test_large_time(self):
+ t1 = 5000000000 # some day in 2128
+ os.utime(self.fname, (t1, t1))
+ self.assertEqual(os.stat(self.fname).st_mtime, t1)
def test_1686475(self):
# Verify that an open file can be stat'ed
try:
os.stat(r"c:\pagefile.sys")
except WindowsError, e:
- if e == 2: # file does not exist; cannot run test
+ if e.errno == 2: # file does not exist; cannot run test
return
self.fail("Could not stat pagefile.sys")
@@ -317,8 +359,9 @@
def test_update2(self):
if os.path.exists("/bin/sh"):
os.environ.update(HELLO="World")
- value = os.popen("/bin/sh -c 'echo $HELLO'").read().strip()
- self.assertEquals(value, "World")
+ with os.popen("/bin/sh -c 'echo $HELLO'") as popen:
+ value = popen.read().strip()
+ self.assertEqual(value, "World")
class WalkTests(unittest.TestCase):
"""Tests for os.walk()."""
@@ -328,75 +371,104 @@
from os.path import join
# Build:
- # TESTFN/ a file kid and two directory kids
+ # TESTFN/
+ # TEST1/ a file kid and two directory kids
# tmp1
# SUB1/ a file kid and a directory kid
- # tmp2
- # SUB11/ no kids
- # SUB2/ just a file kid
- # tmp3
- sub1_path = join(test_support.TESTFN, "SUB1")
+ # tmp2
+ # SUB11/ no kids
+ # SUB2/ a file kid and a dirsymlink kid
+ # tmp3
+ # link/ a symlink to TESTFN.2
+ # TEST2/
+ # tmp4 a lone file
+ walk_path = join(test_support.TESTFN, "TEST1")
+ sub1_path = join(walk_path, "SUB1")
sub11_path = join(sub1_path, "SUB11")
- sub2_path = join(test_support.TESTFN, "SUB2")
- tmp1_path = join(test_support.TESTFN, "tmp1")
+ sub2_path = join(walk_path, "SUB2")
+ tmp1_path = join(walk_path, "tmp1")
tmp2_path = join(sub1_path, "tmp2")
tmp3_path = join(sub2_path, "tmp3")
+ link_path = join(sub2_path, "link")
+ t2_path = join(test_support.TESTFN, "TEST2")
+ tmp4_path = join(test_support.TESTFN, "TEST2", "tmp4")
# Create stuff.
os.makedirs(sub11_path)
os.makedirs(sub2_path)
- for path in tmp1_path, tmp2_path, tmp3_path:
+ os.makedirs(t2_path)
+ for path in tmp1_path, tmp2_path, tmp3_path, tmp4_path:
f = file(path, "w")
f.write("I'm " + path + " and proud of it. Blame test_os.\n")
f.close()
+ if hasattr(os, "symlink"):
+ os.symlink(os.path.abspath(t2_path), link_path)
+ sub2_tree = (sub2_path, ["link"], ["tmp3"])
+ else:
+ sub2_tree = (sub2_path, [], ["tmp3"])
# Walk top-down.
- all = list(os.walk(test_support.TESTFN))
+ all = list(os.walk(walk_path))
self.assertEqual(len(all), 4)
# We can't know which order SUB1 and SUB2 will appear in.
# Not flipped: TESTFN, SUB1, SUB11, SUB2
# flipped: TESTFN, SUB2, SUB1, SUB11
flipped = all[0][1][0] != "SUB1"
all[0][1].sort()
- self.assertEqual(all[0], (test_support.TESTFN, ["SUB1", "SUB2"], ["tmp1"]))
+ self.assertEqual(all[0], (walk_path, ["SUB1", "SUB2"], ["tmp1"]))
self.assertEqual(all[1 + flipped], (sub1_path, ["SUB11"], ["tmp2"]))
self.assertEqual(all[2 + flipped], (sub11_path, [], []))
- self.assertEqual(all[3 - 2 * flipped], (sub2_path, [], ["tmp3"]))
+ self.assertEqual(all[3 - 2 * flipped], sub2_tree)
# Prune the search.
all = []
- for root, dirs, files in os.walk(test_support.TESTFN):
+ for root, dirs, files in os.walk(walk_path):
all.append((root, dirs, files))
# Don't descend into SUB1.
if 'SUB1' in dirs:
# Note that this also mutates the dirs we appended to all!
dirs.remove('SUB1')
self.assertEqual(len(all), 2)
- self.assertEqual(all[0], (test_support.TESTFN, ["SUB2"], ["tmp1"]))
- self.assertEqual(all[1], (sub2_path, [], ["tmp3"]))
+ self.assertEqual(all[0], (walk_path, ["SUB2"], ["tmp1"]))
+ self.assertEqual(all[1], sub2_tree)
# Walk bottom-up.
- all = list(os.walk(test_support.TESTFN, topdown=False))
+ all = list(os.walk(walk_path, topdown=False))
self.assertEqual(len(all), 4)
# We can't know which order SUB1 and SUB2 will appear in.
# Not flipped: SUB11, SUB1, SUB2, TESTFN
# flipped: SUB2, SUB11, SUB1, TESTFN
flipped = all[3][1][0] != "SUB1"
all[3][1].sort()
- self.assertEqual(all[3], (test_support.TESTFN, ["SUB1", "SUB2"], ["tmp1"]))
+ self.assertEqual(all[3], (walk_path, ["SUB1", "SUB2"], ["tmp1"]))
self.assertEqual(all[flipped], (sub11_path, [], []))
self.assertEqual(all[flipped + 1], (sub1_path, ["SUB11"], ["tmp2"]))
- self.assertEqual(all[2 - 2 * flipped], (sub2_path, [], ["tmp3"]))
+ self.assertEqual(all[2 - 2 * flipped], sub2_tree)
+ if hasattr(os, "symlink"):
+ # Walk, following symlinks.
+ for root, dirs, files in os.walk(walk_path, followlinks=True):
+ if root == link_path:
+ self.assertEqual(dirs, [])
+ self.assertEqual(files, ["tmp4"])
+ break
+ else:
+ self.fail("Didn't follow symlink with followlinks=True")
+
+ def tearDown(self):
# Tear everything down. This is a decent use for bottom-up on
# Windows, which doesn't have a recursive delete command. The
# (not so) subtlety is that rmdir will fail unless the dir's
# kids are removed first, so bottom up is essential.
for root, dirs, files in os.walk(test_support.TESTFN, topdown=False):
for name in files:
- os.remove(join(root, name))
+ os.remove(os.path.join(root, name))
for name in dirs:
- os.rmdir(join(root, name))
+ dirname = os.path.join(root, name)
+ if not os.path.islink(dirname):
+ os.rmdir(dirname)
+ else:
+ os.remove(dirname)
os.rmdir(test_support.TESTFN)
class MakedirTests (unittest.TestCase):
@@ -411,7 +483,7 @@
os.makedirs(path)
# Try paths with a '.' in them
- self.failUnlessRaises(OSError, os.makedirs, os.curdir)
+ self.assertRaises(OSError, os.makedirs, os.curdir)
path = os.path.join(base, 'dir1', 'dir2', 'dir3', 'dir4', 'dir5', os.curdir)
os.makedirs(path)
path = os.path.join(base, 'dir1', os.curdir, 'dir2', 'dir3', 'dir4',
@@ -448,9 +520,16 @@
self.assertEqual(len(os.urandom(10)), 10)
self.assertEqual(len(os.urandom(100)), 100)
self.assertEqual(len(os.urandom(1000)), 1000)
+ # see http://bugs.python.org/issue3708
+ self.assertRaises(TypeError, os.urandom, 0.9)
+ self.assertRaises(TypeError, os.urandom, 1.1)
+ self.assertRaises(TypeError, os.urandom, 2.0)
except NotImplementedError:
pass
+ def test_execvpe_with_bad_arglist(self):
+ self.assertRaises(ValueError, os.execvpe, 'notepad', [], None)
+
class Win32ErrorTests(unittest.TestCase):
def test_rename(self):
self.assertRaises(WindowsError, os.rename, test_support.TESTFN, test_support.TESTFN+".bak")
@@ -462,21 +541,271 @@
self.assertRaises(WindowsError, os.chdir, test_support.TESTFN)
def test_mkdir(self):
- self.assertRaises(WindowsError, os.chdir, test_support.TESTFN)
+ f = open(test_support.TESTFN, "w")
+ try:
+ self.assertRaises(WindowsError, os.mkdir, test_support.TESTFN)
+ finally:
+ f.close()
+ os.unlink(test_support.TESTFN)
def test_utime(self):
self.assertRaises(WindowsError, os.utime, test_support.TESTFN, None)
- def test_access(self):
- self.assertRaises(WindowsError, os.utime, test_support.TESTFN, 0)
+ def test_chmod(self):
+ self.assertRaises(WindowsError, os.chmod, test_support.TESTFN, 0)
- def test_chmod(self):
- self.assertRaises(WindowsError, os.utime, test_support.TESTFN, 0)
+class TestInvalidFD(unittest.TestCase):
+ singles = ["fchdir", "fdopen", "dup", "fdatasync", "fstat",
+ "fstatvfs", "fsync", "tcgetpgrp", "ttyname"]
+ #singles.append("close")
+ #We omit close because it doesn'r raise an exception on some platforms
+ def get_single(f):
+ def helper(self):
+ if hasattr(os, f):
+ self.check(getattr(os, f))
+ return helper
+ for f in singles:
+ locals()["test_"+f] = get_single(f)
+
+ def check(self, f, *args):
+ try:
+ f(test_support.make_bad_fd(), *args)
+ except OSError as e:
+ self.assertEqual(e.errno, errno.EBADF)
+ else:
+ self.fail("%r didn't raise a OSError with a bad file descriptor"
+ % f)
+
+ def test_isatty(self):
+ if hasattr(os, "isatty"):
+ self.assertEqual(os.isatty(test_support.make_bad_fd()), False)
+
+ def test_closerange(self):
+ if hasattr(os, "closerange"):
+ fd = test_support.make_bad_fd()
+ # Make sure none of the descriptors we are about to close are
+ # currently valid (issue 6542).
+ for i in range(10):
+ try: os.fstat(fd+i)
+ except OSError:
+ pass
+ else:
+ break
+ if i < 2:
+ raise unittest.SkipTest(
+ "Unable to acquire a range of invalid file descriptors")
+ self.assertEqual(os.closerange(fd, fd + i-1), None)
+
+ def test_dup2(self):
+ if hasattr(os, "dup2"):
+ self.check(os.dup2, 20)
+
+ def test_fchmod(self):
+ if hasattr(os, "fchmod"):
+ self.check(os.fchmod, 0)
+
+ def test_fchown(self):
+ if hasattr(os, "fchown"):
+ self.check(os.fchown, -1, -1)
+
+ def test_fpathconf(self):
+ if hasattr(os, "fpathconf"):
+ self.check(os.fpathconf, "PC_NAME_MAX")
+
+ def test_ftruncate(self):
+ if hasattr(os, "ftruncate"):
+ self.check(os.ftruncate, 0)
+
+ def test_lseek(self):
+ if hasattr(os, "lseek"):
+ self.check(os.lseek, 0, 0)
+
+ def test_read(self):
+ if hasattr(os, "read"):
+ self.check(os.read, 1)
+
+ def test_tcsetpgrpt(self):
+ if hasattr(os, "tcsetpgrp"):
+ self.check(os.tcsetpgrp, 0)
+
+ def test_write(self):
+ if hasattr(os, "write"):
+ self.check(os.write, " ")
if sys.platform != 'win32':
class Win32ErrorTests(unittest.TestCase):
pass
+ class PosixUidGidTests(unittest.TestCase):
+ if hasattr(os, 'setuid'):
+ def test_setuid(self):
+ if os.getuid() != 0:
+ self.assertRaises(os.error, os.setuid, 0)
+ self.assertRaises(OverflowError, os.setuid, 1<<32)
+
+ if hasattr(os, 'setgid'):
+ def test_setgid(self):
+ if os.getuid() != 0:
+ self.assertRaises(os.error, os.setgid, 0)
+ self.assertRaises(OverflowError, os.setgid, 1<<32)
+
+ if hasattr(os, 'seteuid'):
+ def test_seteuid(self):
+ if os.getuid() != 0:
+ self.assertRaises(os.error, os.seteuid, 0)
+ self.assertRaises(OverflowError, os.seteuid, 1<<32)
+
+ if hasattr(os, 'setegid'):
+ def test_setegid(self):
+ if os.getuid() != 0:
+ self.assertRaises(os.error, os.setegid, 0)
+ self.assertRaises(OverflowError, os.setegid, 1<<32)
+
+ if hasattr(os, 'setreuid'):
+ def test_setreuid(self):
+ if os.getuid() != 0:
+ self.assertRaises(os.error, os.setreuid, 0, 0)
+ self.assertRaises(OverflowError, os.setreuid, 1<<32, 0)
+ self.assertRaises(OverflowError, os.setreuid, 0, 1<<32)
+
+ def test_setreuid_neg1(self):
+ # Needs to accept -1. We run this in a subprocess to avoid
+ # altering the test runner's process state (issue8045).
+ subprocess.check_call([
+ sys.executable, '-c',
+ 'import os,sys;os.setreuid(-1,-1);sys.exit(0)'])
+
+ if hasattr(os, 'setregid'):
+ def test_setregid(self):
+ if os.getuid() != 0:
+ self.assertRaises(os.error, os.setregid, 0, 0)
+ self.assertRaises(OverflowError, os.setregid, 1<<32, 0)
+ self.assertRaises(OverflowError, os.setregid, 0, 1<<32)
+
+ def test_setregid_neg1(self):
+ # Needs to accept -1. We run this in a subprocess to avoid
+ # altering the test runner's process state (issue8045).
+ subprocess.check_call([
+ sys.executable, '-c',
+ 'import os,sys;os.setregid(-1,-1);sys.exit(0)'])
+else:
+ class PosixUidGidTests(unittest.TestCase):
+ pass
+
+ at unittest.skipUnless(sys.platform == "win32", "Win32 specific tests")
+class Win32KillTests(unittest.TestCase):
+ def _kill(self, sig):
+ # Start sys.executable as a subprocess and communicate from the
+ # subprocess to the parent that the interpreter is ready. When it
+ # becomes ready, send *sig* via os.kill to the subprocess and check
+ # that the return code is equal to *sig*.
+ import ctypes
+ from ctypes import wintypes
+ import msvcrt
+
+ # Since we can't access the contents of the process' stdout until the
+ # process has exited, use PeekNamedPipe to see what's inside stdout
+ # without waiting. This is done so we can tell that the interpreter
+ # is started and running at a point where it could handle a signal.
+ PeekNamedPipe = ctypes.windll.kernel32.PeekNamedPipe
+ PeekNamedPipe.restype = wintypes.BOOL
+ PeekNamedPipe.argtypes = (wintypes.HANDLE, # Pipe handle
+ ctypes.POINTER(ctypes.c_char), # stdout buf
+ wintypes.DWORD, # Buffer size
+ ctypes.POINTER(wintypes.DWORD), # bytes read
+ ctypes.POINTER(wintypes.DWORD), # bytes avail
+ ctypes.POINTER(wintypes.DWORD)) # bytes left
+ msg = "running"
+ proc = subprocess.Popen([sys.executable, "-c",
+ "import sys;"
+ "sys.stdout.write('{}');"
+ "sys.stdout.flush();"
+ "input()".format(msg)],
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ stdin=subprocess.PIPE)
+ self.addCleanup(proc.stdout.close)
+ self.addCleanup(proc.stderr.close)
+ self.addCleanup(proc.stdin.close)
+
+ count, max = 0, 100
+ while count < max and proc.poll() is None:
+ # Create a string buffer to store the result of stdout from the pipe
+ buf = ctypes.create_string_buffer(len(msg))
+ # Obtain the text currently in proc.stdout
+ # Bytes read/avail/left are left as NULL and unused
+ rslt = PeekNamedPipe(msvcrt.get_osfhandle(proc.stdout.fileno()),
+ buf, ctypes.sizeof(buf), None, None, None)
+ self.assertNotEqual(rslt, 0, "PeekNamedPipe failed")
+ if buf.value:
+ self.assertEqual(msg, buf.value)
+ break
+ time.sleep(0.1)
+ count += 1
+ else:
+ self.fail("Did not receive communication from the subprocess")
+
+ os.kill(proc.pid, sig)
+ self.assertEqual(proc.wait(), sig)
+
+ def test_kill_sigterm(self):
+ # SIGTERM doesn't mean anything special, but make sure it works
+ self._kill(signal.SIGTERM)
+
+ def test_kill_int(self):
+ # os.kill on Windows can take an int which gets set as the exit code
+ self._kill(100)
+
+ def _kill_with_event(self, event, name):
+ tagname = "test_os_%s" % uuid.uuid1()
+ m = mmap.mmap(-1, 1, tagname)
+ m[0] = '0'
+ # Run a script which has console control handling enabled.
+ proc = subprocess.Popen([sys.executable,
+ os.path.join(os.path.dirname(__file__),
+ "win_console_handler.py"), tagname],
+ creationflags=subprocess.CREATE_NEW_PROCESS_GROUP)
+ # Let the interpreter startup before we send signals. See #3137.
+ count, max = 0, 20
+ while count < max and proc.poll() is None:
+ if m[0] == '1':
+ break
+ time.sleep(0.5)
+ count += 1
+ else:
+ self.fail("Subprocess didn't finish initialization")
+ os.kill(proc.pid, event)
+ # proc.send_signal(event) could also be done here.
+ # Allow time for the signal to be passed and the process to exit.
+ time.sleep(0.5)
+ if not proc.poll():
+ # Forcefully kill the process if we weren't able to signal it.
+ os.kill(proc.pid, signal.SIGINT)
+ self.fail("subprocess did not stop on {}".format(name))
+
+ @unittest.skip("subprocesses aren't inheriting CTRL+C property")
+ def test_CTRL_C_EVENT(self):
+ from ctypes import wintypes
+ import ctypes
+
+ # Make a NULL value by creating a pointer with no argument.
+ NULL = ctypes.POINTER(ctypes.c_int)()
+ SetConsoleCtrlHandler = ctypes.windll.kernel32.SetConsoleCtrlHandler
+ SetConsoleCtrlHandler.argtypes = (ctypes.POINTER(ctypes.c_int),
+ wintypes.BOOL)
+ SetConsoleCtrlHandler.restype = wintypes.BOOL
+
+ # Calling this with NULL and FALSE causes the calling process to
+ # handle CTRL+C, rather than ignore it. This property is inherited
+ # by subprocesses.
+ SetConsoleCtrlHandler(NULL, 0)
+
+ self._kill_with_event(signal.CTRL_C_EVENT, "CTRL_C_EVENT")
+
+ def test_CTRL_BREAK_EVENT(self):
+ self._kill_with_event(signal.CTRL_BREAK_EVENT, "CTRL_BREAK_EVENT")
+
+
def test_main():
test_support.run_unittest(
FileTests,
@@ -487,7 +816,10 @@
MakedirTests,
DevNullTests,
URandomTests,
- Win32ErrorTests
+ Win32ErrorTests,
+ TestInvalidFD,
+ PosixUidGidTests,
+ Win32KillTests
)
if __name__ == "__main__":
diff --git a/Lib/test/test_pkgimport.py b/Lib/test/test_pkgimport.py
--- a/Lib/test/test_pkgimport.py
+++ b/Lib/test/test_pkgimport.py
@@ -6,14 +6,14 @@
def __init__(self, *args, **kw):
self.package_name = 'PACKAGE_'
- while sys.modules.has_key(self.package_name):
+ while self.package_name in sys.modules:
self.package_name += random.choose(string.letters)
self.module_name = self.package_name + '.foo'
unittest.TestCase.__init__(self, *args, **kw)
def remove_modules(self):
for module_name in (self.package_name, self.module_name):
- if sys.modules.has_key(module_name):
+ if module_name in sys.modules:
del sys.modules[module_name]
def setUp(self):
@@ -59,8 +59,8 @@
try: __import__(self.module_name)
except SyntaxError: pass
else: raise RuntimeError, 'Failed to induce SyntaxError'
- self.assert_(not sys.modules.has_key(self.module_name) and
- not hasattr(sys.modules[self.package_name], 'foo'))
+ self.assertTrue(self.module_name not in sys.modules)
+ self.assertFalse(hasattr(sys.modules[self.package_name], 'foo'))
# ...make up a variable name that isn't bound in __builtins__
import __builtin__
diff --git a/Lib/test/test_pprint.py b/Lib/test/test_pprint.py
--- a/Lib/test/test_pprint.py
+++ b/Lib/test/test_pprint.py
@@ -1,6 +1,7 @@
import pprint
import test.test_support
import unittest
+import test.test_set
try:
uni = unicode
@@ -39,20 +40,19 @@
def test_basic(self):
# Verify .isrecursive() and .isreadable() w/o recursion
- verify = self.assert_
pp = pprint.PrettyPrinter()
for safe in (2, 2.0, 2j, "abc", [3], (2,2), {3: 3}, uni("yaddayadda"),
self.a, self.b):
# module-level convenience functions
- verify(not pprint.isrecursive(safe),
- "expected not isrecursive for %r" % (safe,))
- verify(pprint.isreadable(safe),
- "expected isreadable for %r" % (safe,))
+ self.assertFalse(pprint.isrecursive(safe),
+ "expected not isrecursive for %r" % (safe,))
+ self.assertTrue(pprint.isreadable(safe),
+ "expected isreadable for %r" % (safe,))
# PrettyPrinter methods
- verify(not pp.isrecursive(safe),
- "expected not isrecursive for %r" % (safe,))
- verify(pp.isreadable(safe),
- "expected isreadable for %r" % (safe,))
+ self.assertFalse(pp.isrecursive(safe),
+ "expected not isrecursive for %r" % (safe,))
+ self.assertTrue(pp.isreadable(safe),
+ "expected isreadable for %r" % (safe,))
def test_knotted(self):
# Verify .isrecursive() and .isreadable() w/ recursion
@@ -62,14 +62,13 @@
self.d = {}
self.d[0] = self.d[1] = self.d[2] = self.d
- verify = self.assert_
pp = pprint.PrettyPrinter()
for icky in self.a, self.b, self.d, (self.d, self.d):
- verify(pprint.isrecursive(icky), "expected isrecursive")
- verify(not pprint.isreadable(icky), "expected not isreadable")
- verify(pp.isrecursive(icky), "expected isrecursive")
- verify(not pp.isreadable(icky), "expected not isreadable")
+ self.assertTrue(pprint.isrecursive(icky), "expected isrecursive")
+ self.assertFalse(pprint.isreadable(icky), "expected not isreadable")
+ self.assertTrue(pp.isrecursive(icky), "expected isrecursive")
+ self.assertFalse(pp.isreadable(icky), "expected not isreadable")
# Break the cycles.
self.d.clear()
@@ -78,31 +77,30 @@
for safe in self.a, self.b, self.d, (self.d, self.d):
# module-level convenience functions
- verify(not pprint.isrecursive(safe),
- "expected not isrecursive for %r" % (safe,))
- verify(pprint.isreadable(safe),
- "expected isreadable for %r" % (safe,))
+ self.assertFalse(pprint.isrecursive(safe),
+ "expected not isrecursive for %r" % (safe,))
+ self.assertTrue(pprint.isreadable(safe),
+ "expected isreadable for %r" % (safe,))
# PrettyPrinter methods
- verify(not pp.isrecursive(safe),
- "expected not isrecursive for %r" % (safe,))
- verify(pp.isreadable(safe),
- "expected isreadable for %r" % (safe,))
+ self.assertFalse(pp.isrecursive(safe),
+ "expected not isrecursive for %r" % (safe,))
+ self.assertTrue(pp.isreadable(safe),
+ "expected isreadable for %r" % (safe,))
def test_unreadable(self):
# Not recursive but not readable anyway
- verify = self.assert_
pp = pprint.PrettyPrinter()
for unreadable in type(3), pprint, pprint.isrecursive:
# module-level convenience functions
- verify(not pprint.isrecursive(unreadable),
- "expected not isrecursive for %r" % (unreadable,))
- verify(not pprint.isreadable(unreadable),
- "expected not isreadable for %r" % (unreadable,))
+ self.assertFalse(pprint.isrecursive(unreadable),
+ "expected not isrecursive for %r" % (unreadable,))
+ self.assertFalse(pprint.isreadable(unreadable),
+ "expected not isreadable for %r" % (unreadable,))
# PrettyPrinter methods
- verify(not pp.isrecursive(unreadable),
- "expected not isrecursive for %r" % (unreadable,))
- verify(not pp.isreadable(unreadable),
- "expected not isreadable for %r" % (unreadable,))
+ self.assertFalse(pp.isrecursive(unreadable),
+ "expected not isrecursive for %r" % (unreadable,))
+ self.assertFalse(pp.isreadable(unreadable),
+ "expected not isreadable for %r" % (unreadable,))
def test_same_as_repr(self):
# Simple objects, small containers and classes that overwrite __repr__
@@ -113,12 +111,11 @@
# it sorted a dict display if and only if the display required
# multiple lines. For that reason, dicts with more than one element
# aren't tested here.
- verify = self.assert_
for simple in (0, 0L, 0+0j, 0.0, "", uni(""),
(), tuple2(), tuple3(),
[], list2(), list3(),
{}, dict2(), dict3(),
- verify, pprint,
+ self.assertTrue, pprint,
-6, -6L, -6-6j, -1.5, "x", uni("x"), (3,), [3], {3: 6},
(1,2), [3,4], {5: 6},
tuple2((1,2)), tuple3((1,2)), tuple3(range(100)),
@@ -130,8 +127,9 @@
for function in "pformat", "saferepr":
f = getattr(pprint, function)
got = f(simple)
- verify(native == got, "expected %s got %s from pprint.%s" %
- (native, got, function))
+ self.assertEqual(native, got,
+ "expected %s got %s from pprint.%s" %
+ (native, got, function))
def test_basic_line_wrap(self):
# verify basic line-wrapping operation
@@ -169,6 +167,17 @@
for type in [list, list2]:
self.assertEqual(pprint.pformat(type(o), indent=4), exp)
+ def test_nested_indentations(self):
+ o1 = list(range(10))
+ o2 = dict(first=1, second=2, third=3)
+ o = [o1, o2]
+ expected = """\
+[ [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
+ { 'first': 1,
+ 'second': 2,
+ 'third': 3}]"""
+ self.assertEqual(pprint.pformat(o, indent=4, width=42), expected)
+
def test_sorted_dict(self):
# Starting in Python 2.5, pprint sorts dict displays by key regardless
# of how small the dictionary may be.
@@ -195,6 +204,212 @@
others.should.not.be: like.this}"""
self.assertEqual(DottedPrettyPrinter().pformat(o), exp)
+ def test_set_reprs(self):
+ self.assertEqual(pprint.pformat(set()), 'set()')
+ self.assertEqual(pprint.pformat(set(range(3))), 'set([0, 1, 2])')
+ self.assertEqual(pprint.pformat(frozenset()), 'frozenset()')
+ self.assertEqual(pprint.pformat(frozenset(range(3))), 'frozenset([0, 1, 2])')
+ cube_repr_tgt = """\
+{frozenset([]): frozenset([frozenset([2]), frozenset([0]), frozenset([1])]),
+ frozenset([0]): frozenset([frozenset(),
+ frozenset([0, 2]),
+ frozenset([0, 1])]),
+ frozenset([1]): frozenset([frozenset(),
+ frozenset([1, 2]),
+ frozenset([0, 1])]),
+ frozenset([2]): frozenset([frozenset(),
+ frozenset([1, 2]),
+ frozenset([0, 2])]),
+ frozenset([1, 2]): frozenset([frozenset([2]),
+ frozenset([1]),
+ frozenset([0, 1, 2])]),
+ frozenset([0, 2]): frozenset([frozenset([2]),
+ frozenset([0]),
+ frozenset([0, 1, 2])]),
+ frozenset([0, 1]): frozenset([frozenset([0]),
+ frozenset([1]),
+ frozenset([0, 1, 2])]),
+ frozenset([0, 1, 2]): frozenset([frozenset([1, 2]),
+ frozenset([0, 2]),
+ frozenset([0, 1])])}"""
+ cube = test.test_set.cube(3)
+ self.assertEqual(pprint.pformat(cube), cube_repr_tgt)
+ cubo_repr_tgt = """\
+{frozenset([frozenset([0, 2]), frozenset([0])]): frozenset([frozenset([frozenset([0,
+ 2]),
+ frozenset([0,
+ 1,
+ 2])]),
+ frozenset([frozenset([0]),
+ frozenset([0,
+ 1])]),
+ frozenset([frozenset(),
+ frozenset([0])]),
+ frozenset([frozenset([2]),
+ frozenset([0,
+ 2])])]),
+ frozenset([frozenset([0, 1]), frozenset([1])]): frozenset([frozenset([frozenset([0,
+ 1]),
+ frozenset([0,
+ 1,
+ 2])]),
+ frozenset([frozenset([0]),
+ frozenset([0,
+ 1])]),
+ frozenset([frozenset([1]),
+ frozenset([1,
+ 2])]),
+ frozenset([frozenset(),
+ frozenset([1])])]),
+ frozenset([frozenset([1, 2]), frozenset([1])]): frozenset([frozenset([frozenset([1,
+ 2]),
+ frozenset([0,
+ 1,
+ 2])]),
+ frozenset([frozenset([2]),
+ frozenset([1,
+ 2])]),
+ frozenset([frozenset(),
+ frozenset([1])]),
+ frozenset([frozenset([1]),
+ frozenset([0,
+ 1])])]),
+ frozenset([frozenset([1, 2]), frozenset([2])]): frozenset([frozenset([frozenset([1,
+ 2]),
+ frozenset([0,
+ 1,
+ 2])]),
+ frozenset([frozenset([1]),
+ frozenset([1,
+ 2])]),
+ frozenset([frozenset([2]),
+ frozenset([0,
+ 2])]),
+ frozenset([frozenset(),
+ frozenset([2])])]),
+ frozenset([frozenset([]), frozenset([0])]): frozenset([frozenset([frozenset([0]),
+ frozenset([0,
+ 1])]),
+ frozenset([frozenset([0]),
+ frozenset([0,
+ 2])]),
+ frozenset([frozenset(),
+ frozenset([1])]),
+ frozenset([frozenset(),
+ frozenset([2])])]),
+ frozenset([frozenset([]), frozenset([1])]): frozenset([frozenset([frozenset(),
+ frozenset([0])]),
+ frozenset([frozenset([1]),
+ frozenset([1,
+ 2])]),
+ frozenset([frozenset(),
+ frozenset([2])]),
+ frozenset([frozenset([1]),
+ frozenset([0,
+ 1])])]),
+ frozenset([frozenset([2]), frozenset([])]): frozenset([frozenset([frozenset([2]),
+ frozenset([1,
+ 2])]),
+ frozenset([frozenset(),
+ frozenset([0])]),
+ frozenset([frozenset(),
+ frozenset([1])]),
+ frozenset([frozenset([2]),
+ frozenset([0,
+ 2])])]),
+ frozenset([frozenset([0, 1, 2]), frozenset([0, 1])]): frozenset([frozenset([frozenset([1,
+ 2]),
+ frozenset([0,
+ 1,
+ 2])]),
+ frozenset([frozenset([0,
+ 2]),
+ frozenset([0,
+ 1,
+ 2])]),
+ frozenset([frozenset([0]),
+ frozenset([0,
+ 1])]),
+ frozenset([frozenset([1]),
+ frozenset([0,
+ 1])])]),
+ frozenset([frozenset([0]), frozenset([0, 1])]): frozenset([frozenset([frozenset(),
+ frozenset([0])]),
+ frozenset([frozenset([0,
+ 1]),
+ frozenset([0,
+ 1,
+ 2])]),
+ frozenset([frozenset([0]),
+ frozenset([0,
+ 2])]),
+ frozenset([frozenset([1]),
+ frozenset([0,
+ 1])])]),
+ frozenset([frozenset([2]), frozenset([0, 2])]): frozenset([frozenset([frozenset([0,
+ 2]),
+ frozenset([0,
+ 1,
+ 2])]),
+ frozenset([frozenset([2]),
+ frozenset([1,
+ 2])]),
+ frozenset([frozenset([0]),
+ frozenset([0,
+ 2])]),
+ frozenset([frozenset(),
+ frozenset([2])])]),
+ frozenset([frozenset([0, 1, 2]), frozenset([0, 2])]): frozenset([frozenset([frozenset([1,
+ 2]),
+ frozenset([0,
+ 1,
+ 2])]),
+ frozenset([frozenset([0,
+ 1]),
+ frozenset([0,
+ 1,
+ 2])]),
+ frozenset([frozenset([0]),
+ frozenset([0,
+ 2])]),
+ frozenset([frozenset([2]),
+ frozenset([0,
+ 2])])]),
+ frozenset([frozenset([1, 2]), frozenset([0, 1, 2])]): frozenset([frozenset([frozenset([0,
+ 2]),
+ frozenset([0,
+ 1,
+ 2])]),
+ frozenset([frozenset([0,
+ 1]),
+ frozenset([0,
+ 1,
+ 2])]),
+ frozenset([frozenset([2]),
+ frozenset([1,
+ 2])]),
+ frozenset([frozenset([1]),
+ frozenset([1,
+ 2])])])}"""
+
+ cubo = test.test_set.linegraph(cube)
+ self.assertEqual(pprint.pformat(cubo), cubo_repr_tgt)
+
+ def test_depth(self):
+ nested_tuple = (1, (2, (3, (4, (5, 6)))))
+ nested_dict = {1: {2: {3: {4: {5: {6: 6}}}}}}
+ nested_list = [1, [2, [3, [4, [5, [6, []]]]]]]
+ self.assertEqual(pprint.pformat(nested_tuple), repr(nested_tuple))
+ self.assertEqual(pprint.pformat(nested_dict), repr(nested_dict))
+ self.assertEqual(pprint.pformat(nested_list), repr(nested_list))
+
+ lv1_tuple = '(1, (...))'
+ lv1_dict = '{1: {...}}'
+ lv1_list = '[1, [...]]'
+ self.assertEqual(pprint.pformat(nested_tuple, depth=1), lv1_tuple)
+ self.assertEqual(pprint.pformat(nested_dict, depth=1), lv1_dict)
+ self.assertEqual(pprint.pformat(nested_list, depth=1), lv1_list)
+
class DottedPrettyPrinter(pprint.PrettyPrinter):
diff --git a/Lib/test/test_profilehooks.py b/Lib/test/test_profilehooks.py
--- a/Lib/test/test_profilehooks.py
+++ b/Lib/test/test_profilehooks.py
@@ -10,6 +10,22 @@
from test import test_support
+class TestGetProfile(unittest.TestCase):
+ def setUp(self):
+ sys.setprofile(None)
+
+ def tearDown(self):
+ sys.setprofile(None)
+
+ def test_empty(self):
+ assert sys.getprofile() == None
+
+ def test_setget(self):
+ def fn(*args):
+ pass
+
+ sys.setprofile(fn)
+ assert sys.getprofile() == fn
class HookWatcher:
def __init__(self):
@@ -100,7 +116,7 @@
def test_exception(self):
def f(p):
- 1/0
+ 1./0
f_ident = ident(f)
self.check_events(f, [(1, 'call', f_ident),
(1, 'return', f_ident),
@@ -108,7 +124,7 @@
def test_caught_exception(self):
def f(p):
- try: 1/0
+ try: 1./0
except: pass
f_ident = ident(f)
self.check_events(f, [(1, 'call', f_ident),
@@ -117,7 +133,7 @@
def test_caught_nested_exception(self):
def f(p):
- try: 1/0
+ try: 1./0
except: pass
f_ident = ident(f)
self.check_events(f, [(1, 'call', f_ident),
@@ -126,7 +142,7 @@
def test_nested_exception(self):
def f(p):
- 1/0
+ 1./0
f_ident = ident(f)
self.check_events(f, [(1, 'call', f_ident),
# This isn't what I expected:
@@ -137,7 +153,7 @@
def test_exception_in_except_clause(self):
def f(p):
- 1/0
+ 1./0
def g(p):
try:
f(p)
@@ -156,7 +172,7 @@
def test_exception_propogation(self):
def f(p):
- 1/0
+ 1./0
def g(p):
try: f(p)
finally: p.add_event("falling through")
@@ -171,8 +187,8 @@
def test_raise_twice(self):
def f(p):
- try: 1/0
- except: 1/0
+ try: 1./0
+ except: 1./0
f_ident = ident(f)
self.check_events(f, [(1, 'call', f_ident),
(1, 'return', f_ident),
@@ -180,7 +196,7 @@
def test_raise_reraise(self):
def f(p):
- try: 1/0
+ try: 1./0
except: raise
f_ident = ident(f)
self.check_events(f, [(1, 'call', f_ident),
@@ -197,7 +213,7 @@
def test_distant_exception(self):
def f():
- 1/0
+ 1./0
def g():
f()
def h():
@@ -282,7 +298,7 @@
def test_basic_exception(self):
def f(p):
- 1/0
+ 1./0
f_ident = ident(f)
self.check_events(f, [(1, 'call', f_ident),
(1, 'return', f_ident),
@@ -290,7 +306,7 @@
def test_caught_exception(self):
def f(p):
- try: 1/0
+ try: 1./0
except: pass
f_ident = ident(f)
self.check_events(f, [(1, 'call', f_ident),
@@ -299,7 +315,7 @@
def test_distant_exception(self):
def f():
- 1/0
+ 1./0
def g():
f()
def h():
@@ -379,6 +395,7 @@
del ProfileSimulatorTestCase.test_distant_exception
test_support.run_unittest(
+ TestGetProfile,
ProfileHookTestCase,
ProfileSimulatorTestCase
)
diff --git a/Lib/test/test_random.py b/Lib/test/test_random.py
--- a/Lib/test/test_random.py
+++ b/Lib/test/test_random.py
@@ -5,7 +5,8 @@
import time
import pickle
import warnings
-from math import log, exp, sqrt, pi
+from math import log, exp, pi, fsum, sin
+from functools import reduce
from test import test_support
class TestBasicOps(unittest.TestCase):
@@ -52,10 +53,9 @@
state3 = self.gen.getstate() # s/b distinct from state2
self.assertNotEqual(state2, state3)
- self.assertRaises(TypeError, self.gen.jumpahead) # needs an arg
- self.assertRaises(TypeError, self.gen.jumpahead, "ick") # wrong type
- self.assertRaises(TypeError, self.gen.jumpahead, 2.3) # wrong type
- self.assertRaises(TypeError, self.gen.jumpahead, 2, 3) # too many
+ with test_support.check_py3k_warnings(quiet=True):
+ self.assertRaises(TypeError, self.gen.jumpahead) # needs an arg
+ self.assertRaises(TypeError, self.gen.jumpahead, 2, 3) # too many
def test_sample(self):
# For the entire allowable range of 0 <= k <= N, validate that
@@ -67,7 +67,7 @@
self.assertEqual(len(s), k)
uniq = set(s)
self.assertEqual(len(uniq), k)
- self.failUnless(uniq <= set(population))
+ self.assertTrue(uniq <= set(population))
self.assertEqual(self.gen.sample([], 0), []) # test edge case N==k==0
def test_sample_distribution(self):
@@ -112,7 +112,7 @@
samp = self.gen.sample(d, k)
# Verify that we got ints back (keys); the values are complex.
for x in samp:
- self.assert_(type(x) is int)
+ self.assertTrue(type(x) is int)
samp.sort()
self.assertEqual(samp, range(N))
@@ -140,6 +140,19 @@
restoredseq = [newgen.random() for i in xrange(10)]
self.assertEqual(origseq, restoredseq)
+ def test_bug_1727780(self):
+ # verify that version-2-pickles can be loaded
+ # fine, whether they are created on 32-bit or 64-bit
+ # platforms, and that version-3-pickles load fine.
+ files = [("randv2_32.pck", 780),
+ ("randv2_64.pck", 866),
+ ("randv3.pck", 343)]
+ for file, value in files:
+ f = open(test_support.findfile(file),"rb")
+ r = pickle.load(f)
+ f.close()
+ self.assertEqual(r.randrange(1000), value)
+
class WichmannHill_TestBasicOps(TestBasicOps):
gen = random.WichmannHill()
@@ -178,10 +191,9 @@
def test_bigrand(self):
# Verify warnings are raised when randrange is too large for random()
- oldfilters = warnings.filters[:]
- warnings.filterwarnings("error", "Underlying random")
- self.assertRaises(UserWarning, self.gen.randrange, 2**60)
- warnings.filters[:] = oldfilters
+ with warnings.catch_warnings():
+ warnings.filterwarnings("error", "Underlying random")
+ self.assertRaises(UserWarning, self.gen.randrange, 2**60)
class SystemRandom_TestBasicOps(TestBasicOps):
gen = random.SystemRandom()
@@ -225,7 +237,7 @@
cum = 0
for i in xrange(100):
r = self.gen.randrange(span)
- self.assert_(0 <= r < span)
+ self.assertTrue(0 <= r < span)
cum |= r
self.assertEqual(cum, span-1)
@@ -235,7 +247,7 @@
stop = self.gen.randrange(2 ** (i-2))
if stop <= start:
return
- self.assert_(start <= self.gen.randrange(start, stop) < stop)
+ self.assertTrue(start <= self.gen.randrange(start, stop) < stop)
def test_rangelimits(self):
for start, stop in [(-2,0), (-(2**60)-2,-(2**60)), (2**60,2**60+2)]:
@@ -245,7 +257,7 @@
def test_genrandbits(self):
# Verify ranges
for k in xrange(1, 1000):
- self.assert_(0 <= self.gen.getrandbits(k) < 2**k)
+ self.assertTrue(0 <= self.gen.getrandbits(k) < 2**k)
# Verify all bits active
getbits = self.gen.getrandbits
@@ -271,17 +283,17 @@
numbits = i+1
k = int(1.00001 + _log(n, 2))
self.assertEqual(k, numbits)
- self.assert_(n == 2**(k-1))
+ self.assertTrue(n == 2**(k-1))
n += n - 1 # check 1 below the next power of two
k = int(1.00001 + _log(n, 2))
- self.assert_(k in [numbits, numbits+1])
- self.assert_(2**k > n > 2**(k-2))
+ self.assertIn(k, [numbits, numbits+1])
+ self.assertTrue(2**k > n > 2**(k-2))
n -= n >> 15 # check a little farther below the next power of two
k = int(1.00001 + _log(n, 2))
self.assertEqual(k, numbits) # note the stronger assertion
- self.assert_(2**k > n > 2**(k-1)) # note the stronger assertion
+ self.assertTrue(2**k > n > 2**(k-1)) # note the stronger assertion
class MersenneTwister_TestBasicOps(TestBasicOps):
@@ -377,7 +389,7 @@
cum = 0
for i in xrange(100):
r = self.gen.randrange(span)
- self.assert_(0 <= r < span)
+ self.assertTrue(0 <= r < span)
cum |= r
self.assertEqual(cum, span-1)
@@ -387,7 +399,7 @@
stop = self.gen.randrange(2 ** (i-2))
if stop <= start:
return
- self.assert_(start <= self.gen.randrange(start, stop) < stop)
+ self.assertTrue(start <= self.gen.randrange(start, stop) < stop)
def test_rangelimits(self):
for start, stop in [(-2,0), (-(2**60)-2,-(2**60)), (2**60,2**60+2)]:
@@ -401,7 +413,7 @@
97904845777343510404718956115L)
# Verify ranges
for k in xrange(1, 1000):
- self.assert_(0 <= self.gen.getrandbits(k) < 2**k)
+ self.assertTrue(0 <= self.gen.getrandbits(k) < 2**k)
# Verify all bits active
getbits = self.gen.getrandbits
@@ -427,37 +439,43 @@
numbits = i+1
k = int(1.00001 + _log(n, 2))
self.assertEqual(k, numbits)
- self.assert_(n == 2**(k-1))
+ self.assertTrue(n == 2**(k-1))
n += n - 1 # check 1 below the next power of two
k = int(1.00001 + _log(n, 2))
- self.assert_(k in [numbits, numbits+1])
- self.assert_(2**k > n > 2**(k-2))
+ self.assertIn(k, [numbits, numbits+1])
+ self.assertTrue(2**k > n > 2**(k-2))
n -= n >> 15 # check a little farther below the next power of two
k = int(1.00001 + _log(n, 2))
self.assertEqual(k, numbits) # note the stronger assertion
- self.assert_(2**k > n > 2**(k-1)) # note the stronger assertion
+ self.assertTrue(2**k > n > 2**(k-1)) # note the stronger assertion
def test_randrange_bug_1590891(self):
start = 1000000000000
stop = -100000000000000000000
step = -200
x = self.gen.randrange(start, stop, step)
- self.assert_(stop < x <= start)
+ self.assertTrue(stop < x <= start)
self.assertEqual((x+stop)%step, 0)
-_gammacoeff = (0.9999999999995183, 676.5203681218835, -1259.139216722289,
- 771.3234287757674, -176.6150291498386, 12.50734324009056,
- -0.1385710331296526, 0.9934937113930748e-05, 0.1659470187408462e-06)
-
-def gamma(z, cof=_gammacoeff, g=7):
- z -= 1.0
- sum = cof[0]
- for i in xrange(1,len(cof)):
- sum += cof[i] / (z+i)
- z += 0.5
- return (z+g)**z / exp(z+g) * sqrt(2*pi) * sum
+def gamma(z, sqrt2pi=(2.0*pi)**0.5):
+ # Reflection to right half of complex plane
+ if z < 0.5:
+ return pi / sin(pi*z) / gamma(1.0-z)
+ # Lanczos approximation with g=7
+ az = z + (7.0 - 0.5)
+ return az ** (z-0.5) / exp(az) * sqrt2pi * fsum([
+ 0.9999999999995183,
+ 676.5203681218835 / z,
+ -1259.139216722289 / (z+1.0),
+ 771.3234287757674 / (z+2.0),
+ -176.6150291498386 / (z+3.0),
+ 12.50734324009056 / (z+4.0),
+ -0.1385710331296526 / (z+5.0),
+ 0.9934937113930748e-05 / (z+6.0),
+ 0.1659470187408462e-06 / (z+7.0),
+ ])
class TestDistributions(unittest.TestCase):
def test_zeroinputs(self):
@@ -476,6 +494,7 @@
g.random = x[:].pop; g.gammavariate(1.0, 1.0)
g.random = x[:].pop; g.gammavariate(200.0, 1.0)
g.random = x[:].pop; g.betavariate(3.0, 3.0)
+ g.random = x[:].pop; g.triangular(0.0, 1.0, 1.0/3.0)
def test_avg_std(self):
# Use integration to test distribution average and standard deviation.
@@ -485,6 +504,7 @@
x = [i/float(N) for i in xrange(1,N)]
for variate, args, mu, sigmasqrd in [
(g.uniform, (1.0,10.0), (10.0+1.0)/2, (10.0-1.0)**2/12),
+ (g.triangular, (0.0, 1.0, 1.0/3.0), 4.0/9.0, 7.0/9.0/18.0),
(g.expovariate, (1.5,), 1/1.5, 1/1.5**2),
(g.paretovariate, (5.0,), 5.0/(5.0-1),
5.0/((5.0-1)**2*(5.0-2))),
@@ -514,7 +534,7 @@
def test__all__(self):
# tests validity but not completeness of the __all__ list
- self.failUnless(set(random.__all__) <= set(dir(random)))
+ self.assertTrue(set(random.__all__) <= set(dir(random)))
def test_random_subclass_with_kwargs(self):
# SF bug #1486663 -- this used to erroneously raise a TypeError
diff --git a/Lib/test/test_repr.py b/Lib/test/test_repr.py
--- a/Lib/test/test_repr.py
+++ b/Lib/test/test_repr.py
@@ -8,7 +8,7 @@
import shutil
import unittest
-from test.test_support import run_unittest
+from test.test_support import run_unittest, check_py3k_warnings
from repr import repr as r # Don't shadow builtin repr
from repr import Repr
@@ -22,7 +22,7 @@
class ReprTests(unittest.TestCase):
def test_string(self):
- eq = self.assertEquals
+ eq = self.assertEqual
eq(r("abc"), "'abc'")
eq(r("abcdefghijklmnop"),"'abcdefghijklmnop'")
@@ -36,7 +36,7 @@
eq(r(s), expected)
def test_tuple(self):
- eq = self.assertEquals
+ eq = self.assertEqual
eq(r((1,)), "(1,)")
t3 = (1, 2, 3)
@@ -51,7 +51,7 @@
from array import array
from collections import deque
- eq = self.assertEquals
+ eq = self.assertEqual
# Tuples give up after 6 elements
eq(r(()), "()")
eq(r((1,)), "(1,)")
@@ -101,7 +101,7 @@
"array('i', [1, 2, 3, 4, 5, ...])")
def test_numbers(self):
- eq = self.assertEquals
+ eq = self.assertEqual
eq(r(123), repr(123))
eq(r(123L), repr(123L))
eq(r(1.0/3), repr(1.0/3))
@@ -111,7 +111,7 @@
eq(r(n), expected)
def test_instance(self):
- eq = self.assertEquals
+ eq = self.assertEqual
i1 = ClassWithRepr("a")
eq(r(i1), repr(i1))
@@ -123,40 +123,39 @@
eq(r(i3), ("<ClassWithFailingRepr instance at %x>"%id(i3)))
s = r(ClassWithFailingRepr)
- self.failUnless(s.startswith("<class "))
- self.failUnless(s.endswith(">"))
- self.failUnless(s.find("...") == 8)
+ self.assertTrue(s.startswith("<class "))
+ self.assertTrue(s.endswith(">"))
+ self.assertTrue(s.find("...") == 8)
def test_file(self):
fp = open(unittest.__file__)
- self.failUnless(repr(fp).startswith(
+ self.assertTrue(repr(fp).startswith(
"<open file '%s', mode 'r' at 0x" % unittest.__file__))
fp.close()
- self.failUnless(repr(fp).startswith(
+ self.assertTrue(repr(fp).startswith(
"<closed file '%s', mode 'r' at 0x" % unittest.__file__))
def test_lambda(self):
- self.failUnless(repr(lambda x: x).startswith(
+ self.assertTrue(repr(lambda x: x).startswith(
"<function <lambda"))
# XXX anonymous functions? see func_repr
def test_builtin_function(self):
- eq = self.assertEquals
+ eq = self.assertEqual
# Functions
eq(repr(hash), '<built-in function hash>')
# Methods
- self.failUnless(repr(''.split).startswith(
+ self.assertTrue(repr(''.split).startswith(
'<built-in method split of str object at 0x'))
def test_xrange(self):
- import warnings
- eq = self.assertEquals
+ eq = self.assertEqual
eq(repr(xrange(1)), 'xrange(1)')
eq(repr(xrange(1, 2)), 'xrange(1, 2)')
eq(repr(xrange(1, 2, 3)), 'xrange(1, 4, 3)')
def test_nesting(self):
- eq = self.assertEquals
+ eq = self.assertEqual
# everything is meant to give up after 6 levels.
eq(r([[[[[[[]]]]]]]), "[[[[[[[]]]]]]]")
eq(r([[[[[[[[]]]]]]]]), "[[[[[[[...]]]]]]]")
@@ -175,15 +174,16 @@
def test_buffer(self):
# XXX doesn't test buffers with no b_base or read-write buffers (see
# bufferobject.c). The test is fairly incomplete too. Sigh.
- x = buffer('foo')
- self.failUnless(repr(x).startswith('<read-only buffer for 0x'))
+ with check_py3k_warnings():
+ x = buffer('foo')
+ self.assertTrue(repr(x).startswith('<read-only buffer for 0x'))
def test_cell(self):
# XXX Hmm? How to get at a cell object?
pass
def test_descriptors(self):
- eq = self.assertEquals
+ eq = self.assertEqual
# method descriptors
eq(repr(dict.items), "<method 'items' of 'dict' objects>")
# XXX member descriptors
@@ -193,9 +193,9 @@
class C:
def foo(cls): pass
x = staticmethod(C.foo)
- self.failUnless(repr(x).startswith('<staticmethod object at 0x'))
+ self.assertTrue(repr(x).startswith('<staticmethod object at 0x'))
x = classmethod(C.foo)
- self.failUnless(repr(x).startswith('<classmethod object at 0x'))
+ self.assertTrue(repr(x).startswith('<classmethod object at 0x'))
def test_unsortable(self):
# Repr.repr() used to call sorted() on sets, frozensets and dicts
@@ -212,10 +212,6 @@
fp.write(text)
fp.close()
-def zap(actions, dirname, names):
- for name in names:
- actions.append(os.path.join(dirname, name))
-
class LongReprTest(unittest.TestCase):
def setUp(self):
longname = 'areallylongpackageandmodulenametotestreprtruncation'
@@ -234,7 +230,9 @@
def tearDown(self):
actions = []
- os.path.walk(self.pkgname, zap, actions)
+ for dirpath, dirnames, filenames in os.walk(self.pkgname):
+ for name in dirnames + filenames:
+ actions.append(os.path.join(dirpath, name))
actions.append(self.pkgname)
actions.sort()
actions.reverse()
@@ -246,7 +244,7 @@
del sys.path[0]
def test_module(self):
- eq = self.assertEquals
+ eq = self.assertEqual
touch(os.path.join(self.subpkgname, self.pkgname + os.extsep + 'py'))
from areallylongpackageandmodulenametotestreprtruncation.areallylongpackageandmodulenametotestreprtruncation import areallylongpackageandmodulenametotestreprtruncation
eq(repr(areallylongpackageandmodulenametotestreprtruncation),
@@ -255,7 +253,7 @@
#eq(repr(sys), "<module 'sys' (built-in)>")
def test_type(self):
- eq = self.assertEquals
+ eq = self.assertEqual
touch(os.path.join(self.subpkgname, 'foo'+os.extsep+'py'), '''\
class foo(object):
pass
@@ -276,7 +274,7 @@
''')
from areallylongpackageandmodulenametotestreprtruncation.areallylongpackageandmodulenametotestreprtruncation import bar
# Module name may be prefixed with "test.", depending on how run.
- self.failUnless(repr(bar.bar).startswith(
+ self.assertTrue(repr(bar.bar).startswith(
"<class %s.bar at 0x" % bar.__name__))
def test_instance(self):
@@ -286,11 +284,11 @@
''')
from areallylongpackageandmodulenametotestreprtruncation.areallylongpackageandmodulenametotestreprtruncation import baz
ibaz = baz.baz()
- self.failUnless(repr(ibaz).startswith(
+ self.assertTrue(repr(ibaz).startswith(
"<%s.baz instance at 0x" % baz.__name__))
def test_method(self):
- eq = self.assertEquals
+ eq = self.assertEqual
touch(os.path.join(self.subpkgname, 'qux'+os.extsep+'py'), '''\
class aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa:
def amethod(self): pass
@@ -301,7 +299,7 @@
'<unbound method aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa.amethod>')
# Bound method next
iqux = qux.aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa()
- self.failUnless(repr(iqux.amethod).startswith(
+ self.assertTrue(repr(iqux.amethod).startswith(
'<bound method aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa.amethod of <%s.aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa instance at 0x' \
% (qux.__name__,) ))
@@ -327,8 +325,7 @@
# XXX: Jython lacks the buffer type
del ReprTests.test_buffer
run_unittest(ReprTests)
- if os.name != 'mac':
- run_unittest(LongReprTest)
+ run_unittest(LongReprTest)
if __name__ == "__main__":
diff --git a/Lib/test/test_robotparser.py b/Lib/test/test_robotparser.py
--- a/Lib/test/test_robotparser.py
+++ b/Lib/test/test_robotparser.py
@@ -20,9 +20,9 @@
url = self.url
agent = self.agent
if self.good:
- self.failUnless(self.parser.can_fetch(agent, url))
+ self.assertTrue(self.parser.can_fetch(agent, url))
else:
- self.failIf(self.parser.can_fetch(agent, url))
+ self.assertFalse(self.parser.can_fetch(agent, url))
def __str__(self):
return self.str
@@ -202,7 +202,18 @@
RobotTest(13, doc, good, bad, agent="googlebot")
-# 14. For issue #4108 (obey first * entry)
+# 14. For issue #6325 (query string support)
+doc = """
+User-agent: *
+Disallow: /some/path?name=value
+"""
+
+good = ['/some/path']
+bad = ['/some/path?name=value']
+
+RobotTest(14, doc, good, bad)
+
+# 15. For issue #4108 (obey first * entry)
doc = """
User-agent: *
Disallow: /some/path
@@ -214,22 +225,36 @@
good = ['/another/path']
bad = ['/some/path']
-RobotTest(14, doc, good, bad)
+RobotTest(15, doc, good, bad)
-class TestCase(unittest.TestCase):
- def runTest(self):
+class NetworkTestCase(unittest.TestCase):
+
+ def testPasswordProtectedSite(self):
test_support.requires('network')
- # whole site is password-protected.
- url = 'http://mueblesmoraleda.com'
- parser = robotparser.RobotFileParser()
- parser.set_url(url)
- parser.read()
- self.assertEqual(parser.can_fetch("*", url+"/robots.txt"), False)
+ with test_support.transient_internet('mueblesmoraleda.com'):
+ url = 'http://mueblesmoraleda.com'
+ parser = robotparser.RobotFileParser()
+ parser.set_url(url)
+ try:
+ parser.read()
+ except IOError:
+ self.skipTest('%s is unavailable' % url)
+ self.assertEqual(parser.can_fetch("*", url+"/robots.txt"), False)
+
+ def testPythonOrg(self):
+ test_support.requires('network')
+ with test_support.transient_internet('www.python.org'):
+ parser = robotparser.RobotFileParser(
+ "http://www.python.org/robots.txt")
+ parser.read()
+ self.assertTrue(
+ parser.can_fetch("*", "http://www.python.org/robots.txt"))
+
def test_main():
test_support.run_unittest(tests)
- TestCase().run()
+ test_support.run_unittest(NetworkTestCase)
if __name__=='__main__':
test_support.verbose = 1
diff --git a/Lib/test/test_shutil.py b/Lib/test/test_shutil.py
--- a/Lib/test/test_shutil.py
+++ b/Lib/test/test_shutil.py
@@ -7,10 +7,71 @@
import stat
import os
import os.path
+from os.path import splitdrive
+from distutils.spawn import find_executable, spawn
+from shutil import (_make_tarball, _make_zipfile, make_archive,
+ register_archive_format, unregister_archive_format,
+ get_archive_formats)
+import tarfile
+import warnings
+
from test import test_support
-from test.test_support import TESTFN
+from test.test_support import TESTFN, check_warnings, captured_stdout
+
+TESTFN2 = TESTFN + "2"
+
+try:
+ import grp
+ import pwd
+ UID_GID_SUPPORT = True
+except ImportError:
+ UID_GID_SUPPORT = False
+
+try:
+ import zlib
+except ImportError:
+ zlib = None
+
+try:
+ import zipfile
+ ZIP_SUPPORT = True
+except ImportError:
+ ZIP_SUPPORT = find_executable('zip')
class TestShutil(unittest.TestCase):
+
+ def setUp(self):
+ super(TestShutil, self).setUp()
+ self.tempdirs = []
+
+ def tearDown(self):
+ super(TestShutil, self).tearDown()
+ while self.tempdirs:
+ d = self.tempdirs.pop()
+ shutil.rmtree(d, os.name in ('nt', 'cygwin'))
+
+ def write_file(self, path, content='xxx'):
+ """Writes a file in the given path.
+
+
+ path can be a string or a sequence.
+ """
+ if isinstance(path, (list, tuple)):
+ path = os.path.join(*path)
+ f = open(path, 'w')
+ try:
+ f.write(content)
+ finally:
+ f.close()
+
+ def mkdtemp(self):
+ """Create a temporary directory that will be cleaned up.
+
+ Returns the path of the directory.
+ """
+ d = tempfile.mkdtemp()
+ self.tempdirs.append(d)
+ return d
def test_rmtree_errors(self):
# filename is guaranteed not to exist
filename = tempfile.mktemp()
@@ -48,15 +109,29 @@
shutil.rmtree(TESTFN)
def check_args_to_onerror(self, func, arg, exc):
+ # test_rmtree_errors deliberately runs rmtree
+ # on a directory that is chmod 400, which will fail.
+ # This function is run when shutil.rmtree fails.
+ # 99.9% of the time it initially fails to remove
+ # a file in the directory, so the first time through
+ # func is os.remove.
+ # However, some Linux machines running ZFS on
+ # FUSE experienced a failure earlier in the process
+ # at os.listdir. The first failure may legally
+ # be either.
if self.errorState == 0:
- self.assertEqual(func, os.remove)
- self.assertEqual(arg, self.childpath)
- self.failUnless(issubclass(exc[0], OSError))
+ if func is os.remove:
+ self.assertEqual(arg, self.childpath)
+ else:
+ self.assertIs(func, os.listdir,
+ "func must be either os.remove or os.listdir")
+ self.assertEqual(arg, TESTFN)
+ self.assertTrue(issubclass(exc[0], OSError))
self.errorState = 1
else:
self.assertEqual(func, os.rmdir)
self.assertEqual(arg, TESTFN)
- self.failUnless(issubclass(exc[0], OSError))
+ self.assertTrue(issubclass(exc[0], OSError))
self.errorState = 2
def test_rmtree_dont_delete_file(self):
@@ -66,17 +141,6 @@
self.assertRaises(OSError, shutil.rmtree, path)
os.remove(path)
- def test_dont_move_dir_in_itself(self):
- src_dir = tempfile.mkdtemp()
- try:
- dst = os.path.join(src_dir, 'foo')
- self.assertRaises(shutil.Error, shutil.move, src_dir, dst)
- finally:
- try:
- os.rmdir(src_dir)
- except:
- pass
-
def test_copytree_simple(self):
def write_data(path, data):
f = open(path, "w")
@@ -116,13 +180,92 @@
):
if os.path.exists(path):
os.remove(path)
- for path in (
- os.path.join(src_dir, 'test_dir'),
- os.path.join(dst_dir, 'test_dir'),
+ for path in (src_dir,
+ os.path.dirname(dst_dir)
):
if os.path.exists(path):
- os.removedirs(path)
+ shutil.rmtree(path)
+ def test_copytree_with_exclude(self):
+
+ def write_data(path, data):
+ f = open(path, "w")
+ f.write(data)
+ f.close()
+
+ def read_data(path):
+ f = open(path)
+ data = f.read()
+ f.close()
+ return data
+
+ # creating data
+ join = os.path.join
+ exists = os.path.exists
+ src_dir = tempfile.mkdtemp()
+ try:
+ dst_dir = join(tempfile.mkdtemp(), 'destination')
+ write_data(join(src_dir, 'test.txt'), '123')
+ write_data(join(src_dir, 'test.tmp'), '123')
+ os.mkdir(join(src_dir, 'test_dir'))
+ write_data(join(src_dir, 'test_dir', 'test.txt'), '456')
+ os.mkdir(join(src_dir, 'test_dir2'))
+ write_data(join(src_dir, 'test_dir2', 'test.txt'), '456')
+ os.mkdir(join(src_dir, 'test_dir2', 'subdir'))
+ os.mkdir(join(src_dir, 'test_dir2', 'subdir2'))
+ write_data(join(src_dir, 'test_dir2', 'subdir', 'test.txt'), '456')
+ write_data(join(src_dir, 'test_dir2', 'subdir2', 'test.py'), '456')
+
+
+ # testing glob-like patterns
+ try:
+ patterns = shutil.ignore_patterns('*.tmp', 'test_dir2')
+ shutil.copytree(src_dir, dst_dir, ignore=patterns)
+ # checking the result: some elements should not be copied
+ self.assertTrue(exists(join(dst_dir, 'test.txt')))
+ self.assertTrue(not exists(join(dst_dir, 'test.tmp')))
+ self.assertTrue(not exists(join(dst_dir, 'test_dir2')))
+ finally:
+ if os.path.exists(dst_dir):
+ shutil.rmtree(dst_dir)
+ try:
+ patterns = shutil.ignore_patterns('*.tmp', 'subdir*')
+ shutil.copytree(src_dir, dst_dir, ignore=patterns)
+ # checking the result: some elements should not be copied
+ self.assertTrue(not exists(join(dst_dir, 'test.tmp')))
+ self.assertTrue(not exists(join(dst_dir, 'test_dir2', 'subdir2')))
+ self.assertTrue(not exists(join(dst_dir, 'test_dir2', 'subdir')))
+ finally:
+ if os.path.exists(dst_dir):
+ shutil.rmtree(dst_dir)
+
+ # testing callable-style
+ try:
+ def _filter(src, names):
+ res = []
+ for name in names:
+ path = os.path.join(src, name)
+
+ if (os.path.isdir(path) and
+ path.split()[-1] == 'subdir'):
+ res.append(name)
+ elif os.path.splitext(path)[-1] in ('.py'):
+ res.append(name)
+ return res
+
+ shutil.copytree(src_dir, dst_dir, ignore=_filter)
+
+ # checking the result: some elements should not be copied
+ self.assertTrue(not exists(join(dst_dir, 'test_dir2', 'subdir2',
+ 'test.py')))
+ self.assertTrue(not exists(join(dst_dir, 'test_dir2', 'subdir')))
+
+ finally:
+ if os.path.exists(dst_dir):
+ shutil.rmtree(dst_dir)
+ finally:
+ shutil.rmtree(src_dir)
+ shutil.rmtree(os.path.dirname(dst_dir))
if hasattr(os, "symlink"):
def test_dont_copy_file_onto_link_to_itself(self):
@@ -137,7 +280,8 @@
os.link(src, dst)
self.assertRaises(shutil.Error, shutil.copyfile, src, dst)
- self.assertEqual(open(src,'r').read(), 'cheddar')
+ with open(src, 'r') as f:
+ self.assertEqual(f.read(), 'cheddar')
os.remove(dst)
# Using `src` here would mean we end up with a symlink pointing
@@ -145,7 +289,8 @@
# TESTFN/cheese.
os.symlink('cheese', dst)
self.assertRaises(shutil.Error, shutil.copyfile, src, dst)
- self.assertEqual(open(src,'r').read(), 'cheddar')
+ with open(src, 'r') as f:
+ self.assertEqual(f.read(), 'cheddar')
os.remove(dst)
finally:
try:
@@ -153,8 +298,519 @@
except OSError:
pass
+ def test_rmtree_on_symlink(self):
+ # bug 1669.
+ os.mkdir(TESTFN)
+ try:
+ src = os.path.join(TESTFN, 'cheese')
+ dst = os.path.join(TESTFN, 'shop')
+ os.mkdir(src)
+ os.symlink(src, dst)
+ self.assertRaises(OSError, shutil.rmtree, dst)
+ finally:
+ shutil.rmtree(TESTFN, ignore_errors=True)
+
+ if hasattr(os, "mkfifo"):
+ # Issue #3002: copyfile and copytree block indefinitely on named pipes
+ def test_copyfile_named_pipe(self):
+ os.mkfifo(TESTFN)
+ try:
+ self.assertRaises(shutil.SpecialFileError,
+ shutil.copyfile, TESTFN, TESTFN2)
+ self.assertRaises(shutil.SpecialFileError,
+ shutil.copyfile, __file__, TESTFN)
+ finally:
+ os.remove(TESTFN)
+
+ def test_copytree_named_pipe(self):
+ os.mkdir(TESTFN)
+ try:
+ subdir = os.path.join(TESTFN, "subdir")
+ os.mkdir(subdir)
+ pipe = os.path.join(subdir, "mypipe")
+ os.mkfifo(pipe)
+ try:
+ shutil.copytree(TESTFN, TESTFN2)
+ except shutil.Error as e:
+ errors = e.args[0]
+ self.assertEqual(len(errors), 1)
+ src, dst, error_msg = errors[0]
+ self.assertEqual("`%s` is a named pipe" % pipe, error_msg)
+ else:
+ self.fail("shutil.Error should have been raised")
+ finally:
+ shutil.rmtree(TESTFN, ignore_errors=True)
+ shutil.rmtree(TESTFN2, ignore_errors=True)
+
+ @unittest.skipUnless(zlib, "requires zlib")
+ def test_make_tarball(self):
+ # creating something to tar
+ tmpdir = self.mkdtemp()
+ self.write_file([tmpdir, 'file1'], 'xxx')
+ self.write_file([tmpdir, 'file2'], 'xxx')
+ os.mkdir(os.path.join(tmpdir, 'sub'))
+ self.write_file([tmpdir, 'sub', 'file3'], 'xxx')
+
+ tmpdir2 = self.mkdtemp()
+ unittest.skipUnless(splitdrive(tmpdir)[0] == splitdrive(tmpdir2)[0],
+ "source and target should be on same drive")
+
+ base_name = os.path.join(tmpdir2, 'archive')
+
+ # working with relative paths to avoid tar warnings
+ old_dir = os.getcwd()
+ os.chdir(tmpdir)
+ try:
+ _make_tarball(splitdrive(base_name)[1], '.')
+ finally:
+ os.chdir(old_dir)
+
+ # check if the compressed tarball was created
+ tarball = base_name + '.tar.gz'
+ self.assertTrue(os.path.exists(tarball))
+
+ # trying an uncompressed one
+ base_name = os.path.join(tmpdir2, 'archive')
+ old_dir = os.getcwd()
+ os.chdir(tmpdir)
+ try:
+ _make_tarball(splitdrive(base_name)[1], '.', compress=None)
+ finally:
+ os.chdir(old_dir)
+ tarball = base_name + '.tar'
+ self.assertTrue(os.path.exists(tarball))
+
+ def _tarinfo(self, path):
+ tar = tarfile.open(path)
+ try:
+ names = tar.getnames()
+ names.sort()
+ return tuple(names)
+ finally:
+ tar.close()
+
+ def _create_files(self):
+ # creating something to tar
+ tmpdir = self.mkdtemp()
+ dist = os.path.join(tmpdir, 'dist')
+ os.mkdir(dist)
+ self.write_file([dist, 'file1'], 'xxx')
+ self.write_file([dist, 'file2'], 'xxx')
+ os.mkdir(os.path.join(dist, 'sub'))
+ self.write_file([dist, 'sub', 'file3'], 'xxx')
+ os.mkdir(os.path.join(dist, 'sub2'))
+ tmpdir2 = self.mkdtemp()
+ base_name = os.path.join(tmpdir2, 'archive')
+ return tmpdir, tmpdir2, base_name
+
+ @unittest.skipUnless(zlib, "Requires zlib")
+ @unittest.skipUnless(find_executable('tar') and find_executable('gzip'),
+ 'Need the tar command to run')
+ def test_tarfile_vs_tar(self):
+ tmpdir, tmpdir2, base_name = self._create_files()
+ old_dir = os.getcwd()
+ os.chdir(tmpdir)
+ try:
+ _make_tarball(base_name, 'dist')
+ finally:
+ os.chdir(old_dir)
+
+ # check if the compressed tarball was created
+ tarball = base_name + '.tar.gz'
+ self.assertTrue(os.path.exists(tarball))
+
+ # now create another tarball using `tar`
+ tarball2 = os.path.join(tmpdir, 'archive2.tar.gz')
+ tar_cmd = ['tar', '-cf', 'archive2.tar', 'dist']
+ gzip_cmd = ['gzip', '-f9', 'archive2.tar']
+ old_dir = os.getcwd()
+ os.chdir(tmpdir)
+ try:
+ with captured_stdout() as s:
+ spawn(tar_cmd)
+ spawn(gzip_cmd)
+ finally:
+ os.chdir(old_dir)
+
+ self.assertTrue(os.path.exists(tarball2))
+ # let's compare both tarballs
+ self.assertEqual(self._tarinfo(tarball), self._tarinfo(tarball2))
+
+ # trying an uncompressed one
+ base_name = os.path.join(tmpdir2, 'archive')
+ old_dir = os.getcwd()
+ os.chdir(tmpdir)
+ try:
+ _make_tarball(base_name, 'dist', compress=None)
+ finally:
+ os.chdir(old_dir)
+ tarball = base_name + '.tar'
+ self.assertTrue(os.path.exists(tarball))
+
+ # now for a dry_run
+ base_name = os.path.join(tmpdir2, 'archive')
+ old_dir = os.getcwd()
+ os.chdir(tmpdir)
+ try:
+ _make_tarball(base_name, 'dist', compress=None, dry_run=True)
+ finally:
+ os.chdir(old_dir)
+ tarball = base_name + '.tar'
+ self.assertTrue(os.path.exists(tarball))
+
+ @unittest.skipUnless(zlib, "Requires zlib")
+ @unittest.skipUnless(ZIP_SUPPORT, 'Need zip support to run')
+ def test_make_zipfile(self):
+ # creating something to tar
+ tmpdir = self.mkdtemp()
+ self.write_file([tmpdir, 'file1'], 'xxx')
+ self.write_file([tmpdir, 'file2'], 'xxx')
+
+ tmpdir2 = self.mkdtemp()
+ base_name = os.path.join(tmpdir2, 'archive')
+ _make_zipfile(base_name, tmpdir)
+
+ # check if the compressed tarball was created
+ tarball = base_name + '.zip'
+ self.assertTrue(os.path.exists(tarball))
+
+
+ def test_make_archive(self):
+ tmpdir = self.mkdtemp()
+ base_name = os.path.join(tmpdir, 'archive')
+ self.assertRaises(ValueError, make_archive, base_name, 'xxx')
+
+ @unittest.skipUnless(zlib, "Requires zlib")
+ def test_make_archive_owner_group(self):
+ # testing make_archive with owner and group, with various combinations
+ # this works even if there's not gid/uid support
+ if UID_GID_SUPPORT:
+ group = grp.getgrgid(0)[0]
+ owner = pwd.getpwuid(0)[0]
+ else:
+ group = owner = 'root'
+
+ base_dir, root_dir, base_name = self._create_files()
+ base_name = os.path.join(self.mkdtemp() , 'archive')
+ res = make_archive(base_name, 'zip', root_dir, base_dir, owner=owner,
+ group=group)
+ self.assertTrue(os.path.exists(res))
+
+ res = make_archive(base_name, 'zip', root_dir, base_dir)
+ self.assertTrue(os.path.exists(res))
+
+ res = make_archive(base_name, 'tar', root_dir, base_dir,
+ owner=owner, group=group)
+ self.assertTrue(os.path.exists(res))
+
+ res = make_archive(base_name, 'tar', root_dir, base_dir,
+ owner='kjhkjhkjg', group='oihohoh')
+ self.assertTrue(os.path.exists(res))
+
+ @unittest.skipUnless(zlib, "Requires zlib")
+ @unittest.skipUnless(UID_GID_SUPPORT, "Requires grp and pwd support")
+ def test_tarfile_root_owner(self):
+ tmpdir, tmpdir2, base_name = self._create_files()
+ old_dir = os.getcwd()
+ os.chdir(tmpdir)
+ group = grp.getgrgid(0)[0]
+ owner = pwd.getpwuid(0)[0]
+ try:
+ archive_name = _make_tarball(base_name, 'dist', compress=None,
+ owner=owner, group=group)
+ finally:
+ os.chdir(old_dir)
+
+ # check if the compressed tarball was created
+ self.assertTrue(os.path.exists(archive_name))
+
+ # now checks the rights
+ archive = tarfile.open(archive_name)
+ try:
+ for member in archive.getmembers():
+ self.assertEqual(member.uid, 0)
+ self.assertEqual(member.gid, 0)
+ finally:
+ archive.close()
+
+ def test_make_archive_cwd(self):
+ current_dir = os.getcwd()
+ def _breaks(*args, **kw):
+ raise RuntimeError()
+
+ register_archive_format('xxx', _breaks, [], 'xxx file')
+ try:
+ try:
+ make_archive('xxx', 'xxx', root_dir=self.mkdtemp())
+ except Exception:
+ pass
+ self.assertEqual(os.getcwd(), current_dir)
+ finally:
+ unregister_archive_format('xxx')
+
+ def test_register_archive_format(self):
+
+ self.assertRaises(TypeError, register_archive_format, 'xxx', 1)
+ self.assertRaises(TypeError, register_archive_format, 'xxx', lambda: x,
+ 1)
+ self.assertRaises(TypeError, register_archive_format, 'xxx', lambda: x,
+ [(1, 2), (1, 2, 3)])
+
+ register_archive_format('xxx', lambda: x, [(1, 2)], 'xxx file')
+ formats = [name for name, params in get_archive_formats()]
+ self.assertIn('xxx', formats)
+
+ unregister_archive_format('xxx')
+ formats = [name for name, params in get_archive_formats()]
+ self.assertNotIn('xxx', formats)
+
+
+class TestMove(unittest.TestCase):
+
+ def setUp(self):
+ filename = "foo"
+ self.src_dir = tempfile.mkdtemp()
+ self.dst_dir = tempfile.mkdtemp()
+ self.src_file = os.path.join(self.src_dir, filename)
+ self.dst_file = os.path.join(self.dst_dir, filename)
+ # Try to create a dir in the current directory, hoping that it is
+ # not located on the same filesystem as the system tmp dir.
+ try:
+ self.dir_other_fs = tempfile.mkdtemp(
+ dir=os.path.dirname(__file__))
+ self.file_other_fs = os.path.join(self.dir_other_fs,
+ filename)
+ except OSError:
+ self.dir_other_fs = None
+ with open(self.src_file, "wb") as f:
+ f.write("spam")
+
+ def tearDown(self):
+ for d in (self.src_dir, self.dst_dir, self.dir_other_fs):
+ try:
+ if d:
+ shutil.rmtree(d)
+ except:
+ pass
+
+ def _check_move_file(self, src, dst, real_dst):
+ with open(src, "rb") as f:
+ contents = f.read()
+ shutil.move(src, dst)
+ with open(real_dst, "rb") as f:
+ self.assertEqual(contents, f.read())
+ self.assertFalse(os.path.exists(src))
+
+ def _check_move_dir(self, src, dst, real_dst):
+ contents = sorted(os.listdir(src))
+ shutil.move(src, dst)
+ self.assertEqual(contents, sorted(os.listdir(real_dst)))
+ self.assertFalse(os.path.exists(src))
+
+ def test_move_file(self):
+ # Move a file to another location on the same filesystem.
+ self._check_move_file(self.src_file, self.dst_file, self.dst_file)
+
+ def test_move_file_to_dir(self):
+ # Move a file inside an existing dir on the same filesystem.
+ self._check_move_file(self.src_file, self.dst_dir, self.dst_file)
+
+ def test_move_file_other_fs(self):
+ # Move a file to an existing dir on another filesystem.
+ if not self.dir_other_fs:
+ # skip
+ return
+ self._check_move_file(self.src_file, self.file_other_fs,
+ self.file_other_fs)
+
+ def test_move_file_to_dir_other_fs(self):
+ # Move a file to another location on another filesystem.
+ if not self.dir_other_fs:
+ # skip
+ return
+ self._check_move_file(self.src_file, self.dir_other_fs,
+ self.file_other_fs)
+
+ def test_move_dir(self):
+ # Move a dir to another location on the same filesystem.
+ dst_dir = tempfile.mktemp()
+ try:
+ self._check_move_dir(self.src_dir, dst_dir, dst_dir)
+ finally:
+ try:
+ shutil.rmtree(dst_dir)
+ except:
+ pass
+
+ def test_move_dir_other_fs(self):
+ # Move a dir to another location on another filesystem.
+ if not self.dir_other_fs:
+ # skip
+ return
+ dst_dir = tempfile.mktemp(dir=self.dir_other_fs)
+ try:
+ self._check_move_dir(self.src_dir, dst_dir, dst_dir)
+ finally:
+ try:
+ shutil.rmtree(dst_dir)
+ except:
+ pass
+
+ def test_move_dir_to_dir(self):
+ # Move a dir inside an existing dir on the same filesystem.
+ self._check_move_dir(self.src_dir, self.dst_dir,
+ os.path.join(self.dst_dir, os.path.basename(self.src_dir)))
+
+ def test_move_dir_to_dir_other_fs(self):
+ # Move a dir inside an existing dir on another filesystem.
+ if not self.dir_other_fs:
+ # skip
+ return
+ self._check_move_dir(self.src_dir, self.dir_other_fs,
+ os.path.join(self.dir_other_fs, os.path.basename(self.src_dir)))
+
+ def test_existing_file_inside_dest_dir(self):
+ # A file with the same name inside the destination dir already exists.
+ with open(self.dst_file, "wb"):
+ pass
+ self.assertRaises(shutil.Error, shutil.move, self.src_file, self.dst_dir)
+
+ def test_dont_move_dir_in_itself(self):
+ # Moving a dir inside itself raises an Error.
+ dst = os.path.join(self.src_dir, "bar")
+ self.assertRaises(shutil.Error, shutil.move, self.src_dir, dst)
+
+ def test_destinsrc_false_negative(self):
+ os.mkdir(TESTFN)
+ try:
+ for src, dst in [('srcdir', 'srcdir/dest')]:
+ src = os.path.join(TESTFN, src)
+ dst = os.path.join(TESTFN, dst)
+ self.assertTrue(shutil._destinsrc(src, dst),
+ msg='_destinsrc() wrongly concluded that '
+ 'dst (%s) is not in src (%s)' % (dst, src))
+ finally:
+ shutil.rmtree(TESTFN, ignore_errors=True)
+
+ def test_destinsrc_false_positive(self):
+ os.mkdir(TESTFN)
+ try:
+ for src, dst in [('srcdir', 'src/dest'), ('srcdir', 'srcdir.new')]:
+ src = os.path.join(TESTFN, src)
+ dst = os.path.join(TESTFN, dst)
+ self.assertFalse(shutil._destinsrc(src, dst),
+ msg='_destinsrc() wrongly concluded that '
+ 'dst (%s) is in src (%s)' % (dst, src))
+ finally:
+ shutil.rmtree(TESTFN, ignore_errors=True)
+
+
+class TestCopyFile(unittest.TestCase):
+
+ _delete = False
+
+ class Faux(object):
+ _entered = False
+ _exited_with = None
+ _raised = False
+ def __init__(self, raise_in_exit=False, suppress_at_exit=True):
+ self._raise_in_exit = raise_in_exit
+ self._suppress_at_exit = suppress_at_exit
+ def read(self, *args):
+ return ''
+ def __enter__(self):
+ self._entered = True
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self._exited_with = exc_type, exc_val, exc_tb
+ if self._raise_in_exit:
+ self._raised = True
+ raise IOError("Cannot close")
+ return self._suppress_at_exit
+
+ def tearDown(self):
+ if self._delete:
+ del shutil.open
+
+ def _set_shutil_open(self, func):
+ shutil.open = func
+ self._delete = True
+
+ def test_w_source_open_fails(self):
+ def _open(filename, mode='r'):
+ if filename == 'srcfile':
+ raise IOError('Cannot open "srcfile"')
+ assert 0 # shouldn't reach here.
+
+ self._set_shutil_open(_open)
+
+ self.assertRaises(IOError, shutil.copyfile, 'srcfile', 'destfile')
+
+ def test_w_dest_open_fails(self):
+
+ srcfile = self.Faux()
+
+ def _open(filename, mode='r'):
+ if filename == 'srcfile':
+ return srcfile
+ if filename == 'destfile':
+ raise IOError('Cannot open "destfile"')
+ assert 0 # shouldn't reach here.
+
+ self._set_shutil_open(_open)
+
+ shutil.copyfile('srcfile', 'destfile')
+ self.assertTrue(srcfile._entered)
+ self.assertTrue(srcfile._exited_with[0] is IOError)
+ self.assertEqual(srcfile._exited_with[1].args,
+ ('Cannot open "destfile"',))
+
+ def test_w_dest_close_fails(self):
+
+ srcfile = self.Faux()
+ destfile = self.Faux(True)
+
+ def _open(filename, mode='r'):
+ if filename == 'srcfile':
+ return srcfile
+ if filename == 'destfile':
+ return destfile
+ assert 0 # shouldn't reach here.
+
+ self._set_shutil_open(_open)
+
+ shutil.copyfile('srcfile', 'destfile')
+ self.assertTrue(srcfile._entered)
+ self.assertTrue(destfile._entered)
+ self.assertTrue(destfile._raised)
+ self.assertTrue(srcfile._exited_with[0] is IOError)
+ self.assertEqual(srcfile._exited_with[1].args,
+ ('Cannot close',))
+
+ def test_w_source_close_fails(self):
+
+ srcfile = self.Faux(True)
+ destfile = self.Faux()
+
+ def _open(filename, mode='r'):
+ if filename == 'srcfile':
+ return srcfile
+ if filename == 'destfile':
+ return destfile
+ assert 0 # shouldn't reach here.
+
+ self._set_shutil_open(_open)
+
+ self.assertRaises(IOError,
+ shutil.copyfile, 'srcfile', 'destfile')
+ self.assertTrue(srcfile._entered)
+ self.assertTrue(destfile._entered)
+ self.assertFalse(destfile._raised)
+ self.assertTrue(srcfile._exited_with[0] is None)
+ self.assertTrue(srcfile._raised)
+
+
def test_main():
- test_support.run_unittest(TestShutil)
+ test_support.run_unittest(TestShutil, TestMove, TestCopyFile)
if __name__ == '__main__':
test_main()
diff --git a/Lib/test/test_str.py b/Lib/test/test_str.py
deleted file mode 100644
--- a/Lib/test/test_str.py
+++ /dev/null
@@ -1,382 +0,0 @@
-
-import struct
-import sys
-from test import test_support, string_tests
-
-
-class StrTest(
- string_tests.CommonTest,
- string_tests.MixinStrUnicodeUserStringTest,
- string_tests.MixinStrUserStringTest,
- string_tests.MixinStrUnicodeTest,
- ):
-
- type2test = str
-
- # We don't need to propagate to str
- def fixtype(self, obj):
- return obj
-
- def test_basic_creation(self):
- self.assertEqual(str(''), '')
- self.assertEqual(str(0), '0')
- self.assertEqual(str(0L), '0')
- self.assertEqual(str(()), '()')
- self.assertEqual(str([]), '[]')
- self.assertEqual(str({}), '{}')
- a = []
- a.append(a)
- self.assertEqual(str(a), '[[...]]')
- a = {}
- a[0] = a
- self.assertEqual(str(a), '{0: {...}}')
-
- def test_formatting(self):
- string_tests.MixinStrUnicodeUserStringTest.test_formatting(self)
- self.assertRaises(OverflowError, '%c'.__mod__, 0x1234)
-
- def test_conversion(self):
- # Make sure __str__() behaves properly
- class Foo0:
- def __unicode__(self):
- return u"foo"
-
- class Foo1:
- def __str__(self):
- return "foo"
-
- class Foo2(object):
- def __str__(self):
- return "foo"
-
- class Foo3(object):
- def __str__(self):
- return u"foo"
-
- class Foo4(unicode):
- def __str__(self):
- return u"foo"
-
- class Foo5(str):
- def __str__(self):
- return u"foo"
-
- class Foo6(str):
- def __str__(self):
- return "foos"
-
- def __unicode__(self):
- return u"foou"
-
- class Foo7(unicode):
- def __str__(self):
- return "foos"
- def __unicode__(self):
- return u"foou"
-
- class Foo8(str):
- def __new__(cls, content=""):
- return str.__new__(cls, 2*content)
- def __str__(self):
- return self
-
- class Foo9(str):
- def __str__(self):
- return "string"
- def __unicode__(self):
- return "not unicode"
-
- self.assert_(str(Foo0()).startswith("<")) # this is different from __unicode__
- self.assertEqual(str(Foo1()), "foo")
- self.assertEqual(str(Foo2()), "foo")
- self.assertEqual(str(Foo3()), "foo")
- self.assertEqual(str(Foo4("bar")), "foo")
- self.assertEqual(str(Foo5("bar")), "foo")
- self.assertEqual(str(Foo6("bar")), "foos")
- self.assertEqual(str(Foo7("bar")), "foos")
- self.assertEqual(str(Foo8("foo")), "foofoo")
- self.assertEqual(str(Foo9("foo")), "string")
- self.assertEqual(unicode(Foo9("foo")), u"not unicode")
-
- def test_expandtabs_overflows_gracefully(self):
- # This test only affects 32-bit platforms because expandtabs can only take
- # an int as the max value, not a 64-bit C long. If expandtabs is changed
- # to take a 64-bit long, this test should apply to all platforms.
-
- # Jython uses a different algorithm for which overflows cannot occur;
- # but memory exhaustion of course can. So not applicable.
- if sys.maxint > (1 << 32) or test_support.is_jython or struct.calcsize('P') != 4:
- return
- self.assertRaises(OverflowError, 't\tt\t'.expandtabs, sys.maxint)
-
- def test__format__(self):
- def test(value, format, expected):
- # test both with and without the trailing 's'
- self.assertEqual(value.__format__(format), expected)
- self.assertEqual(value.__format__(format + 's'), expected)
-
- test('', '', '')
- test('abc', '', 'abc')
- test('abc', '.3', 'abc')
- test('ab', '.3', 'ab')
- test('abcdef', '.3', 'abc')
- test('abcdef', '.0', '')
- test('abc', '3.3', 'abc')
- test('abc', '2.3', 'abc')
- test('abc', '2.2', 'ab')
- test('abc', '3.2', 'ab ')
- test('result', 'x<0', 'result')
- test('result', 'x<5', 'result')
- test('result', 'x<6', 'result')
- test('result', 'x<7', 'resultx')
- test('result', 'x<8', 'resultxx')
- test('result', ' <7', 'result ')
- test('result', '<7', 'result ')
- test('result', '>7', ' result')
- test('result', '>8', ' result')
- test('result', '^8', ' result ')
- test('result', '^9', ' result ')
- test('result', '^10', ' result ')
- test('a', '10000', 'a' + ' ' * 9999)
- test('', '10000', ' ' * 10000)
- test('', '10000000', ' ' * 10000000)
-
- def test_format(self):
- self.assertEqual(''.format(), '')
- self.assertEqual('a'.format(), 'a')
- self.assertEqual('ab'.format(), 'ab')
- self.assertEqual('a{{'.format(), 'a{')
- self.assertEqual('a}}'.format(), 'a}')
- self.assertEqual('{{b'.format(), '{b')
- self.assertEqual('}}b'.format(), '}b')
- self.assertEqual('a{{b'.format(), 'a{b')
-
- # examples from the PEP:
- import datetime
- self.assertEqual("My name is {0}".format('Fred'), "My name is Fred")
- self.assertEqual("My name is {0[name]}".format(dict(name='Fred')),
- "My name is Fred")
- self.assertEqual("My name is {0} :-{{}}".format('Fred'),
- "My name is Fred :-{}")
-
- d = datetime.date(2007, 8, 18)
- self.assertEqual("The year is {0.year}".format(d),
- "The year is 2007")
-
- # classes we'll use for testing
- class C:
- def __init__(self, x=100):
- self._x = x
- def __format__(self, spec):
- return spec
-
- class D:
- def __init__(self, x):
- self.x = x
- def __format__(self, spec):
- return str(self.x)
-
- # class with __str__, but no __format__
- class E:
- def __init__(self, x):
- self.x = x
- def __str__(self):
- return 'E(' + self.x + ')'
-
- # class with __repr__, but no __format__ or __str__
- class F:
- def __init__(self, x):
- self.x = x
- def __repr__(self):
- return 'F(' + self.x + ')'
-
- # class with __format__ that forwards to string, for some format_spec's
- class G:
- def __init__(self, x):
- self.x = x
- def __str__(self):
- return "string is " + self.x
- def __format__(self, format_spec):
- if format_spec == 'd':
- return 'G(' + self.x + ')'
- return object.__format__(self, format_spec)
-
- # class that returns a bad type from __format__
- class H:
- def __format__(self, format_spec):
- return 1.0
-
- class I(datetime.date):
- def __format__(self, format_spec):
- return self.strftime(format_spec)
-
- class J(int):
- def __format__(self, format_spec):
- return int.__format__(self * 2, format_spec)
-
-
- self.assertEqual(''.format(), '')
- self.assertEqual('abc'.format(), 'abc')
- self.assertEqual('{0}'.format('abc'), 'abc')
- self.assertEqual('{0:}'.format('abc'), 'abc')
- self.assertEqual('X{0}'.format('abc'), 'Xabc')
- self.assertEqual('{0}X'.format('abc'), 'abcX')
- self.assertEqual('X{0}Y'.format('abc'), 'XabcY')
- self.assertEqual('{1}'.format(1, 'abc'), 'abc')
- self.assertEqual('X{1}'.format(1, 'abc'), 'Xabc')
- self.assertEqual('{1}X'.format(1, 'abc'), 'abcX')
- self.assertEqual('X{1}Y'.format(1, 'abc'), 'XabcY')
- self.assertEqual('{0}'.format(-15), '-15')
- self.assertEqual('{0}{1}'.format(-15, 'abc'), '-15abc')
- self.assertEqual('{0}X{1}'.format(-15, 'abc'), '-15Xabc')
- self.assertEqual('{{'.format(), '{')
- self.assertEqual('}}'.format(), '}')
- self.assertEqual('{{}}'.format(), '{}')
- self.assertEqual('{{x}}'.format(), '{x}')
- self.assertEqual('{{{0}}}'.format(123), '{123}')
- self.assertEqual('{{{{0}}}}'.format(), '{{0}}')
- self.assertEqual('}}{{'.format(), '}{')
- self.assertEqual('}}x{{'.format(), '}x{')
-
- # weird field names
- self.assertEqual("{0[foo-bar]}".format({'foo-bar':'baz'}), 'baz')
- self.assertEqual("{0[foo bar]}".format({'foo bar':'baz'}), 'baz')
- self.assertEqual("{0[ ]}".format({' ':3}), '3')
-
- self.assertEqual('{foo._x}'.format(foo=C(20)), '20')
- self.assertEqual('{1}{0}'.format(D(10), D(20)), '2010')
- self.assertEqual('{0._x.x}'.format(C(D('abc'))), 'abc')
- self.assertEqual('{0[0]}'.format(['abc', 'def']), 'abc')
- self.assertEqual('{0[1]}'.format(['abc', 'def']), 'def')
- self.assertEqual('{0[1][0]}'.format(['abc', ['def']]), 'def')
- self.assertEqual('{0[1][0].x}'.format(['abc', [D('def')]]), 'def')
-
- # strings
- self.assertEqual('{0:.3s}'.format('abc'), 'abc')
- self.assertEqual('{0:.3s}'.format('ab'), 'ab')
- self.assertEqual('{0:.3s}'.format('abcdef'), 'abc')
- self.assertEqual('{0:.0s}'.format('abcdef'), '')
- self.assertEqual('{0:3.3s}'.format('abc'), 'abc')
- self.assertEqual('{0:2.3s}'.format('abc'), 'abc')
- self.assertEqual('{0:2.2s}'.format('abc'), 'ab')
- self.assertEqual('{0:3.2s}'.format('abc'), 'ab ')
- self.assertEqual('{0:x<0s}'.format('result'), 'result')
- self.assertEqual('{0:x<5s}'.format('result'), 'result')
- self.assertEqual('{0:x<6s}'.format('result'), 'result')
- self.assertEqual('{0:x<7s}'.format('result'), 'resultx')
- self.assertEqual('{0:x<8s}'.format('result'), 'resultxx')
- self.assertEqual('{0: <7s}'.format('result'), 'result ')
- self.assertEqual('{0:<7s}'.format('result'), 'result ')
- self.assertEqual('{0:>7s}'.format('result'), ' result')
- self.assertEqual('{0:>8s}'.format('result'), ' result')
- self.assertEqual('{0:^8s}'.format('result'), ' result ')
- self.assertEqual('{0:^9s}'.format('result'), ' result ')
- self.assertEqual('{0:^10s}'.format('result'), ' result ')
- self.assertEqual('{0:10000}'.format('a'), 'a' + ' ' * 9999)
- self.assertEqual('{0:10000}'.format(''), ' ' * 10000)
- self.assertEqual('{0:10000000}'.format(''), ' ' * 10000000)
-
- # format specifiers for user defined type
- self.assertEqual('{0:abc}'.format(C()), 'abc')
-
- # !r and !s coersions
- self.assertEqual('{0!s}'.format('Hello'), 'Hello')
- self.assertEqual('{0!s:}'.format('Hello'), 'Hello')
- self.assertEqual('{0!s:15}'.format('Hello'), 'Hello ')
- self.assertEqual('{0!s:15s}'.format('Hello'), 'Hello ')
- self.assertEqual('{0!r}'.format('Hello'), "'Hello'")
- self.assertEqual('{0!r:}'.format('Hello'), "'Hello'")
- self.assertEqual('{0!r}'.format(F('Hello')), 'F(Hello)')
-
- # test fallback to object.__format__
- self.assertEqual('{0}'.format({}), '{}')
- self.assertEqual('{0}'.format([]), '[]')
- self.assertEqual('{0}'.format([1]), '[1]')
- self.assertEqual('{0}'.format(E('data')), 'E(data)')
- self.assertEqual('{0:^10}'.format(E('data')), ' E(data) ')
- self.assertEqual('{0:^10s}'.format(E('data')), ' E(data) ')
- self.assertEqual('{0:d}'.format(G('data')), 'G(data)')
- self.assertEqual('{0:>15s}'.format(G('data')), ' string is data')
- self.assertEqual('{0!s}'.format(G('data')), 'string is data')
-
- self.assertEqual("{0:date: %Y-%m-%d}".format(I(year=2007,
- month=8,
- day=27)),
- "date: 2007-08-27")
-
- # test deriving from a builtin type and overriding __format__
- self.assertEqual("{0}".format(J(10)), "20")
-
-
- # string format specifiers
- self.assertEqual('{0:}'.format('a'), 'a')
-
- # computed format specifiers
- self.assertEqual("{0:.{1}}".format('hello world', 5), 'hello')
- self.assertEqual("{0:.{1}s}".format('hello world', 5), 'hello')
- self.assertEqual("{0:.{precision}s}".format('hello world', precision=5), 'hello')
- self.assertEqual("{0:{width}.{precision}s}".format('hello world', width=10, precision=5), 'hello ')
- self.assertEqual("{0:{width}.{precision}s}".format('hello world', width='10', precision='5'), 'hello ')
-
- # test various errors
- self.assertRaises(ValueError, '{'.format)
- self.assertRaises(ValueError, '}'.format)
- self.assertRaises(ValueError, 'a{'.format)
- self.assertRaises(ValueError, 'a}'.format)
- self.assertRaises(ValueError, '{a'.format)
- self.assertRaises(ValueError, '}a'.format)
- self.assertRaises(IndexError, '{0}'.format)
- self.assertRaises(IndexError, '{1}'.format, 'abc')
- self.assertRaises(KeyError, '{x}'.format)
- self.assertRaises(ValueError, "}{".format)
- self.assertRaises(ValueError, "{".format)
- self.assertRaises(ValueError, "}".format)
- self.assertRaises(ValueError, "abc{0:{}".format)
- self.assertRaises(ValueError, "{0".format)
- self.assertRaises(IndexError, "{0.}".format)
- self.assertRaises(ValueError, "{0.}".format, 0)
- self.assertRaises(IndexError, "{0[}".format)
- self.assertRaises(ValueError, "{0[}".format, [])
- self.assertRaises(KeyError, "{0]}".format)
- self.assertRaises(ValueError, "{0.[]}".format, 0)
- self.assertRaises(ValueError, "{0..foo}".format, 0)
- self.assertRaises(ValueError, "{0[0}".format, 0)
- self.assertRaises(ValueError, "{0[0:foo}".format, 0)
- self.assertRaises(KeyError, "{c]}".format)
- self.assertRaises(ValueError, "{{ {{{0}}".format, 0)
- self.assertRaises(ValueError, "{0}}".format, 0)
- self.assertRaises(KeyError, "{foo}".format, bar=3)
- self.assertRaises(ValueError, "{0!x}".format, 3)
- self.assertRaises(ValueError, "{0!}".format, 0)
- self.assertRaises(ValueError, "{0!rs}".format, 0)
- self.assertRaises(ValueError, "{!}".format)
- self.assertRaises(ValueError, "{:}".format)
- self.assertRaises(ValueError, "{:s}".format)
- self.assertRaises(ValueError, "{}".format)
-
- # issue 6089
- self.assertRaises(ValueError, "{0[0]x}".format, [None])
- self.assertRaises(ValueError, "{0[0](10)}".format, [None])
-
- # can't have a replacement on the field name portion
- self.assertRaises(TypeError, '{0[{1}]}'.format, 'abcdefg', 4)
-
- # exceed maximum recursion depth
- self.assertRaises(ValueError, "{0:{1:{2}}}".format, 'abc', 's', '')
- self.assertRaises(ValueError, "{0:{1:{2:{3:{4:{5:{6}}}}}}}".format,
- 0, 1, 2, 3, 4, 5, 6, 7)
-
- # string format spec errors
- self.assertRaises(ValueError, "{0:-s}".format, '')
- self.assertRaises(ValueError, format, "", "-")
- self.assertRaises(ValueError, "{0:=s}".format, '')
-
- def test_buffer_is_readonly(self):
- self.assertRaises(TypeError, sys.stdin.readinto, b"")
-
-
-def test_main():
- test_support.run_unittest(StrTest)
-
-if __name__ == "__main__":
- test_main()
diff --git a/Lib/test/test_tempfile.py b/Lib/test/test_tempfile.py
--- a/Lib/test/test_tempfile.py
+++ b/Lib/test/test_tempfile.py
@@ -1,6 +1,5 @@
# From Python 2.5.1
# tempfile.py unit tests.
-from __future__ import with_statement
import tempfile
import os
import sys
@@ -82,7 +81,8 @@
"gettempprefix" : 1,
"gettempdir" : 1,
"tempdir" : 1,
- "template" : 1
+ "template" : 1,
+ "SpooledTemporaryFile" : 1
}
unexp = []
@@ -128,7 +128,7 @@
if i == 20:
break
except:
- failOnException("iteration")
+ self.failOnException("iteration")
test_classes.append(test__RandomNameSequence)
@@ -150,13 +150,11 @@
# _candidate_tempdir_list contains the expected directories
# Make sure the interesting environment variables are all set.
- added = []
- try:
+ with test_support.EnvironmentVarGuard() as env:
for envname in 'TMPDIR', 'TEMP', 'TMP':
dirname = os.getenv(envname)
if not dirname:
- os.environ[envname] = os.path.abspath(envname)
- added.append(envname)
+ env.set(envname, os.path.abspath(envname))
cand = tempfile._candidate_tempdir_list()
@@ -174,9 +172,6 @@
# Not practical to try to verify the presence of OS-specific
# paths in this list.
- finally:
- for p in added:
- del os.environ[p]
test_classes.append(test__candidate_tempdir_list)
@@ -581,11 +576,12 @@
class test_NamedTemporaryFile(TC):
"""Test NamedTemporaryFile()."""
- def do_create(self, dir=None, pre="", suf=""):
+ def do_create(self, dir=None, pre="", suf="", delete=True):
if dir is None:
dir = tempfile.gettempdir()
try:
- file = tempfile.NamedTemporaryFile(dir=dir, prefix=pre, suffix=suf)
+ file = tempfile.NamedTemporaryFile(dir=dir, prefix=pre, suffix=suf,
+ delete=delete)
except:
self.failOnException("NamedTemporaryFile")
@@ -619,6 +615,22 @@
finally:
os.rmdir(dir)
+ def test_dis_del_on_close(self):
+ # Tests that delete-on-close can be disabled
+ dir = tempfile.mkdtemp()
+ tmp = None
+ try:
+ f = tempfile.NamedTemporaryFile(dir=dir, delete=False)
+ tmp = f.name
+ f.write('blat')
+ f.close()
+ self.failUnless(os.path.exists(f.name),
+ "NamedTemporaryFile %s missing after close" % f.name)
+ finally:
+ if tmp is not None:
+ os.unlink(tmp)
+ os.rmdir(dir)
+
def test_multiple_close(self):
# A NamedTemporaryFile can be closed many times without error
f = tempfile.NamedTemporaryFile()
@@ -644,6 +656,160 @@
test_classes.append(test_NamedTemporaryFile)
+class test_SpooledTemporaryFile(TC):
+ """Test SpooledTemporaryFile()."""
+
+ def do_create(self, max_size=0, dir=None, pre="", suf=""):
+ if dir is None:
+ dir = tempfile.gettempdir()
+ try:
+ file = tempfile.SpooledTemporaryFile(max_size=max_size, dir=dir, prefix=pre, suffix=suf)
+ except:
+ self.failOnException("SpooledTemporaryFile")
+
+ return file
+
+
+ def test_basic(self):
+ # SpooledTemporaryFile can create files
+ f = self.do_create()
+ self.failIf(f._rolled)
+ f = self.do_create(max_size=100, pre="a", suf=".txt")
+ self.failIf(f._rolled)
+
+ def test_del_on_close(self):
+ # A SpooledTemporaryFile is deleted when closed
+ dir = tempfile.mkdtemp()
+ try:
+ f = tempfile.SpooledTemporaryFile(max_size=10, dir=dir)
+ self.failIf(f._rolled)
+ f.write('blat ' * 5)
+ self.failUnless(f._rolled)
+ filename = f.name
+ f.close()
+ self.failIf(os.path.exists(filename),
+ "SpooledTemporaryFile %s exists after close" % filename)
+ finally:
+ os.rmdir(dir)
+
+ def test_rewrite_small(self):
+ # A SpooledTemporaryFile can be written to multiple within the max_size
+ f = self.do_create(max_size=30)
+ self.failIf(f._rolled)
+ for i in range(5):
+ f.seek(0, 0)
+ f.write('x' * 20)
+ self.failIf(f._rolled)
+
+ def test_write_sequential(self):
+ # A SpooledTemporaryFile should hold exactly max_size bytes, and roll
+ # over afterward
+ f = self.do_create(max_size=30)
+ self.failIf(f._rolled)
+ f.write('x' * 20)
+ self.failIf(f._rolled)
+ f.write('x' * 10)
+ self.failIf(f._rolled)
+ f.write('x')
+ self.failUnless(f._rolled)
+
+ def test_sparse(self):
+ # A SpooledTemporaryFile that is written late in the file will extend
+ # when that occurs
+ f = self.do_create(max_size=30)
+ self.failIf(f._rolled)
+ f.seek(100, 0)
+ self.failIf(f._rolled)
+ f.write('x')
+ self.failUnless(f._rolled)
+
+ def test_fileno(self):
+ # A SpooledTemporaryFile should roll over to a real file on fileno()
+ f = self.do_create(max_size=30)
+ self.failIf(f._rolled)
+ self.failUnless(f.fileno() > 0)
+ self.failUnless(f._rolled)
+
+ def test_multiple_close_before_rollover(self):
+ # A SpooledTemporaryFile can be closed many times without error
+ f = tempfile.SpooledTemporaryFile()
+ f.write('abc\n')
+ self.failIf(f._rolled)
+ f.close()
+ try:
+ f.close()
+ f.close()
+ except:
+ self.failOnException("close")
+
+ def test_multiple_close_after_rollover(self):
+ # A SpooledTemporaryFile can be closed many times without error
+ f = tempfile.SpooledTemporaryFile(max_size=1)
+ f.write('abc\n')
+ self.failUnless(f._rolled)
+ f.close()
+ try:
+ f.close()
+ f.close()
+ except:
+ self.failOnException("close")
+
+ def test_bound_methods(self):
+ # It should be OK to steal a bound method from a SpooledTemporaryFile
+ # and use it independently; when the file rolls over, those bound
+ # methods should continue to function
+ f = self.do_create(max_size=30)
+ read = f.read
+ write = f.write
+ seek = f.seek
+
+ write("a" * 35)
+ write("b" * 35)
+ seek(0, 0)
+ self.failUnless(read(70) == 'a'*35 + 'b'*35)
+
+ def test_context_manager_before_rollover(self):
+ # A SpooledTemporaryFile can be used as a context manager
+ with tempfile.SpooledTemporaryFile(max_size=1) as f:
+ self.failIf(f._rolled)
+ self.failIf(f.closed)
+ self.failUnless(f.closed)
+ def use_closed():
+ with f:
+ pass
+ self.failUnlessRaises(ValueError, use_closed)
+
+ def test_context_manager_during_rollover(self):
+ # A SpooledTemporaryFile can be used as a context manager
+ with tempfile.SpooledTemporaryFile(max_size=1) as f:
+ self.failIf(f._rolled)
+ f.write('abc\n')
+ f.flush()
+ self.failUnless(f._rolled)
+ self.failIf(f.closed)
+ self.failUnless(f.closed)
+ def use_closed():
+ with f:
+ pass
+ self.failUnlessRaises(ValueError, use_closed)
+
+ def test_context_manager_after_rollover(self):
+ # A SpooledTemporaryFile can be used as a context manager
+ f = tempfile.SpooledTemporaryFile(max_size=1)
+ f.write('abc\n')
+ f.flush()
+ self.failUnless(f._rolled)
+ with f:
+ self.failIf(f.closed)
+ self.failUnless(f.closed)
+ def use_closed():
+ with f:
+ pass
+ self.failUnlessRaises(ValueError, use_closed)
+
+
+test_classes.append(test_SpooledTemporaryFile)
+
class test_TemporaryFile(TC):
"""Test TemporaryFile()."""
diff --git a/Lib/test/test_time.py b/Lib/test/test_time.py
--- a/Lib/test/test_time.py
+++ b/Lib/test/test_time.py
@@ -18,9 +18,9 @@
time.clock()
def test_conversions(self):
- self.assert_(time.ctime(self.t)
+ self.assertTrue(time.ctime(self.t)
== time.asctime(time.localtime(self.t)))
- self.assert_(long(time.mktime(time.localtime(self.t)))
+ self.assertTrue(long(time.mktime(time.localtime(self.t)))
== long(self.t))
def test_sleep(self):
@@ -89,11 +89,6 @@
(1900, 1, 1, 0, 0, 0, 0, -1, -1))
self.assertRaises(ValueError, time.strftime, '',
(1900, 1, 1, 0, 0, 0, 0, 367, -1))
- # Check daylight savings flag [-1, 1]
- self.assertRaises(ValueError, time.strftime, '',
- (1900, 1, 1, 0, 0, 0, 0, 1, -2))
- self.assertRaises(ValueError, time.strftime, '',
- (1900, 1, 1, 0, 0, 0, 0, 1, 2))
def test_default_values_for_zero(self):
# Make sure that using all zeros uses the proper default values.
@@ -107,18 +102,22 @@
# and it is not converted to 2000.
expected = "0000 01 01 00 00 00 1 001"
result = time.strftime("%Y %m %d %H %M %S %w %j", (0,)*9)
- self.assertEquals(expected, result)
+ self.assertEqual(expected, result)
def test_strptime(self):
+ # Should be able to go round-trip from strftime to strptime without
+ # throwing an exception.
tt = time.gmtime(self.t)
for directive in ('a', 'A', 'b', 'B', 'c', 'd', 'H', 'I',
'j', 'm', 'M', 'p', 'S',
'U', 'w', 'W', 'x', 'X', 'y', 'Y', 'Z', '%'):
- format = ' %' + directive
+ format = '%' + directive
+ strf_output = time.strftime(format, tt)
try:
- time.strptime(time.strftime(format, tt), format)
+ time.strptime(strf_output, format)
except ValueError:
- self.fail('conversion specifier: %r failed.' % format)
+ self.fail("conversion specifier %r failed with '%s' input." %
+ (format, strf_output))
def test_strptime_empty(self):
try:
@@ -129,10 +128,20 @@
def test_asctime(self):
time.asctime(time.gmtime(self.t))
self.assertRaises(TypeError, time.asctime, 0)
+ self.assertRaises(TypeError, time.asctime, ())
+ # XXX: Posix compiant asctime should refuse to convert
+ # year > 9999, but Linux implementation does not.
+ # self.assertRaises(ValueError, time.asctime,
+ # (12345, 1, 0, 0, 0, 0, 0, 0, 0))
+ # XXX: For now, just make sure we don't have a crash:
+ try:
+ time.asctime((12345, 1, 1, 0, 0, 0, 0, 1, 0))
+ except ValueError:
+ pass
+ @unittest.skipIf(not hasattr(time, "tzset"),
+ "time module has no attribute tzset")
def test_tzset(self):
- if not hasattr(time, "tzset"):
- return # Can't test this; don't want the test suite to fail
from os import environ
@@ -158,36 +167,36 @@
time.tzset()
environ['TZ'] = utc
time.tzset()
- self.failUnlessEqual(
+ self.assertEqual(
time.gmtime(xmas2002), time.localtime(xmas2002)
)
- self.failUnlessEqual(time.daylight, 0)
- self.failUnlessEqual(time.timezone, 0)
- self.failUnlessEqual(time.localtime(xmas2002).tm_isdst, 0)
+ self.assertEqual(time.daylight, 0)
+ self.assertEqual(time.timezone, 0)
+ self.assertEqual(time.localtime(xmas2002).tm_isdst, 0)
# Make sure we can switch to US/Eastern
environ['TZ'] = eastern
time.tzset()
- self.failIfEqual(time.gmtime(xmas2002), time.localtime(xmas2002))
- self.failUnlessEqual(time.tzname, ('EST', 'EDT'))
- self.failUnlessEqual(len(time.tzname), 2)
- self.failUnlessEqual(time.daylight, 1)
- self.failUnlessEqual(time.timezone, 18000)
- self.failUnlessEqual(time.altzone, 14400)
- self.failUnlessEqual(time.localtime(xmas2002).tm_isdst, 0)
- self.failUnlessEqual(len(time.tzname), 2)
+ self.assertNotEqual(time.gmtime(xmas2002), time.localtime(xmas2002))
+ self.assertEqual(time.tzname, ('EST', 'EDT'))
+ self.assertEqual(len(time.tzname), 2)
+ self.assertEqual(time.daylight, 1)
+ self.assertEqual(time.timezone, 18000)
+ self.assertEqual(time.altzone, 14400)
+ self.assertEqual(time.localtime(xmas2002).tm_isdst, 0)
+ self.assertEqual(len(time.tzname), 2)
# Now go to the southern hemisphere.
environ['TZ'] = victoria
time.tzset()
- self.failIfEqual(time.gmtime(xmas2002), time.localtime(xmas2002))
- self.failUnless(time.tzname[0] == 'AEST', str(time.tzname[0]))
- self.failUnless(time.tzname[1] == 'AEDT', str(time.tzname[1]))
- self.failUnlessEqual(len(time.tzname), 2)
- self.failUnlessEqual(time.daylight, 1)
- self.failUnlessEqual(time.timezone, -36000)
- self.failUnlessEqual(time.altzone, -39600)
- self.failUnlessEqual(time.localtime(xmas2002).tm_isdst, 1)
+ self.assertNotEqual(time.gmtime(xmas2002), time.localtime(xmas2002))
+ self.assertTrue(time.tzname[0] == 'AEST', str(time.tzname[0]))
+ self.assertTrue(time.tzname[1] == 'AEDT', str(time.tzname[1]))
+ self.assertEqual(len(time.tzname), 2)
+ self.assertEqual(time.daylight, 1)
+ self.assertEqual(time.timezone, -36000)
+ self.assertEqual(time.altzone, -39600)
+ self.assertEqual(time.localtime(xmas2002).tm_isdst, 1)
finally:
# Repair TZ environment variable in case any other tests
@@ -219,14 +228,25 @@
gt1 = time.gmtime(None)
t0 = time.mktime(gt0)
t1 = time.mktime(gt1)
- self.assert_(0 <= (t1-t0) < 0.2)
+ self.assertTrue(0 <= (t1-t0) < 0.2)
def test_localtime_without_arg(self):
lt0 = time.localtime()
lt1 = time.localtime(None)
t0 = time.mktime(lt0)
t1 = time.mktime(lt1)
- self.assert_(0 <= (t1-t0) < 0.2)
+ self.assertTrue(0 <= (t1-t0) < 0.2)
+
+ def test_mktime(self):
+ # Issue #1726687
+ for t in (-2, -1, 0, 1):
+ try:
+ tt = time.localtime(t)
+ except (OverflowError, ValueError):
+ pass
+ else:
+ self.assertEqual(time.mktime(tt), t)
+
def test_main():
test_support.run_unittest(TimeTestCase)
diff --git a/Lib/test/test_trace.py b/Lib/test/test_trace.py
--- a/Lib/test/test_trace.py
+++ b/Lib/test/test_trace.py
@@ -22,6 +22,7 @@
import unittest
import sys
import difflib
+import gc
# A very basic example. If this fails, we're in deep trouble.
def basic():
@@ -262,6 +263,17 @@
return self.trace
class TraceTestCase(unittest.TestCase):
+
+ # Disable gc collection when tracing, otherwise the
+ # deallocators may be traced as well.
+ def setUp(self):
+ self.using_gc = gc.isenabled()
+ gc.disable()
+
+ def tearDown(self):
+ if self.using_gc:
+ gc.enable()
+
def compare_events(self, line_offset, events, expected_events):
events = [(l - line_offset, e) for (l, e) in events]
if events != expected_events:
@@ -288,6 +300,20 @@
self.compare_events(func.func_code.co_firstlineno,
tracer.events, func.events)
+ def set_and_retrieve_none(self):
+ sys.settrace(None)
+ assert sys.gettrace() is None
+
+ def set_and_retrieve_func(self):
+ def fn(*args):
+ pass
+
+ sys.settrace(fn)
+ try:
+ assert sys.gettrace() is fn
+ finally:
+ sys.settrace(None)
+
def test_01_basic(self):
self.run_test(basic)
def test_02_arigo(self):
@@ -324,7 +350,7 @@
sys.settrace(tracer.traceWithGenexp)
generator_example()
sys.settrace(None)
- self.compare_events(generator_example.func_code.co_firstlineno,
+ self.compare_events(generator_example.__code__.co_firstlineno,
tracer.events, generator_example.events)
def test_14_onliner_if(self):
@@ -393,7 +419,7 @@
we're testing, so that the 'exception' trace event fires."""
if self.raiseOnEvent == 'exception':
x = 0
- y = 1/x
+ y = 1 // x
else:
return 1
@@ -732,6 +758,23 @@
def test_19_no_jump_without_trace_function(self):
no_jump_without_trace_function()
+ def test_20_large_function(self):
+ d = {}
+ exec("""def f(output): # line 0
+ x = 0 # line 1
+ y = 1 # line 2
+ ''' # line 3
+ %s # lines 4-1004
+ ''' # line 1005
+ x += 1 # line 1006
+ output.append(x) # line 1007
+ return""" % ('\n' * 1000,), d)
+ f = d['f']
+
+ f.jump = (2, 1007)
+ f.output = [0]
+ self.run_test(f)
+
def test_main():
tests = [TraceTestCase,
RaisingTraceFuncTestCase]
diff --git a/Lib/test/test_univnewlines.py b/Lib/test/test_univnewlines.py
--- a/Lib/test/test_univnewlines.py
+++ b/Lib/test/test_univnewlines.py
@@ -6,7 +6,7 @@
from test import test_support
if not hasattr(sys.stdin, 'newlines'):
- raise unittest.SkipTest, \
+ raise test_support.TestSkipped, \
"This Python does not have universal newline support"
FATX = 'x' * (2**14)
@@ -38,8 +38,9 @@
WRITEMODE = 'wb'
def setUp(self):
- with open(test_support.TESTFN, self.WRITEMODE) as fp:
- fp.write(self.DATA)
+ fp = open(test_support.TESTFN, self.WRITEMODE)
+ fp.write(self.DATA)
+ fp.close()
def tearDown(self):
try:
@@ -48,40 +49,41 @@
pass
def test_read(self):
- with open(test_support.TESTFN, self.READMODE) as fp:
- data = fp.read()
+ fp = open(test_support.TESTFN, self.READMODE)
+ data = fp.read()
self.assertEqual(data, DATA_LF)
self.assertEqual(repr(fp.newlines), repr(self.NEWLINE))
def test_readlines(self):
- with open(test_support.TESTFN, self.READMODE) as fp:
- data = fp.readlines()
+ fp = open(test_support.TESTFN, self.READMODE)
+ data = fp.readlines()
self.assertEqual(data, DATA_SPLIT)
self.assertEqual(repr(fp.newlines), repr(self.NEWLINE))
def test_readline(self):
- with open(test_support.TESTFN, self.READMODE) as fp:
- data = []
+ fp = open(test_support.TESTFN, self.READMODE)
+ data = []
+ d = fp.readline()
+ while d:
+ data.append(d)
d = fp.readline()
- while d:
- data.append(d)
- d = fp.readline()
self.assertEqual(data, DATA_SPLIT)
self.assertEqual(repr(fp.newlines), repr(self.NEWLINE))
def test_seek(self):
- with open(test_support.TESTFN, self.READMODE) as fp:
- fp.readline()
- pos = fp.tell()
- data = fp.readlines()
- self.assertEqual(data, DATA_SPLIT[1:])
- fp.seek(pos)
- data = fp.readlines()
+ fp = open(test_support.TESTFN, self.READMODE)
+ fp.readline()
+ pos = fp.tell()
+ data = fp.readlines()
+ self.assertEqual(data, DATA_SPLIT[1:])
+ fp.seek(pos)
+ data = fp.readlines()
self.assertEqual(data, DATA_SPLIT[1:])
def test_execfile(self):
namespace = {}
- execfile(test_support.TESTFN, namespace)
+ with test_support._check_py3k_warnings():
+ execfile(test_support.TESTFN, namespace)
func = namespace['line3']
self.assertEqual(func.func_code.co_firstlineno, 3)
self.assertEqual(namespace['line4'], FATX)
@@ -106,10 +108,10 @@
DATA = DATA_CRLF
def test_tell(self):
- with open(test_support.TESTFN, self.READMODE) as fp:
- self.assertEqual(repr(fp.newlines), repr(None))
- data = fp.readline()
- pos = fp.tell()
+ fp = open(test_support.TESTFN, self.READMODE)
+ self.assertEqual(repr(fp.newlines), repr(None))
+ data = fp.readline()
+ pos = fp.tell()
self.assertEqual(repr(fp.newlines), repr(self.NEWLINE))
class TestMixedNewlines(TestGenericUnivNewlines):
diff --git a/Lib/test/test_urllib2.py b/Lib/test/test_urllib2.py
--- a/Lib/test/test_urllib2.py
+++ b/Lib/test/test_urllib2.py
@@ -1,7 +1,8 @@
import unittest
from test import test_support
-import os, socket
+import os
+import socket
import StringIO
import urllib2
@@ -20,18 +21,19 @@
# XXX Name hacking to get this to work on Windows.
fname = os.path.abspath(urllib2.__file__).replace('\\', '/')
- if fname[1:2] == ":":
- fname = fname[2:]
+
# And more hacking to get it to work on MacOS. This assumes
# urllib.pathname2url works, unfortunately...
- if os.name == 'mac':
- fname = '/' + fname.replace(':', '/')
- elif os.name == 'riscos':
+ if os.name == 'riscos':
import string
fname = os.expand(fname)
fname = fname.translate(string.maketrans("/.", "./"))
- file_url = "file://%s" % fname
+ if os.name == 'nt':
+ file_url = "file:///%s" % fname
+ else:
+ file_url = "file://%s" % fname
+
f = urllib2.urlopen(file_url)
buf = f.read()
@@ -43,7 +45,7 @@
('a, b, "c", "d", "e,f", g, h', ['a', 'b', '"c"', '"d"', '"e,f"', 'g', 'h']),
('a="b\\"c", d="e\\,f", g="h\\\\i"', ['a="b"c"', 'd="e,f"', 'g="h\\i"'])]
for string, list in tests:
- self.assertEquals(urllib2.parse_http_list(string), list)
+ self.assertEqual(urllib2.parse_http_list(string), list)
def test_request_headers_dict():
@@ -223,8 +225,8 @@
class MockOpener:
addheaders = []
- def open(self, req, data=None):
- self.req, self.data = req, data
+ def open(self, req, data=None,timeout=socket._GLOBAL_DEFAULT_TIMEOUT):
+ self.req, self.data, self.timeout = req, data, timeout
def error(self, proto, *args):
self.proto, self.args = proto, args
@@ -260,6 +262,51 @@
def __call__(self, *args):
return self.handle(self.meth_name, self.action, *args)
+class MockHTTPResponse:
+ def __init__(self, fp, msg, status, reason):
+ self.fp = fp
+ self.msg = msg
+ self.status = status
+ self.reason = reason
+ def read(self):
+ return ''
+
+class MockHTTPClass:
+ def __init__(self):
+ self.req_headers = []
+ self.data = None
+ self.raise_on_endheaders = False
+ self._tunnel_headers = {}
+
+ def __call__(self, host, timeout=socket._GLOBAL_DEFAULT_TIMEOUT):
+ self.host = host
+ self.timeout = timeout
+ return self
+
+ def set_debuglevel(self, level):
+ self.level = level
+
+ def set_tunnel(self, host, port=None, headers=None):
+ self._tunnel_host = host
+ self._tunnel_port = port
+ if headers:
+ self._tunnel_headers = headers
+ else:
+ self._tunnel_headers.clear()
+ def request(self, method, url, body=None, headers=None):
+ self.method = method
+ self.selector = url
+ if headers is not None:
+ self.req_headers += headers.items()
+ self.req_headers.sort()
+ if body:
+ self.data = body
+ if self.raise_on_endheaders:
+ import socket
+ raise socket.error()
+ def getresponse(self):
+ return MockHTTPResponse(MockFile(), {}, 200, "OK")
+
class MockHandler:
# useful for testing handler machinery
# see add_ordered_mock_handlers() docstring
@@ -367,6 +414,17 @@
msg = mimetools.Message(StringIO("\r\n\r\n"))
return MockResponse(200, "OK", msg, "", req.get_full_url())
+class MockHTTPSHandler(urllib2.AbstractHTTPHandler):
+ # Useful for testing the Proxy-Authorization request by verifying the
+ # properties of httpcon
+
+ def __init__(self):
+ urllib2.AbstractHTTPHandler.__init__(self)
+ self.httpconn = MockHTTPClass()
+
+ def https_open(self, req):
+ return self.do_open(self.httpconn, req)
+
class MockPasswordManager:
def add_password(self, realm, uri, user, password):
self.realm = realm
@@ -520,15 +578,15 @@
# *_request
self.assertEqual((handler, name), calls[i])
self.assertEqual(len(args), 1)
- self.assert_(isinstance(args[0], Request))
+ self.assertIsInstance(args[0], Request)
else:
# *_response
self.assertEqual((handler, name), calls[i])
self.assertEqual(len(args), 2)
- self.assert_(isinstance(args[0], Request))
+ self.assertIsInstance(args[0], Request)
# response from opener.open is None, because there's no
# handler that defines http_open to handle it
- self.assert_(args[1] is None or
+ self.assertTrue(args[1] is None or
isinstance(args[1], MockResponse))
@@ -552,32 +610,45 @@
class NullFTPHandler(urllib2.FTPHandler):
def __init__(self, data): self.data = data
- def connect_ftp(self, user, passwd, host, port, dirs):
+ def connect_ftp(self, user, passwd, host, port, dirs,
+ timeout=socket._GLOBAL_DEFAULT_TIMEOUT):
self.user, self.passwd = user, passwd
self.host, self.port = host, port
self.dirs = dirs
self.ftpwrapper = MockFTPWrapper(self.data)
return self.ftpwrapper
- import ftplib, socket
+ import ftplib
data = "rheum rhaponicum"
h = NullFTPHandler(data)
o = h.parent = MockOpener()
- for url, host, port, type_, dirs, filename, mimetype in [
+ for url, host, port, user, passwd, type_, dirs, filename, mimetype in [
("ftp://localhost/foo/bar/baz.html",
- "localhost", ftplib.FTP_PORT, "I",
+ "localhost", ftplib.FTP_PORT, "", "", "I",
+ ["foo", "bar"], "baz.html", "text/html"),
+ ("ftp://parrot@localhost/foo/bar/baz.html",
+ "localhost", ftplib.FTP_PORT, "parrot", "", "I",
+ ["foo", "bar"], "baz.html", "text/html"),
+ ("ftp://%25parrot@localhost/foo/bar/baz.html",
+ "localhost", ftplib.FTP_PORT, "%parrot", "", "I",
+ ["foo", "bar"], "baz.html", "text/html"),
+ ("ftp://%2542parrot@localhost/foo/bar/baz.html",
+ "localhost", ftplib.FTP_PORT, "%42parrot", "", "I",
["foo", "bar"], "baz.html", "text/html"),
("ftp://localhost:80/foo/bar/",
- "localhost", 80, "D",
+ "localhost", 80, "", "", "D",
["foo", "bar"], "", None),
("ftp://localhost/baz.gif;type=a",
- "localhost", ftplib.FTP_PORT, "A",
+ "localhost", ftplib.FTP_PORT, "", "", "A",
[], "baz.gif", None), # XXX really this should guess image/gif
]:
- r = h.ftp_open(Request(url))
+ req = Request(url)
+ req.timeout = None
+ r = h.ftp_open(req)
# ftp authentication not yet implemented by FTPHandler
- self.assert_(h.user == h.passwd == "")
+ self.assertEqual(h.user, user)
+ self.assertEqual(h.passwd, passwd)
self.assertEqual(h.host, socket.gethostbyname(host))
self.assertEqual(h.port, port)
self.assertEqual(h.dirs, dirs)
@@ -588,7 +659,7 @@
self.assertEqual(int(headers["Content-length"]), len(data))
def test_file(self):
- import time, rfc822, socket
+ import rfc822, socket
h = urllib2.FileHandler()
o = h.parent = MockOpener()
@@ -619,7 +690,7 @@
try:
data = r.read()
headers = r.info()
- newurl = r.geturl()
+ respurl = r.geturl()
finally:
r.close()
stats = os.stat(TESTFN)
@@ -630,14 +701,15 @@
self.assertEqual(headers["Content-type"], "text/plain")
self.assertEqual(headers["Content-length"], "13")
self.assertEqual(headers["Last-modified"], modified)
+ self.assertEqual(respurl, url)
for url in [
"file://localhost:80%s" % urlpath,
-# XXXX bug: these fail with socket.gaierror, should be URLError
-## "file://%s:80%s/%s" % (socket.gethostbyname('localhost'),
-## os.getcwd(), TESTFN),
-## "file://somerandomhost.ontheinternet.com%s/%s" %
-## (os.getcwd(), TESTFN),
+ "file:///file_does_not_exist.txt",
+ "file://%s:80%s/%s" % (socket.gethostbyname('localhost'),
+ os.getcwd(), TESTFN),
+ "file://somerandomhost.ontheinternet.com%s/%s" %
+ (os.getcwd(), TESTFN),
]:
try:
f = open(TESTFN, "wb")
@@ -665,48 +737,21 @@
("file://ftp.example.com///foo.txt", False),
# XXXX bug: fails with OSError, should be URLError
("file://ftp.example.com/foo.txt", False),
+ ("file://somehost//foo/something.txt", True),
+ ("file://localhost//foo/something.txt", False),
]:
req = Request(url)
try:
h.file_open(req)
# XXXX remove OSError when bug fixed
except (urllib2.URLError, OSError):
- self.assert_(not ftp)
+ self.assertTrue(not ftp)
else:
- self.assert_(o.req is req)
+ self.assertTrue(o.req is req)
self.assertEqual(req.type, "ftp")
+ self.assertEqual(req.type == "ftp", ftp)
def test_http(self):
- class MockHTTPResponse:
- def __init__(self, fp, msg, status, reason):
- self.fp = fp
- self.msg = msg
- self.status = status
- self.reason = reason
- def read(self):
- return ''
- class MockHTTPClass:
- def __init__(self):
- self.req_headers = []
- self.data = None
- self.raise_on_endheaders = False
- def __call__(self, host):
- self.host = host
- return self
- def set_debuglevel(self, level):
- self.level = level
- def request(self, method, url, body=None, headers={}):
- self.method = method
- self.selector = url
- self.req_headers += headers.items()
- self.req_headers.sort()
- if body:
- self.data = body
- if self.raise_on_endheaders:
- import socket
- raise socket.error()
- def getresponse(self):
- return MockHTTPResponse(MockFile(), {}, 200, "OK")
h = urllib2.AbstractHTTPHandler()
o = h.parent = MockOpener()
@@ -714,6 +759,7 @@
url = "http://example.com/"
for method, data in [("GET", None), ("POST", "blah")]:
req = Request(url, data, {"Foo": "bar"})
+ req.timeout = None
req.add_unredirected_header("Spam", "eggs")
http = MockHTTPClass()
r = h.do_open(http, req)
@@ -746,8 +792,8 @@
r = MockResponse(200, "OK", {}, "")
newreq = h.do_request_(req)
if data is None: # GET
- self.assert_("Content-length" not in req.unredirected_hdrs)
- self.assert_("Content-type" not in req.unredirected_hdrs)
+ self.assertNotIn("Content-length", req.unredirected_hdrs)
+ self.assertNotIn("Content-type", req.unredirected_hdrs)
else: # POST
self.assertEqual(req.unredirected_hdrs["Content-length"], "0")
self.assertEqual(req.unredirected_hdrs["Content-type"],
@@ -767,22 +813,75 @@
self.assertEqual(req.unredirected_hdrs["Host"], "baz")
self.assertEqual(req.unredirected_hdrs["Spam"], "foo")
+ def test_http_doubleslash(self):
+ # Checks that the presence of an unnecessary double slash in a url doesn't break anything
+ # Previously, a double slash directly after the host could cause incorrect parsing of the url
+ h = urllib2.AbstractHTTPHandler()
+ o = h.parent = MockOpener()
+
+ data = ""
+ ds_urls = [
+ "http://example.com/foo/bar/baz.html",
+ "http://example.com//foo/bar/baz.html",
+ "http://example.com/foo//bar/baz.html",
+ "http://example.com/foo/bar//baz.html",
+ ]
+
+ for ds_url in ds_urls:
+ ds_req = Request(ds_url, data)
+
+ # Check whether host is determined correctly if there is no proxy
+ np_ds_req = h.do_request_(ds_req)
+ self.assertEqual(np_ds_req.unredirected_hdrs["Host"],"example.com")
+
+ # Check whether host is determined correctly if there is a proxy
+ ds_req.set_proxy("someproxy:3128",None)
+ p_ds_req = h.do_request_(ds_req)
+ self.assertEqual(p_ds_req.unredirected_hdrs["Host"],"example.com")
+
+ def test_fixpath_in_weirdurls(self):
+ # Issue4493: urllib2 to supply '/' when to urls where path does not
+ # start with'/'
+
+ h = urllib2.AbstractHTTPHandler()
+ o = h.parent = MockOpener()
+
+ weird_url = 'http://www.python.org?getspam'
+ req = Request(weird_url)
+ newreq = h.do_request_(req)
+ self.assertEqual(newreq.get_host(),'www.python.org')
+ self.assertEqual(newreq.get_selector(),'/?getspam')
+
+ url_without_path = 'http://www.python.org'
+ req = Request(url_without_path)
+ newreq = h.do_request_(req)
+ self.assertEqual(newreq.get_host(),'www.python.org')
+ self.assertEqual(newreq.get_selector(),'')
+
def test_errors(self):
h = urllib2.HTTPErrorProcessor()
o = h.parent = MockOpener()
url = "http://example.com/"
req = Request(url)
- # 200 OK is passed through
+ # all 2xx are passed through
r = MockResponse(200, "OK", {}, "", url)
newr = h.http_response(req, r)
- self.assert_(r is newr)
- self.assert_(not hasattr(o, "proto")) # o.error not called
+ self.assertTrue(r is newr)
+ self.assertTrue(not hasattr(o, "proto")) # o.error not called
+ r = MockResponse(202, "Accepted", {}, "", url)
+ newr = h.http_response(req, r)
+ self.assertTrue(r is newr)
+ self.assertTrue(not hasattr(o, "proto")) # o.error not called
+ r = MockResponse(206, "Partial content", {}, "", url)
+ newr = h.http_response(req, r)
+ self.assertTrue(r is newr)
+ self.assertTrue(not hasattr(o, "proto")) # o.error not called
# anything else calls o.error (and MockOpener returns None, here)
- r = MockResponse(201, "Created", {}, "", url)
- self.assert_(h.http_response(req, r) is None)
+ r = MockResponse(502, "Bad gateway", {}, "", url)
+ self.assertTrue(h.http_response(req, r) is None)
self.assertEqual(o.proto, "http") # o.error called
- self.assertEqual(o.args, (req, r, 201, "Created", {}))
+ self.assertEqual(o.args, (req, r, 502, "Bad gateway", {}))
def test_cookies(self):
cj = MockCookieJar()
@@ -792,12 +891,12 @@
req = Request("http://example.com/")
r = MockResponse(200, "OK", {}, "")
newreq = h.http_request(req)
- self.assert_(cj.ach_req is req is newreq)
- self.assertEquals(req.get_origin_req_host(), "example.com")
- self.assert_(not req.is_unverifiable())
+ self.assertTrue(cj.ach_req is req is newreq)
+ self.assertEqual(req.get_origin_req_host(), "example.com")
+ self.assertTrue(not req.is_unverifiable())
newr = h.http_response(req, r)
- self.assert_(cj.ec_req is req)
- self.assert_(cj.ec_r is r is newr)
+ self.assertTrue(cj.ec_req is req)
+ self.assertTrue(cj.ec_r is r is newr)
def test_redirect(self):
from_url = "http://example.com/a.html"
@@ -811,25 +910,36 @@
method = getattr(h, "http_error_%s" % code)
req = Request(from_url, data)
req.add_header("Nonsense", "viking=withhold")
+ req.timeout = socket._GLOBAL_DEFAULT_TIMEOUT
+ if data is not None:
+ req.add_header("Content-Length", str(len(data)))
req.add_unredirected_header("Spam", "spam")
try:
method(req, MockFile(), code, "Blah",
MockHeaders({"location": to_url}))
except urllib2.HTTPError:
# 307 in response to POST requires user OK
- self.assert_(code == 307 and data is not None)
+ self.assertTrue(code == 307 and data is not None)
self.assertEqual(o.req.get_full_url(), to_url)
try:
self.assertEqual(o.req.get_method(), "GET")
except AttributeError:
- self.assert_(not o.req.has_data())
+ self.assertTrue(not o.req.has_data())
+
+ # now it's a GET, there should not be headers regarding content
+ # (possibly dragged from before being a POST)
+ headers = [x.lower() for x in o.req.headers]
+ self.assertNotIn("content-length", headers)
+ self.assertNotIn("content-type", headers)
+
self.assertEqual(o.req.headers["Nonsense"],
"viking=withhold")
- self.assert_("Spam" not in o.req.headers)
- self.assert_("Spam" not in o.req.unredirected_hdrs)
+ self.assertNotIn("Spam", o.req.headers)
+ self.assertNotIn("Spam", o.req.unredirected_hdrs)
# loop detection
req = Request(from_url)
+ req.timeout = socket._GLOBAL_DEFAULT_TIMEOUT
def redirect(h, req, url=to_url):
h.http_error_302(req, MockFile(), 302, "Blah",
MockHeaders({"location": url}))
@@ -839,6 +949,7 @@
# detect infinite loop redirect of a URL to itself
req = Request(from_url, origin_req_host="example.com")
count = 0
+ req.timeout = socket._GLOBAL_DEFAULT_TIMEOUT
try:
while 1:
redirect(h, req, "http://example.com/")
@@ -850,6 +961,7 @@
# detect endless non-repeating chain of redirects
req = Request(from_url, origin_req_host="example.com")
count = 0
+ req.timeout = socket._GLOBAL_DEFAULT_TIMEOUT
try:
while 1:
redirect(h, req, "http://example.com/%d" % count)
@@ -872,7 +984,7 @@
cp = urllib2.HTTPCookieProcessor(cj)
o = build_test_opener(hh, hdeh, hrh, cp)
o.open("http://www.example.com/")
- self.assert_(not hh.req.has_header("Cookie"))
+ self.assertTrue(not hh.req.has_header("Cookie"))
def test_proxy(self):
o = OpenerDirector()
@@ -891,13 +1003,68 @@
self.assertEqual([(handlers[0], "http_open")],
[tup[0:2] for tup in o.calls])
- def test_basic_auth(self):
+ def test_proxy_no_proxy(self):
+ os.environ['no_proxy'] = 'python.org'
+ o = OpenerDirector()
+ ph = urllib2.ProxyHandler(dict(http="proxy.example.com"))
+ o.add_handler(ph)
+ req = Request("http://www.perl.org/")
+ self.assertEqual(req.get_host(), "www.perl.org")
+ r = o.open(req)
+ self.assertEqual(req.get_host(), "proxy.example.com")
+ req = Request("http://www.python.org")
+ self.assertEqual(req.get_host(), "www.python.org")
+ r = o.open(req)
+ self.assertEqual(req.get_host(), "www.python.org")
+ del os.environ['no_proxy']
+
+
+ def test_proxy_https(self):
+ o = OpenerDirector()
+ ph = urllib2.ProxyHandler(dict(https='proxy.example.com:3128'))
+ o.add_handler(ph)
+ meth_spec = [
+ [("https_open","return response")]
+ ]
+ handlers = add_ordered_mock_handlers(o, meth_spec)
+ req = Request("https://www.example.com/")
+ self.assertEqual(req.get_host(), "www.example.com")
+ r = o.open(req)
+ self.assertEqual(req.get_host(), "proxy.example.com:3128")
+ self.assertEqual([(handlers[0], "https_open")],
+ [tup[0:2] for tup in o.calls])
+
+ def test_proxy_https_proxy_authorization(self):
+ o = OpenerDirector()
+ ph = urllib2.ProxyHandler(dict(https='proxy.example.com:3128'))
+ o.add_handler(ph)
+ https_handler = MockHTTPSHandler()
+ o.add_handler(https_handler)
+ req = Request("https://www.example.com/")
+ req.add_header("Proxy-Authorization","FooBar")
+ req.add_header("User-Agent","Grail")
+ self.assertEqual(req.get_host(), "www.example.com")
+ self.assertIsNone(req._tunnel_host)
+ r = o.open(req)
+ # Verify Proxy-Authorization gets tunneled to request.
+ # httpsconn req_headers do not have the Proxy-Authorization header but
+ # the req will have.
+ self.assertNotIn(("Proxy-Authorization","FooBar"),
+ https_handler.httpconn.req_headers)
+ self.assertIn(("User-Agent","Grail"),
+ https_handler.httpconn.req_headers)
+ self.assertIsNotNone(req._tunnel_host)
+ self.assertEqual(req.get_host(), "proxy.example.com:3128")
+ self.assertEqual(req.get_header("Proxy-authorization"),"FooBar")
+
+ def test_basic_auth(self, quote_char='"'):
opener = OpenerDirector()
password_manager = MockPasswordManager()
auth_handler = urllib2.HTTPBasicAuthHandler(password_manager)
realm = "ACME Widget Store"
http_handler = MockHTTPHandler(
- 401, 'WWW-Authenticate: Basic realm="%s"\r\n\r\n' % realm)
+ 401, 'WWW-Authenticate: Basic realm=%s%s%s\r\n\r\n' %
+ (quote_char, realm, quote_char) )
opener.add_handler(auth_handler)
opener.add_handler(http_handler)
self._test_basic_auth(opener, auth_handler, "Authorization",
@@ -906,6 +1073,9 @@
"http://acme.example.com/protected",
)
+ def test_basic_auth_with_single_quoted_realm(self):
+ self.test_basic_auth(quote_char="'")
+
def test_proxy_basic_auth(self):
opener = OpenerDirector()
ph = urllib2.ProxyHandler(dict(http="proxy.example.com:3128"))
@@ -973,7 +1143,7 @@
def _test_basic_auth(self, opener, auth_handler, auth_header,
realm, http_handler, password_manager,
request_url, protected_url):
- import base64, httplib
+ import base64
user, password = "wile", "coyote"
# .add_password() fed through to password manager
@@ -996,7 +1166,8 @@
auth_hdr_value = 'Basic '+base64.encodestring(userpass).strip()
self.assertEqual(http_handler.requests[1].get_header(auth_header),
auth_hdr_value)
-
+ self.assertEqual(http_handler.requests[1].unredirected_hdrs[auth_header],
+ auth_hdr_value)
# if the password manager can't find a password, the handler won't
# handle the HTTP auth error
password_manager.user = password_manager.password = None
@@ -1005,7 +1176,6 @@
self.assertEqual(len(http_handler.requests), 1)
self.assertFalse(http_handler.requests[0].has_header(auth_header))
-
class MiscTests(unittest.TestCase):
def test_build_opener(self):
@@ -1050,7 +1220,62 @@
if h.__class__ == handler_class:
break
else:
- self.assert_(False)
+ self.assertTrue(False)
+
+class RequestTests(unittest.TestCase):
+
+ def setUp(self):
+ self.get = urllib2.Request("http://www.python.org/~jeremy/")
+ self.post = urllib2.Request("http://www.python.org/~jeremy/",
+ "data",
+ headers={"X-Test": "test"})
+
+ def test_method(self):
+ self.assertEqual("POST", self.post.get_method())
+ self.assertEqual("GET", self.get.get_method())
+
+ def test_add_data(self):
+ self.assertTrue(not self.get.has_data())
+ self.assertEqual("GET", self.get.get_method())
+ self.get.add_data("spam")
+ self.assertTrue(self.get.has_data())
+ self.assertEqual("POST", self.get.get_method())
+
+ def test_get_full_url(self):
+ self.assertEqual("http://www.python.org/~jeremy/",
+ self.get.get_full_url())
+
+ def test_selector(self):
+ self.assertEqual("/~jeremy/", self.get.get_selector())
+ req = urllib2.Request("http://www.python.org/")
+ self.assertEqual("/", req.get_selector())
+
+ def test_get_type(self):
+ self.assertEqual("http", self.get.get_type())
+
+ def test_get_host(self):
+ self.assertEqual("www.python.org", self.get.get_host())
+
+ def test_get_host_unquote(self):
+ req = urllib2.Request("http://www.%70ython.org/")
+ self.assertEqual("www.python.org", req.get_host())
+
+ def test_proxy(self):
+ self.assertTrue(not self.get.has_proxy())
+ self.get.set_proxy("www.perl.org", "http")
+ self.assertTrue(self.get.has_proxy())
+ self.assertEqual("www.python.org", self.get.get_origin_req_host())
+ self.assertEqual("www.perl.org", self.get.get_host())
+
+ def test_wrapped_url(self):
+ req = Request("<URL:http://www.python.org>")
+ self.assertEqual("www.python.org", req.get_host())
+
+ def test_urlwith_fragment(self):
+ req = Request("http://www.python.org/?qs=query#fragment=true")
+ self.assertEqual("/?qs=query", req.get_selector())
+ req = Request("http://www.python.org/#fun=true")
+ self.assertEqual("/", req.get_selector())
def test_main(verbose=None):
@@ -1060,7 +1285,8 @@
tests = (TrivialTests,
OpenerDirectorTests,
HandlerTests,
- MiscTests)
+ MiscTests,
+ RequestTests)
test_support.run_unittest(*tests)
if __name__ == "__main__":
diff --git a/Lib/test/test_weakref.py b/Lib/test/test_weakref.py
--- a/Lib/test/test_weakref.py
+++ b/Lib/test/test_weakref.py
@@ -3,6 +3,7 @@
import unittest
import UserList
import weakref
+import operator
from test import test_support
@@ -67,10 +68,10 @@
# Live reference:
o = C()
wr = weakref.ref(o)
- `wr`
+ repr(wr)
# Dead reference:
del o
- `wr`
+ repr(wr)
def test_basic_callback(self):
self.check_basic_callback(C)
@@ -191,7 +192,8 @@
p.append(12)
self.assertEqual(len(L), 1)
self.failUnless(p, "proxy for non-empty UserList should be true")
- p[:] = [2, 3]
+ with test_support._check_py3k_warnings():
+ p[:] = [2, 3]
self.assertEqual(len(L), 2)
self.assertEqual(len(p), 2)
self.failUnless(3 in p,
@@ -205,16 +207,48 @@
## self.assertEqual(repr(L2), repr(p2))
L3 = UserList.UserList(range(10))
p3 = weakref.proxy(L3)
- self.assertEqual(L3[:], p3[:])
- self.assertEqual(L3[5:], p3[5:])
- self.assertEqual(L3[:5], p3[:5])
- self.assertEqual(L3[2:5], p3[2:5])
+ with test_support._check_py3k_warnings():
+ self.assertEqual(L3[:], p3[:])
+ self.assertEqual(L3[5:], p3[5:])
+ self.assertEqual(L3[:5], p3[:5])
+ self.assertEqual(L3[2:5], p3[2:5])
+
+ def test_proxy_unicode(self):
+ # See bug 5037
+ class C(object):
+ def __str__(self):
+ return "string"
+ def __unicode__(self):
+ return u"unicode"
+ instance = C()
+ self.assertTrue("__unicode__" in dir(weakref.proxy(instance)))
+ self.assertEqual(unicode(weakref.proxy(instance)), u"unicode")
+
+ def test_proxy_index(self):
+ class C:
+ def __index__(self):
+ return 10
+ o = C()
+ p = weakref.proxy(o)
+ self.assertEqual(operator.index(p), 10)
+
+ def test_proxy_div(self):
+ class C:
+ def __floordiv__(self, other):
+ return 42
+ def __ifloordiv__(self, other):
+ return 21
+ o = C()
+ p = weakref.proxy(o)
+ self.assertEqual(p // 5, 42)
+ p //= 5
+ self.assertEqual(p, 21)
# The PyWeakref_* C API is documented as allowing either NULL or
# None as the value for the callback, where either means "no
# callback". The "no callback" ref and proxy objects are supposed
# to be shared so long as they exist by all callers so long as
- # they are active. In Python 2.3.3 and earlier, this guaranttee
+ # they are active. In Python 2.3.3 and earlier, this guarantee
# was not honored, and was broken in different ways for
# PyWeakref_NewRef() and PyWeakref_NewProxy(). (Two tests.)
@@ -676,8 +710,16 @@
w = Target()
+ def test_init(self):
+ # Issue 3634
+ # <weakref to class>.__init__() doesn't check errors correctly
+ r = weakref.ref(Exception)
+ self.assertRaises(TypeError, r.__init__, 0, 0, 0, 0, 0)
+ # No exception should be raised here
+ gc.collect()
-class SubclassableWeakrefTestCase(unittest.TestCase):
+
+class SubclassableWeakrefTestCase(TestBase):
def test_subclass_refs(self):
class MyRef(weakref.ref):
@@ -741,6 +783,44 @@
self.assertEqual(r.meth(), "abcdef")
self.failIf(hasattr(r, "__dict__"))
+ def test_subclass_refs_with_cycle(self):
+ # Bug #3110
+ # An instance of a weakref subclass can have attributes.
+ # If such a weakref holds the only strong reference to the object,
+ # deleting the weakref will delete the object. In this case,
+ # the callback must not be called, because the ref object is
+ # being deleted.
+ class MyRef(weakref.ref):
+ pass
+
+ # Use a local callback, for "regrtest -R::"
+ # to detect refcounting problems
+ def callback(w):
+ self.cbcalled += 1
+
+ o = C()
+ r1 = MyRef(o, callback)
+ r1.o = o
+ del o
+
+ del r1 # Used to crash here
+
+ self.assertEqual(self.cbcalled, 0)
+
+ # Same test, with two weakrefs to the same object
+ # (since code paths are different)
+ o = C()
+ r1 = MyRef(o, callback)
+ r2 = MyRef(o, callback)
+ r1.r = r2
+ r2.o = o
+ del o
+ del r2
+
+ del r1 # Used to crash here
+
+ self.assertEqual(self.cbcalled, 0)
+
class Object:
def __init__(self, arg):
@@ -789,7 +869,7 @@
def test_weak_keys(self):
#
# This exercises d.copy(), d.items(), d[] = v, d[], del d[],
- # len(d), d.has_key().
+ # len(d), in d.
#
dict, objects = self.make_weak_keyed_dict()
for o in objects:
@@ -813,8 +893,8 @@
"deleting the keys did not clear the dictionary")
o = Object(42)
dict[o] = "What is the meaning of the universe?"
- self.assert_(dict.has_key(o))
- self.assert_(not dict.has_key(34))
+ self.assertTrue(o in dict)
+ self.assertTrue(34 not in dict)
def test_weak_keyed_iters(self):
dict, objects = self.make_weak_keyed_dict()
@@ -826,8 +906,7 @@
objects2 = list(objects)
for wr in refs:
ob = wr()
- self.assert_(dict.has_key(ob))
- self.assert_(ob in dict)
+ self.assertTrue(ob in dict)
self.assertEqual(ob.arg, dict[ob])
objects2.remove(ob)
self.assertEqual(len(objects2), 0)
@@ -837,8 +916,7 @@
self.assertEqual(len(list(dict.iterkeyrefs())), len(objects))
for wr in dict.iterkeyrefs():
ob = wr()
- self.assert_(dict.has_key(ob))
- self.assert_(ob in dict)
+ self.assertTrue(ob in dict)
self.assertEqual(ob.arg, dict[ob])
objects2.remove(ob)
self.assertEqual(len(objects2), 0)
@@ -951,16 +1029,16 @@
" -- value parameters must be distinct objects")
weakdict = klass()
o = weakdict.setdefault(key, value1)
- self.assert_(o is value1)
- self.assert_(weakdict.has_key(key))
- self.assert_(weakdict.get(key) is value1)
- self.assert_(weakdict[key] is value1)
+ self.assertTrue(o is value1)
+ self.assertTrue(key in weakdict)
+ self.assertTrue(weakdict.get(key) is value1)
+ self.assertTrue(weakdict[key] is value1)
o = weakdict.setdefault(key, value2)
- self.assert_(o is value1)
- self.assert_(weakdict.has_key(key))
- self.assert_(weakdict.get(key) is value1)
- self.assert_(weakdict[key] is value1)
+ self.assertTrue(o is value1)
+ self.assertTrue(key in weakdict)
+ self.assertTrue(weakdict.get(key) is value1)
+ self.assertTrue(weakdict[key] is value1)
def test_weak_valued_dict_setdefault(self):
self.check_setdefault(weakref.WeakValueDictionary,
@@ -972,24 +1050,24 @@
def check_update(self, klass, dict):
#
- # This exercises d.update(), len(d), d.keys(), d.has_key(),
+ # This exercises d.update(), len(d), d.keys(), in d,
# d.get(), d[].
#
weakdict = klass()
weakdict.update(dict)
- self.assert_(len(weakdict) == len(dict))
+ self.assertEqual(len(weakdict), len(dict))
for k in weakdict.keys():
- self.assert_(dict.has_key(k),
+ self.assertTrue(k in dict,
"mysterious new key appeared in weak dict")
v = dict.get(k)
- self.assert_(v is weakdict[k])
- self.assert_(v is weakdict.get(k))
+ self.assertTrue(v is weakdict[k])
+ self.assertTrue(v is weakdict.get(k))
for k in dict.keys():
- self.assert_(weakdict.has_key(k),
+ self.assertTrue(k in weakdict,
"original key disappeared in weak dict")
v = dict[k]
- self.assert_(v is weakdict[k])
- self.assert_(v is weakdict.get(k))
+ self.assertTrue(v is weakdict[k])
+ self.assertTrue(v is weakdict.get(k))
def test_weak_valued_dict_update(self):
self.check_update(weakref.WeakValueDictionary,
@@ -1096,7 +1174,7 @@
def _reference(self):
return self.__ref.copy()
-libreftest = """ Doctest for examples in the library reference: libweakref.tex
+libreftest = """ Doctest for examples in the library reference: weakref.rst
>>> import weakref
>>> class Dict(dict):
@@ -1199,6 +1277,7 @@
MappingTestCase,
WeakValueDictionaryTestCase,
WeakKeyDictionaryTestCase,
+ SubclassableWeakrefTestCase,
)
test_support.run_doctest(sys.modules[__name__])
diff --git a/Lib/test/test_xml_etree.py b/Lib/test/test_xml_etree.py
--- a/Lib/test/test_xml_etree.py
+++ b/Lib/test/test_xml_etree.py
@@ -2,7 +2,8 @@
# all included components work as they should. For a more extensive
# test suite, see the selftest script in the ElementTree distribution.
-import doctest, sys
+import doctest
+import sys
from test import test_support
@@ -36,7 +37,7 @@
"""
def check_method(method):
- if not callable(method):
+ if not hasattr(method, '__call__'):
print method, "not callable"
def serialize(ET, elem, encoding=None):
diff --git a/Lib/test/test_xml_etree_c.py b/Lib/test/test_xml_etree_c.py
--- a/Lib/test/test_xml_etree_c.py
+++ b/Lib/test/test_xml_etree_c.py
@@ -1,6 +1,7 @@
# xml.etree test for cElementTree
-import doctest, sys
+import doctest
+import sys
from test import test_support
@@ -34,7 +35,7 @@
"""
def check_method(method):
- if not callable(method):
+ if not hasattr(method, '__call__'):
print method, "not callable"
def serialize(ET, elem, encoding=None):
diff --git a/Lib/test/test_zlib.py b/Lib/test/test_zlib.py
--- a/Lib/test/test_zlib.py
+++ b/Lib/test/test_zlib.py
@@ -1,23 +1,9 @@
import unittest
from test import test_support
import zlib
+import binascii
import random
-
-# print test_support.TESTFN
-
-def getbuf():
- # This was in the original. Avoid non-repeatable sources.
- # Left here (unused) in case something wants to be done with it.
- import imp
- try:
- t = imp.find_module('test_zlib')
- file = t[0]
- except ImportError:
- file = open(__file__)
- buf = file.read() * 8
- file.close()
- return buf
-
+from test.test_support import precisionbigmemtest, _1G
class ChecksumTestCase(unittest.TestCase):
@@ -59,33 +45,91 @@
self.assertEqual(zlib.crc32("penguin"), zlib.crc32("penguin", 0))
self.assertEqual(zlib.adler32("penguin"),zlib.adler32("penguin",1))
+ def test_abcdefghijklmnop(self):
+ """test issue1202 compliance: signed crc32, adler32 in 2.x"""
+ foo = 'abcdefghijklmnop'
+ # explicitly test signed behavior
+ self.assertEqual(zlib.crc32(foo), -1808088941)
+ self.assertEqual(zlib.crc32('spam'), 1138425661)
+ self.assertEqual(zlib.adler32(foo+foo), -721416943)
+ self.assertEqual(zlib.adler32('spam'), 72286642)
+ def test_same_as_binascii_crc32(self):
+ foo = 'abcdefghijklmnop'
+ self.assertEqual(binascii.crc32(foo), zlib.crc32(foo))
+ self.assertEqual(binascii.crc32('spam'), zlib.crc32('spam'))
+
+ def test_negative_crc_iv_input(self):
+ # The range of valid input values for the crc state should be
+ # -2**31 through 2**32-1 to allow inputs artifically constrained
+ # to a signed 32-bit integer.
+ self.assertEqual(zlib.crc32('ham', -1), zlib.crc32('ham', 0xffffffffL))
+ self.assertEqual(zlib.crc32('spam', -3141593),
+ zlib.crc32('spam', 0xffd01027L))
+ self.assertEqual(zlib.crc32('spam', -(2**31)),
+ zlib.crc32('spam', (2**31)))
+
+ def test_decompress_badinput(self):
+ self.assertRaises(zlib.error, zlib.decompress, 'foo')
class ExceptionTestCase(unittest.TestCase):
# make sure we generate some expected errors
- def test_bigbits(self):
- # specifying total bits too large causes an error
- self.assertRaises(zlib.error,
- zlib.compress, 'ERROR', zlib.MAX_WBITS + 1)
+ def test_badlevel(self):
+ # specifying compression level out of range causes an error
+ # (but -1 is Z_DEFAULT_COMPRESSION and apparently the zlib
+ # accepts 0 too)
+ self.assertRaises(zlib.error, zlib.compress, 'ERROR', 10)
def test_badcompressobj(self):
# verify failure on building compress object with bad params
self.assertRaises(ValueError, zlib.compressobj, 1, zlib.DEFLATED, 0)
+ # specifying total bits too large causes an error
+ self.assertRaises(ValueError,
+ zlib.compressobj, 1, zlib.DEFLATED, zlib.MAX_WBITS + 1)
def test_baddecompressobj(self):
# verify failure on building decompress object with bad params
- self.assertRaises(ValueError, zlib.decompressobj, 0)
+ self.assertRaises(ValueError, zlib.decompressobj, -1)
def test_decompressobj_badflush(self):
# verify failure on calling decompressobj.flush with bad params
self.assertRaises(ValueError, zlib.decompressobj().flush, 0)
self.assertRaises(ValueError, zlib.decompressobj().flush, -1)
- def test_decompress_badinput(self):
- self.assertRaises(zlib.error, zlib.decompress, 'foo')
+class BaseCompressTestCase(object):
+ def check_big_compress_buffer(self, size, compress_func):
+ _1M = 1024 * 1024
+ fmt = "%%0%dx" % (2 * _1M)
+ # Generate 10MB worth of random, and expand it by repeating it.
+ # The assumption is that zlib's memory is not big enough to exploit
+ # such spread out redundancy.
+ data = ''.join([binascii.a2b_hex(fmt % random.getrandbits(8 * _1M))
+ for i in range(10)])
+ data = data * (size // len(data) + 1)
+ try:
+ compress_func(data)
+ finally:
+ # Release memory
+ data = None
-class CompressTestCase(unittest.TestCase):
+ def check_big_decompress_buffer(self, size, decompress_func):
+ data = 'x' * size
+ try:
+ compressed = zlib.compress(data, 1)
+ finally:
+ # Release memory
+ data = None
+ data = decompress_func(compressed)
+ # Sanity check
+ try:
+ self.assertEqual(len(data), size)
+ self.assertEqual(len(data.strip('x')), 0)
+ finally:
+ data = None
+
+
+class CompressTestCase(BaseCompressTestCase, unittest.TestCase):
# Test compression in one go (whole message compression)
def test_speech(self):
x = zlib.compress(HAMLET_SCENE)
@@ -97,10 +141,31 @@
x = zlib.compress(data)
self.assertEqual(zlib.decompress(x), data)
+ def test_incomplete_stream(self):
+ # An useful error message is given
+ x = zlib.compress(HAMLET_SCENE)
+ try:
+ zlib.decompress(x[:-1])
+ except zlib.error as e:
+ self.assertTrue(
+ "Error -5 while decompressing data: incomplete or truncated stream"
+ in str(e), str(e))
+ else:
+ self.fail("zlib.error not raised")
+ # Memory use of the following functions takes into account overallocation
+ @precisionbigmemtest(size=_1G + 1024 * 1024, memuse=3)
+ def test_big_compress_buffer(self, size):
+ compress = lambda s: zlib.compress(s, 1)
+ self.check_big_compress_buffer(size, compress)
-class CompressObjectTestCase(unittest.TestCase):
+ @precisionbigmemtest(size=_1G + 1024 * 1024, memuse=2)
+ def test_big_decompress_buffer(self, size):
+ self.check_big_decompress_buffer(size, zlib.decompress)
+
+
+class CompressObjectTestCase(BaseCompressTestCase, unittest.TestCase):
# Test compression object
def test_pair(self):
# straightforward compress/decompress objects
@@ -314,6 +379,19 @@
dco = zlib.decompressobj()
self.assertEqual(dco.flush(), "") # Returns nothing
+ def test_decompress_incomplete_stream(self):
+ # This is 'foo', deflated
+ x = 'x\x9cK\xcb\xcf\x07\x00\x02\x82\x01E'
+ # For the record
+ self.assertEqual(zlib.decompress(x), 'foo')
+ self.assertRaises(zlib.error, zlib.decompress, x[:-5])
+ # Omitting the stream end works with decompressor objects
+ # (see issue #8672).
+ dco = zlib.decompressobj()
+ y = dco.decompress(x[:-5])
+ y += dco.flush()
+ self.assertEqual(y, 'foo')
+
if hasattr(zlib.compressobj(), "copy"):
def test_compresscopy(self):
# Test copying a compression object
@@ -374,6 +452,21 @@
d.flush()
self.assertRaises(ValueError, d.copy)
+ # Memory use of the following functions takes into account overallocation
+
+ @precisionbigmemtest(size=_1G + 1024 * 1024, memuse=3)
+ def test_big_compress_buffer(self, size):
+ c = zlib.compressobj(1)
+ compress = lambda s: c.compress(s) + c.flush()
+ self.check_big_compress_buffer(size, compress)
+
+ @precisionbigmemtest(size=_1G + 1024 * 1024, memuse=2)
+ def test_big_decompress_buffer(self, size):
+ d = zlib.decompressobj()
+ decompress = lambda s: d.decompress(s) + d.flush()
+ self.check_big_decompress_buffer(size, decompress)
+
+
def genblock(seed, length, step=1024, generator=random):
"""length-byte stream of random data from a seed (in step-byte blocks)."""
if seed is not None:
@@ -473,21 +566,3 @@
if __name__ == "__main__":
test_main()
-
-def test(tests=''):
- if not tests: tests = 'o'
- testcases = []
- if 'k' in tests: testcases.append(ChecksumTestCase)
- if 'x' in tests: testcases.append(ExceptionTestCase)
- if 'c' in tests: testcases.append(CompressTestCase)
- if 'o' in tests: testcases.append(CompressObjectTestCase)
- test_support.run_unittest(*testcases)
-
-if False:
- import sys
- sys.path.insert(1, '/Py23Src/python/dist/src/Lib/test')
- import test_zlib as tz
- ts, ut = tz.test_support, tz.unittest
- su = ut.TestSuite()
- su.addTest(ut.makeSuite(tz.CompressTestCase))
- ts.run_suite(su)
diff --git a/Lib/timeit.py b/Lib/timeit.py
--- a/Lib/timeit.py
+++ b/Lib/timeit.py
@@ -9,7 +9,7 @@
Library usage: see the Timer class.
Command line usage:
- python timeit.py [-n N] [-r N] [-s S] [-t] [-c] [-h] [statement]
+ python timeit.py [-n N] [-r N] [-s S] [-t] [-c] [-h] [--] [statement]
Options:
-n/--number N: how many times to execute 'statement' (default: see below)
@@ -19,6 +19,7 @@
-c/--clock: use time.clock() (default on Windows)
-v/--verbose: print raw timing results; repeat for more digits precision
-h/--help: print this usage message and exit
+ --: separate options from statement, use when statement starts with -
statement: statement to be timed (default 'pass')
A multi-line statement may be given by specifying each line as a
@@ -90,6 +91,17 @@
"""Helper to reindent a multi-line statement."""
return src.replace("\n", "\n" + " "*indent)
+def _template_func(setup, func):
+ """Create a timer function. Used if the "statement" is a callable."""
+ def inner(_it, _timer, _func=func):
+ setup()
+ _t0 = _timer()
+ for _i in _it:
+ _func()
+ _t1 = _timer()
+ return _t1 - _t0
+ return inner
+
class Timer:
"""Class for timing execution speed of small code snippets.
@@ -109,14 +121,32 @@
def __init__(self, stmt="pass", setup="pass", timer=default_timer):
"""Constructor. See class doc string."""
self.timer = timer
- stmt = reindent(stmt, 8)
- setup = reindent(setup, 4)
- src = template % {'stmt': stmt, 'setup': setup}
- self.src = src # Save for traceback display
- code = compile(src, dummy_src_name, "exec")
ns = {}
- exec code in globals(), ns
- self.inner = ns["inner"]
+ if isinstance(stmt, basestring):
+ stmt = reindent(stmt, 8)
+ if isinstance(setup, basestring):
+ setup = reindent(setup, 4)
+ src = template % {'stmt': stmt, 'setup': setup}
+ elif hasattr(setup, '__call__'):
+ src = template % {'stmt': stmt, 'setup': '_setup()'}
+ ns['_setup'] = setup
+ else:
+ raise ValueError("setup is neither a string nor callable")
+ self.src = src # Save for traceback display
+ code = compile(src, dummy_src_name, "exec")
+ exec code in globals(), ns
+ self.inner = ns["inner"]
+ elif hasattr(stmt, '__call__'):
+ self.src = None
+ if isinstance(setup, basestring):
+ _setup = setup
+ def setup():
+ exec _setup in globals(), ns
+ elif not hasattr(setup, '__call__'):
+ raise ValueError("setup is neither a string nor callable")
+ self.inner = _template_func(setup, stmt)
+ else:
+ raise ValueError("stmt is neither a string nor callable")
def print_exc(self, file=None):
"""Helper to print a traceback from the timed code.
@@ -136,10 +166,13 @@
sent; it defaults to sys.stderr.
"""
import linecache, traceback
- linecache.cache[dummy_src_name] = (len(self.src),
- None,
- self.src.split("\n"),
- dummy_src_name)
+ if self.src is not None:
+ linecache.cache[dummy_src_name] = (len(self.src),
+ None,
+ self.src.split("\n"),
+ dummy_src_name)
+ # else the source is already stored somewhere else
+
traceback.print_exc(file=file)
def timeit(self, number=default_number):
@@ -192,6 +225,16 @@
r.append(t)
return r
+def timeit(stmt="pass", setup="pass", timer=default_timer,
+ number=default_number):
+ """Convenience function to create Timer object and call timeit method."""
+ return Timer(stmt, setup, timer).timeit(number)
+
+def repeat(stmt="pass", setup="pass", timer=default_timer,
+ repeat=default_repeat, number=default_number):
+ """Convenience function to create Timer object and call repeat method."""
+ return Timer(stmt, setup, timer).repeat(repeat, number)
+
def main(args=None):
"""Main program, used when run as a script.
diff --git a/Lib/types.py b/Lib/types.py
--- a/Lib/types.py
+++ b/Lib/types.py
@@ -43,16 +43,11 @@
def _f(): pass
FunctionType = type(_f)
LambdaType = type(lambda: None) # Same as FunctionType
-try:
- CodeType = type(_f.func_code)
-except RuntimeError:
- # Execution in restricted environment
- pass
+CodeType = type(_f.func_code)
-def g():
+def _g():
yield 1
-GeneratorType = type(g())
-del g
+GeneratorType = type(_g())
class _C:
def _m(self): pass
@@ -74,15 +69,10 @@
try:
raise TypeError
except TypeError:
- try:
- tb = sys.exc_info()[2]
- TracebackType = type(tb)
- FrameType = type(tb.tb_frame)
- except AttributeError:
- # In the restricted environment, exc_info returns (None, None,
- # None) Then, tb.tb_frame gives an attribute error
- pass
- tb = None; del tb
+ tb = sys.exc_info()[2]
+ TracebackType = type(tb)
+ FrameType = type(tb.tb_frame)
+ del tb
SliceType = slice
EllipsisType = type(Ellipsis)
@@ -90,4 +80,8 @@
DictProxyType = type(TypeType.__dict__)
NotImplementedType = type(NotImplemented)
-del sys, _f, _C, _x # Not for export
+# For Jython, the following two types are identical
+GetSetDescriptorType = type(FunctionType.func_code)
+MemberDescriptorType = type(FunctionType.func_globals)
+
+del sys, _f, _g, _C, _x # Not for export
diff --git a/Lib/warnings.py b/Lib/warnings.py
--- a/Lib/warnings.py
+++ b/Lib/warnings.py
@@ -46,7 +46,14 @@
append=0):
"""Insert an entry into the list of warnings filters (at the front).
- Use assertions to check that all arguments have the right type."""
+ 'action' -- one of "error", "ignore", "always", "default", "module",
+ or "once"
+ 'message' -- a regex that the warning message must match
+ 'category' -- a class that the warning must be a subclass of
+ 'module' -- a regex that the module name must match
+ 'lineno' -- an integer line number, 0 matches all warnings
+ 'append' -- if true, append to the list of filters
+ """
import re
assert action in ("error", "ignore", "always", "default", "module",
"once"), "invalid action: %r" % (action,)
@@ -68,6 +75,11 @@
"""Insert a simple entry into the list of warnings filters (at the front).
A simple filter matches all modules and messages.
+ 'action' -- one of "error", "ignore", "always", "default", "module",
+ or "once"
+ 'category' -- a class that the warning must be a subclass of
+ 'lineno' -- an integer line number, 0 matches all warnings
+ 'append' -- if true, append to the list of filters
"""
assert action in ("error", "ignore", "always", "default", "module",
"once"), "invalid action: %r" % (action,)
@@ -264,24 +276,6 @@
raise RuntimeError(
"Unrecognized action (%r) in warnings.filters:\n %s" %
(action, item))
- # Warn if showwarning() does not support the 'line' argument.
- # Don't use 'inspect' as it relies on an extension module, which break the
- # build thanks to 'warnings' being imported by setup.py.
- fxn_code = None
- if hasattr(showwarning, 'func_code'):
- fxn_code = showwarning.func_code
- elif hasattr(showwarning, '__func__'):
- fxn_code = showwarning.__func__.func_code
- if fxn_code:
- args = fxn_code.co_varnames[:fxn_code.co_argcount]
- CO_VARARGS = 0x4
- if 'line' not in args and not fxn_code.co_flags & CO_VARARGS:
- showwarning_msg = ("functions overriding warnings.showwarning() "
- "must support the 'line' argument")
- if message == showwarning_msg:
- _show_warning(message, category, filename, lineno)
- else:
- warn(showwarning_msg, DeprecationWarning)
# Print message and context
showwarning(message, category, filename, lineno)
@@ -391,8 +385,12 @@
# Module initialization
_processoptions(sys.warnoptions)
if not _warnings_defaults:
- simplefilter("ignore", category=PendingDeprecationWarning, append=1)
- simplefilter("ignore", category=ImportWarning, append=1)
+ silence = [ImportWarning, PendingDeprecationWarning]
+ # Don't silence DeprecationWarning if -3 or -Q was used.
+ if not sys.py3kwarning and not sys.flags.division_warning:
+ silence.append(DeprecationWarning)
+ for cls in silence:
+ simplefilter("ignore", category=cls)
bytes_warning = sys.flags.bytes_warning
if bytes_warning > 1:
bytes_action = "error"
diff --git a/Lib/weakref.py b/Lib/weakref.py
--- a/Lib/weakref.py
+++ b/Lib/weakref.py
@@ -2,7 +2,7 @@
This module is an implementation of PEP 205:
-http://python.sourceforge.net/peps/pep-0205.html
+http://www.python.org/dev/peps/pep-0205/
"""
# Naming convention: Variables named "wr" are weak reference objects;
@@ -20,14 +20,16 @@
ProxyType,
ReferenceType)
+from _weakrefset import WeakSet
+
from exceptions import ReferenceError
ProxyTypes = (ProxyType, CallableProxyType)
__all__ = ["ref", "proxy", "getweakrefcount", "getweakrefs",
- "WeakKeyDictionary", "ReferenceType", "ProxyType",
- "CallableProxyType", "ProxyTypes", "WeakValueDictionary"]
+ "WeakKeyDictionary", "ReferenceError", "ReferenceType", "ProxyType",
+ "CallableProxyType", "ProxyTypes", "WeakValueDictionary", 'WeakSet']
class WeakValueDictionary(UserDict.UserDict):
@@ -88,6 +90,17 @@
new[key] = o
return new
+ __copy__ = copy
+
+ def __deepcopy__(self, memo):
+ from copy import deepcopy
+ new = self.__class__()
+ for key, wr in self.data.items():
+ o = wr()
+ if o is not None:
+ new[deepcopy(key, memo)] = o
+ return new
+
def get(self, key, default=None):
try:
wr = self.data[key]
@@ -262,6 +275,17 @@
new[o] = value
return new
+ __copy__ = copy
+
+ def __deepcopy__(self, memo):
+ from copy import deepcopy
+ new = self.__class__()
+ for key, value in self.data.items():
+ o = key()
+ if o is not None:
+ new[o] = deepcopy(value, memo)
+ return new
+
def get(self, key, default=None):
return self.data.get(ref(key),default)
diff --git a/Lib/zipfile.py b/Lib/zipfile.py
--- a/Lib/zipfile.py
+++ b/Lib/zipfile.py
@@ -1,13 +1,17 @@
"""
Read and write ZIP files.
"""
-import struct, os, time, sys
-import binascii, cStringIO
+import struct, os, time, sys, shutil
+import binascii, cStringIO, stat
+import io
+import re
try:
import zlib # We may need its compression method
+ crc32 = zlib.crc32
except ImportError:
zlib = None
+ crc32 = binascii.crc32
__all__ = ["BadZipfile", "error", "ZIP_STORED", "ZIP_DEFLATED", "is_zipfile",
"ZipInfo", "ZipFile", "PyZipFile", "LargeZipFile" ]
@@ -26,32 +30,52 @@
error = BadZipfile # The exception raised by this module
-ZIP64_LIMIT= (1 << 31) - 1
+ZIP64_LIMIT = (1 << 31) - 1
+ZIP_FILECOUNT_LIMIT = 1 << 16
+ZIP_MAX_COMMENT = (1 << 16) - 1
# constants for Zip file compression methods
ZIP_STORED = 0
ZIP_DEFLATED = 8
# Other ZIP compression methods not supported
-# Here are some struct module formats for reading headers
-structEndArchive = "<4s4H2LH" # 9 items, end of archive, 22 bytes
-stringEndArchive = "PK\005\006" # magic number for end of archive record
-structCentralDir = "<4s4B4HlLL5HLL"# 19 items, central directory, 46 bytes
-stringCentralDir = "PK\001\002" # magic number for central directory
-structFileHeader = "<4s2B4HlLL2H" # 12 items, file header record, 30 bytes
-stringFileHeader = "PK\003\004" # magic number for file header
-structEndArchive64Locator = "<4slql" # 4 items, locate Zip64 header, 20 bytes
-stringEndArchive64Locator = "PK\x06\x07" # magic token for locator header
-structEndArchive64 = "<4sqhhllqqqq" # 10 items, end of archive (Zip64), 56 bytes
-stringEndArchive64 = "PK\x06\x06" # magic token for Zip64 header
+# Below are some formats and associated data for reading/writing headers using
+# the struct module. The names and structures of headers/records are those used
+# in the PKWARE description of the ZIP file format:
+# http://www.pkware.com/documents/casestudies/APPNOTE.TXT
+# (URL valid as of January 2008)
+# The "end of central directory" structure, magic number, size, and indices
+# (section V.I in the format document)
+structEndArchive = "<4s4H2LH"
+stringEndArchive = "PK\005\006"
+sizeEndCentDir = struct.calcsize(structEndArchive)
+
+_ECD_SIGNATURE = 0
+_ECD_DISK_NUMBER = 1
+_ECD_DISK_START = 2
+_ECD_ENTRIES_THIS_DISK = 3
+_ECD_ENTRIES_TOTAL = 4
+_ECD_SIZE = 5
+_ECD_OFFSET = 6
+_ECD_COMMENT_SIZE = 7
+# These last two indices are not part of the structure as defined in the
+# spec, but they are used internally by this module as a convenience
+_ECD_COMMENT = 8
+_ECD_LOCATION = 9
+
+# The "central directory" structure, magic number, size, and indices
+# of entries in the structure (section V.F in the format document)
+structCentralDir = "<4s4B4HL2L5H2L"
+stringCentralDir = "PK\001\002"
+sizeCentralDir = struct.calcsize(structCentralDir)
# indexes of entries in the central directory structure
_CD_SIGNATURE = 0
_CD_CREATE_VERSION = 1
_CD_CREATE_SYSTEM = 2
_CD_EXTRACT_VERSION = 3
-_CD_EXTRACT_SYSTEM = 4 # is this meaningful?
+_CD_EXTRACT_SYSTEM = 4
_CD_FLAG_BITS = 5
_CD_COMPRESS_TYPE = 6
_CD_TIME = 7
@@ -67,10 +91,15 @@
_CD_EXTERNAL_FILE_ATTRIBUTES = 17
_CD_LOCAL_HEADER_OFFSET = 18
-# indexes of entries in the local file header structure
+# The "local file header" structure, magic number, size, and indices
+# (section V.A in the format document)
+structFileHeader = "<4s2B4HL2L2H"
+stringFileHeader = "PK\003\004"
+sizeFileHeader = struct.calcsize(structFileHeader)
+
_FH_SIGNATURE = 0
_FH_EXTRACT_VERSION = 1
-_FH_EXTRACT_SYSTEM = 2 # is this meaningful?
+_FH_EXTRACT_SYSTEM = 2
_FH_GENERAL_PURPOSE_FLAG_BITS = 3
_FH_COMPRESSION_METHOD = 4
_FH_LAST_MOD_TIME = 5
@@ -81,25 +110,64 @@
_FH_FILENAME_LENGTH = 10
_FH_EXTRA_FIELD_LENGTH = 11
-def is_zipfile(filename):
- """Quickly see if file is a ZIP file by checking the magic number."""
+# The "Zip64 end of central directory locator" structure, magic number, and size
+structEndArchive64Locator = "<4sLQL"
+stringEndArchive64Locator = "PK\x06\x07"
+sizeEndCentDir64Locator = struct.calcsize(structEndArchive64Locator)
+
+# The "Zip64 end of central directory" record, magic number, size, and indices
+# (section V.G in the format document)
+structEndArchive64 = "<4sQ2H2L4Q"
+stringEndArchive64 = "PK\x06\x06"
+sizeEndCentDir64 = struct.calcsize(structEndArchive64)
+
+_CD64_SIGNATURE = 0
+_CD64_DIRECTORY_RECSIZE = 1
+_CD64_CREATE_VERSION = 2
+_CD64_EXTRACT_VERSION = 3
+_CD64_DISK_NUMBER = 4
+_CD64_DISK_NUMBER_START = 5
+_CD64_NUMBER_ENTRIES_THIS_DISK = 6
+_CD64_NUMBER_ENTRIES_TOTAL = 7
+_CD64_DIRECTORY_SIZE = 8
+_CD64_OFFSET_START_CENTDIR = 9
+
+def _check_zipfile(fp):
try:
- fpin = open(filename, "rb")
- endrec = _EndRecData(fpin)
- fpin.close()
- if endrec:
- return True # file has correct magic number
+ if _EndRecData(fp):
+ return True # file has correct magic number
except IOError:
pass
return False
+def is_zipfile(filename):
+ """Quickly see if a file is a ZIP file by checking the magic number.
+
+ The filename argument may be a file or file-like object too.
+ """
+ result = False
+ try:
+ if hasattr(filename, "read"):
+ result = _check_zipfile(fp=filename)
+ else:
+ with open(filename, "rb") as fp:
+ result = _check_zipfile(fp)
+ except IOError:
+ pass
+ return result
+
def _EndRecData64(fpin, offset, endrec):
"""
Read the ZIP64 end-of-archive records and use that to update endrec
"""
- locatorSize = struct.calcsize(structEndArchive64Locator)
- fpin.seek(offset - locatorSize, 2)
- data = fpin.read(locatorSize)
+ try:
+ fpin.seek(offset - sizeEndCentDir64Locator, 2)
+ except IOError:
+ # If the seek fails, the file is not large enough to contain a ZIP64
+ # end-of-archive record, so just return the end record we were given.
+ return endrec
+
+ data = fpin.read(sizeEndCentDir64Locator)
sig, diskno, reloff, disks = struct.unpack(structEndArchive64Locator, data)
if sig != stringEndArchive64Locator:
return endrec
@@ -108,9 +176,8 @@
raise BadZipfile("zipfiles that span multiple disks are not supported")
# Assume no 'zip64 extensible data'
- endArchiveSize = struct.calcsize(structEndArchive64)
- fpin.seek(offset - locatorSize - endArchiveSize, 2)
- data = fpin.read(endArchiveSize)
+ fpin.seek(offset - sizeEndCentDir64Locator - sizeEndCentDir64, 2)
+ data = fpin.read(sizeEndCentDir64)
sig, sz, create_version, read_version, disk_num, disk_dir, \
dircount, dircount2, dirsize, diroffset = \
struct.unpack(structEndArchive64, data)
@@ -118,12 +185,13 @@
return endrec
# Update the original endrec using data from the ZIP64 record
- endrec[1] = disk_num
- endrec[2] = disk_dir
- endrec[3] = dircount
- endrec[4] = dircount2
- endrec[5] = dirsize
- endrec[6] = diroffset
+ endrec[_ECD_SIGNATURE] = sig
+ endrec[_ECD_DISK_NUMBER] = disk_num
+ endrec[_ECD_DISK_START] = disk_dir
+ endrec[_ECD_ENTRIES_THIS_DISK] = dircount
+ endrec[_ECD_ENTRIES_TOTAL] = dircount2
+ endrec[_ECD_SIZE] = dirsize
+ endrec[_ECD_OFFSET] = diroffset
return endrec
@@ -132,38 +200,57 @@
The data is a list of the nine items in the ZIP "End of central dir"
record followed by a tenth item, the file seek offset of this record."""
- fpin.seek(-22, 2) # Assume no archive comment.
- filesize = fpin.tell() + 22 # Get file size
+
+ # Determine file size
+ fpin.seek(0, 2)
+ filesize = fpin.tell()
+
+ # Check to see if this is ZIP file with no archive comment (the
+ # "end of central directory" structure should be the last item in the
+ # file if this is the case).
+ try:
+ fpin.seek(-sizeEndCentDir, 2)
+ except IOError:
+ return None
data = fpin.read()
if data[0:4] == stringEndArchive and data[-2:] == "\000\000":
+ # the signature is correct and there's no comment, unpack structure
endrec = struct.unpack(structEndArchive, data)
- endrec = list(endrec)
- endrec.append("") # Append the archive comment
- endrec.append(filesize - 22) # Append the record start offset
- if endrec[-4] == -1 or endrec[-4] == 0xffffffff:
- return _EndRecData64(fpin, -22, endrec)
- return endrec
- # Search the last END_BLOCK bytes of the file for the record signature.
- # The comment is appended to the ZIP file and has a 16 bit length.
- # So the comment may be up to 64K long. We limit the search for the
- # signature to a few Kbytes at the end of the file for efficiency.
- # also, the signature must not appear in the comment.
- END_BLOCK = min(filesize, 1024 * 4)
- fpin.seek(filesize - END_BLOCK, 0)
+ endrec=list(endrec)
+
+ # Append a blank comment and record start offset
+ endrec.append("")
+ endrec.append(filesize - sizeEndCentDir)
+
+ # Try to read the "Zip64 end of central directory" structure
+ return _EndRecData64(fpin, -sizeEndCentDir, endrec)
+
+ # Either this is not a ZIP file, or it is a ZIP file with an archive
+ # comment. Search the end of the file for the "end of central directory"
+ # record signature. The comment is the last item in the ZIP file and may be
+ # up to 64K long. It is assumed that the "end of central directory" magic
+ # number does not appear in the comment.
+ maxCommentStart = max(filesize - (1 << 16) - sizeEndCentDir, 0)
+ fpin.seek(maxCommentStart, 0)
data = fpin.read()
start = data.rfind(stringEndArchive)
- if start >= 0: # Correct signature string was found
- endrec = struct.unpack(structEndArchive, data[start:start+22])
- endrec = list(endrec)
- comment = data[start+22:]
- if endrec[7] == len(comment): # Comment length checks out
+ if start >= 0:
+ # found the magic number; attempt to unpack and interpret
+ recData = data[start:start+sizeEndCentDir]
+ endrec = list(struct.unpack(structEndArchive, recData))
+ comment = data[start+sizeEndCentDir:]
+ # check that comment length is correct
+ if endrec[_ECD_COMMENT_SIZE] == len(comment):
# Append the archive comment and start offset
endrec.append(comment)
- endrec.append(filesize - END_BLOCK + start)
- if endrec[-4] == -1 or endrec[-4] == 0xffffffff:
- return _EndRecData64(fpin, - END_BLOCK + start, endrec)
- return endrec
- return # Error, return None
+ endrec.append(maxCommentStart + start)
+
+ # Try to read the "Zip64 end of central directory" structure
+ return _EndRecData64(fpin, maxCommentStart + start - filesize,
+ endrec)
+
+ # Unable to find a valid end of central directory structure
+ return
class ZipInfo (object):
@@ -188,6 +275,7 @@
'CRC',
'compress_size',
'file_size',
+ '_raw_time',
)
def __init__(self, filename="NoName", date_time=(1980,1,1,0,0,0)):
@@ -246,34 +334,50 @@
if file_size > ZIP64_LIMIT or compress_size > ZIP64_LIMIT:
# File is larger than what fits into a 4 byte integer,
# fall back to the ZIP64 extension
- fmt = '<hhqq'
+ fmt = '<HHQQ'
extra = extra + struct.pack(fmt,
1, struct.calcsize(fmt)-4, file_size, compress_size)
- file_size = 0xffffffff # -1
- compress_size = 0xffffffff # -1
+ file_size = 0xffffffff
+ compress_size = 0xffffffff
self.extract_version = max(45, self.extract_version)
self.create_version = max(45, self.extract_version)
+ filename, flag_bits = self._encodeFilenameFlags()
header = struct.pack(structFileHeader, stringFileHeader,
- self.extract_version, self.reserved, self.flag_bits,
+ self.extract_version, self.reserved, flag_bits,
self.compress_type, dostime, dosdate, CRC,
compress_size, file_size,
- len(self.filename), len(extra))
- return header + self.filename + extra
+ len(filename), len(extra))
+ return header + filename + extra
+
+ def _encodeFilenameFlags(self):
+ if isinstance(self.filename, unicode):
+ try:
+ return self.filename.encode('ascii'), self.flag_bits
+ except UnicodeEncodeError:
+ return self.filename.encode('utf-8'), self.flag_bits | 0x800
+ else:
+ return self.filename, self.flag_bits
+
+ def _decodeFilename(self):
+ if self.flag_bits & 0x800:
+ return self.filename.decode('utf-8')
+ else:
+ return self.filename
def _decodeExtra(self):
# Try to decode the extra field.
extra = self.extra
unpack = struct.unpack
while extra:
- tp, ln = unpack('<hh', extra[:4])
+ tp, ln = unpack('<HH', extra[:4])
if tp == 1:
if ln >= 24:
- counts = unpack('<qqq', extra[4:28])
+ counts = unpack('<QQQ', extra[4:28])
elif ln == 16:
- counts = unpack('<qq', extra[4:20])
+ counts = unpack('<QQ', extra[4:20])
elif ln == 8:
- counts = unpack('<q', extra[4:12])
+ counts = unpack('<Q', extra[4:12])
elif ln == 0:
counts = ()
else:
@@ -282,15 +386,15 @@
idx = 0
# ZIP64 extension (large files and/or large archives)
- if self.file_size == -1 or self.file_size == 0xFFFFFFFFL:
+ if self.file_size in (0xffffffffffffffffL, 0xffffffffL):
self.file_size = counts[idx]
idx += 1
- if self.compress_size == -1 or self.compress_size == 0xFFFFFFFFL:
+ if self.compress_size == 0xFFFFFFFFL:
self.compress_size = counts[idx]
idx += 1
- if self.header_offset == -1 or self.header_offset == 0xffffffffL:
+ if self.header_offset == 0xffffffffL:
old = self.header_offset
self.header_offset = counts[idx]
idx+=1
@@ -298,10 +402,259 @@
extra = extra[ln+4:]
+class _ZipDecrypter:
+ """Class to handle decryption of files stored within a ZIP archive.
+
+ ZIP supports a password-based form of encryption. Even though known
+ plaintext attacks have been found against it, it is still useful
+ to be able to get data out of such a file.
+
+ Usage:
+ zd = _ZipDecrypter(mypwd)
+ plain_char = zd(cypher_char)
+ plain_text = map(zd, cypher_text)
+ """
+
+ def _GenerateCRCTable():
+ """Generate a CRC-32 table.
+
+ ZIP encryption uses the CRC32 one-byte primitive for scrambling some
+ internal keys. We noticed that a direct implementation is faster than
+ relying on binascii.crc32().
+ """
+ poly = 0xedb88320
+ table = [0] * 256
+ for i in range(256):
+ crc = i
+ for j in range(8):
+ if crc & 1:
+ crc = ((crc >> 1) & 0x7FFFFFFF) ^ poly
+ else:
+ crc = ((crc >> 1) & 0x7FFFFFFF)
+ table[i] = crc
+ return table
+ crctable = _GenerateCRCTable()
+
+ def _crc32(self, ch, crc):
+ """Compute the CRC32 primitive on one byte."""
+ return ((crc >> 8) & 0xffffff) ^ self.crctable[(crc ^ ord(ch)) & 0xff]
+
+ def __init__(self, pwd):
+ self.key0 = 305419896
+ self.key1 = 591751049
+ self.key2 = 878082192
+ for p in pwd:
+ self._UpdateKeys(p)
+
+ def _UpdateKeys(self, c):
+ self.key0 = self._crc32(c, self.key0)
+ self.key1 = (self.key1 + (self.key0 & 255)) & 4294967295
+ self.key1 = (self.key1 * 134775813 + 1) & 4294967295
+ self.key2 = self._crc32(chr((self.key1 >> 24) & 255), self.key2)
+
+ def __call__(self, c):
+ """Decrypt a single character."""
+ c = ord(c)
+ k = self.key2 | 2
+ c = c ^ (((k * (k^1)) >> 8) & 255)
+ c = chr(c)
+ self._UpdateKeys(c)
+ return c
+
+class ZipExtFile(io.BufferedIOBase):
+ """File-like object for reading an archive member.
+ Is returned by ZipFile.open().
+ """
+
+ # Max size supported by decompressor.
+ MAX_N = 1 << 31 - 1
+
+ # Read from compressed files in 4k blocks.
+ MIN_READ_SIZE = 4096
+
+ # Search for universal newlines or line chunks.
+ PATTERN = re.compile(r'^(?P<chunk>[^\r\n]+)|(?P<newline>\n|\r\n?)')
+
+ def __init__(self, fileobj, mode, zipinfo, decrypter=None):
+ self._fileobj = fileobj
+ self._decrypter = decrypter
+
+ self._compress_type = zipinfo.compress_type
+ self._compress_size = zipinfo.compress_size
+ self._compress_left = zipinfo.compress_size
+
+ if self._compress_type == ZIP_DEFLATED:
+ self._decompressor = zlib.decompressobj(-15)
+ self._unconsumed = ''
+
+ self._readbuffer = ''
+ self._offset = 0
+
+ self._universal = 'U' in mode
+ self.newlines = None
+
+ # Adjust read size for encrypted files since the first 12 bytes
+ # are for the encryption/password information.
+ if self._decrypter is not None:
+ self._compress_left -= 12
+
+ self.mode = mode
+ self.name = zipinfo.filename
+
+ if hasattr(zipinfo, 'CRC'):
+ self._expected_crc = zipinfo.CRC
+ self._running_crc = crc32(b'') & 0xffffffff
+ else:
+ self._expected_crc = None
+
+ def readline(self, limit=-1):
+ """Read and return a line from the stream.
+
+ If limit is specified, at most limit bytes will be read.
+ """
+
+ if not self._universal and limit < 0:
+ # Shortcut common case - newline found in buffer.
+ i = self._readbuffer.find('\n', self._offset) + 1
+ if i > 0:
+ line = self._readbuffer[self._offset: i]
+ self._offset = i
+ return line
+
+ if not self._universal:
+ return io.BufferedIOBase.readline(self, limit)
+
+ line = ''
+ while limit < 0 or len(line) < limit:
+ readahead = self.peek(2)
+ if readahead == '':
+ return line
+
+ #
+ # Search for universal newlines or line chunks.
+ #
+ # The pattern returns either a line chunk or a newline, but not
+ # both. Combined with peek(2), we are assured that the sequence
+ # '\r\n' is always retrieved completely and never split into
+ # separate newlines - '\r', '\n' due to coincidental readaheads.
+ #
+ match = self.PATTERN.search(readahead)
+ newline = match.group('newline')
+ if newline is not None:
+ if self.newlines is None:
+ self.newlines = []
+ if newline not in self.newlines:
+ self.newlines.append(newline)
+ self._offset += len(newline)
+ return line + '\n'
+
+ chunk = match.group('chunk')
+ if limit >= 0:
+ chunk = chunk[: limit - len(line)]
+
+ self._offset += len(chunk)
+ line += chunk
+
+ return line
+
+ def peek(self, n=1):
+ """Returns buffered bytes without advancing the position."""
+ if n > len(self._readbuffer) - self._offset:
+ chunk = self.read(n)
+ self._offset -= len(chunk)
+
+ # Return up to 512 bytes to reduce allocation overhead for tight loops.
+ return self._readbuffer[self._offset: self._offset + 512]
+
+ def readable(self):
+ return True
+
+ def read(self, n=-1):
+ """Read and return up to n bytes.
+ If the argument is omitted, None, or negative, data is read and returned until EOF is reached..
+ """
+ buf = ''
+ if n is None:
+ n = -1
+ while True:
+ if n < 0:
+ data = self.read1(n)
+ elif n > len(buf):
+ data = self.read1(n - len(buf))
+ else:
+ return buf
+ if len(data) == 0:
+ return buf
+ buf += data
+
+ def _update_crc(self, newdata, eof):
+ # Update the CRC using the given data.
+ if self._expected_crc is None:
+ # No need to compute the CRC if we don't have a reference value
+ return
+ self._running_crc = crc32(newdata, self._running_crc) & 0xffffffff
+ # Check the CRC if we're at the end of the file
+ if eof and self._running_crc != self._expected_crc:
+ raise BadZipfile("Bad CRC-32 for file %r" % self.name)
+
+ def read1(self, n):
+ """Read up to n bytes with at most one read() system call."""
+
+ # Simplify algorithm (branching) by transforming negative n to large n.
+ if n < 0 or n is None:
+ n = self.MAX_N
+
+ # Bytes available in read buffer.
+ len_readbuffer = len(self._readbuffer) - self._offset
+
+ # Read from file.
+ if self._compress_left > 0 and n > len_readbuffer + len(self._unconsumed):
+ nbytes = n - len_readbuffer - len(self._unconsumed)
+ nbytes = max(nbytes, self.MIN_READ_SIZE)
+ nbytes = min(nbytes, self._compress_left)
+
+ data = self._fileobj.read(nbytes)
+ self._compress_left -= len(data)
+
+ if data and self._decrypter is not None:
+ data = ''.join(map(self._decrypter, data))
+
+ if self._compress_type == ZIP_STORED:
+ self._update_crc(data, eof=(self._compress_left==0))
+ self._readbuffer = self._readbuffer[self._offset:] + data
+ self._offset = 0
+ else:
+ # Prepare deflated bytes for decompression.
+ self._unconsumed += data
+
+ # Handle unconsumed data.
+ if (len(self._unconsumed) > 0 and n > len_readbuffer and
+ self._compress_type == ZIP_DEFLATED):
+ data = self._decompressor.decompress(
+ self._unconsumed,
+ max(n - len_readbuffer, self.MIN_READ_SIZE)
+ )
+
+ self._unconsumed = self._decompressor.unconsumed_tail
+ eof = len(self._unconsumed) == 0 and self._compress_left == 0
+ if eof:
+ data += self._decompressor.flush()
+
+ self._update_crc(data, eof=eof)
+ self._readbuffer = self._readbuffer[self._offset:] + data
+ self._offset = 0
+
+ # Read from buffer.
+ data = self._readbuffer[self._offset: self._offset + n]
+ self._offset += len(data)
+ return data
+
+
+
class ZipFile:
""" Class with methods to open, read, write, close, list zip files.
- z = ZipFile(file, mode="r", compression=ZIP_STORED, allowZip64=True)
+ z = ZipFile(file, mode="r", compression=ZIP_STORED, allowZip64=False)
file: Either the path to the file, or a file-like object.
If it is a path, the file will be opened and closed by ZipFile.
@@ -317,8 +670,9 @@
def __init__(self, file, mode="r", compression=ZIP_STORED, allowZip64=False):
"""Open the ZIP file with mode read "r", write "w" or append "a"."""
- self._allowZip64 = allowZip64
- self._didModify = False
+ if mode not in ("r", "w", "a"):
+ raise RuntimeError('ZipFile() requires mode "r", "w", or "a"')
+
if compression == ZIP_STORED:
pass
elif compression == ZIP_DEFLATED:
@@ -327,18 +681,30 @@
"Compression requires the (missing) zlib module"
else:
raise RuntimeError, "That compression method is not supported"
+
+ self._allowZip64 = allowZip64
+ self._didModify = False
self.debug = 0 # Level of printing: 0 through 3
self.NameToInfo = {} # Find file info given name
self.filelist = [] # List of ZipInfo instances for archive
self.compression = compression # Method of compression
self.mode = key = mode.replace('b', '')[0]
+ self.pwd = None
+ self.comment = ''
# Check if we were passed a file-like object
if isinstance(file, basestring):
self._filePassed = 0
self.filename = file
modeDict = {'r' : 'rb', 'w': 'wb', 'a' : 'r+b'}
- self.fp = open(file, modeDict[mode])
+ try:
+ self.fp = open(file, modeDict[mode])
+ except IOError:
+ if mode == 'a':
+ mode = key = 'w'
+ self.fp = open(file, modeDict[mode])
+ else:
+ raise
else:
self._filePassed = 1
self.fp = file
@@ -347,20 +713,34 @@
if key == 'r':
self._GetContents()
elif key == 'w':
- pass
+ # set the modified flag so central directory gets written
+ # even if no files are added to the archive
+ self._didModify = True
elif key == 'a':
- try: # See if file is a zip file
+ try:
+ # See if file is a zip file
self._RealGetContents()
# seek to start of directory and overwrite
self.fp.seek(self.start_dir, 0)
- except BadZipfile: # file is not a zip file, just append
+ except BadZipfile:
+ # file is not a zip file, just append
self.fp.seek(0, 2)
+
+ # set the modified flag so central directory gets written
+ # even if no files are added to the archive
+ self._didModify = True
else:
if not self._filePassed:
self.fp.close()
self.fp = None
raise RuntimeError, 'Mode must be "r", "w" or "a"'
+ def __enter__(self):
+ return self
+
+ def __exit__(self, type, value, traceback):
+ self.close()
+
def _GetContents(self):
"""Read the directory, making sure we close the file if the format
is bad."""
@@ -375,23 +755,27 @@
def _RealGetContents(self):
"""Read in the table of contents for the ZIP file."""
fp = self.fp
- endrec = _EndRecData(fp)
+ try:
+ endrec = _EndRecData(fp)
+ except IOError:
+ raise BadZipfile("File is not a zip file")
if not endrec:
raise BadZipfile, "File is not a zip file"
if self.debug > 1:
print endrec
- size_cd = endrec[5] # bytes in central directory
- offset_cd = endrec[6] # offset of central directory
- self.comment = endrec[8] # archive comment
- # endrec[9] is the offset of the "End of Central Dir" record
- if endrec[9] > ZIP64_LIMIT:
- x = endrec[9] - size_cd - 56 - 20
- else:
- x = endrec[9] - size_cd
+ size_cd = endrec[_ECD_SIZE] # bytes in central directory
+ offset_cd = endrec[_ECD_OFFSET] # offset of central directory
+ self.comment = endrec[_ECD_COMMENT] # archive comment
+
# "concat" is zero, unless zip was concatenated to another file
- concat = x - offset_cd
+ concat = endrec[_ECD_LOCATION] - size_cd - offset_cd
+ if endrec[_ECD_SIGNATURE] == stringEndArchive64:
+ # If Zip64 extension structures are present, account for them
+ concat -= (sizeEndCentDir64 + sizeEndCentDir64Locator)
+
if self.debug > 2:
- print "given, inferred, offset", offset_cd, x, concat
+ inferred = concat + offset_cd
+ print "given, inferred, offset", offset_cd, inferred, concat
# self.start_dir: Position of start of central directory
self.start_dir = offset_cd + concat
fp.seek(self.start_dir, 0)
@@ -399,8 +783,7 @@
fp = cStringIO.StringIO(data)
total = 0
while total < size_cd:
- centdir = fp.read(46)
- total = total + 46
+ centdir = fp.read(sizeCentralDir)
if centdir[0:4] != stringCentralDir:
raise BadZipfile, "Bad magic number for central directory"
centdir = struct.unpack(structCentralDir, centdir)
@@ -411,22 +794,27 @@
x = ZipInfo(filename)
x.extra = fp.read(centdir[_CD_EXTRA_FIELD_LENGTH])
x.comment = fp.read(centdir[_CD_COMMENT_LENGTH])
- total = (total + centdir[_CD_FILENAME_LENGTH]
- + centdir[_CD_EXTRA_FIELD_LENGTH]
- + centdir[_CD_COMMENT_LENGTH])
x.header_offset = centdir[_CD_LOCAL_HEADER_OFFSET]
(x.create_version, x.create_system, x.extract_version, x.reserved,
x.flag_bits, x.compress_type, t, d,
x.CRC, x.compress_size, x.file_size) = centdir[1:12]
x.volume, x.internal_attr, x.external_attr = centdir[15:18]
# Convert date/time code to (year, month, day, hour, min, sec)
+ x._raw_time = t
x.date_time = ( (d>>9)+1980, (d>>5)&0xF, d&0x1F,
t>>11, (t>>5)&0x3F, (t&0x1F) * 2 )
x._decodeExtra()
x.header_offset = x.header_offset + concat
+ x.filename = x._decodeFilename()
self.filelist.append(x)
self.NameToInfo[x.filename] = x
+
+ # update total bytes read from central directory
+ total = (total + sizeCentralDir + centdir[_CD_FILENAME_LENGTH]
+ + centdir[_CD_EXTRA_FIELD_LENGTH]
+ + centdir[_CD_COMMENT_LENGTH])
+
if self.debug > 2:
print "total", total
@@ -452,67 +840,165 @@
def testzip(self):
"""Read all the files and check the CRC."""
+ chunk_size = 2 ** 20
for zinfo in self.filelist:
try:
- self.read(zinfo.filename) # Check CRC-32
+ # Read by chunks, to avoid an OverflowError or a
+ # MemoryError with very large embedded files.
+ f = self.open(zinfo.filename, "r")
+ while f.read(chunk_size): # Check CRC-32
+ pass
except BadZipfile:
return zinfo.filename
-
def getinfo(self, name):
"""Return the instance of ZipInfo given 'name'."""
- return self.NameToInfo[name]
+ info = self.NameToInfo.get(name)
+ if info is None:
+ raise KeyError(
+ 'There is no item named %r in the archive' % name)
- def read(self, name):
+ return info
+
+ def setpassword(self, pwd):
+ """Set default password for encrypted files."""
+ self.pwd = pwd
+
+ def read(self, name, pwd=None):
"""Return file bytes (as a string) for name."""
- if self.mode not in ("r", "a"):
- raise RuntimeError, 'read() requires mode "r" or "a"'
+ return self.open(name, "r", pwd).read()
+
+ def open(self, name, mode="r", pwd=None):
+ """Return file-like object for 'name'."""
+ if mode not in ("r", "U", "rU"):
+ raise RuntimeError, 'open() requires mode "r", "U", or "rU"'
if not self.fp:
raise RuntimeError, \
"Attempt to read ZIP archive that was already closed"
- zinfo = self.getinfo(name)
- filepos = self.fp.tell()
- self.fp.seek(zinfo.header_offset, 0)
+ # Only open a new file for instances where we were not
+ # given a file object in the constructor
+ if self._filePassed:
+ zef_file = self.fp
+ else:
+ zef_file = open(self.filename, 'rb')
+
+ # Make sure we have an info object
+ if isinstance(name, ZipInfo):
+ # 'name' is already an info object
+ zinfo = name
+ else:
+ # Get info object for name
+ zinfo = self.getinfo(name)
+
+ zef_file.seek(zinfo.header_offset, 0)
# Skip the file header:
- fheader = self.fp.read(30)
+ fheader = zef_file.read(sizeFileHeader)
if fheader[0:4] != stringFileHeader:
raise BadZipfile, "Bad magic number for file header"
fheader = struct.unpack(structFileHeader, fheader)
- fname = self.fp.read(fheader[_FH_FILENAME_LENGTH])
+ fname = zef_file.read(fheader[_FH_FILENAME_LENGTH])
if fheader[_FH_EXTRA_FIELD_LENGTH]:
- self.fp.read(fheader[_FH_EXTRA_FIELD_LENGTH])
+ zef_file.read(fheader[_FH_EXTRA_FIELD_LENGTH])
if fname != zinfo.orig_filename:
raise BadZipfile, \
'File name in directory "%s" and header "%s" differ.' % (
zinfo.orig_filename, fname)
- bytes = self.fp.read(zinfo.compress_size)
- self.fp.seek(filepos, 0)
- if zinfo.compress_type == ZIP_STORED:
- pass
- elif zinfo.compress_type == ZIP_DEFLATED:
- if not zlib:
- raise RuntimeError, \
- "De-compression requires the (missing) zlib module"
- # zlib compress/decompress code by Jeremy Hylton of CNRI
- dc = zlib.decompressobj(-15)
- bytes = dc.decompress(bytes)
- # need to feed in unused pad byte so that zlib won't choke
- ex = dc.decompress('Z') + dc.flush()
- if ex:
- bytes = bytes + ex
+ # check for encrypted flag & handle password
+ is_encrypted = zinfo.flag_bits & 0x1
+ zd = None
+ if is_encrypted:
+ if not pwd:
+ pwd = self.pwd
+ if not pwd:
+ raise RuntimeError, "File %s is encrypted, " \
+ "password required for extraction" % name
+
+ zd = _ZipDecrypter(pwd)
+ # The first 12 bytes in the cypher stream is an encryption header
+ # used to strengthen the algorithm. The first 11 bytes are
+ # completely random, while the 12th contains the MSB of the CRC,
+ # or the MSB of the file time depending on the header type
+ # and is used to check the correctness of the password.
+ bytes = zef_file.read(12)
+ h = map(zd, bytes[0:12])
+ if zinfo.flag_bits & 0x8:
+ # compare against the file type from extended local headers
+ check_byte = (zinfo._raw_time >> 8) & 0xff
+ else:
+ # compare against the CRC otherwise
+ check_byte = (zinfo.CRC >> 24) & 0xff
+ if ord(h[11]) != check_byte:
+ raise RuntimeError("Bad password for file", name)
+
+ return ZipExtFile(zef_file, mode, zinfo, zd)
+
+ def extract(self, member, path=None, pwd=None):
+ """Extract a member from the archive to the current working directory,
+ using its full name. Its file information is extracted as accurately
+ as possible. `member' may be a filename or a ZipInfo object. You can
+ specify a different directory using `path'.
+ """
+ if not isinstance(member, ZipInfo):
+ member = self.getinfo(member)
+
+ if path is None:
+ path = os.getcwd()
+
+ return self._extract_member(member, path, pwd)
+
+ def extractall(self, path=None, members=None, pwd=None):
+ """Extract all members from the archive to the current working
+ directory. `path' specifies a different directory to extract to.
+ `members' is optional and must be a subset of the list returned
+ by namelist().
+ """
+ if members is None:
+ members = self.namelist()
+
+ for zipinfo in members:
+ self.extract(zipinfo, path, pwd)
+
+ def _extract_member(self, member, targetpath, pwd):
+ """Extract the ZipInfo object 'member' to a physical
+ file on the path targetpath.
+ """
+ # build the destination pathname, replacing
+ # forward slashes to platform specific separators.
+ # Strip trailing path separator, unless it represents the root.
+ if (targetpath[-1:] in (os.path.sep, os.path.altsep)
+ and len(os.path.splitdrive(targetpath)[1]) > 1):
+ targetpath = targetpath[:-1]
+
+ # don't include leading "/" from file name if present
+ if member.filename[0] == '/':
+ targetpath = os.path.join(targetpath, member.filename[1:])
else:
- raise BadZipfile, \
- "Unsupported compression method %d for file %s" % \
- (zinfo.compress_type, name)
- crc = binascii.crc32(bytes)
- if crc != zinfo.CRC:
- raise BadZipfile, "Bad CRC-32 for file %s" % name
- return bytes
+ targetpath = os.path.join(targetpath, member.filename)
+
+ targetpath = os.path.normpath(targetpath)
+
+ # Create all upper directories if necessary.
+ upperdirs = os.path.dirname(targetpath)
+ if upperdirs and not os.path.exists(upperdirs):
+ os.makedirs(upperdirs)
+
+ if member.filename[-1] == '/':
+ if not os.path.isdir(targetpath):
+ os.mkdir(targetpath)
+ return targetpath
+
+ source = self.open(member, pwd=pwd)
+ target = file(targetpath, "wb")
+ shutil.copyfileobj(source, target)
+ source.close()
+ target.close()
+
+ return targetpath
def _writecheck(self, zinfo):
"""Check for errors before writing a file to the archive."""
@@ -540,7 +1026,12 @@
def write(self, filename, arcname=None, compress_type=None):
"""Put the bytes from filename into the archive under the name
arcname."""
+ if not self.fp:
+ raise RuntimeError(
+ "Attempt to write to ZIP archive that was already closed")
+
st = os.stat(filename)
+ isdir = stat.S_ISDIR(st.st_mode)
mtime = time.localtime(st.st_mtime)
date_time = mtime[0:6]
# Create ZipInfo instance to store file information
@@ -549,6 +1040,8 @@
arcname = os.path.normpath(os.path.splitdrive(arcname)[1])
while arcname[0] in (os.sep, os.altsep):
arcname = arcname[1:]
+ if isdir:
+ arcname += '/'
zinfo = ZipInfo(arcname, date_time)
zinfo.external_attr = (st[0] & 0xFFFF) << 16L # Unix attributes
if compress_type is None:
@@ -562,28 +1055,37 @@
self._writecheck(zinfo)
self._didModify = True
- fp = open(filename, "rb")
- # Must overwrite CRC and sizes with correct data later
- zinfo.CRC = CRC = 0
- zinfo.compress_size = compress_size = 0
- zinfo.file_size = file_size = 0
- self.fp.write(zinfo.FileHeader())
- if zinfo.compress_type == ZIP_DEFLATED:
- cmpr = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION,
- zlib.DEFLATED, -15)
- else:
- cmpr = None
- while 1:
- buf = fp.read(1024 * 8)
- if not buf:
- break
- file_size = file_size + len(buf)
- CRC = binascii.crc32(buf, CRC)
- if cmpr:
- buf = cmpr.compress(buf)
- compress_size = compress_size + len(buf)
- self.fp.write(buf)
- fp.close()
+
+ if isdir:
+ zinfo.file_size = 0
+ zinfo.compress_size = 0
+ zinfo.CRC = 0
+ self.filelist.append(zinfo)
+ self.NameToInfo[zinfo.filename] = zinfo
+ self.fp.write(zinfo.FileHeader())
+ return
+
+ with open(filename, "rb") as fp:
+ # Must overwrite CRC and sizes with correct data later
+ zinfo.CRC = CRC = 0
+ zinfo.compress_size = compress_size = 0
+ zinfo.file_size = file_size = 0
+ self.fp.write(zinfo.FileHeader())
+ if zinfo.compress_type == ZIP_DEFLATED:
+ cmpr = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION,
+ zlib.DEFLATED, -15)
+ else:
+ cmpr = None
+ while 1:
+ buf = fp.read(1024 * 8)
+ if not buf:
+ break
+ file_size = file_size + len(buf)
+ CRC = crc32(buf, CRC) & 0xffffffff
+ if cmpr:
+ buf = cmpr.compress(buf)
+ compress_size = compress_size + len(buf)
+ self.fp.write(buf)
if cmpr:
buf = cmpr.flush()
compress_size = compress_size + len(buf)
@@ -596,27 +1098,37 @@
# Seek backwards and write CRC and file sizes
position = self.fp.tell() # Preserve current position in file
self.fp.seek(zinfo.header_offset + 14, 0)
- self.fp.write(struct.pack("<lLL", zinfo.CRC, zinfo.compress_size,
+ self.fp.write(struct.pack("<LLL", zinfo.CRC, zinfo.compress_size,
zinfo.file_size))
self.fp.seek(position, 0)
self.filelist.append(zinfo)
self.NameToInfo[zinfo.filename] = zinfo
- def writestr(self, zinfo_or_arcname, bytes):
+ def writestr(self, zinfo_or_arcname, bytes, compress_type=None):
"""Write a file into the archive. The contents is the string
'bytes'. 'zinfo_or_arcname' is either a ZipInfo instance or
the name of the file in the archive."""
if not isinstance(zinfo_or_arcname, ZipInfo):
zinfo = ZipInfo(filename=zinfo_or_arcname,
date_time=time.localtime(time.time())[:6])
+
zinfo.compress_type = self.compression
+ zinfo.external_attr = 0600 << 16
else:
zinfo = zinfo_or_arcname
+
+ if not self.fp:
+ raise RuntimeError(
+ "Attempt to write to ZIP archive that was already closed")
+
+ if compress_type is not None:
+ zinfo.compress_type = compress_type
+
zinfo.file_size = len(bytes) # Uncompressed size
zinfo.header_offset = self.fp.tell() # Start of header bytes
self._writecheck(zinfo)
self._didModify = True
- zinfo.CRC = binascii.crc32(bytes) # CRC-32 checksum
+ zinfo.CRC = crc32(bytes) & 0xffffffff # CRC-32 checksum
if zinfo.compress_type == ZIP_DEFLATED:
co = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION,
zlib.DEFLATED, -15)
@@ -630,7 +1142,7 @@
self.fp.flush()
if zinfo.flag_bits & 0x08:
# Write CRC and file sizes after the file data
- self.fp.write(struct.pack("<lLL", zinfo.CRC, zinfo.compress_size,
+ self.fp.write(struct.pack("<LLL", zinfo.CRC, zinfo.compress_size,
zinfo.file_size))
self.filelist.append(zinfo)
self.NameToInfo[zinfo.filename] = zinfo
@@ -658,15 +1170,15 @@
or zinfo.compress_size > ZIP64_LIMIT:
extra.append(zinfo.file_size)
extra.append(zinfo.compress_size)
- file_size = 0xffffffff #-1
- compress_size = 0xffffffff #-1
+ file_size = 0xffffffff
+ compress_size = 0xffffffff
else:
file_size = zinfo.file_size
compress_size = zinfo.compress_size
if zinfo.header_offset > ZIP64_LIMIT:
extra.append(zinfo.header_offset)
- header_offset = -1 # struct "l" format: 32 one bits
+ header_offset = 0xffffffffL
else:
header_offset = zinfo.header_offset
@@ -674,7 +1186,7 @@
if extra:
# Append a ZIP64 field to the extra's
extra_data = struct.pack(
- '<hh' + 'q'*len(extra),
+ '<HH' + 'Q'*len(extra),
1, 8*len(extra), *extra) + extra_data
extract_version = max(45, zinfo.extract_version)
@@ -683,44 +1195,68 @@
extract_version = zinfo.extract_version
create_version = zinfo.create_version
- centdir = struct.pack(structCentralDir,
- stringCentralDir, create_version,
- zinfo.create_system, extract_version, zinfo.reserved,
- zinfo.flag_bits, zinfo.compress_type, dostime, dosdate,
- zinfo.CRC, compress_size, file_size,
- len(zinfo.filename), len(extra_data), len(zinfo.comment),
- 0, zinfo.internal_attr, zinfo.external_attr,
- header_offset)
+ try:
+ filename, flag_bits = zinfo._encodeFilenameFlags()
+ centdir = struct.pack(structCentralDir,
+ stringCentralDir, create_version,
+ zinfo.create_system, extract_version, zinfo.reserved,
+ flag_bits, zinfo.compress_type, dostime, dosdate,
+ zinfo.CRC, compress_size, file_size,
+ len(filename), len(extra_data), len(zinfo.comment),
+ 0, zinfo.internal_attr, zinfo.external_attr,
+ header_offset)
+ except DeprecationWarning:
+ print >>sys.stderr, (structCentralDir,
+ stringCentralDir, create_version,
+ zinfo.create_system, extract_version, zinfo.reserved,
+ zinfo.flag_bits, zinfo.compress_type, dostime, dosdate,
+ zinfo.CRC, compress_size, file_size,
+ len(zinfo.filename), len(extra_data), len(zinfo.comment),
+ 0, zinfo.internal_attr, zinfo.external_attr,
+ header_offset)
+ raise
self.fp.write(centdir)
- self.fp.write(zinfo.filename)
+ self.fp.write(filename)
self.fp.write(extra_data)
self.fp.write(zinfo.comment)
pos2 = self.fp.tell()
# Write end-of-zip-archive record
- if pos1 > ZIP64_LIMIT:
+ centDirCount = count
+ centDirSize = pos2 - pos1
+ centDirOffset = pos1
+ if (centDirCount >= ZIP_FILECOUNT_LIMIT or
+ centDirOffset > ZIP64_LIMIT or
+ centDirSize > ZIP64_LIMIT):
# Need to write the ZIP64 end-of-archive records
zip64endrec = struct.pack(
structEndArchive64, stringEndArchive64,
- 44, 45, 45, 0, 0, count, count, pos2 - pos1, pos1)
+ 44, 45, 45, 0, 0, centDirCount, centDirCount,
+ centDirSize, centDirOffset)
self.fp.write(zip64endrec)
zip64locrec = struct.pack(
structEndArchive64Locator,
stringEndArchive64Locator, 0, pos2, 1)
self.fp.write(zip64locrec)
+ centDirCount = min(centDirCount, 0xFFFF)
+ centDirSize = min(centDirSize, 0xFFFFFFFF)
+ centDirOffset = min(centDirOffset, 0xFFFFFFFF)
- # XXX Why is `pos3` computed next? It's never referenced.
- pos3 = self.fp.tell()
- endrec = struct.pack(structEndArchive, stringEndArchive,
- 0, 0, count, count, pos2 - pos1, -1, 0)
- self.fp.write(endrec)
+ # check for valid comment length
+ if len(self.comment) >= ZIP_MAX_COMMENT:
+ if self.debug > 0:
+ msg = 'Archive comment is too long; truncating to %d bytes' \
+ % ZIP_MAX_COMMENT
+ self.comment = self.comment[:ZIP_MAX_COMMENT]
- else:
- endrec = struct.pack(structEndArchive, stringEndArchive,
- 0, 0, count, count, pos2 - pos1, pos1, 0)
- self.fp.write(endrec)
+ endrec = struct.pack(structEndArchive, stringEndArchive,
+ 0, 0, centDirCount, centDirCount,
+ centDirSize, centDirOffset, len(self.comment))
+ self.fp.write(endrec)
+ self.fp.write(self.comment)
self.fp.flush()
+
if not self._filePassed:
self.fp.close()
self.fp = None
@@ -854,7 +1390,9 @@
print USAGE
sys.exit(1)
zf = ZipFile(args[1], 'r')
- zf.testzip()
+ badfile = zf.testzip()
+ if badfile:
+ print("The following enclosed file is corrupted: {!r}".format(badfile))
print "Done testing"
elif args[0] == '-e':
@@ -873,9 +1411,8 @@
tgtdir = os.path.dirname(tgt)
if not os.path.exists(tgtdir):
os.makedirs(tgtdir)
- fp = open(tgt, 'wb')
- fp.write(zf.read(path))
- fp.close()
+ with open(tgt, 'wb') as fp:
+ fp.write(zf.read(path))
zf.close()
elif args[0] == '-c':
diff --git a/src/org/python/core/ArgParser.java b/src/org/python/core/ArgParser.java
--- a/src/org/python/core/ArgParser.java
+++ b/src/org/python/core/ArgParser.java
@@ -1,14 +1,14 @@
package org.python.core;
+import org.python.antlr.AST;
+
import java.util.HashSet;
import java.util.Set;
-import org.python.antlr.AST;
-
/**
* A utility class for handling mixed positional and keyword arguments.
*
- * A typical usage:
+ * Typical usage:
*
* <pre>
* public MatchObject search(PyObject[] args, String[] kws) {
@@ -51,12 +51,12 @@
}
/**
- * Create an ArgParser with one method argument
+ * Create an ArgParser for a one-argument function.
*
- * @param funcname Name of the method. Used in error messages.
+ * @param funcname Name of the function. Used in error messages.
* @param args The actual call arguments supplied in the call.
* @param kws The actual keyword names supplied in the call.
- * @param p0 The expected argument in the method definition.
+ * @param p0 The expected argument in the function definition.
*/
public ArgParser(String funcname, PyObject[] args, String[] kws, String p0) {
this(funcname, args, kws);
@@ -65,13 +65,13 @@
}
/**
- * Create an ArgParser with two method argument
+ * Create an ArgParser for a two-argument function.
*
- * @param funcname Name of the method. Used in error messages.
+ * @param funcname Name of the function. Used in error messages.
* @param args The actual call arguments supplied in the call.
* @param kws The actual keyword names supplied in the call.
- * @param p0 The first expected argument in the method definition.
- * @param p1 The second expected argument in the method definition.
+ * @param p0 The first expected argument in the function definition.
+ * @param p1 The second expected argument in the function definition.
*/
public ArgParser(String funcname, PyObject[] args, String[] kws, String p0,
String p1) {
@@ -81,14 +81,14 @@
}
/**
- * Create an ArgParser with three method argument
+ * Create an ArgParser for a three-argument function.
*
- * @param funcname Name of the method. Used in error messages.
+ * @param funcname Name of the function. Used in error messages.
* @param args The actual call arguments supplied in the call.
* @param kws The actual keyword names supplied in the call.
- * @param p0 The first expected argument in the method definition.
- * @param p1 The second expected argument in the method definition.
- * @param p2 The third expected argument in the method definition.
+ * @param p0 The first expected argument in the function definition.
+ * @param p1 The second expected argument in the function definition.
+ * @param p2 The third expected argument in the function definition.
*/
public ArgParser(String funcname, PyObject[] args, String[] kws, String p0,
String p1, String p2) {
@@ -98,12 +98,12 @@
}
/**
- * Create an ArgParser with three method argument
+ * Create an ArgParser for a multi-argument function.
*
- * @param funcname Name of the method. Used in error messages.
+ * @param funcname Name of the function. Used in error messages.
* @param args The actual call arguments supplied in the call.
* @param kws The actual keyword names supplied in the call.
- * @param paramnames The list of expected argument in the method definition.
+ * @param paramnames The list of expected argument in the function definition.
*/
public ArgParser(String funcname, PyObject[] args, String[] kws,
String[] paramnames) {
diff --git a/src/org/python/core/PyString.java b/src/org/python/core/PyString.java
--- a/src/org/python/core/PyString.java
+++ b/src/org/python/core/PyString.java
@@ -1,8 +1,6 @@
/// Copyright (c) Corporation for National Research Initiatives
package org.python.core;
-import java.math.BigInteger;
-
import org.python.core.stringlib.FieldNameIterator;
import org.python.core.stringlib.InternalFormatSpec;
import org.python.core.stringlib.InternalFormatSpecParser;
@@ -14,6 +12,8 @@
import org.python.expose.ExposedType;
import org.python.expose.MethodType;
+import java.math.BigInteger;
+
/**
* A builtin python string.
*/
@@ -2492,37 +2492,43 @@
}
public String encode() {
- return str_encode(null, null);
+ return encode(null, null);
}
public String encode(String encoding) {
- return str_encode(encoding, null);
+ return encode(encoding, null);
}
public String encode(String encoding, String errors) {
- return str_encode(encoding, errors);
- }
-
- @ExposedMethod(defaults = {"null", "null"}, doc = BuiltinDocs.str_encode_doc)
- final String str_encode(String encoding, String errors) {
return codecs.encode(this, encoding, errors);
}
+ @ExposedMethod(doc = BuiltinDocs.str_encode_doc)
+ final String str_encode(PyObject[] args, String[] keywords) {
+ ArgParser ap = new ArgParser("encode", args, keywords, "encoding", "errors");
+ String encoding = ap.getString(0, null);
+ String errors = ap.getString(1, null);
+ return encode(encoding, errors);
+ }
+
public PyObject decode() {
- return str_decode(null, null);
+ return decode(null, null);
}
public PyObject decode(String encoding) {
- return str_decode(encoding, null);
+ return decode(encoding, null);
}
public PyObject decode(String encoding, String errors) {
- return str_decode(encoding, errors);
+ return codecs.decode(this, encoding, errors);
}
- @ExposedMethod(defaults = {"null", "null"}, doc = BuiltinDocs.str_decode_doc)
- final PyObject str_decode(String encoding, String errors) {
- return codecs.decode(this, encoding, errors);
+ @ExposedMethod(doc = BuiltinDocs.str_decode_doc)
+ final PyObject str_decode(PyObject[] args, String[] keywords) {
+ ArgParser ap = new ArgParser("decode", args, keywords, "encoding", "errors");
+ String encoding = ap.getString(0, null);
+ String errors = ap.getString(1, null);
+ return decode(encoding, errors);
}
@ExposedMethod(doc = BuiltinDocs.str__formatter_parser_doc)
diff --git a/src/org/python/modules/struct.java b/src/org/python/modules/struct.java
--- a/src/org/python/modules/struct.java
+++ b/src/org/python/modules/struct.java
@@ -562,7 +562,26 @@
return Py.newInteger(buf.readByte());
}
}
+
+ static class PointerFormatDef extends FormatDef {
+ FormatDef init(char name) {
+ String dataModel = System.getProperty("sun.arch.data.model");
+ if (dataModel == null)
+ throw Py.NotImplementedError("Can't determine if JVM is 32- or 64-bit");
+ int length = dataModel.equals("64") ? 8 : 4;
+ super.init(name, length, length);
+ return this;
+ }
+
+ void pack(ByteStream buf, PyObject value) {
+ throw Py.NotImplementedError("Pointer packing/unpacking not implemented in Jython");
+ }
+ Object unpack(ByteStream buf) {
+ throw Py.NotImplementedError("Pointer packing/unpacking not implemented in Jython");
+ }
+ }
+
static class LEShortFormatDef extends FormatDef {
void pack(ByteStream buf, PyObject value) {
int v = get_int(value);
@@ -876,6 +895,7 @@
new BEUnsignedLongFormatDef() .init('Q', 8, 8),
new BEFloatFormatDef() .init('f', 4, 4),
new BEDoubleFormatDef() .init('d', 8, 8),
+ new PointerFormatDef() .init('P')
};
--
Repository URL: http://hg.python.org/jython
More information about the Jython-checkins
mailing list