[Jython-checkins] jython: Synced the Jython versions of all stdlib modules with CPythonLib 2.7 that could
alex.gronholm
jython-checkins at python.org
Wed Mar 14 20:07:40 CET 2012
http://hg.python.org/jython/rev/119341d62e11
changeset: 6362:119341d62e11
user: Alex Grönholm <alex.gronholm at nextday.fi>
date: Wed Mar 14 11:59:07 2012 -0700
summary:
Synced the Jython versions of all stdlib modules with CPythonLib 2.7 that could be automatically patched
files:
Lib/decimal.py | 1432 +++++++++----
Lib/distutils/ccompiler.py | 92 +-
Lib/distutils/command/bdist.py | 10 +-
Lib/distutils/command/bdist_dumb.py | 4 +-
Lib/distutils/sysconfig.py | 157 +-
Lib/distutils/tests/test_build_py.py | 19 +-
Lib/distutils/util.py | 89 +-
Lib/filecmp.py | 7 +-
Lib/fileinput.py | 2 +-
Lib/gettext.py | 4 +-
Lib/mailbox.py | 62 +-
Lib/netrc.py | 5 +-
Lib/new.py | 4 +
Lib/py_compile.py | 8 +-
Lib/test/list_tests.py | 32 +-
Lib/test/test_array.py | 79 +-
Lib/test/test_code.py | 4 +
Lib/test/test_codeccallbacks.py | 30 +-
Lib/test/test_compile.py | 70 +-
Lib/test/test_copy.py | 5 +-
Lib/test/test_descrtut.py | 7 +-
Lib/test/test_dumbdbm.py | 24 +
Lib/test/test_genexps.py | 2 +-
Lib/test/test_hashlib.py | 22 +-
Lib/test/test_hmac.py | 39 +-
Lib/test/test_iter.py | 24 +-
Lib/test/test_logging.py | 1344 +++++++-----
Lib/test/test_operator.py | 63 +-
Lib/test/test_os.py | 270 ++-
Lib/test/test_pkgimport.py | 8 +-
Lib/test/test_pprint.py | 285 ++-
Lib/test/test_profilehooks.py | 43 +-
Lib/test/test_random.py | 40 +-
Lib/test/test_repr.py | 14 +-
Lib/test/test_shutil.py | 354 +++-
Lib/test/test_tempfile.py | 190 +-
Lib/test/test_time.py | 20 +-
Lib/test/test_trace.py | 47 +-
Lib/test/test_univnewlines.py | 50 +-
Lib/test/test_urllib2.py | 321 ++-
Lib/test/test_weakref.py | 145 +-
Lib/test/test_xml_etree.py | 5 +-
Lib/test/test_xml_etree_c.py | 5 +-
Lib/test/test_zlib.py | 161 +-
Lib/timeit.py | 67 +-
Lib/types.py | 11 +-
Lib/weakref.py | 4 +-
Lib/zipfile.py | 805 ++++++-
48 files changed, 4728 insertions(+), 1757 deletions(-)
diff --git a/Lib/decimal.py b/Lib/decimal.py
--- a/Lib/decimal.py
+++ b/Lib/decimal.py
@@ -35,26 +35,26 @@
useful for financial applications or for contexts where users have
expectations that are at odds with binary floating point (for instance,
in binary floating point, 1.00 % 0.1 gives 0.09999999999999995 instead
-of the expected Decimal("0.00") returned by decimal floating point).
+of the expected Decimal('0.00') returned by decimal floating point).
Here are some examples of using the decimal module:
>>> from decimal import *
>>> setcontext(ExtendedContext)
>>> Decimal(0)
-Decimal("0")
->>> Decimal("1")
-Decimal("1")
->>> Decimal("-.0123")
-Decimal("-0.0123")
+Decimal('0')
+>>> Decimal('1')
+Decimal('1')
+>>> Decimal('-.0123')
+Decimal('-0.0123')
>>> Decimal(123456)
-Decimal("123456")
->>> Decimal("123.45e12345678901234567890")
-Decimal("1.2345E+12345678901234567892")
->>> Decimal("1.33") + Decimal("1.27")
-Decimal("2.60")
->>> Decimal("12.34") + Decimal("3.87") - Decimal("18.41")
-Decimal("-2.20")
+Decimal('123456')
+>>> Decimal('123.45e12345678901234567890')
+Decimal('1.2345E+12345678901234567892')
+>>> Decimal('1.33') + Decimal('1.27')
+Decimal('2.60')
+>>> Decimal('12.34') + Decimal('3.87') - Decimal('18.41')
+Decimal('-2.20')
>>> dig = Decimal(1)
>>> print dig / Decimal(3)
0.333333333
@@ -91,7 +91,7 @@
>>> print c.flags[InvalidOperation]
0
>>> c.divide(Decimal(0), Decimal(0))
-Decimal("NaN")
+Decimal('NaN')
>>> c.traps[InvalidOperation] = 1
>>> print c.flags[InvalidOperation]
1
@@ -135,6 +135,13 @@
]
import copy as _copy
+import numbers as _numbers
+
+try:
+ from collections import namedtuple as _namedtuple
+ DecimalTuple = _namedtuple('DecimalTuple', 'sign digits exponent')
+except ImportError:
+ DecimalTuple = lambda *args: args
# Rounding
ROUND_DOWN = 'ROUND_DOWN'
@@ -158,7 +165,7 @@
anything, though.
handle -- Called when context._raise_error is called and the
- trap_enabler is set. First argument is self, second is the
+ trap_enabler is not set. First argument is self, second is the
context. More arguments can be given, those being after
the explanation in _raise_error (For example,
context._raise_error(NewError, '(-x)!', self._sign) would
@@ -210,7 +217,7 @@
if args:
ans = _dec_from_triple(args[0]._sign, args[0]._int, 'n', True)
return ans._fix_nan(context)
- return NaN
+ return _NaN
class ConversionSyntax(InvalidOperation):
"""Trying to convert badly formed string.
@@ -220,7 +227,7 @@
syntax. The result is [0,qNaN].
"""
def handle(self, context, *args):
- return NaN
+ return _NaN
class DivisionByZero(DecimalException, ZeroDivisionError):
"""Division by 0.
@@ -236,7 +243,7 @@
"""
def handle(self, context, sign, *args):
- return Infsign[sign]
+ return _SignedInfinity[sign]
class DivisionImpossible(InvalidOperation):
"""Cannot perform the division adequately.
@@ -247,7 +254,7 @@
"""
def handle(self, context, *args):
- return NaN
+ return _NaN
class DivisionUndefined(InvalidOperation, ZeroDivisionError):
"""Undefined result of division.
@@ -258,7 +265,7 @@
"""
def handle(self, context, *args):
- return NaN
+ return _NaN
class Inexact(DecimalException):
"""Had to round, losing information.
@@ -284,7 +291,7 @@
"""
def handle(self, context, *args):
- return NaN
+ return _NaN
class Rounded(DecimalException):
"""Number got rounded (not necessarily changed during rounding).
@@ -334,15 +341,15 @@
def handle(self, context, sign, *args):
if context.rounding in (ROUND_HALF_UP, ROUND_HALF_EVEN,
ROUND_HALF_DOWN, ROUND_UP):
- return Infsign[sign]
+ return _SignedInfinity[sign]
if sign == 0:
if context.rounding == ROUND_CEILING:
- return Infsign[sign]
+ return _SignedInfinity[sign]
return _dec_from_triple(sign, '9'*context.prec,
context.Emax-context.prec+1)
if sign == 1:
if context.rounding == ROUND_FLOOR:
- return Infsign[sign]
+ return _SignedInfinity[sign]
return _dec_from_triple(sign, '9'*context.prec,
context.Emax-context.prec+1)
@@ -471,11 +478,7 @@
# General Decimal Arithmetic Specification
return +s # Convert result to normal context
- """
- # The string below can't be included in the docstring until Python 2.6
- # as the doctest module doesn't understand __future__ statements
- """
- >>> from __future__ import with_statement
+ >>> setcontext(DefaultContext)
>>> print getcontext().prec
28
>>> with localcontext():
@@ -510,13 +513,15 @@
"""Create a decimal point instance.
>>> Decimal('3.14') # string input
- Decimal("3.14")
+ Decimal('3.14')
>>> Decimal((0, (3, 1, 4), -2)) # tuple (sign, digit_tuple, exponent)
- Decimal("3.14")
+ Decimal('3.14')
>>> Decimal(314) # int or long
- Decimal("314")
+ Decimal('314')
>>> Decimal(Decimal(314)) # another decimal instance
- Decimal("314")
+ Decimal('314')
+ >>> Decimal(' 3.14 \\n') # leading and trailing whitespace okay
+ Decimal('3.14')
"""
# Note that the coefficient, self._int, is actually stored as
@@ -532,7 +537,7 @@
# From a string
# REs insist on real strings, so we can too.
if isinstance(value, basestring):
- m = _parser(value)
+ m = _parser(value.strip())
if m is None:
if context is None:
context = getcontext()
@@ -546,20 +551,16 @@
intpart = m.group('int')
if intpart is not None:
# finite number
- fracpart = m.group('frac')
+ fracpart = m.group('frac') or ''
exp = int(m.group('exp') or '0')
- if fracpart is not None:
- self._int = str((intpart+fracpart).lstrip('0') or '0')
- self._exp = exp - len(fracpart)
- else:
- self._int = str(intpart.lstrip('0') or '0')
- self._exp = exp
+ self._int = str(int(intpart+fracpart))
+ self._exp = exp - len(fracpart)
self._is_special = False
else:
diag = m.group('diag')
if diag is not None:
# NaN
- self._int = str(diag.lstrip('0'))
+ self._int = str(int(diag or '0')).lstrip('0')
if m.group('signal'):
self._exp = 'N'
else:
@@ -709,6 +710,39 @@
return other._fix_nan(context)
return 0
+ def _compare_check_nans(self, other, context):
+ """Version of _check_nans used for the signaling comparisons
+ compare_signal, __le__, __lt__, __ge__, __gt__.
+
+ Signal InvalidOperation if either self or other is a (quiet
+ or signaling) NaN. Signaling NaNs take precedence over quiet
+ NaNs.
+
+ Return 0 if neither operand is a NaN.
+
+ """
+ if context is None:
+ context = getcontext()
+
+ if self._is_special or other._is_special:
+ if self.is_snan():
+ return context._raise_error(InvalidOperation,
+ 'comparison involving sNaN',
+ self)
+ elif other.is_snan():
+ return context._raise_error(InvalidOperation,
+ 'comparison involving sNaN',
+ other)
+ elif self.is_qnan():
+ return context._raise_error(InvalidOperation,
+ 'comparison involving NaN',
+ self)
+ elif other.is_qnan():
+ return context._raise_error(InvalidOperation,
+ 'comparison involving NaN',
+ other)
+ return 0
+
def __nonzero__(self):
"""Return True if self is nonzero; otherwise return False.
@@ -716,21 +750,23 @@
"""
return self._is_special or self._int != '0'
- def __cmp__(self, other):
- other = _convert_other(other)
- if other is NotImplemented:
- # Never return NotImplemented
- return 1
+ def _cmp(self, other):
+ """Compare the two non-NaN decimal instances self and other.
+
+ Returns -1 if self < other, 0 if self == other and 1
+ if self > other. This routine is for internal use only."""
if self._is_special or other._is_special:
- # check for nans, without raising on a signaling nan
- if self._isnan() or other._isnan():
- return 1 # Comparison involving NaN's always reports self > other
-
- # INF = INF
- return cmp(self._isinfinity(), other._isinfinity())
-
- # check for zeros; note that cmp(0, -0) should return 0
+ self_inf = self._isinfinity()
+ other_inf = other._isinfinity()
+ if self_inf == other_inf:
+ return 0
+ elif self_inf < other_inf:
+ return -1
+ else:
+ return 1
+
+ # check for zeros; Decimal('0') == Decimal('-0')
if not self:
if not other:
return 0
@@ -750,21 +786,82 @@
if self_adjusted == other_adjusted:
self_padded = self._int + '0'*(self._exp - other._exp)
other_padded = other._int + '0'*(other._exp - self._exp)
- return cmp(self_padded, other_padded) * (-1)**self._sign
+ if self_padded == other_padded:
+ return 0
+ elif self_padded < other_padded:
+ return -(-1)**self._sign
+ else:
+ return (-1)**self._sign
elif self_adjusted > other_adjusted:
return (-1)**self._sign
else: # self_adjusted < other_adjusted
return -((-1)**self._sign)
+ # Note: The Decimal standard doesn't cover rich comparisons for
+ # Decimals. In particular, the specification is silent on the
+ # subject of what should happen for a comparison involving a NaN.
+ # We take the following approach:
+ #
+ # == comparisons involving a NaN always return False
+ # != comparisons involving a NaN always return True
+ # <, >, <= and >= comparisons involving a (quiet or signaling)
+ # NaN signal InvalidOperation, and return False if the
+ # InvalidOperation is not trapped.
+ #
+ # This behavior is designed to conform as closely as possible to
+ # that specified by IEEE 754.
+
def __eq__(self, other):
- if not isinstance(other, (Decimal, int, long)):
- return NotImplemented
- return self.__cmp__(other) == 0
+ other = _convert_other(other)
+ if other is NotImplemented:
+ return other
+ if self.is_nan() or other.is_nan():
+ return False
+ return self._cmp(other) == 0
def __ne__(self, other):
- if not isinstance(other, (Decimal, int, long)):
- return NotImplemented
- return self.__cmp__(other) != 0
+ other = _convert_other(other)
+ if other is NotImplemented:
+ return other
+ if self.is_nan() or other.is_nan():
+ return True
+ return self._cmp(other) != 0
+
+ def __lt__(self, other, context=None):
+ other = _convert_other(other)
+ if other is NotImplemented:
+ return other
+ ans = self._compare_check_nans(other, context)
+ if ans:
+ return False
+ return self._cmp(other) < 0
+
+ def __le__(self, other, context=None):
+ other = _convert_other(other)
+ if other is NotImplemented:
+ return other
+ ans = self._compare_check_nans(other, context)
+ if ans:
+ return False
+ return self._cmp(other) <= 0
+
+ def __gt__(self, other, context=None):
+ other = _convert_other(other)
+ if other is NotImplemented:
+ return other
+ ans = self._compare_check_nans(other, context)
+ if ans:
+ return False
+ return self._cmp(other) > 0
+
+ def __ge__(self, other, context=None):
+ other = _convert_other(other)
+ if other is NotImplemented:
+ return other
+ ans = self._compare_check_nans(other, context)
+ if ans:
+ return False
+ return self._cmp(other) >= 0
def compare(self, other, context=None):
"""Compares one to another.
@@ -783,7 +880,7 @@
if ans:
return ans
- return Decimal(self.__cmp__(other))
+ return Decimal(self._cmp(other))
def __hash__(self):
"""x.__hash__() <==> hash(x)"""
@@ -791,7 +888,7 @@
#
# The hash of a nonspecial noninteger Decimal must depend only
# on the value of that Decimal, and not on its representation.
- # For example: hash(Decimal("100E-1")) == hash(Decimal("10")).
+ # For example: hash(Decimal('100E-1')) == hash(Decimal('10')).
if self._is_special:
if self._isnan():
raise TypeError('Cannot hash a NaN value.')
@@ -800,7 +897,13 @@
return 0
if self._isinteger():
op = _WorkRep(self.to_integral_value())
- return hash((-1)**op.sign*op.int*10**op.exp)
+ # to make computation feasible for Decimals with large
+ # exponent, we use the fact that hash(n) == hash(m) for
+ # any two nonzero integers n and m such that (i) n and m
+ # have the same sign, and (ii) n is congruent to m modulo
+ # 2**64-1. So we can replace hash((-1)**s*c*10**e) with
+ # hash((-1)**s*c*pow(10, e, 2**64-1).
+ return hash((-1)**op.sign*op.int*pow(10, op.exp, 2**64-1))
# The value of a nonzero nonspecial Decimal instance is
# faithfully represented by the triple consisting of its sign,
# its adjusted exponent, and its coefficient with trailing
@@ -814,12 +917,12 @@
To show the internals exactly as they are.
"""
- return (self._sign, tuple(map(int, self._int)), self._exp)
+ return DecimalTuple(self._sign, tuple(map(int, self._int)), self._exp)
def __repr__(self):
"""Represents the number as an instance of Decimal."""
# Invariant: eval(repr(d)) == d
- return 'Decimal("%s")' % str(self)
+ return "Decimal('%s')" % str(self)
def __str__(self, eng=False, context=None):
"""Return string representation of the number in scientific notation.
@@ -1077,12 +1180,12 @@
if self._isinfinity():
if not other:
return context._raise_error(InvalidOperation, '(+-)INF * 0')
- return Infsign[resultsign]
+ return _SignedInfinity[resultsign]
if other._isinfinity():
if not self:
return context._raise_error(InvalidOperation, '0 * (+-)INF')
- return Infsign[resultsign]
+ return _SignedInfinity[resultsign]
resultexp = self._exp + other._exp
@@ -1112,7 +1215,7 @@
return ans
__rmul__ = __mul__
- def __div__(self, other, context=None):
+ def __truediv__(self, other, context=None):
"""Return self / other."""
other = _convert_other(other)
if other is NotImplemented:
@@ -1132,7 +1235,7 @@
return context._raise_error(InvalidOperation, '(+-)INF/(+-)INF')
if self._isinfinity():
- return Infsign[sign]
+ return _SignedInfinity[sign]
if other._isinfinity():
context._raise_error(Clamped, 'Division by infinity')
@@ -1171,8 +1274,6 @@
ans = _dec_from_triple(sign, str(coeff), exp)
return ans._fix(context)
- __truediv__ = __div__
-
def _divide(self, other, context):
"""Return (self // other, self % other), to context.prec precision.
@@ -1206,13 +1307,15 @@
'quotient too large in //, % or divmod')
return ans, ans
- def __rdiv__(self, other, context=None):
- """Swaps self/other and returns __div__."""
+ def __rtruediv__(self, other, context=None):
+ """Swaps self/other and returns __truediv__."""
other = _convert_other(other)
if other is NotImplemented:
return other
- return other.__div__(self, context=context)
- __rtruediv__ = __rdiv__
+ return other.__truediv__(self, context=context)
+
+ __div__ = __truediv__
+ __rdiv__ = __rtruediv__
def __divmod__(self, other, context=None):
"""
@@ -1235,7 +1338,7 @@
ans = context._raise_error(InvalidOperation, 'divmod(INF, INF)')
return ans, ans
else:
- return (Infsign[sign],
+ return (_SignedInfinity[sign],
context._raise_error(InvalidOperation, 'INF % x'))
if not other:
@@ -1383,7 +1486,7 @@
if other._isinfinity():
return context._raise_error(InvalidOperation, 'INF // INF')
else:
- return Infsign[self._sign ^ other._sign]
+ return _SignedInfinity[self._sign ^ other._sign]
if not other:
if self:
@@ -1409,16 +1512,31 @@
"""Converts self to an int, truncating if necessary."""
if self._is_special:
if self._isnan():
- context = getcontext()
- return context._raise_error(InvalidContext)
+ raise ValueError("Cannot convert NaN to integer")
elif self._isinfinity():
- raise OverflowError("Cannot convert infinity to long")
+ raise OverflowError("Cannot convert infinity to integer")
s = (-1)**self._sign
if self._exp >= 0:
return s*int(self._int)*10**self._exp
else:
return s*int(self._int[:self._exp] or '0')
+ __trunc__ = __int__
+
+ def real(self):
+ return self
+ real = property(real)
+
+ def imag(self):
+ return Decimal(0)
+ imag = property(imag)
+
+ def conjugate(self):
+ return self
+
+ def __complex__(self):
+ return complex(float(self))
+
def __long__(self):
"""Converts to a long.
@@ -1474,47 +1592,53 @@
exp_min = len(self._int) + self._exp - context.prec
if exp_min > Etop:
# overflow: exp_min > Etop iff self.adjusted() > Emax
+ ans = context._raise_error(Overflow, 'above Emax', self._sign)
context._raise_error(Inexact)
context._raise_error(Rounded)
- return context._raise_error(Overflow, 'above Emax', self._sign)
+ return ans
+
self_is_subnormal = exp_min < Etiny
if self_is_subnormal:
- context._raise_error(Subnormal)
exp_min = Etiny
# round if self has too many digits
if self._exp < exp_min:
- context._raise_error(Rounded)
digits = len(self._int) + self._exp - exp_min
if digits < 0:
self = _dec_from_triple(self._sign, '1', exp_min-1)
digits = 0
- this_function = getattr(self, self._pick_rounding_function[context.rounding])
- changed = this_function(digits)
+ rounding_method = self._pick_rounding_function[context.rounding]
+ changed = getattr(self, rounding_method)(digits)
coeff = self._int[:digits] or '0'
- if changed == 1:
+ if changed > 0:
coeff = str(int(coeff)+1)
- ans = _dec_from_triple(self._sign, coeff, exp_min)
-
+ if len(coeff) > context.prec:
+ coeff = coeff[:-1]
+ exp_min += 1
+
+ # check whether the rounding pushed the exponent out of range
+ if exp_min > Etop:
+ ans = context._raise_error(Overflow, 'above Emax', self._sign)
+ else:
+ ans = _dec_from_triple(self._sign, coeff, exp_min)
+
+ # raise the appropriate signals, taking care to respect
+ # the precedence described in the specification
+ if changed and self_is_subnormal:
+ context._raise_error(Underflow)
+ if self_is_subnormal:
+ context._raise_error(Subnormal)
if changed:
context._raise_error(Inexact)
- if self_is_subnormal:
- context._raise_error(Underflow)
- if not ans:
- # raise Clamped on underflow to 0
- context._raise_error(Clamped)
- elif len(ans._int) == context.prec+1:
- # we get here only if rescaling rounds the
- # cofficient up to exactly 10**context.prec
- if ans._exp < Etop:
- ans = _dec_from_triple(ans._sign,
- ans._int[:-1], ans._exp+1)
- else:
- # Inexact and Rounded have already been raised
- ans = context._raise_error(Overflow, 'above Emax',
- self._sign)
+ context._raise_error(Rounded)
+ if not ans:
+ # raise Clamped on underflow to 0
+ context._raise_error(Clamped)
return ans
+ if self_is_subnormal:
+ context._raise_error(Subnormal)
+
# fold down if _clamp == 1 and self has too few digits
if context._clamp == 1 and self._exp > Etop:
context._raise_error(Clamped)
@@ -1622,12 +1746,12 @@
if not other:
return context._raise_error(InvalidOperation,
'INF * 0 in fma')
- product = Infsign[self._sign ^ other._sign]
+ product = _SignedInfinity[self._sign ^ other._sign]
elif other._exp == 'F':
if not self:
return context._raise_error(InvalidOperation,
'0 * INF in fma')
- product = Infsign[self._sign ^ other._sign]
+ product = _SignedInfinity[self._sign ^ other._sign]
else:
product = _dec_from_triple(self._sign ^ other._sign,
str(int(self._int) * int(other._int)),
@@ -1794,12 +1918,14 @@
# case where xc == 1: result is 10**(xe*y), with xe*y
# required to be an integer
if xc == 1:
- if ye >= 0:
- exponent = xe*yc*10**ye
- else:
- exponent, remainder = divmod(xe*yc, 10**-ye)
- if remainder:
- return None
+ xe *= yc
+ # result is now 10**(xe * 10**ye); xe * 10**ye must be integral
+ while xe % 10 == 0:
+ xe //= 10
+ ye += 1
+ if ye < 0:
+ return None
+ exponent = xe * 10**ye
if y.sign == 1:
exponent = -exponent
# if other is a nonnegative integer, use ideal exponent
@@ -1977,7 +2103,7 @@
if not self:
return context._raise_error(InvalidOperation, '0 ** 0')
else:
- return Dec_p1
+ return _One
# result has sign 1 iff self._sign is 1 and other is an odd integer
result_sign = 0
@@ -1999,19 +2125,19 @@
if other._sign == 0:
return _dec_from_triple(result_sign, '0', 0)
else:
- return Infsign[result_sign]
+ return _SignedInfinity[result_sign]
# Inf**(+ve or Inf) = Inf; Inf**(-ve or -Inf) = 0
if self._isinfinity():
if other._sign == 0:
- return Infsign[result_sign]
+ return _SignedInfinity[result_sign]
else:
return _dec_from_triple(result_sign, '0', 0)
# 1**other = 1, but the choice of exponent and the flags
# depend on the exponent of self, and on whether other is a
# positive integer, a negative integer, or neither
- if self == Dec_p1:
+ if self == _One:
if other._isinteger():
# exp = max(self._exp*max(int(other), 0),
# 1-context.prec) but evaluating int(other) directly
@@ -2044,11 +2170,12 @@
if (other._sign == 0) == (self_adj < 0):
return _dec_from_triple(result_sign, '0', 0)
else:
- return Infsign[result_sign]
+ return _SignedInfinity[result_sign]
# from here on, the result always goes through the call
# to _fix at the end of this function.
ans = None
+ exact = False
# crude test to catch cases of extreme overflow/underflow. If
# log10(self)*other >= 10**bound and bound >= len(str(Emax))
@@ -2071,8 +2198,10 @@
# try for an exact result with precision +1
if ans is None:
ans = self._power_exact(other, context.prec + 1)
- if ans is not None and result_sign == 1:
- ans = _dec_from_triple(1, ans._int, ans._exp)
+ if ans is not None:
+ if result_sign == 1:
+ ans = _dec_from_triple(1, ans._int, ans._exp)
+ exact = True
# usual case: inexact result, x**y computed directly as exp(y*log(x))
if ans is None:
@@ -2095,24 +2224,55 @@
ans = _dec_from_triple(result_sign, str(coeff), exp)
- # the specification says that for non-integer other we need to
- # raise Inexact, even when the result is actually exact. In
- # the same way, we need to raise Underflow here if the result
- # is subnormal. (The call to _fix will take care of raising
- # Rounded and Subnormal, as usual.)
- if not other._isinteger():
- context._raise_error(Inexact)
- # pad with zeros up to length context.prec+1 if necessary
+ # unlike exp, ln and log10, the power function respects the
+ # rounding mode; no need to switch to ROUND_HALF_EVEN here
+
+ # There's a difficulty here when 'other' is not an integer and
+ # the result is exact. In this case, the specification
+ # requires that the Inexact flag be raised (in spite of
+ # exactness), but since the result is exact _fix won't do this
+ # for us. (Correspondingly, the Underflow signal should also
+ # be raised for subnormal results.) We can't directly raise
+ # these signals either before or after calling _fix, since
+ # that would violate the precedence for signals. So we wrap
+ # the ._fix call in a temporary context, and reraise
+ # afterwards.
+ if exact and not other._isinteger():
+ # pad with zeros up to length context.prec+1 if necessary; this
+ # ensures that the Rounded signal will be raised.
if len(ans._int) <= context.prec:
- expdiff = context.prec+1 - len(ans._int)
+ expdiff = context.prec + 1 - len(ans._int)
ans = _dec_from_triple(ans._sign, ans._int+'0'*expdiff,
ans._exp-expdiff)
- if ans.adjusted() < context.Emin:
- context._raise_error(Underflow)
-
- # unlike exp, ln and log10, the power function respects the
- # rounding mode; no need to use ROUND_HALF_EVEN here
- ans = ans._fix(context)
+
+ # create a copy of the current context, with cleared flags/traps
+ newcontext = context.copy()
+ newcontext.clear_flags()
+ for exception in _signals:
+ newcontext.traps[exception] = 0
+
+ # round in the new context
+ ans = ans._fix(newcontext)
+
+ # raise Inexact, and if necessary, Underflow
+ newcontext._raise_error(Inexact)
+ if newcontext.flags[Subnormal]:
+ newcontext._raise_error(Underflow)
+
+ # propagate signals to the original context; _fix could
+ # have raised any of Overflow, Underflow, Subnormal,
+ # Inexact, Rounded, Clamped. Overflow needs the correct
+ # arguments. Note that the order of the exceptions is
+ # important here.
+ if newcontext.flags[Overflow]:
+ context._raise_error(Overflow, 'above Emax', ans._sign)
+ for exception in Underflow, Subnormal, Inexact, Rounded, Clamped:
+ if newcontext.flags[exception]:
+ context._raise_error(exception)
+
+ else:
+ ans = ans._fix(context)
+
return ans
def __rpow__(self, other, context=None):
@@ -2206,14 +2366,15 @@
'quantize result has too many digits for current context')
# raise appropriate flags
+ if ans and ans.adjusted() < context.Emin:
+ context._raise_error(Subnormal)
if ans._exp > self._exp:
- context._raise_error(Rounded)
if ans != self:
context._raise_error(Inexact)
- if ans and ans.adjusted() < context.Emin:
- context._raise_error(Subnormal)
-
- # call to fix takes care of any necessary folddown
+ context._raise_error(Rounded)
+
+ # call to fix takes care of any necessary folddown, and
+ # signals Clamped if necessary
ans = ans._fix(context)
return ans
@@ -2266,6 +2427,29 @@
coeff = str(int(coeff)+1)
return _dec_from_triple(self._sign, coeff, exp)
+ def _round(self, places, rounding):
+ """Round a nonzero, nonspecial Decimal to a fixed number of
+ significant figures, using the given rounding mode.
+
+ Infinities, NaNs and zeros are returned unaltered.
+
+ This operation is quiet: it raises no flags, and uses no
+ information from the context.
+
+ """
+ if places <= 0:
+ raise ValueError("argument should be at least 1 in _round")
+ if self._is_special or not self:
+ return Decimal(self)
+ ans = self._rescale(self.adjusted()+1-places, rounding)
+ # it can happen that the rescale alters the adjusted exponent;
+ # for example when rounding 99.97 to 3 significant figures.
+ # When this happens we end up with an extra 0 at the end of
+ # the number; a second rescale fixes this.
+ if ans.adjusted() != self.adjusted():
+ ans = ans._rescale(ans.adjusted()+1-places, rounding)
+ return ans
+
def to_integral_exact(self, rounding=None, context=None):
"""Rounds to a nearby integer.
@@ -2289,10 +2473,10 @@
context = getcontext()
if rounding is None:
rounding = context.rounding
- context._raise_error(Rounded)
ans = self._rescale(0, rounding)
if ans != self:
context._raise_error(Inexact)
+ context._raise_error(Rounded)
return ans
def to_integral_value(self, rounding=None, context=None):
@@ -2436,7 +2620,7 @@
return other._fix(context)
return self._check_nans(other, context)
- c = self.__cmp__(other)
+ c = self._cmp(other)
if c == 0:
# If both operands are finite and equal in numerical value
# then an ordering is applied:
@@ -2478,7 +2662,7 @@
return other._fix(context)
return self._check_nans(other, context)
- c = self.__cmp__(other)
+ c = self._cmp(other)
if c == 0:
c = self.compare_total(other)
@@ -2526,23 +2710,10 @@
It's pretty much like compare(), but all NaNs signal, with signaling
NaNs taking precedence over quiet NaNs.
"""
- if context is None:
- context = getcontext()
-
- self_is_nan = self._isnan()
- other_is_nan = other._isnan()
- if self_is_nan == 2:
- return context._raise_error(InvalidOperation, 'sNaN',
- self)
- if other_is_nan == 2:
- return context._raise_error(InvalidOperation, 'sNaN',
- other)
- if self_is_nan:
- return context._raise_error(InvalidOperation, 'NaN in compare_signal',
- self)
- if other_is_nan:
- return context._raise_error(InvalidOperation, 'NaN in compare_signal',
- other)
+ other = _convert_other(other, raiseit = True)
+ ans = self._compare_check_nans(other, context)
+ if ans:
+ return ans
return self.compare(other, context=context)
def compare_total(self, other):
@@ -2552,11 +2723,13 @@
value. Note that a total ordering is defined for all possible abstract
representations.
"""
+ other = _convert_other(other, raiseit=True)
+
# if one is negative and the other is positive, it's easy
if self._sign and not other._sign:
- return Dec_n1
+ return _NegativeOne
if not self._sign and other._sign:
- return Dec_p1
+ return _One
sign = self._sign
# let's handle both NaN types
@@ -2564,53 +2737,56 @@
other_nan = other._isnan()
if self_nan or other_nan:
if self_nan == other_nan:
- if self._int < other._int:
+ # compare payloads as though they're integers
+ self_key = len(self._int), self._int
+ other_key = len(other._int), other._int
+ if self_key < other_key:
if sign:
- return Dec_p1
+ return _One
else:
- return Dec_n1
- if self._int > other._int:
+ return _NegativeOne
+ if self_key > other_key:
if sign:
- return Dec_n1
+ return _NegativeOne
else:
- return Dec_p1
- return Dec_0
+ return _One
+ return _Zero
if sign:
if self_nan == 1:
- return Dec_n1
+ return _NegativeOne
if other_nan == 1:
- return Dec_p1
+ return _One
if self_nan == 2:
- return Dec_n1
+ return _NegativeOne
if other_nan == 2:
- return Dec_p1
+ return _One
else:
if self_nan == 1:
- return Dec_p1
+ return _One
if other_nan == 1:
- return Dec_n1
+ return _NegativeOne
if self_nan == 2:
- return Dec_p1
+ return _One
if other_nan == 2:
- return Dec_n1
+ return _NegativeOne
if self < other:
- return Dec_n1
+ return _NegativeOne
if self > other:
- return Dec_p1
+ return _One
if self._exp < other._exp:
if sign:
- return Dec_p1
+ return _One
else:
- return Dec_n1
+ return _NegativeOne
if self._exp > other._exp:
if sign:
- return Dec_n1
+ return _NegativeOne
else:
- return Dec_p1
- return Dec_0
+ return _One
+ return _Zero
def compare_total_mag(self, other):
@@ -2618,6 +2794,8 @@
Like compare_total, but with operand's sign ignored and assumed to be 0.
"""
+ other = _convert_other(other, raiseit=True)
+
s = self.copy_abs()
o = other.copy_abs()
return s.compare_total(o)
@@ -2651,11 +2829,11 @@
# exp(-Infinity) = 0
if self._isinfinity() == -1:
- return Dec_0
+ return _Zero
# exp(0) = 1
if not self:
- return Dec_p1
+ return _One
# exp(Infinity) = Infinity
if self._isinfinity() == 1:
@@ -2743,7 +2921,7 @@
return False
if context is None:
context = getcontext()
- return context.Emin <= self.adjusted() <= context.Emax
+ return context.Emin <= self.adjusted()
def is_qnan(self):
"""Return True if self is a quiet NaN; otherwise return False."""
@@ -2807,15 +2985,15 @@
# ln(0.0) == -Infinity
if not self:
- return negInf
+ return _NegativeInfinity
# ln(Infinity) = Infinity
if self._isinfinity() == 1:
- return Inf
+ return _Infinity
# ln(1.0) == 0.0
- if self == Dec_p1:
- return Dec_0
+ if self == _One:
+ return _Zero
# ln(negative) raises InvalidOperation
if self._sign == 1:
@@ -2887,11 +3065,11 @@
# log10(0.0) == -Infinity
if not self:
- return negInf
+ return _NegativeInfinity
# log10(Infinity) = Infinity
if self._isinfinity() == 1:
- return Inf
+ return _Infinity
# log10(negative or -Infinity) raises InvalidOperation
if self._sign == 1:
@@ -2943,7 +3121,7 @@
# logb(+/-Inf) = +Inf
if self._isinfinity():
- return Inf
+ return _Infinity
# logb(0) = -Inf, DivisionByZero
if not self:
@@ -2952,12 +3130,13 @@
# otherwise, simply return the adjusted exponent of self, as a
# Decimal. Note that no attempt is made to fit the result
# into the current context.
- return Decimal(self.adjusted())
+ ans = Decimal(self.adjusted())
+ return ans._fix(context)
def _islogical(self):
"""Return True if self is a logical operand.
- For being logical, it must be a finite numbers with a sign of 0,
+ For being logical, it must be a finite number with a sign of 0,
an exponent of 0, and a coefficient whose digits must all be
either 0 or 1.
"""
@@ -2985,6 +3164,9 @@
"""Applies an 'and' operation between self and other's digits."""
if context is None:
context = getcontext()
+
+ other = _convert_other(other, raiseit=True)
+
if not self._islogical() or not other._islogical():
return context._raise_error(InvalidOperation)
@@ -3006,6 +3188,9 @@
"""Applies an 'or' operation between self and other's digits."""
if context is None:
context = getcontext()
+
+ other = _convert_other(other, raiseit=True)
+
if not self._islogical() or not other._islogical():
return context._raise_error(InvalidOperation)
@@ -3013,13 +3198,16 @@
(opa, opb) = self._fill_logical(context, self._int, other._int)
# make the operation, and clean starting zeroes
- result = "".join(str(int(a)|int(b)) for a,b in zip(opa,opb))
+ result = "".join([str(int(a)|int(b)) for a,b in zip(opa,opb)])
return _dec_from_triple(0, result.lstrip('0') or '0', 0)
def logical_xor(self, other, context=None):
"""Applies an 'xor' operation between self and other's digits."""
if context is None:
context = getcontext()
+
+ other = _convert_other(other, raiseit=True)
+
if not self._islogical() or not other._islogical():
return context._raise_error(InvalidOperation)
@@ -3027,7 +3215,7 @@
(opa, opb) = self._fill_logical(context, self._int, other._int)
# make the operation, and clean starting zeroes
- result = "".join(str(int(a)^int(b)) for a,b in zip(opa,opb))
+ result = "".join([str(int(a)^int(b)) for a,b in zip(opa,opb)])
return _dec_from_triple(0, result.lstrip('0') or '0', 0)
def max_mag(self, other, context=None):
@@ -3049,7 +3237,7 @@
return other._fix(context)
return self._check_nans(other, context)
- c = self.copy_abs().__cmp__(other.copy_abs())
+ c = self.copy_abs()._cmp(other.copy_abs())
if c == 0:
c = self.compare_total(other)
@@ -3079,7 +3267,7 @@
return other._fix(context)
return self._check_nans(other, context)
- c = self.copy_abs().__cmp__(other.copy_abs())
+ c = self.copy_abs()._cmp(other.copy_abs())
if c == 0:
c = self.compare_total(other)
@@ -3100,7 +3288,7 @@
return ans
if self._isinfinity() == -1:
- return negInf
+ return _NegativeInfinity
if self._isinfinity() == 1:
return _dec_from_triple(0, '9'*context.prec, context.Etop())
@@ -3123,7 +3311,7 @@
return ans
if self._isinfinity() == 1:
- return Inf
+ return _Infinity
if self._isinfinity() == -1:
return _dec_from_triple(1, '9'*context.prec, context.Etop())
@@ -3154,7 +3342,7 @@
if ans:
return ans
- comparison = self.__cmp__(other)
+ comparison = self._cmp(other)
if comparison == 0:
return self.copy_sign(other)
@@ -3168,13 +3356,13 @@
context._raise_error(Overflow,
'Infinite result from next_toward',
ans._sign)
+ context._raise_error(Inexact)
context._raise_error(Rounded)
- context._raise_error(Inexact)
elif ans.adjusted() < context.Emin:
context._raise_error(Underflow)
context._raise_error(Subnormal)
+ context._raise_error(Inexact)
context._raise_error(Rounded)
- context._raise_error(Inexact)
# if precision == 1 then we don't raise Clamped for a
# result 0E-Etiny.
if not ans:
@@ -3233,6 +3421,8 @@
if context is None:
context = getcontext()
+ other = _convert_other(other, raiseit=True)
+
ans = self._check_nans(other, context)
if ans:
return ans
@@ -3249,19 +3439,23 @@
torot = int(other)
rotdig = self._int
topad = context.prec - len(rotdig)
- if topad:
+ if topad > 0:
rotdig = '0'*topad + rotdig
+ elif topad < 0:
+ rotdig = rotdig[-topad:]
# let's rotate!
rotated = rotdig[torot:] + rotdig[:torot]
return _dec_from_triple(self._sign,
rotated.lstrip('0') or '0', self._exp)
- def scaleb (self, other, context=None):
+ def scaleb(self, other, context=None):
"""Returns self operand after adding the second value to its exp."""
if context is None:
context = getcontext()
+ other = _convert_other(other, raiseit=True)
+
ans = self._check_nans(other, context)
if ans:
return ans
@@ -3285,6 +3479,8 @@
if context is None:
context = getcontext()
+ other = _convert_other(other, raiseit=True)
+
ans = self._check_nans(other, context)
if ans:
return ans
@@ -3299,22 +3495,22 @@
# get values, pad if necessary
torot = int(other)
- if not torot:
- return Decimal(self)
rotdig = self._int
topad = context.prec - len(rotdig)
- if topad:
+ if topad > 0:
rotdig = '0'*topad + rotdig
+ elif topad < 0:
+ rotdig = rotdig[-topad:]
# let's shift!
if torot < 0:
- rotated = rotdig[:torot]
+ shifted = rotdig[:torot]
else:
- rotated = rotdig + '0'*torot
- rotated = rotated[-context.prec:]
+ shifted = rotdig + '0'*torot
+ shifted = shifted[-context.prec:]
return _dec_from_triple(self._sign,
- rotated.lstrip('0') or '0', self._exp)
+ shifted.lstrip('0') or '0', self._exp)
# Support for pickling, copy, and deepcopy
def __reduce__(self):
@@ -3330,14 +3526,94 @@
return self # My components are also immutable
return self.__class__(str(self))
- # support for Jython __tojava__:
- def __tojava__(self, java_class):
- from java.lang import Object
- from java.math import BigDecimal
- from org.python.core import Py
- if java_class not in (BigDecimal, Object):
- return Py.NoConversion
- return BigDecimal(str(self))
+ # PEP 3101 support. See also _parse_format_specifier and _format_align
+ def __format__(self, specifier, context=None):
+ """Format a Decimal instance according to the given specifier.
+
+ The specifier should be a standard format specifier, with the
+ form described in PEP 3101. Formatting types 'e', 'E', 'f',
+ 'F', 'g', 'G', and '%' are supported. If the formatting type
+ is omitted it defaults to 'g' or 'G', depending on the value
+ of context.capitals.
+
+ At this time the 'n' format specifier type (which is supposed
+ to use the current locale) is not supported.
+ """
+
+ # Note: PEP 3101 says that if the type is not present then
+ # there should be at least one digit after the decimal point.
+ # We take the liberty of ignoring this requirement for
+ # Decimal---it's presumably there to make sure that
+ # format(float, '') behaves similarly to str(float).
+ if context is None:
+ context = getcontext()
+
+ spec = _parse_format_specifier(specifier)
+
+ # special values don't care about the type or precision...
+ if self._is_special:
+ return _format_align(str(self), spec)
+
+ # a type of None defaults to 'g' or 'G', depending on context
+ # if type is '%', adjust exponent of self accordingly
+ if spec['type'] is None:
+ spec['type'] = ['g', 'G'][context.capitals]
+ elif spec['type'] == '%':
+ self = _dec_from_triple(self._sign, self._int, self._exp+2)
+
+ # round if necessary, taking rounding mode from the context
+ rounding = context.rounding
+ precision = spec['precision']
+ if precision is not None:
+ if spec['type'] in 'eE':
+ self = self._round(precision+1, rounding)
+ elif spec['type'] in 'gG':
+ if len(self._int) > precision:
+ self = self._round(precision, rounding)
+ elif spec['type'] in 'fF%':
+ self = self._rescale(-precision, rounding)
+ # special case: zeros with a positive exponent can't be
+ # represented in fixed point; rescale them to 0e0.
+ elif not self and self._exp > 0 and spec['type'] in 'fF%':
+ self = self._rescale(0, rounding)
+
+ # figure out placement of the decimal point
+ leftdigits = self._exp + len(self._int)
+ if spec['type'] in 'fF%':
+ dotplace = leftdigits
+ elif spec['type'] in 'eE':
+ if not self and precision is not None:
+ dotplace = 1 - precision
+ else:
+ dotplace = 1
+ elif spec['type'] in 'gG':
+ if self._exp <= 0 and leftdigits > -6:
+ dotplace = leftdigits
+ else:
+ dotplace = 1
+
+ # figure out main part of numeric string...
+ if dotplace <= 0:
+ num = '0.' + '0'*(-dotplace) + self._int
+ elif dotplace >= len(self._int):
+ # make sure we're not padding a '0' with extra zeros on the right
+ assert dotplace==len(self._int) or self._int != '0'
+ num = self._int + '0'*(dotplace-len(self._int))
+ else:
+ num = self._int[:dotplace] + '.' + self._int[dotplace:]
+
+ # ...then the trailing exponent, or trailing '%'
+ if leftdigits != dotplace or spec['type'] in 'eE':
+ echar = {'E': 'E', 'e': 'e', 'G': 'E', 'g': 'e'}[spec['type']]
+ num = num + "{0}{1:+}".format(echar, leftdigits-dotplace)
+ elif spec['type'] == '%':
+ num = num + '%'
+
+ # add sign
+ if self._sign == 1:
+ num = '-' + num
+ return _format_align(num, spec)
+
def _dec_from_triple(sign, coefficient, exponent, special=False):
"""Create a decimal instance directly, without any validation,
@@ -3355,6 +3631,12 @@
return self
+# Register Decimal as a kind of Number (an abstract base class).
+# However, do not register it as Real (because Decimals are not
+# interoperable with floats).
+_numbers.Number.register(Decimal)
+
+
##### Context class #######################################################
@@ -3393,7 +3675,7 @@
traps - If traps[exception] = 1, then the exception is
raised when it is caused. Otherwise, a value is
substituted in.
- flags - When an exception is caused, flags[exception] is incremented.
+ flags - When an exception is caused, flags[exception] is set.
(Whether or not the trap_enabler is set)
Should be reset by user of Decimal instance.
Emin - Minimum exponent
@@ -3408,22 +3690,38 @@
Emin=None, Emax=None,
capitals=None, _clamp=0,
_ignored_flags=None):
+ # Set defaults; for everything except flags and _ignored_flags,
+ # inherit from DefaultContext.
+ try:
+ dc = DefaultContext
+ except NameError:
+ pass
+
+ self.prec = prec if prec is not None else dc.prec
+ self.rounding = rounding if rounding is not None else dc.rounding
+ self.Emin = Emin if Emin is not None else dc.Emin
+ self.Emax = Emax if Emax is not None else dc.Emax
+ self.capitals = capitals if capitals is not None else dc.capitals
+ self._clamp = _clamp if _clamp is not None else dc._clamp
+
+ if _ignored_flags is None:
+ self._ignored_flags = []
+ else:
+ self._ignored_flags = _ignored_flags
+
+ if traps is None:
+ self.traps = dc.traps.copy()
+ elif not isinstance(traps, dict):
+ self.traps = dict((s, int(s in traps)) for s in _signals)
+ else:
+ self.traps = traps
+
if flags is None:
- flags = []
- if _ignored_flags is None:
- _ignored_flags = []
- if not isinstance(flags, dict):
- flags = dict([(s,s in flags) for s in _signals])
- del s
- if traps is not None and not isinstance(traps, dict):
- traps = dict([(s,s in traps) for s in _signals])
- del s
- for name, val in locals().items():
- if val is None:
- setattr(self, name, _copy.copy(getattr(DefaultContext, name)))
- else:
- setattr(self, name, val)
- del self.self
+ self.flags = dict.fromkeys(_signals, 0)
+ elif not isinstance(flags, dict):
+ self.flags = dict((s, int(s in flags)) for s in _signals)
+ else:
+ self.flags = flags
def __repr__(self):
"""Show the current context."""
@@ -3461,23 +3759,23 @@
"""Handles an error
If the flag is in _ignored_flags, returns the default response.
- Otherwise, it increments the flag, then, if the corresponding
- trap_enabler is set, it reaises the exception. Otherwise, it returns
- the default value after incrementing the flag.
+ Otherwise, it sets the flag, then, if the corresponding
+ trap_enabler is set, it reraises the exception. Otherwise, it returns
+ the default value after setting the flag.
"""
error = _condition_map.get(condition, condition)
if error in self._ignored_flags:
# Don't touch the flag
return error().handle(self, *args)
- self.flags[error] += 1
+ self.flags[error] = 1
if not self.traps[error]:
# The errors define how to handle themselves.
return condition().handle(self, *args)
# Errors should only be risked on copies of the context
# self._ignored_flags = []
- raise error, explanation
+ raise error(explanation)
def _ignore_all_flags(self):
"""Ignore all flags, if they are raised"""
@@ -3497,10 +3795,8 @@
for flag in flags:
self._ignored_flags.remove(flag)
- def __hash__(self):
- """A Context cannot be hashed."""
- # We inherit object.__hash__, so we must deny this explicitly
- raise TypeError("Cannot hash a Context.")
+ # We inherit object.__hash__, so we must deny this explicitly
+ __hash__ = None
def Etiny(self):
"""Returns Etiny (= Emin - prec + 1)"""
@@ -3530,7 +3826,16 @@
return rounding
def create_decimal(self, num='0'):
- """Creates a new Decimal instance but using self as context."""
+ """Creates a new Decimal instance but using self as context.
+
+ This method implements the to-number operation of the
+ IBM Decimal specification."""
+
+ if isinstance(num, basestring) and num != num.strip():
+ return self._raise_error(ConversionSyntax,
+ "no trailing or leading whitespace is "
+ "permitted.")
+
d = Decimal(num, context=self)
if d._isnan() and len(d._int) > self.prec - self._clamp:
return self._raise_error(ConversionSyntax,
@@ -3546,13 +3851,13 @@
the plus operation on the operand.
>>> ExtendedContext.abs(Decimal('2.1'))
- Decimal("2.1")
+ Decimal('2.1')
>>> ExtendedContext.abs(Decimal('-100'))
- Decimal("100")
+ Decimal('100')
>>> ExtendedContext.abs(Decimal('101.5'))
- Decimal("101.5")
+ Decimal('101.5')
>>> ExtendedContext.abs(Decimal('-101.5'))
- Decimal("101.5")
+ Decimal('101.5')
"""
return a.__abs__(context=self)
@@ -3560,9 +3865,9 @@
"""Return the sum of the two operands.
>>> ExtendedContext.add(Decimal('12'), Decimal('7.00'))
- Decimal("19.00")
+ Decimal('19.00')
>>> ExtendedContext.add(Decimal('1E+2'), Decimal('1.01E+4'))
- Decimal("1.02E+4")
+ Decimal('1.02E+4')
"""
return a.__add__(b, context=self)
@@ -3576,7 +3881,7 @@
received object already is in its canonical form.
>>> ExtendedContext.canonical(Decimal('2.50'))
- Decimal("2.50")
+ Decimal('2.50')
"""
return a.canonical(context=self)
@@ -3595,17 +3900,17 @@
zero or negative zero, or '1' if the result is greater than zero.
>>> ExtendedContext.compare(Decimal('2.1'), Decimal('3'))
- Decimal("-1")
+ Decimal('-1')
>>> ExtendedContext.compare(Decimal('2.1'), Decimal('2.1'))
- Decimal("0")
+ Decimal('0')
>>> ExtendedContext.compare(Decimal('2.1'), Decimal('2.10'))
- Decimal("0")
+ Decimal('0')
>>> ExtendedContext.compare(Decimal('3'), Decimal('2.1'))
- Decimal("1")
+ Decimal('1')
>>> ExtendedContext.compare(Decimal('2.1'), Decimal('-3'))
- Decimal("1")
+ Decimal('1')
>>> ExtendedContext.compare(Decimal('-3'), Decimal('2.1'))
- Decimal("-1")
+ Decimal('-1')
"""
return a.compare(b, context=self)
@@ -3617,21 +3922,21 @@
>>> c = ExtendedContext
>>> c.compare_signal(Decimal('2.1'), Decimal('3'))
- Decimal("-1")
+ Decimal('-1')
>>> c.compare_signal(Decimal('2.1'), Decimal('2.1'))
- Decimal("0")
+ Decimal('0')
>>> c.flags[InvalidOperation] = 0
>>> print c.flags[InvalidOperation]
0
>>> c.compare_signal(Decimal('NaN'), Decimal('2.1'))
- Decimal("NaN")
+ Decimal('NaN')
>>> print c.flags[InvalidOperation]
1
>>> c.flags[InvalidOperation] = 0
>>> print c.flags[InvalidOperation]
0
>>> c.compare_signal(Decimal('sNaN'), Decimal('2.1'))
- Decimal("NaN")
+ Decimal('NaN')
>>> print c.flags[InvalidOperation]
1
"""
@@ -3645,17 +3950,17 @@
representations.
>>> ExtendedContext.compare_total(Decimal('12.73'), Decimal('127.9'))
- Decimal("-1")
+ Decimal('-1')
>>> ExtendedContext.compare_total(Decimal('-127'), Decimal('12'))
- Decimal("-1")
+ Decimal('-1')
>>> ExtendedContext.compare_total(Decimal('12.30'), Decimal('12.3'))
- Decimal("-1")
+ Decimal('-1')
>>> ExtendedContext.compare_total(Decimal('12.30'), Decimal('12.30'))
- Decimal("0")
+ Decimal('0')
>>> ExtendedContext.compare_total(Decimal('12.3'), Decimal('12.300'))
- Decimal("1")
+ Decimal('1')
>>> ExtendedContext.compare_total(Decimal('12.3'), Decimal('NaN'))
- Decimal("-1")
+ Decimal('-1')
"""
return a.compare_total(b)
@@ -3670,9 +3975,9 @@
"""Returns a copy of the operand with the sign set to 0.
>>> ExtendedContext.copy_abs(Decimal('2.1'))
- Decimal("2.1")
+ Decimal('2.1')
>>> ExtendedContext.copy_abs(Decimal('-100'))
- Decimal("100")
+ Decimal('100')
"""
return a.copy_abs()
@@ -3680,9 +3985,9 @@
"""Returns a copy of the decimal objet.
>>> ExtendedContext.copy_decimal(Decimal('2.1'))
- Decimal("2.1")
+ Decimal('2.1')
>>> ExtendedContext.copy_decimal(Decimal('-1.00'))
- Decimal("-1.00")
+ Decimal('-1.00')
"""
return Decimal(a)
@@ -3690,9 +3995,9 @@
"""Returns a copy of the operand with the sign inverted.
>>> ExtendedContext.copy_negate(Decimal('101.5'))
- Decimal("-101.5")
+ Decimal('-101.5')
>>> ExtendedContext.copy_negate(Decimal('-101.5'))
- Decimal("101.5")
+ Decimal('101.5')
"""
return a.copy_negate()
@@ -3703,13 +4008,13 @@
equal to the sign of the second operand.
>>> ExtendedContext.copy_sign(Decimal( '1.50'), Decimal('7.33'))
- Decimal("1.50")
+ Decimal('1.50')
>>> ExtendedContext.copy_sign(Decimal('-1.50'), Decimal('7.33'))
- Decimal("1.50")
+ Decimal('1.50')
>>> ExtendedContext.copy_sign(Decimal( '1.50'), Decimal('-7.33'))
- Decimal("-1.50")
+ Decimal('-1.50')
>>> ExtendedContext.copy_sign(Decimal('-1.50'), Decimal('-7.33'))
- Decimal("-1.50")
+ Decimal('-1.50')
"""
return a.copy_sign(b)
@@ -3717,25 +4022,25 @@
"""Decimal division in a specified context.
>>> ExtendedContext.divide(Decimal('1'), Decimal('3'))
- Decimal("0.333333333")
+ Decimal('0.333333333')
>>> ExtendedContext.divide(Decimal('2'), Decimal('3'))
- Decimal("0.666666667")
+ Decimal('0.666666667')
>>> ExtendedContext.divide(Decimal('5'), Decimal('2'))
- Decimal("2.5")
+ Decimal('2.5')
>>> ExtendedContext.divide(Decimal('1'), Decimal('10'))
- Decimal("0.1")
+ Decimal('0.1')
>>> ExtendedContext.divide(Decimal('12'), Decimal('12'))
- Decimal("1")
+ Decimal('1')
>>> ExtendedContext.divide(Decimal('8.00'), Decimal('2'))
- Decimal("4.00")
+ Decimal('4.00')
>>> ExtendedContext.divide(Decimal('2.400'), Decimal('2.0'))
- Decimal("1.20")
+ Decimal('1.20')
>>> ExtendedContext.divide(Decimal('1000'), Decimal('100'))
- Decimal("10")
+ Decimal('10')
>>> ExtendedContext.divide(Decimal('1000'), Decimal('1'))
- Decimal("1000")
+ Decimal('1000')
>>> ExtendedContext.divide(Decimal('2.40E+6'), Decimal('2'))
- Decimal("1.20E+6")
+ Decimal('1.20E+6')
"""
return a.__div__(b, context=self)
@@ -3743,15 +4048,22 @@
"""Divides two numbers and returns the integer part of the result.
>>> ExtendedContext.divide_int(Decimal('2'), Decimal('3'))
- Decimal("0")
+ Decimal('0')
>>> ExtendedContext.divide_int(Decimal('10'), Decimal('3'))
- Decimal("3")
+ Decimal('3')
>>> ExtendedContext.divide_int(Decimal('1'), Decimal('0.3'))
- Decimal("3")
+ Decimal('3')
"""
return a.__floordiv__(b, context=self)
def divmod(self, a, b):
+ """Return (a // b, a % b)
+
+ >>> ExtendedContext.divmod(Decimal(8), Decimal(3))
+ (Decimal('2'), Decimal('2'))
+ >>> ExtendedContext.divmod(Decimal(8), Decimal(4))
+ (Decimal('2'), Decimal('0'))
+ """
return a.__divmod__(b, context=self)
def exp(self, a):
@@ -3761,17 +4073,17 @@
>>> c.Emin = -999
>>> c.Emax = 999
>>> c.exp(Decimal('-Infinity'))
- Decimal("0")
+ Decimal('0')
>>> c.exp(Decimal('-1'))
- Decimal("0.367879441")
+ Decimal('0.367879441')
>>> c.exp(Decimal('0'))
- Decimal("1")
+ Decimal('1')
>>> c.exp(Decimal('1'))
- Decimal("2.71828183")
+ Decimal('2.71828183')
>>> c.exp(Decimal('0.693147181'))
- Decimal("2.00000000")
+ Decimal('2.00000000')
>>> c.exp(Decimal('+Infinity'))
- Decimal("Infinity")
+ Decimal('Infinity')
"""
return a.exp(context=self)
@@ -3783,11 +4095,11 @@
multiplication, using add, all with only one final rounding.
>>> ExtendedContext.fma(Decimal('3'), Decimal('5'), Decimal('7'))
- Decimal("22")
+ Decimal('22')
>>> ExtendedContext.fma(Decimal('3'), Decimal('-5'), Decimal('7'))
- Decimal("-8")
+ Decimal('-8')
>>> ExtendedContext.fma(Decimal('888565290'), Decimal('1557.96930'), Decimal('-86087.7578'))
- Decimal("1.38435736E+12")
+ Decimal('1.38435736E+12')
"""
return a.fma(b, c, context=self)
@@ -3941,15 +4253,15 @@
>>> c.Emin = -999
>>> c.Emax = 999
>>> c.ln(Decimal('0'))
- Decimal("-Infinity")
+ Decimal('-Infinity')
>>> c.ln(Decimal('1.000'))
- Decimal("0")
+ Decimal('0')
>>> c.ln(Decimal('2.71828183'))
- Decimal("1.00000000")
+ Decimal('1.00000000')
>>> c.ln(Decimal('10'))
- Decimal("2.30258509")
+ Decimal('2.30258509')
>>> c.ln(Decimal('+Infinity'))
- Decimal("Infinity")
+ Decimal('Infinity')
"""
return a.ln(context=self)
@@ -3960,19 +4272,19 @@
>>> c.Emin = -999
>>> c.Emax = 999
>>> c.log10(Decimal('0'))
- Decimal("-Infinity")
+ Decimal('-Infinity')
>>> c.log10(Decimal('0.001'))
- Decimal("-3")
+ Decimal('-3')
>>> c.log10(Decimal('1.000'))
- Decimal("0")
+ Decimal('0')
>>> c.log10(Decimal('2'))
- Decimal("0.301029996")
+ Decimal('0.301029996')
>>> c.log10(Decimal('10'))
- Decimal("1")
+ Decimal('1')
>>> c.log10(Decimal('70'))
- Decimal("1.84509804")
+ Decimal('1.84509804')
>>> c.log10(Decimal('+Infinity'))
- Decimal("Infinity")
+ Decimal('Infinity')
"""
return a.log10(context=self)
@@ -3985,13 +4297,13 @@
value of that digit and without limiting the resulting exponent).
>>> ExtendedContext.logb(Decimal('250'))
- Decimal("2")
+ Decimal('2')
>>> ExtendedContext.logb(Decimal('2.50'))
- Decimal("0")
+ Decimal('0')
>>> ExtendedContext.logb(Decimal('0.03'))
- Decimal("-2")
+ Decimal('-2')
>>> ExtendedContext.logb(Decimal('0'))
- Decimal("-Infinity")
+ Decimal('-Infinity')
"""
return a.logb(context=self)
@@ -4001,17 +4313,17 @@
The operands must be both logical numbers.
>>> ExtendedContext.logical_and(Decimal('0'), Decimal('0'))
- Decimal("0")
+ Decimal('0')
>>> ExtendedContext.logical_and(Decimal('0'), Decimal('1'))
- Decimal("0")
+ Decimal('0')
>>> ExtendedContext.logical_and(Decimal('1'), Decimal('0'))
- Decimal("0")
+ Decimal('0')
>>> ExtendedContext.logical_and(Decimal('1'), Decimal('1'))
- Decimal("1")
+ Decimal('1')
>>> ExtendedContext.logical_and(Decimal('1100'), Decimal('1010'))
- Decimal("1000")
+ Decimal('1000')
>>> ExtendedContext.logical_and(Decimal('1111'), Decimal('10'))
- Decimal("10")
+ Decimal('10')
"""
return a.logical_and(b, context=self)
@@ -4021,13 +4333,13 @@
The operand must be a logical number.
>>> ExtendedContext.logical_invert(Decimal('0'))
- Decimal("111111111")
+ Decimal('111111111')
>>> ExtendedContext.logical_invert(Decimal('1'))
- Decimal("111111110")
+ Decimal('111111110')
>>> ExtendedContext.logical_invert(Decimal('111111111'))
- Decimal("0")
+ Decimal('0')
>>> ExtendedContext.logical_invert(Decimal('101010101'))
- Decimal("10101010")
+ Decimal('10101010')
"""
return a.logical_invert(context=self)
@@ -4037,17 +4349,17 @@
The operands must be both logical numbers.
>>> ExtendedContext.logical_or(Decimal('0'), Decimal('0'))
- Decimal("0")
+ Decimal('0')
>>> ExtendedContext.logical_or(Decimal('0'), Decimal('1'))
- Decimal("1")
+ Decimal('1')
>>> ExtendedContext.logical_or(Decimal('1'), Decimal('0'))
- Decimal("1")
+ Decimal('1')
>>> ExtendedContext.logical_or(Decimal('1'), Decimal('1'))
- Decimal("1")
+ Decimal('1')
>>> ExtendedContext.logical_or(Decimal('1100'), Decimal('1010'))
- Decimal("1110")
+ Decimal('1110')
>>> ExtendedContext.logical_or(Decimal('1110'), Decimal('10'))
- Decimal("1110")
+ Decimal('1110')
"""
return a.logical_or(b, context=self)
@@ -4057,17 +4369,17 @@
The operands must be both logical numbers.
>>> ExtendedContext.logical_xor(Decimal('0'), Decimal('0'))
- Decimal("0")
+ Decimal('0')
>>> ExtendedContext.logical_xor(Decimal('0'), Decimal('1'))
- Decimal("1")
+ Decimal('1')
>>> ExtendedContext.logical_xor(Decimal('1'), Decimal('0'))
- Decimal("1")
+ Decimal('1')
>>> ExtendedContext.logical_xor(Decimal('1'), Decimal('1'))
- Decimal("0")
+ Decimal('0')
>>> ExtendedContext.logical_xor(Decimal('1100'), Decimal('1010'))
- Decimal("110")
+ Decimal('110')
>>> ExtendedContext.logical_xor(Decimal('1111'), Decimal('10'))
- Decimal("1101")
+ Decimal('1101')
"""
return a.logical_xor(b, context=self)
@@ -4075,19 +4387,19 @@
"""max compares two values numerically and returns the maximum.
If either operand is a NaN then the general rules apply.
- Otherwise, the operands are compared as as though by the compare
+ Otherwise, the operands are compared as though by the compare
operation. If they are numerically equal then the left-hand operand
is chosen as the result. Otherwise the maximum (closer to positive
infinity) of the two operands is chosen as the result.
>>> ExtendedContext.max(Decimal('3'), Decimal('2'))
- Decimal("3")
+ Decimal('3')
>>> ExtendedContext.max(Decimal('-10'), Decimal('3'))
- Decimal("3")
+ Decimal('3')
>>> ExtendedContext.max(Decimal('1.0'), Decimal('1'))
- Decimal("1")
+ Decimal('1')
>>> ExtendedContext.max(Decimal('7'), Decimal('NaN'))
- Decimal("7")
+ Decimal('7')
"""
return a.max(b, context=self)
@@ -4099,19 +4411,19 @@
"""min compares two values numerically and returns the minimum.
If either operand is a NaN then the general rules apply.
- Otherwise, the operands are compared as as though by the compare
+ Otherwise, the operands are compared as though by the compare
operation. If they are numerically equal then the left-hand operand
is chosen as the result. Otherwise the minimum (closer to negative
infinity) of the two operands is chosen as the result.
>>> ExtendedContext.min(Decimal('3'), Decimal('2'))
- Decimal("2")
+ Decimal('2')
>>> ExtendedContext.min(Decimal('-10'), Decimal('3'))
- Decimal("-10")
+ Decimal('-10')
>>> ExtendedContext.min(Decimal('1.0'), Decimal('1'))
- Decimal("1.0")
+ Decimal('1.0')
>>> ExtendedContext.min(Decimal('7'), Decimal('NaN'))
- Decimal("7")
+ Decimal('7')
"""
return a.min(b, context=self)
@@ -4127,9 +4439,9 @@
has the same exponent as the operand.
>>> ExtendedContext.minus(Decimal('1.3'))
- Decimal("-1.3")
+ Decimal('-1.3')
>>> ExtendedContext.minus(Decimal('-1.3'))
- Decimal("1.3")
+ Decimal('1.3')
"""
return a.__neg__(context=self)
@@ -4142,15 +4454,15 @@
of the two operands.
>>> ExtendedContext.multiply(Decimal('1.20'), Decimal('3'))
- Decimal("3.60")
+ Decimal('3.60')
>>> ExtendedContext.multiply(Decimal('7'), Decimal('3'))
- Decimal("21")
+ Decimal('21')
>>> ExtendedContext.multiply(Decimal('0.9'), Decimal('0.8'))
- Decimal("0.72")
+ Decimal('0.72')
>>> ExtendedContext.multiply(Decimal('0.9'), Decimal('-0'))
- Decimal("-0.0")
+ Decimal('-0.0')
>>> ExtendedContext.multiply(Decimal('654321'), Decimal('654321'))
- Decimal("4.28135971E+11")
+ Decimal('4.28135971E+11')
"""
return a.__mul__(b, context=self)
@@ -4161,13 +4473,13 @@
>>> c.Emin = -999
>>> c.Emax = 999
>>> ExtendedContext.next_minus(Decimal('1'))
- Decimal("0.999999999")
+ Decimal('0.999999999')
>>> c.next_minus(Decimal('1E-1007'))
- Decimal("0E-1007")
+ Decimal('0E-1007')
>>> ExtendedContext.next_minus(Decimal('-1.00000003'))
- Decimal("-1.00000004")
+ Decimal('-1.00000004')
>>> c.next_minus(Decimal('Infinity'))
- Decimal("9.99999999E+999")
+ Decimal('9.99999999E+999')
"""
return a.next_minus(context=self)
@@ -4178,13 +4490,13 @@
>>> c.Emin = -999
>>> c.Emax = 999
>>> ExtendedContext.next_plus(Decimal('1'))
- Decimal("1.00000001")
+ Decimal('1.00000001')
>>> c.next_plus(Decimal('-1E-1007'))
- Decimal("-0E-1007")
+ Decimal('-0E-1007')
>>> ExtendedContext.next_plus(Decimal('-1.00000003'))
- Decimal("-1.00000002")
+ Decimal('-1.00000002')
>>> c.next_plus(Decimal('-Infinity'))
- Decimal("-9.99999999E+999")
+ Decimal('-9.99999999E+999')
"""
return a.next_plus(context=self)
@@ -4200,19 +4512,19 @@
>>> c.Emin = -999
>>> c.Emax = 999
>>> c.next_toward(Decimal('1'), Decimal('2'))
- Decimal("1.00000001")
+ Decimal('1.00000001')
>>> c.next_toward(Decimal('-1E-1007'), Decimal('1'))
- Decimal("-0E-1007")
+ Decimal('-0E-1007')
>>> c.next_toward(Decimal('-1.00000003'), Decimal('0'))
- Decimal("-1.00000002")
+ Decimal('-1.00000002')
>>> c.next_toward(Decimal('1'), Decimal('0'))
- Decimal("0.999999999")
+ Decimal('0.999999999')
>>> c.next_toward(Decimal('1E-1007'), Decimal('-100'))
- Decimal("0E-1007")
+ Decimal('0E-1007')
>>> c.next_toward(Decimal('-1.00000003'), Decimal('-10'))
- Decimal("-1.00000004")
+ Decimal('-1.00000004')
>>> c.next_toward(Decimal('0.00'), Decimal('-0.0000'))
- Decimal("-0.00")
+ Decimal('-0.00')
"""
return a.next_toward(b, context=self)
@@ -4223,17 +4535,17 @@
result.
>>> ExtendedContext.normalize(Decimal('2.1'))
- Decimal("2.1")
+ Decimal('2.1')
>>> ExtendedContext.normalize(Decimal('-2.0'))
- Decimal("-2")
+ Decimal('-2')
>>> ExtendedContext.normalize(Decimal('1.200'))
- Decimal("1.2")
+ Decimal('1.2')
>>> ExtendedContext.normalize(Decimal('-120'))
- Decimal("-1.2E+2")
+ Decimal('-1.2E+2')
>>> ExtendedContext.normalize(Decimal('120.00'))
- Decimal("1.2E+2")
+ Decimal('1.2E+2')
>>> ExtendedContext.normalize(Decimal('0.00'))
- Decimal("0")
+ Decimal('0')
"""
return a.normalize(context=self)
@@ -4292,9 +4604,9 @@
has the same exponent as the operand.
>>> ExtendedContext.plus(Decimal('1.3'))
- Decimal("1.3")
+ Decimal('1.3')
>>> ExtendedContext.plus(Decimal('-1.3'))
- Decimal("-1.3")
+ Decimal('-1.3')
"""
return a.__pos__(context=self)
@@ -4324,46 +4636,46 @@
>>> c.Emin = -999
>>> c.Emax = 999
>>> c.power(Decimal('2'), Decimal('3'))
- Decimal("8")
+ Decimal('8')
>>> c.power(Decimal('-2'), Decimal('3'))
- Decimal("-8")
+ Decimal('-8')
>>> c.power(Decimal('2'), Decimal('-3'))
- Decimal("0.125")
+ Decimal('0.125')
>>> c.power(Decimal('1.7'), Decimal('8'))
- Decimal("69.7575744")
+ Decimal('69.7575744')
>>> c.power(Decimal('10'), Decimal('0.301029996'))
- Decimal("2.00000000")
+ Decimal('2.00000000')
>>> c.power(Decimal('Infinity'), Decimal('-1'))
- Decimal("0")
+ Decimal('0')
>>> c.power(Decimal('Infinity'), Decimal('0'))
- Decimal("1")
+ Decimal('1')
>>> c.power(Decimal('Infinity'), Decimal('1'))
- Decimal("Infinity")
+ Decimal('Infinity')
>>> c.power(Decimal('-Infinity'), Decimal('-1'))
- Decimal("-0")
+ Decimal('-0')
>>> c.power(Decimal('-Infinity'), Decimal('0'))
- Decimal("1")
+ Decimal('1')
>>> c.power(Decimal('-Infinity'), Decimal('1'))
- Decimal("-Infinity")
+ Decimal('-Infinity')
>>> c.power(Decimal('-Infinity'), Decimal('2'))
- Decimal("Infinity")
+ Decimal('Infinity')
>>> c.power(Decimal('0'), Decimal('0'))
- Decimal("NaN")
+ Decimal('NaN')
>>> c.power(Decimal('3'), Decimal('7'), Decimal('16'))
- Decimal("11")
+ Decimal('11')
>>> c.power(Decimal('-3'), Decimal('7'), Decimal('16'))
- Decimal("-11")
+ Decimal('-11')
>>> c.power(Decimal('-3'), Decimal('8'), Decimal('16'))
- Decimal("1")
+ Decimal('1')
>>> c.power(Decimal('3'), Decimal('7'), Decimal('-16'))
- Decimal("11")
+ Decimal('11')
>>> c.power(Decimal('23E12345'), Decimal('67E189'), Decimal('123456789'))
- Decimal("11729830")
+ Decimal('11729830')
>>> c.power(Decimal('-0'), Decimal('17'), Decimal('1729'))
- Decimal("-0")
+ Decimal('-0')
>>> c.power(Decimal('-23'), Decimal('0'), Decimal('65537'))
- Decimal("1")
+ Decimal('1')
"""
return a.__pow__(b, modulo, context=self)
@@ -4386,35 +4698,35 @@
if the result is subnormal and inexact.
>>> ExtendedContext.quantize(Decimal('2.17'), Decimal('0.001'))
- Decimal("2.170")
+ Decimal('2.170')
>>> ExtendedContext.quantize(Decimal('2.17'), Decimal('0.01'))
- Decimal("2.17")
+ Decimal('2.17')
>>> ExtendedContext.quantize(Decimal('2.17'), Decimal('0.1'))
- Decimal("2.2")
+ Decimal('2.2')
>>> ExtendedContext.quantize(Decimal('2.17'), Decimal('1e+0'))
- Decimal("2")
+ Decimal('2')
>>> ExtendedContext.quantize(Decimal('2.17'), Decimal('1e+1'))
- Decimal("0E+1")
+ Decimal('0E+1')
>>> ExtendedContext.quantize(Decimal('-Inf'), Decimal('Infinity'))
- Decimal("-Infinity")
+ Decimal('-Infinity')
>>> ExtendedContext.quantize(Decimal('2'), Decimal('Infinity'))
- Decimal("NaN")
+ Decimal('NaN')
>>> ExtendedContext.quantize(Decimal('-0.1'), Decimal('1'))
- Decimal("-0")
+ Decimal('-0')
>>> ExtendedContext.quantize(Decimal('-0'), Decimal('1e+5'))
- Decimal("-0E+5")
+ Decimal('-0E+5')
>>> ExtendedContext.quantize(Decimal('+35236450.6'), Decimal('1e-2'))
- Decimal("NaN")
+ Decimal('NaN')
>>> ExtendedContext.quantize(Decimal('-35236450.6'), Decimal('1e-2'))
- Decimal("NaN")
+ Decimal('NaN')
>>> ExtendedContext.quantize(Decimal('217'), Decimal('1e-1'))
- Decimal("217.0")
+ Decimal('217.0')
>>> ExtendedContext.quantize(Decimal('217'), Decimal('1e-0'))
- Decimal("217")
+ Decimal('217')
>>> ExtendedContext.quantize(Decimal('217'), Decimal('1e+1'))
- Decimal("2.2E+2")
+ Decimal('2.2E+2')
>>> ExtendedContext.quantize(Decimal('217'), Decimal('1e+2'))
- Decimal("2E+2")
+ Decimal('2E+2')
"""
return a.quantize(b, context=self)
@@ -4422,7 +4734,7 @@
"""Just returns 10, as this is Decimal, :)
>>> ExtendedContext.radix()
- Decimal("10")
+ Decimal('10')
"""
return Decimal(10)
@@ -4439,17 +4751,17 @@
remainder cannot be calculated).
>>> ExtendedContext.remainder(Decimal('2.1'), Decimal('3'))
- Decimal("2.1")
+ Decimal('2.1')
>>> ExtendedContext.remainder(Decimal('10'), Decimal('3'))
- Decimal("1")
+ Decimal('1')
>>> ExtendedContext.remainder(Decimal('-10'), Decimal('3'))
- Decimal("-1")
+ Decimal('-1')
>>> ExtendedContext.remainder(Decimal('10.2'), Decimal('1'))
- Decimal("0.2")
+ Decimal('0.2')
>>> ExtendedContext.remainder(Decimal('10'), Decimal('0.3'))
- Decimal("0.1")
+ Decimal('0.1')
>>> ExtendedContext.remainder(Decimal('3.6'), Decimal('1.3'))
- Decimal("1.0")
+ Decimal('1.0')
"""
return a.__mod__(b, context=self)
@@ -4464,19 +4776,19 @@
remainder cannot be calculated).
>>> ExtendedContext.remainder_near(Decimal('2.1'), Decimal('3'))
- Decimal("-0.9")
+ Decimal('-0.9')
>>> ExtendedContext.remainder_near(Decimal('10'), Decimal('6'))
- Decimal("-2")
+ Decimal('-2')
>>> ExtendedContext.remainder_near(Decimal('10'), Decimal('3'))
- Decimal("1")
+ Decimal('1')
>>> ExtendedContext.remainder_near(Decimal('-10'), Decimal('3'))
- Decimal("-1")
+ Decimal('-1')
>>> ExtendedContext.remainder_near(Decimal('10.2'), Decimal('1'))
- Decimal("0.2")
+ Decimal('0.2')
>>> ExtendedContext.remainder_near(Decimal('10'), Decimal('0.3'))
- Decimal("0.1")
+ Decimal('0.1')
>>> ExtendedContext.remainder_near(Decimal('3.6'), Decimal('1.3'))
- Decimal("-0.3")
+ Decimal('-0.3')
"""
return a.remainder_near(b, context=self)
@@ -4490,15 +4802,15 @@
positive or to the right otherwise.
>>> ExtendedContext.rotate(Decimal('34'), Decimal('8'))
- Decimal("400000003")
+ Decimal('400000003')
>>> ExtendedContext.rotate(Decimal('12'), Decimal('9'))
- Decimal("12")
+ Decimal('12')
>>> ExtendedContext.rotate(Decimal('123456789'), Decimal('-2'))
- Decimal("891234567")
+ Decimal('891234567')
>>> ExtendedContext.rotate(Decimal('123456789'), Decimal('0'))
- Decimal("123456789")
+ Decimal('123456789')
>>> ExtendedContext.rotate(Decimal('123456789'), Decimal('+2'))
- Decimal("345678912")
+ Decimal('345678912')
"""
return a.rotate(b, context=self)
@@ -4523,11 +4835,11 @@
"""Returns the first operand after adding the second value its exp.
>>> ExtendedContext.scaleb(Decimal('7.50'), Decimal('-2'))
- Decimal("0.0750")
+ Decimal('0.0750')
>>> ExtendedContext.scaleb(Decimal('7.50'), Decimal('0'))
- Decimal("7.50")
+ Decimal('7.50')
>>> ExtendedContext.scaleb(Decimal('7.50'), Decimal('3'))
- Decimal("7.50E+3")
+ Decimal('7.50E+3')
"""
return a.scaleb (b, context=self)
@@ -4542,15 +4854,15 @@
coefficient are zeros.
>>> ExtendedContext.shift(Decimal('34'), Decimal('8'))
- Decimal("400000000")
+ Decimal('400000000')
>>> ExtendedContext.shift(Decimal('12'), Decimal('9'))
- Decimal("0")
+ Decimal('0')
>>> ExtendedContext.shift(Decimal('123456789'), Decimal('-2'))
- Decimal("1234567")
+ Decimal('1234567')
>>> ExtendedContext.shift(Decimal('123456789'), Decimal('0'))
- Decimal("123456789")
+ Decimal('123456789')
>>> ExtendedContext.shift(Decimal('123456789'), Decimal('+2'))
- Decimal("345678900")
+ Decimal('345678900')
"""
return a.shift(b, context=self)
@@ -4561,23 +4873,23 @@
algorithm.
>>> ExtendedContext.sqrt(Decimal('0'))
- Decimal("0")
+ Decimal('0')
>>> ExtendedContext.sqrt(Decimal('-0'))
- Decimal("-0")
+ Decimal('-0')
>>> ExtendedContext.sqrt(Decimal('0.39'))
- Decimal("0.624499800")
+ Decimal('0.624499800')
>>> ExtendedContext.sqrt(Decimal('100'))
- Decimal("10")
+ Decimal('10')
>>> ExtendedContext.sqrt(Decimal('1'))
- Decimal("1")
+ Decimal('1')
>>> ExtendedContext.sqrt(Decimal('1.0'))
- Decimal("1.0")
+ Decimal('1.0')
>>> ExtendedContext.sqrt(Decimal('1.00'))
- Decimal("1.0")
+ Decimal('1.0')
>>> ExtendedContext.sqrt(Decimal('7'))
- Decimal("2.64575131")
+ Decimal('2.64575131')
>>> ExtendedContext.sqrt(Decimal('10'))
- Decimal("3.16227766")
+ Decimal('3.16227766')
>>> ExtendedContext.prec
9
"""
@@ -4587,11 +4899,11 @@
"""Return the difference between the two operands.
>>> ExtendedContext.subtract(Decimal('1.3'), Decimal('1.07'))
- Decimal("0.23")
+ Decimal('0.23')
>>> ExtendedContext.subtract(Decimal('1.3'), Decimal('1.30'))
- Decimal("0.00")
+ Decimal('0.00')
>>> ExtendedContext.subtract(Decimal('1.3'), Decimal('2.07'))
- Decimal("-0.77")
+ Decimal('-0.77')
"""
return a.__sub__(b, context=self)
@@ -4620,21 +4932,21 @@
context.
>>> ExtendedContext.to_integral_exact(Decimal('2.1'))
- Decimal("2")
+ Decimal('2')
>>> ExtendedContext.to_integral_exact(Decimal('100'))
- Decimal("100")
+ Decimal('100')
>>> ExtendedContext.to_integral_exact(Decimal('100.0'))
- Decimal("100")
+ Decimal('100')
>>> ExtendedContext.to_integral_exact(Decimal('101.5'))
- Decimal("102")
+ Decimal('102')
>>> ExtendedContext.to_integral_exact(Decimal('-101.5'))
- Decimal("-102")
+ Decimal('-102')
>>> ExtendedContext.to_integral_exact(Decimal('10E+5'))
- Decimal("1.0E+6")
+ Decimal('1.0E+6')
>>> ExtendedContext.to_integral_exact(Decimal('7.89E+77'))
- Decimal("7.89E+77")
+ Decimal('7.89E+77')
>>> ExtendedContext.to_integral_exact(Decimal('-Inf'))
- Decimal("-Infinity")
+ Decimal('-Infinity')
"""
return a.to_integral_exact(context=self)
@@ -4648,21 +4960,21 @@
be set. The rounding mode is taken from the context.
>>> ExtendedContext.to_integral_value(Decimal('2.1'))
- Decimal("2")
+ Decimal('2')
>>> ExtendedContext.to_integral_value(Decimal('100'))
- Decimal("100")
+ Decimal('100')
>>> ExtendedContext.to_integral_value(Decimal('100.0'))
- Decimal("100")
+ Decimal('100')
>>> ExtendedContext.to_integral_value(Decimal('101.5'))
- Decimal("102")
+ Decimal('102')
>>> ExtendedContext.to_integral_value(Decimal('-101.5'))
- Decimal("-102")
+ Decimal('-102')
>>> ExtendedContext.to_integral_value(Decimal('10E+5'))
- Decimal("1.0E+6")
+ Decimal('1.0E+6')
>>> ExtendedContext.to_integral_value(Decimal('7.89E+77'))
- Decimal("7.89E+77")
+ Decimal('7.89E+77')
>>> ExtendedContext.to_integral_value(Decimal('-Inf'))
- Decimal("-Infinity")
+ Decimal('-Infinity')
"""
return a.to_integral_value(context=self)
@@ -4854,7 +5166,7 @@
log_tenpower = f*M # exact
else:
log_d = 0 # error < 2.31
- log_tenpower = div_nearest(f, 10**-p) # error < 0.5
+ log_tenpower = _div_nearest(f, 10**-p) # error < 0.5
return _div_nearest(log_tenpower+log_d, 100)
@@ -5111,8 +5423,7 @@
##### crud for parsing strings #############################################
-import re
-
+#
# Regular expression used for parsing numeric strings. Additional
# comments:
#
@@ -5124,47 +5435,172 @@
# number between the optional sign and the optional exponent must have
# at least one decimal digit, possibly after the decimal point. The
# lookahead expression '(?=\d|\.\d)' checks this.
-#
-# As the flag UNICODE is not enabled here, we're explicitly avoiding any
-# other meaning for \d than the numbers [0-9].
import re
-_parser = re.compile(r""" # A numeric string consists of:
+_parser = re.compile(r""" # A numeric string consists of:
# \s*
- (?P<sign>[-+])? # an optional sign, followed by either...
+ (?P<sign>[-+])? # an optional sign, followed by either...
(
- (?=\d|\.\d) # ...a number (with at least one digit)
- (?P<int>\d*) # consisting of a (possibly empty) integer part
- (\.(?P<frac>\d*))? # followed by an optional fractional part
- (E(?P<exp>[-+]?\d+))? # followed by an optional exponent, or...
+ (?=\d|\.\d) # ...a number (with at least one digit)
+ (?P<int>\d*) # having a (possibly empty) integer part
+ (\.(?P<frac>\d*))? # followed by an optional fractional part
+ (E(?P<exp>[-+]?\d+))? # followed by an optional exponent, or...
|
- Inf(inity)? # ...an infinity, or...
+ Inf(inity)? # ...an infinity, or...
|
- (?P<signal>s)? # ...an (optionally signaling)
- NaN # NaN
- (?P<diag>\d*) # with (possibly empty) diagnostic information.
+ (?P<signal>s)? # ...an (optionally signaling)
+ NaN # NaN
+ (?P<diag>\d*) # with (possibly empty) diagnostic info.
)
# \s*
- $
-""", re.VERBOSE | re.IGNORECASE).match
+ \Z
+""", re.VERBOSE | re.IGNORECASE | re.UNICODE).match
_all_zeros = re.compile('0*$').match
_exact_half = re.compile('50*$').match
+
+##### PEP3101 support functions ##############################################
+# The functions parse_format_specifier and format_align have little to do
+# with the Decimal class, and could potentially be reused for other pure
+# Python numeric classes that want to implement __format__
+#
+# A format specifier for Decimal looks like:
+#
+# [[fill]align][sign][0][minimumwidth][.precision][type]
+#
+
+_parse_format_specifier_regex = re.compile(r"""\A
+(?:
+ (?P<fill>.)?
+ (?P<align>[<>=^])
+)?
+(?P<sign>[-+ ])?
+(?P<zeropad>0)?
+(?P<minimumwidth>(?!0)\d+)?
+(?:\.(?P<precision>0|(?!0)\d+))?
+(?P<type>[eEfFgG%])?
+\Z
+""", re.VERBOSE)
+
del re
+def _parse_format_specifier(format_spec):
+ """Parse and validate a format specifier.
+
+ Turns a standard numeric format specifier into a dict, with the
+ following entries:
+
+ fill: fill character to pad field to minimum width
+ align: alignment type, either '<', '>', '=' or '^'
+ sign: either '+', '-' or ' '
+ minimumwidth: nonnegative integer giving minimum width
+ precision: nonnegative integer giving precision, or None
+ type: one of the characters 'eEfFgG%', or None
+ unicode: either True or False (always True for Python 3.x)
+
+ """
+ m = _parse_format_specifier_regex.match(format_spec)
+ if m is None:
+ raise ValueError("Invalid format specifier: " + format_spec)
+
+ # get the dictionary
+ format_dict = m.groupdict()
+
+ # defaults for fill and alignment
+ fill = format_dict['fill']
+ align = format_dict['align']
+ if format_dict.pop('zeropad') is not None:
+ # in the face of conflict, refuse the temptation to guess
+ if fill is not None and fill != '0':
+ raise ValueError("Fill character conflicts with '0'"
+ " in format specifier: " + format_spec)
+ if align is not None and align != '=':
+ raise ValueError("Alignment conflicts with '0' in "
+ "format specifier: " + format_spec)
+ fill = '0'
+ align = '='
+ format_dict['fill'] = fill or ' '
+ format_dict['align'] = align or '<'
+
+ if format_dict['sign'] is None:
+ format_dict['sign'] = '-'
+
+ # turn minimumwidth and precision entries into integers.
+ # minimumwidth defaults to 0; precision remains None if not given
+ format_dict['minimumwidth'] = int(format_dict['minimumwidth'] or '0')
+ if format_dict['precision'] is not None:
+ format_dict['precision'] = int(format_dict['precision'])
+
+ # if format type is 'g' or 'G' then a precision of 0 makes little
+ # sense; convert it to 1. Same if format type is unspecified.
+ if format_dict['precision'] == 0:
+ if format_dict['type'] is None or format_dict['type'] in 'gG':
+ format_dict['precision'] = 1
+
+ # record whether return type should be str or unicode
+ format_dict['unicode'] = isinstance(format_spec, unicode)
+
+ return format_dict
+
+def _format_align(body, spec_dict):
+ """Given an unpadded, non-aligned numeric string, add padding and
+ aligment to conform with the given format specifier dictionary (as
+ output from parse_format_specifier).
+
+ It's assumed that if body is negative then it starts with '-'.
+ Any leading sign ('-' or '+') is stripped from the body before
+ applying the alignment and padding rules, and replaced in the
+ appropriate position.
+
+ """
+ # figure out the sign; we only examine the first character, so if
+ # body has leading whitespace the results may be surprising.
+ if len(body) > 0 and body[0] in '-+':
+ sign = body[0]
+ body = body[1:]
+ else:
+ sign = ''
+
+ if sign != '-':
+ if spec_dict['sign'] in ' +':
+ sign = spec_dict['sign']
+ else:
+ sign = ''
+
+ # how much extra space do we have to play with?
+ minimumwidth = spec_dict['minimumwidth']
+ fill = spec_dict['fill']
+ padding = fill*(max(minimumwidth - (len(sign+body)), 0))
+
+ align = spec_dict['align']
+ if align == '<':
+ result = sign + body + padding
+ elif align == '>':
+ result = padding + sign + body
+ elif align == '=':
+ result = sign + padding + body
+ else: #align == '^'
+ half = len(padding)//2
+ result = padding[:half] + sign + body + padding[half:]
+
+ # make sure that result is unicode if necessary
+ if spec_dict['unicode']:
+ result = unicode(result)
+
+ return result
##### Useful Constants (internal use only) ################################
# Reusable defaults
-Inf = Decimal('Inf')
-negInf = Decimal('-Inf')
-NaN = Decimal('NaN')
-Dec_0 = Decimal(0)
-Dec_p1 = Decimal(1)
-Dec_n1 = Decimal(-1)
-
-# Infsign[sign] is infinity w/ that sign
-Infsign = (Inf, negInf)
+_Infinity = Decimal('Inf')
+_NegativeInfinity = Decimal('-Inf')
+_NaN = Decimal('NaN')
+_Zero = Decimal(0)
+_One = Decimal(1)
+_NegativeOne = Decimal(-1)
+
+# _SignedInfinity[sign] is infinity w/ that sign
+_SignedInfinity = (_Infinity, _NegativeInfinity)
diff --git a/Lib/distutils/ccompiler.py b/Lib/distutils/ccompiler.py
--- a/Lib/distutils/ccompiler.py
+++ b/Lib/distutils/ccompiler.py
@@ -5,7 +5,7 @@
# This module should be kept compatible with Python 2.1.
-__revision__ = "$Id: ccompiler.py 46331 2006-05-26 14:07:23Z bob.ippolito $"
+__revision__ = "$Id: ccompiler.py 77425 2010-01-11 22:54:57Z tarek.ziade $"
import sys, os, re
from types import *
@@ -159,7 +159,7 @@
# basically the same things with Unix C compilers.
for key in args.keys():
- if not self.executables.has_key(key):
+ if key not in self.executables:
raise ValueError, \
"unknown executable '%s' for class %s" % \
(key, self.__class__.__name__)
@@ -338,10 +338,7 @@
def _setup_compile(self, outdir, macros, incdirs, sources, depends,
extra):
- """Process arguments and decide which source files to compile.
-
- Merges _fix_compile_args() and _prep_compile().
- """
+ """Process arguments and decide which source files to compile."""
if outdir is None:
outdir = self.output_dir
elif type(outdir) is not StringType:
@@ -371,41 +368,6 @@
output_dir=outdir)
assert len(objects) == len(sources)
- # XXX should redo this code to eliminate skip_source entirely.
- # XXX instead create build and issue skip messages inline
-
- if self.force:
- skip_source = {} # rebuild everything
- for source in sources:
- skip_source[source] = 0
- elif depends is None:
- # If depends is None, figure out which source files we
- # have to recompile according to a simplistic check. We
- # just compare the source and object file, no deep
- # dependency checking involving header files.
- skip_source = {} # rebuild everything
- for source in sources: # no wait, rebuild nothing
- skip_source[source] = 1
-
- n_sources, n_objects = newer_pairwise(sources, objects)
- for source in n_sources: # no really, only rebuild what's
- skip_source[source] = 0 # out-of-date
- else:
- # If depends is a list of files, then do a different
- # simplistic check. Assume that each object depends on
- # its source and all files in the depends list.
- skip_source = {}
- # L contains all the depends plus a spot at the end for a
- # particular source file
- L = depends[:] + [None]
- for i in range(len(objects)):
- source = sources[i]
- L[-1] = source
- if newer_group(L, objects[i]):
- skip_source[source] = 0
- else:
- skip_source[source] = 1
-
pp_opts = gen_preprocess_options(macros, incdirs)
build = {}
@@ -414,10 +376,7 @@
obj = objects[i]
ext = os.path.splitext(src)[1]
self.mkpath(os.path.dirname(obj))
- if skip_source[src]:
- log.debug("skipping %s (%s up-to-date)", src, obj)
- else:
- build[obj] = src, ext
+ build[obj] = (src, ext)
return macros, objects, extra, pp_opts, build
@@ -464,7 +423,6 @@
# _fix_compile_args ()
-
def _prep_compile(self, sources, output_dir, depends=None):
"""Decide which souce files must be recompiled.
@@ -477,42 +435,9 @@
objects = self.object_filenames(sources, output_dir=output_dir)
assert len(objects) == len(sources)
- if self.force:
- skip_source = {} # rebuild everything
- for source in sources:
- skip_source[source] = 0
- elif depends is None:
- # If depends is None, figure out which source files we
- # have to recompile according to a simplistic check. We
- # just compare the source and object file, no deep
- # dependency checking involving header files.
- skip_source = {} # rebuild everything
- for source in sources: # no wait, rebuild nothing
- skip_source[source] = 1
-
- n_sources, n_objects = newer_pairwise(sources, objects)
- for source in n_sources: # no really, only rebuild what's
- skip_source[source] = 0 # out-of-date
- else:
- # If depends is a list of files, then do a different
- # simplistic check. Assume that each object depends on
- # its source and all files in the depends list.
- skip_source = {}
- # L contains all the depends plus a spot at the end for a
- # particular source file
- L = depends[:] + [None]
- for i in range(len(objects)):
- source = sources[i]
- L[-1] = source
- if newer_group(L, objects[i]):
- skip_source[source] = 0
- else:
- skip_source[source] = 1
-
- return objects, skip_source
-
- # _prep_compile ()
-
+ # Return an empty dict for the "which source files can be skipped"
+ # return value to preserve API compatibility.
+ return objects, {}
def _fix_object_args (self, objects, output_dir):
"""Typecheck and fix up some arguments supplied to various methods.
@@ -680,7 +605,6 @@
Raises CompileError on failure.
"""
-
# A concrete compiler class can either override this method
# entirely or implement _compile().
@@ -1041,7 +965,7 @@
return move_file (src, dst, dry_run=self.dry_run)
def mkpath (self, name, mode=0777):
- mkpath (name, mode, self.dry_run)
+ mkpath (name, mode, dry_run=self.dry_run)
# class CCompiler
diff --git a/Lib/distutils/command/bdist.py b/Lib/distutils/command/bdist.py
--- a/Lib/distutils/command/bdist.py
+++ b/Lib/distutils/command/bdist.py
@@ -5,9 +5,9 @@
# This module should be kept compatible with Python 2.1.
-__revision__ = "$Id: bdist.py 37828 2004-11-10 22:23:15Z loewis $"
+__revision__ = "$Id: bdist.py 62197 2008-04-07 01:53:39Z mark.hammond $"
-import os, string
+import os
from types import *
from distutils.core import Command
from distutils.errors import *
@@ -98,7 +98,10 @@
def finalize_options (self):
# have to finalize 'plat_name' before 'bdist_base'
if self.plat_name is None:
- self.plat_name = get_platform()
+ if self.skip_build:
+ self.plat_name = get_platform()
+ else:
+ self.plat_name = self.get_finalized_command('build').plat_name
# 'bdist_base' -- parent of per-built-distribution-format
# temporary directories (eg. we'll probably have
@@ -122,7 +125,6 @@
# finalize_options()
-
def run (self):
# Figure out which sub-commands we need to run.
diff --git a/Lib/distutils/command/bdist_dumb.py b/Lib/distutils/command/bdist_dumb.py
--- a/Lib/distutils/command/bdist_dumb.py
+++ b/Lib/distutils/command/bdist_dumb.py
@@ -6,12 +6,12 @@
# This module should be kept compatible with Python 2.1.
-__revision__ = "$Id: bdist_dumb.py 38697 2005-03-23 18:54:36Z loewis $"
+__revision__ = "$Id: bdist_dumb.py 61000 2008-02-23 17:40:11Z christian.heimes $"
import os
from distutils.core import Command
from distutils.util import get_platform
-from distutils.dir_util import create_tree, remove_tree, ensure_relative
+from distutils.dir_util import remove_tree, ensure_relative
from distutils.errors import *
from distutils.sysconfig import get_python_version
from distutils import log
diff --git a/Lib/distutils/sysconfig.py b/Lib/distutils/sysconfig.py
--- a/Lib/distutils/sysconfig.py
+++ b/Lib/distutils/sysconfig.py
@@ -9,7 +9,7 @@
Email: <fdrake at acm.org>
"""
-__revision__ = "$Id: sysconfig.py 52234 2006-10-08 17:50:26Z ronald.oussoren $"
+__revision__ = "$Id: sysconfig.py 83688 2010-08-03 21:18:06Z mark.dickinson $"
import os
import re
@@ -22,16 +22,32 @@
PREFIX = os.path.normpath(sys.prefix)
EXEC_PREFIX = os.path.normpath(sys.exec_prefix)
+# Path to the base directory of the project. On Windows the binary may
+# live in project/PCBuild9. If we're dealing with an x64 Windows build,
+# it'll live in project/PCbuild/amd64.
+project_base = os.path.dirname(os.path.realpath(sys.executable))
+if os.name == "nt" and "pcbuild" in project_base[-8:].lower():
+ project_base = os.path.abspath(os.path.join(project_base, os.path.pardir))
+# PC/VS7.1
+if os.name == "nt" and "\\pc\\v" in project_base[-10:].lower():
+ project_base = os.path.abspath(os.path.join(project_base, os.path.pardir,
+ os.path.pardir))
+# PC/AMD64
+if os.name == "nt" and "\\pcbuild\\amd64" in project_base[-14:].lower():
+ project_base = os.path.abspath(os.path.join(project_base, os.path.pardir,
+ os.path.pardir))
+
# python_build: (Boolean) if true, we're either building Python or
# building an extension with an un-installed Python, so we use
# different (hard-wired) directories.
-
-argv0_path = os.path.dirname(os.path.abspath(sys.executable))
-landmark = os.path.join(argv0_path, "Modules", "Setup")
-
-python_build = os.path.isfile(landmark)
-
-del landmark
+# Setup.local is available for Makefile builds including VPATH builds,
+# Setup.dist is available on Windows
+def _python_build():
+ for fn in ("Setup.dist", "Setup.local"):
+ if os.path.isfile(os.path.join(project_base, "Modules", fn)):
+ return True
+ return False
+python_build = _python_build()
def get_python_version():
@@ -55,15 +71,19 @@
"""
if prefix is None:
prefix = plat_specific and EXEC_PREFIX or PREFIX
+
if os.name == "posix":
if python_build:
- base = os.path.dirname(os.path.abspath(sys.executable))
+ buildir = os.path.dirname(os.path.realpath(sys.executable))
if plat_specific:
- inc_dir = base
+ # python.h is located in the buildir
+ inc_dir = buildir
else:
- inc_dir = os.path.join(base, "Include")
- if not os.path.exists(inc_dir):
- inc_dir = os.path.join(os.path.dirname(base), "Include")
+ # the source dir is relative to the buildir
+ srcdir = os.path.abspath(os.path.join(buildir,
+ get_config_var('srcdir')))
+ # Include is located in the srcdir
+ inc_dir = os.path.join(srcdir, "Include")
return inc_dir
return os.path.join(prefix, "include", "python" + get_python_version())
elif os.name == "nt":
@@ -113,7 +133,7 @@
if get_python_version() < "2.2":
return prefix
else:
- return os.path.join(PREFIX, "Lib", "site-packages")
+ return os.path.join(prefix, "Lib", "site-packages")
elif os.name == "mac":
if plat_specific:
@@ -129,9 +149,9 @@
elif os.name == "os2" or os.name == "java":
if standard_lib:
- return os.path.join(PREFIX, "Lib")
+ return os.path.join(prefix, "Lib")
else:
- return os.path.join(PREFIX, "Lib", "site-packages")
+ return os.path.join(prefix, "Lib", "site-packages")
else:
raise DistutilsPlatformError(
@@ -150,22 +170,22 @@
get_config_vars('CC', 'CXX', 'OPT', 'CFLAGS',
'CCSHARED', 'LDSHARED', 'SO')
- if os.environ.has_key('CC'):
+ if 'CC' in os.environ:
cc = os.environ['CC']
- if os.environ.has_key('CXX'):
+ if 'CXX' in os.environ:
cxx = os.environ['CXX']
- if os.environ.has_key('LDSHARED'):
+ if 'LDSHARED' in os.environ:
ldshared = os.environ['LDSHARED']
- if os.environ.has_key('CPP'):
+ if 'CPP' in os.environ:
cpp = os.environ['CPP']
else:
cpp = cc + " -E" # not always
- if os.environ.has_key('LDFLAGS'):
+ if 'LDFLAGS' in os.environ:
ldshared = ldshared + ' ' + os.environ['LDFLAGS']
- if os.environ.has_key('CFLAGS'):
+ if 'CFLAGS' in os.environ:
cflags = opt + ' ' + os.environ['CFLAGS']
ldshared = ldshared + ' ' + os.environ['CFLAGS']
- if os.environ.has_key('CPPFLAGS'):
+ if 'CPPFLAGS' in os.environ:
cpp = cpp + ' ' + os.environ['CPPFLAGS']
cflags = cflags + ' ' + os.environ['CPPFLAGS']
ldshared = ldshared + ' ' + os.environ['CPPFLAGS']
@@ -185,7 +205,10 @@
def get_config_h_filename():
"""Return full pathname of installed pyconfig.h file."""
if python_build:
- inc_dir = argv0_path
+ if os.name == "nt":
+ inc_dir = os.path.join(project_base, "PC")
+ else:
+ inc_dir = project_base
else:
inc_dir = get_python_inc(plat_specific=1)
if get_python_version() < '2.2':
@@ -199,7 +222,8 @@
def get_makefile_filename():
"""Return full pathname of installed Makefile from the Python build."""
if python_build:
- return os.path.join(os.path.dirname(sys.executable), "Makefile")
+ return os.path.join(os.path.dirname(os.path.realpath(sys.executable)),
+ "Makefile")
lib_dir = get_python_lib(plat_specific=1, standard_lib=1)
return os.path.join(lib_dir, "config", "Makefile")
@@ -256,18 +280,25 @@
while 1:
line = fp.readline()
- if line is None: # eof
+ if line is None: # eof
break
m = _variable_rx.match(line)
if m:
n, v = m.group(1, 2)
- v = string.strip(v)
- if "$" in v:
+ v = v.strip()
+ # `$$' is a literal `$' in make
+ tmpv = v.replace('$$', '')
+
+ if "$" in tmpv:
notdone[n] = v
else:
- try: v = int(v)
- except ValueError: pass
- done[n] = v
+ try:
+ v = int(v)
+ except ValueError:
+ # insert literal `$'
+ done[n] = v.replace('$$', '$')
+ else:
+ done[n] = v
# do variable interpolation here
while notdone:
@@ -277,12 +308,12 @@
if m:
n = m.group(1)
found = True
- if done.has_key(n):
+ if n in done:
item = str(done[n])
- elif notdone.has_key(n):
+ elif n in notdone:
# get it on a subsequent round
found = False
- elif os.environ.has_key(n):
+ elif n in os.environ:
# do it like make: fall back to environment
item = os.environ[n]
else:
@@ -295,7 +326,7 @@
else:
try: value = int(value)
except ValueError:
- done[name] = string.strip(value)
+ done[name] = value.strip()
else:
done[name] = value
del notdone[name]
@@ -366,7 +397,7 @@
# MACOSX_DEPLOYMENT_TARGET: configure bases some choices on it so
# it needs to be compatible.
# If it isn't set we set it to the configure-time value
- if sys.platform == 'darwin' and g.has_key('MACOSX_DEPLOYMENT_TARGET'):
+ if sys.platform == 'darwin' and 'MACOSX_DEPLOYMENT_TARGET' in g:
cfg_target = g['MACOSX_DEPLOYMENT_TARGET']
cur_target = os.getenv('MACOSX_DEPLOYMENT_TARGET', '')
if cur_target == '':
@@ -428,6 +459,8 @@
g['SO'] = '.pyd'
g['EXE'] = ".exe"
+ g['VERSION'] = get_python_version().replace(".", "")
+ g['BINDIR'] = os.path.dirname(os.path.realpath(sys.executable))
global _config_vars
_config_vars = g
@@ -521,15 +554,57 @@
# are in CFLAGS or LDFLAGS and remove them if they are.
# This is needed when building extensions on a 10.3 system
# using a universal build of python.
- for key in ('LDFLAGS', 'BASECFLAGS',
+ for key in ('LDFLAGS', 'BASECFLAGS', 'LDSHARED',
+ # a number of derived variables. These need to be
+ # patched up as well.
+ 'CFLAGS', 'PY_CFLAGS', 'BLDSHARED'):
+ flags = _config_vars[key]
+ flags = re.sub('-arch\s+\w+\s', ' ', flags)
+ flags = re.sub('-isysroot [^ \t]*', ' ', flags)
+ _config_vars[key] = flags
+
+ else:
+
+ # Allow the user to override the architecture flags using
+ # an environment variable.
+ # NOTE: This name was introduced by Apple in OSX 10.5 and
+ # is used by several scripting languages distributed with
+ # that OS release.
+
+ if 'ARCHFLAGS' in os.environ:
+ arch = os.environ['ARCHFLAGS']
+ for key in ('LDFLAGS', 'BASECFLAGS', 'LDSHARED',
# a number of derived variables. These need to be
# patched up as well.
'CFLAGS', 'PY_CFLAGS', 'BLDSHARED'):
- flags = _config_vars[key]
- flags = re.sub('-arch\s+\w+\s', ' ', flags)
- flags = re.sub('-isysroot [^ \t]*', ' ', flags)
- _config_vars[key] = flags
+ flags = _config_vars[key]
+ flags = re.sub('-arch\s+\w+\s', ' ', flags)
+ flags = flags + ' ' + arch
+ _config_vars[key] = flags
+
+ # If we're on OSX 10.5 or later and the user tries to
+ # compiles an extension using an SDK that is not present
+ # on the current machine it is better to not use an SDK
+ # than to fail.
+ #
+ # The major usecase for this is users using a Python.org
+ # binary installer on OSX 10.6: that installer uses
+ # the 10.4u SDK, but that SDK is not installed by default
+ # when you install Xcode.
+ #
+ m = re.search('-isysroot\s+(\S+)', _config_vars['CFLAGS'])
+ if m is not None:
+ sdk = m.group(1)
+ if not os.path.exists(sdk):
+ for key in ('LDFLAGS', 'BASECFLAGS', 'LDSHARED',
+ # a number of derived variables. These need to be
+ # patched up as well.
+ 'CFLAGS', 'PY_CFLAGS', 'BLDSHARED'):
+
+ flags = _config_vars[key]
+ flags = re.sub('-isysroot\s+\S+(\s|$)', ' ', flags)
+ _config_vars[key] = flags
if args:
vals = []
diff --git a/Lib/distutils/tests/test_build_py.py b/Lib/distutils/tests/test_build_py.py
--- a/Lib/distutils/tests/test_build_py.py
+++ b/Lib/distutils/tests/test_build_py.py
@@ -72,6 +72,7 @@
open(os.path.join(testdir, "testfile"), "w").close()
os.chdir(sources)
+ old_stdout = sys.stdout
sys.stdout = StringIO.StringIO()
try:
@@ -90,7 +91,23 @@
finally:
# Restore state.
os.chdir(cwd)
- sys.stdout = sys.__stdout__
+ sys.stdout = old_stdout
+
+ def test_dont_write_bytecode(self):
+ # makes sure byte_compile is not used
+ pkg_dir, dist = self.create_dist()
+ cmd = build_py(dist)
+ cmd.compile = 1
+ cmd.optimize = 1
+
+ old_dont_write_bytecode = sys.dont_write_bytecode
+ sys.dont_write_bytecode = True
+ try:
+ cmd.byte_compile([])
+ finally:
+ sys.dont_write_bytecode = old_dont_write_bytecode
+
+ self.assertTrue('byte-compiling is disabled' in self.logs[0][1])
def test_suite():
return unittest.makeSuite(BuildPyTestCase)
diff --git a/Lib/distutils/util.py b/Lib/distutils/util.py
--- a/Lib/distutils/util.py
+++ b/Lib/distutils/util.py
@@ -4,13 +4,14 @@
one of the other *util.py modules.
"""
-__revision__ = "$Id: util.py 59116 2007-11-22 10:14:26Z ronald.oussoren $"
+__revision__ = "$Id: util.py 83588 2010-08-02 21:35:06Z ezio.melotti $"
import sys, os, string, re
from distutils.errors import DistutilsPlatformError
from distutils.dep_util import newer
from distutils.spawn import spawn
from distutils import log
+from distutils.errors import DistutilsByteCompileError
def get_platform ():
"""Return a string that identifies the current platform. This is used
@@ -29,8 +30,27 @@
irix-5.3
irix64-6.2
- For non-POSIX platforms, currently just returns 'sys.platform'.
+ Windows will return one of:
+ win-amd64 (64bit Windows on AMD64 (aka x86_64, Intel64, EM64T, etc)
+ win-ia64 (64bit Windows on Itanium)
+ win32 (all others - specifically, sys.platform is returned)
+
+ For other non-POSIX platforms, currently just returns 'sys.platform'.
"""
+ if os.name == 'nt':
+ # sniff sys.version for architecture.
+ prefix = " bit ("
+ i = string.find(sys.version, prefix)
+ if i == -1:
+ return sys.platform
+ j = string.find(sys.version, ")", i)
+ look = sys.version[i+len(prefix):j].lower()
+ if look=='amd64':
+ return 'win-amd64'
+ if look=='itanium':
+ return 'win-ia64'
+ return sys.platform
+
if os.name != "posix" or not hasattr(os, 'uname'):
# XXX what about the architecture? NT is Intel or Alpha,
# Mac OS is M68k or PPC, etc.
@@ -81,7 +101,11 @@
if not macver:
macver = cfgvars.get('MACOSX_DEPLOYMENT_TARGET')
- if not macver:
+ if 1:
+ # Always calculate the release of the running machine,
+ # needed to determine if we can build fat binaries or not.
+
+ macrelease = macver
# Get the system version. Reading this plist is a documented
# way to get the system version (see the documentation for
# the Gestalt Manager)
@@ -97,25 +121,62 @@
r'<string>(.*?)</string>', f.read())
f.close()
if m is not None:
- macver = '.'.join(m.group(1).split('.')[:2])
+ macrelease = '.'.join(m.group(1).split('.')[:2])
# else: fall back to the default behaviour
+ if not macver:
+ macver = macrelease
+
if macver:
from distutils.sysconfig import get_config_vars
release = macver
osname = "macosx"
-
- if (release + '.') >= '10.4.' and \
- get_config_vars().get('UNIVERSALSDK', '').strip():
+ if (macrelease + '.') >= '10.4.' and \
+ '-arch' in get_config_vars().get('CFLAGS', '').strip():
# The universal build will build fat binaries, but not on
# systems before 10.4
+ #
+ # Try to detect 4-way universal builds, those have machine-type
+ # 'universal' instead of 'fat'.
+
machine = 'fat'
+ cflags = get_config_vars().get('CFLAGS')
+
+ archs = re.findall('-arch\s+(\S+)', cflags)
+ archs = tuple(sorted(set(archs)))
+
+ if len(archs) == 1:
+ machine = archs[0]
+ elif archs == ('i386', 'ppc'):
+ machine = 'fat'
+ elif archs == ('i386', 'x86_64'):
+ machine = 'intel'
+ elif archs == ('i386', 'ppc', 'x86_64'):
+ machine = 'fat3'
+ elif archs == ('ppc64', 'x86_64'):
+ machine = 'fat64'
+ elif archs == ('i386', 'ppc', 'ppc64', 'x86_64'):
+ machine = 'universal'
+ else:
+ raise ValueError(
+ "Don't know machine value for archs=%r"%(archs,))
+
+ elif machine == 'i386':
+ # On OSX the machine type returned by uname is always the
+ # 32-bit variant, even if the executable architecture is
+ # the 64-bit variant
+ if sys.maxint >= 2**32:
+ machine = 'x86_64'
elif machine in ('PowerPC', 'Power_Macintosh'):
# Pick a sane name for the PPC architecture.
machine = 'ppc'
+ # See 'i386' case
+ if sys.maxint >= 2**32:
+ machine = 'ppc64'
+
return "%s-%s-%s" % (osname, release, machine)
# get_platform ()
@@ -144,7 +205,7 @@
paths.remove('.')
if not paths:
return os.curdir
- return apply(os.path.join, paths)
+ return os.path.join(*paths)
# convert_path ()
@@ -201,11 +262,11 @@
if _environ_checked:
return
- if os.name == 'posix' and not os.environ.has_key('HOME'):
+ if os.name == 'posix' and 'HOME' not in os.environ:
import pwd
os.environ['HOME'] = pwd.getpwuid(os.getuid())[5]
- if not os.environ.has_key('PLAT'):
+ if 'PLAT' not in os.environ:
os.environ['PLAT'] = get_platform()
_environ_checked = 1
@@ -223,7 +284,7 @@
check_environ()
def _subst (match, local_vars=local_vars):
var_name = match.group(1)
- if local_vars.has_key(var_name):
+ if var_name in local_vars:
return str(local_vars[var_name])
else:
return os.environ[var_name]
@@ -345,7 +406,7 @@
log.info(msg)
if not dry_run:
- apply(func, args)
+ func(*args)
def strtobool (val):
@@ -397,6 +458,9 @@
generated in indirect mode; unless you know what you're doing, leave
it set to None.
"""
+ # nothing is done if sys.dont_write_bytecode is True
+ if sys.dont_write_bytecode:
+ raise DistutilsByteCompileError('byte-compiling is disabled.')
# First, if the caller didn't force us into direct or indirect mode,
# figure out which mode we should be in. We take a conservative
@@ -512,6 +576,5 @@
RFC-822 header, by ensuring there are 8 spaces space after each newline.
"""
lines = string.split(header, '\n')
- lines = map(string.strip, lines)
header = string.join(lines, '\n' + 8*' ')
return header
diff --git a/Lib/filecmp.py b/Lib/filecmp.py
--- a/Lib/filecmp.py
+++ b/Lib/filecmp.py
@@ -11,7 +11,6 @@
import os
import stat
-import warnings
from itertools import ifilter, ifilterfalse, imap, izip
__all__ = ["cmp","dircmp","cmpfiles"]
@@ -136,9 +135,9 @@
def phase1(self): # Compute common names
a = dict(izip(imap(os.path.normcase, self.left_list), self.left_list))
b = dict(izip(imap(os.path.normcase, self.right_list), self.right_list))
- self.common = map(a.__getitem__, ifilter(b.has_key, a))
- self.left_only = map(a.__getitem__, ifilterfalse(b.has_key, a))
- self.right_only = map(b.__getitem__, ifilterfalse(a.has_key, b))
+ self.common = map(a.__getitem__, ifilter(b.__contains__, a))
+ self.left_only = map(a.__getitem__, ifilterfalse(b.__contains__, a))
+ self.right_only = map(b.__getitem__, ifilterfalse(a.__contains__, b))
def phase2(self): # Distinguish files, directories, funnies
self.common_dirs = []
diff --git a/Lib/fileinput.py b/Lib/fileinput.py
--- a/Lib/fileinput.py
+++ b/Lib/fileinput.py
@@ -226,7 +226,7 @@
self._mode = mode
if inplace and openhook:
raise ValueError("FileInput cannot use an opening hook in inplace mode")
- elif openhook and not callable(openhook):
+ elif openhook and not hasattr(openhook, '__call__'):
raise ValueError("FileInput openhook must be callable")
self._openhook = openhook
diff --git a/Lib/gettext.py b/Lib/gettext.py
--- a/Lib/gettext.py
+++ b/Lib/gettext.py
@@ -468,6 +468,7 @@
if fallback:
return NullTranslations()
raise IOError(ENOENT, 'No translation file found for domain', domain)
+ # TBD: do we need to worry about the file pointer getting collected?
# Avoid opening, reading, and parsing the .mo file after it's been done
# once.
result = None
@@ -475,8 +476,7 @@
key = os.path.abspath(mofile)
t = _translations.get(key)
if t is None:
- with open(mofile, 'rb') as fp:
- t = _translations.setdefault(key, class_(fp))
+ t = _translations.setdefault(key, class_(open(mofile, 'rb')))
# Copy the translation object to allow setting fallbacks and
# output charset. All other instance data is shared with the
# cached object.
diff --git a/Lib/mailbox.py b/Lib/mailbox.py
--- a/Lib/mailbox.py
+++ b/Lib/mailbox.py
@@ -16,9 +16,8 @@
import errno
import copy
import email
-import email.Message
-import email.Generator
-import rfc822
+import email.message
+import email.generator
import StringIO
try:
if sys.platform == 'os2emx':
@@ -28,6 +27,13 @@
except ImportError:
fcntl = None
+import warnings
+with warnings.catch_warnings():
+ if sys.py3kwarning:
+ warnings.filterwarnings("ignore", ".*rfc822 has been removed",
+ DeprecationWarning)
+ import rfc822
+
__all__ = [ 'Mailbox', 'Maildir', 'mbox', 'MH', 'Babyl', 'MMDF',
'Message', 'MaildirMessage', 'mboxMessage', 'MHMessage',
'BabylMessage', 'MMDFMessage', 'UnixMailbox',
@@ -196,9 +202,9 @@
# To get native line endings on disk, the user-friendly \n line endings
# used in strings and by email.Message are translated here.
"""Dump message contents to target file."""
- if isinstance(message, email.Message.Message):
+ if isinstance(message, email.message.Message):
buffer = StringIO.StringIO()
- gen = email.Generator.Generator(buffer, mangle_from_, 0)
+ gen = email.generator.Generator(buffer, mangle_from_, 0)
gen.flatten(message)
buffer.seek(0)
target.write(buffer.read().replace('\n', os.linesep))
@@ -398,7 +404,8 @@
result = Maildir(path, factory=self._factory)
maildirfolder_path = os.path.join(path, 'maildirfolder')
if not os.path.exists(maildirfolder_path):
- os.close(os.open(maildirfolder_path, os.O_CREAT | os.O_WRONLY))
+ os.close(os.open(maildirfolder_path, os.O_CREAT | os.O_WRONLY,
+ 0666))
return result
def remove_folder(self, folder):
@@ -520,6 +527,7 @@
self._next_key = 0
self._pending = False # No changes require rewriting the file.
self._locked = False
+ self._file_length = None # Used to record mailbox size
def add(self, message):
"""Add message and return assigned key."""
@@ -573,7 +581,21 @@
"""Write any pending changes to disk."""
if not self._pending:
return
- self._lookup()
+
+ # In order to be writing anything out at all, self._toc must
+ # already have been generated (and presumably has been modified
+ # by adding or deleting an item).
+ assert self._toc is not None
+
+ # Check length of self._file; if it's changed, some other process
+ # has modified the mailbox since we scanned it.
+ self._file.seek(0, 2)
+ cur_len = self._file.tell()
+ if cur_len != self._file_length:
+ raise ExternalClashError('Size of mailbox file changed '
+ '(expected %i, found %i)' %
+ (self._file_length, cur_len))
+
new_file = _create_temporary(self._path)
try:
new_toc = {}
@@ -649,6 +671,7 @@
offsets = self._install_message(message)
self._post_message_hook(self._file)
self._file.flush()
+ self._file_length = self._file.tell() # Record current length of mailbox
return offsets
@@ -698,7 +721,7 @@
message = ''
elif isinstance(message, _mboxMMDFMessage):
from_line = 'From ' + message.get_from()
- elif isinstance(message, email.Message.Message):
+ elif isinstance(message, email.message.Message):
from_line = message.get_unixfrom() # May be None.
if from_line is None:
from_line = 'From MAILER-DAEMON %s' % time.asctime(time.gmtime())
@@ -740,6 +763,7 @@
break
self._toc = dict(enumerate(zip(starts, stops)))
self._next_key = len(self._toc)
+ self._file_length = self._file.tell()
class MMDF(_mboxMMDF):
@@ -783,6 +807,8 @@
break
self._toc = dict(enumerate(zip(starts, stops)))
self._next_key = len(self._toc)
+ self._file.seek(0, 2)
+ self._file_length = self._file.tell()
class MH(Mailbox):
@@ -891,7 +917,7 @@
_unlock_file(f)
finally:
f.close()
- for name, key_list in self.get_sequences():
+ for name, key_list in self.get_sequences().iteritems():
if key in key_list:
msg.add_sequence(name)
return msg
@@ -1209,6 +1235,8 @@
self._toc = dict(enumerate(zip(starts, stops)))
self._labels = dict(enumerate(label_lists))
self._next_key = len(self._toc)
+ self._file.seek(0, 2)
+ self._file_length = self._file.tell()
def _pre_mailbox_hook(self, f):
"""Called before writing the mailbox to file f."""
@@ -1244,9 +1272,9 @@
self._file.write(os.linesep)
else:
self._file.write('1,,' + os.linesep)
- if isinstance(message, email.Message.Message):
+ if isinstance(message, email.message.Message):
orig_buffer = StringIO.StringIO()
- orig_generator = email.Generator.Generator(orig_buffer, False, 0)
+ orig_generator = email.generator.Generator(orig_buffer, False, 0)
orig_generator.flatten(message)
orig_buffer.seek(0)
while True:
@@ -1257,7 +1285,7 @@
self._file.write('*** EOOH ***' + os.linesep)
if isinstance(message, BabylMessage):
vis_buffer = StringIO.StringIO()
- vis_generator = email.Generator.Generator(vis_buffer, False, 0)
+ vis_generator = email.generator.Generator(vis_buffer, False, 0)
vis_generator.flatten(message.get_visible())
while True:
line = vis_buffer.readline()
@@ -1313,12 +1341,12 @@
return (start, stop)
-class Message(email.Message.Message):
+class Message(email.message.Message):
"""Message with mailbox-format-specific properties."""
def __init__(self, message=None):
"""Initialize a Message instance."""
- if isinstance(message, email.Message.Message):
+ if isinstance(message, email.message.Message):
self._become_message(copy.deepcopy(message))
if isinstance(message, Message):
message._explain_to(self)
@@ -1327,7 +1355,7 @@
elif hasattr(message, "read"):
self._become_message(email.message_from_file(message))
elif message is None:
- email.Message.Message.__init__(self)
+ email.message.Message.__init__(self)
else:
raise TypeError('Invalid message type: %s' % type(message))
@@ -1458,7 +1486,7 @@
def __init__(self, message=None):
"""Initialize an mboxMMDFMessage instance."""
self.set_from('MAILER-DAEMON', True)
- if isinstance(message, email.Message.Message):
+ if isinstance(message, email.message.Message):
unixfrom = message.get_unixfrom()
if unixfrom is not None and unixfrom.startswith('From '):
self.set_from(unixfrom[5:])
@@ -1881,7 +1909,7 @@
def _create_carefully(path):
"""Create a file if it doesn't exist and open for reading and writing."""
- fd = os.open(path, os.O_CREAT | os.O_EXCL | os.O_RDWR)
+ fd = os.open(path, os.O_CREAT | os.O_EXCL | os.O_RDWR, 0666)
try:
return open(path, 'rb+')
finally:
diff --git a/Lib/netrc.py b/Lib/netrc.py
--- a/Lib/netrc.py
+++ b/Lib/netrc.py
@@ -27,12 +27,9 @@
file = os.path.join(os.environ['HOME'], ".netrc")
except KeyError:
raise IOError("Could not find .netrc: $HOME is not set")
+ fp = open(file)
self.hosts = {}
self.macros = {}
- with open(file) as fp:
- self._parse(file, fp)
-
- def _parse(self, file, fp):
lexer = shlex.shlex(fp)
lexer.wordchars += r"""!"#$%&'()*+,-./:;<=>?@[\]^_`{|}~"""
while 1:
diff --git a/Lib/new.py b/Lib/new.py
--- a/Lib/new.py
+++ b/Lib/new.py
@@ -3,6 +3,10 @@
This module is no longer required except for backward compatibility.
Objects of most types can now be created by calling the type object.
"""
+from warnings import warnpy3k
+warnpy3k("The 'new' module has been removed in Python 3.0; use the 'types' "
+ "module instead.", stacklevel=2)
+del warnpy3k
from types import ClassType as classobj
from types import FunctionType as function
diff --git a/Lib/py_compile.py b/Lib/py_compile.py
--- a/Lib/py_compile.py
+++ b/Lib/py_compile.py
@@ -114,11 +114,15 @@
"""
if args is None:
args = sys.argv[1:]
+ rv = 0
for filename in args:
try:
compile(filename, doraise=True)
- except PyCompileError,err:
+ except PyCompileError, err:
+ # return value to indicate at least one failure
+ rv = 1
sys.stderr.write(err.msg)
+ return rv
if __name__ == "__main__":
- main()
+ sys.exit(main())
diff --git a/Lib/test/list_tests.py b/Lib/test/list_tests.py
--- a/Lib/test/list_tests.py
+++ b/Lib/test/list_tests.py
@@ -5,7 +5,6 @@
import sys
import os
-import unittest
from test import test_support, seq_tests
class CommonTest(seq_tests.CommonTest):
@@ -37,7 +36,7 @@
self.assertEqual(str(a0), str(l0))
self.assertEqual(repr(a0), repr(l0))
- self.assertEqual(`a2`, `l2`)
+ self.assertEqual(repr(a2), repr(l2))
self.assertEqual(str(a2), "[0, 1, 2]")
self.assertEqual(repr(a2), "[0, 1, 2]")
@@ -46,6 +45,11 @@
self.assertEqual(str(a2), "[0, 1, 2, [...], 3]")
self.assertEqual(repr(a2), "[0, 1, 2, [...], 3]")
+ l0 = []
+ for i in xrange(sys.getrecursionlimit() + 100):
+ l0 = [l0]
+ self.assertRaises(RuntimeError, repr, l0)
+
def test_print(self):
d = self.type2test(xrange(200))
d.append(d)
@@ -53,13 +57,11 @@
d.append(d)
d.append(400)
try:
- fo = open(test_support.TESTFN, "wb")
- print >> fo, d,
- fo.close()
- fo = open(test_support.TESTFN, "rb")
- self.assertEqual(fo.read(), repr(d))
+ with open(test_support.TESTFN, "wb") as fo:
+ print >> fo, d,
+ with open(test_support.TESTFN, "rb") as fo:
+ self.assertEqual(fo.read(), repr(d))
finally:
- fo.close()
os.remove(test_support.TESTFN)
def test_set_subscript(self):
@@ -80,6 +82,8 @@
self.assertRaises(StopIteration, r.next)
self.assertEqual(list(reversed(self.type2test())),
self.type2test())
+ # Bug 3689: make sure list-reversed-iterator doesn't have __len__
+ self.assertRaises(TypeError, len, reversed([1,2,3]))
def test_setitem(self):
a = self.type2test([0, 1])
@@ -179,8 +183,10 @@
self.assertEqual(a, self.type2test(range(10)))
self.assertRaises(TypeError, a.__setslice__, 0, 1, 5)
+ self.assertRaises(TypeError, a.__setitem__, slice(0, 1, 5))
self.assertRaises(TypeError, a.__setslice__)
+ self.assertRaises(TypeError, a.__setitem__)
def test_delslice(self):
a = self.type2test([0, 1])
@@ -413,6 +419,11 @@
self.assertRaises(TypeError, u.reverse, 42)
def test_sort(self):
+ with test_support._check_py3k_warnings(
+ ("the cmp argument is not supported", DeprecationWarning)):
+ self._test_sort()
+
+ def _test_sort(self):
u = self.type2test([1, 0])
u.sort()
self.assertEqual(u, [0, 1])
@@ -515,13 +526,14 @@
a = self.type2test(range(10))
a[::2] = tuple(range(5))
self.assertEqual(a, self.type2test([0, 1, 1, 3, 2, 5, 3, 7, 4, 9]))
+ # test issue7788
+ a = self.type2test(range(10))
+ del a[9::1<<333]
# XXX: CPython specific, PyList doesn't len() during init
def _test_constructor_exception_handling(self):
# Bug #1242657
class F(object):
def __iter__(self):
- yield 23
- def __len__(self):
raise KeyboardInterrupt
self.assertRaises(KeyboardInterrupt, list, F())
diff --git a/Lib/test/test_array.py b/Lib/test/test_array.py
--- a/Lib/test/test_array.py
+++ b/Lib/test/test_array.py
@@ -6,8 +6,8 @@
import unittest
from test import test_support
from weakref import proxy
-import array, cStringIO, math
-from cPickle import loads, dumps
+import array, cStringIO
+from cPickle import loads, dumps, HIGHEST_PROTOCOL
if test_support.is_jython:
import operator
@@ -105,7 +105,7 @@
self.assertEqual(a, b)
def test_pickle(self):
- for protocol in (0, 1, 2):
+ for protocol in range(HIGHEST_PROTOCOL + 1):
a = array.array(self.typecode, self.example)
b = loads(dumps(a, protocol))
self.assertNotEqual(id(a), id(b))
@@ -120,7 +120,7 @@
self.assertEqual(type(a), type(b))
def test_pickle_for_empty_array(self):
- for protocol in (0, 1, 2):
+ for protocol in range(HIGHEST_PROTOCOL + 1):
a = array.array(self.typecode)
b = loads(dumps(a, protocol))
self.assertNotEqual(id(a), id(b))
@@ -171,6 +171,7 @@
a = array.array(self.typecode, 2*self.example)
self.assertRaises(TypeError, a.tofile)
self.assertRaises(TypeError, a.tofile, cStringIO.StringIO())
+ test_support.unlink(test_support.TESTFN)
f = open(test_support.TESTFN, 'wb')
try:
a.tofile(f)
@@ -195,6 +196,17 @@
f.close()
test_support.unlink(test_support.TESTFN)
+ def test_fromfile_ioerror(self):
+ # Issue #5395: Check if fromfile raises a proper IOError
+ # instead of EOFError.
+ a = array.array(self.typecode)
+ f = open(test_support.TESTFN, 'wb')
+ try:
+ self.assertRaises(IOError, a.fromfile, f, len(self.example))
+ finally:
+ f.close()
+ test_support.unlink(test_support.TESTFN)
+
def test_tofromlist(self):
a = array.array(self.typecode, 2*self.example)
b = array.array(self.typecode)
@@ -513,6 +525,18 @@
array.array(self.typecode)
)
+ def test_extended_getslice(self):
+ # Test extended slicing by comparing with list slicing
+ # (Assumes list conversion works correctly, too)
+ a = array.array(self.typecode, self.example)
+ indices = (0, None, 1, 3, 19, 100, -1, -2, -31, -100)
+ for start in indices:
+ for stop in indices:
+ # Everything except the initial 0 (invalid step)
+ for step in indices[1:]:
+ self.assertEqual(list(a[start:stop:step]),
+ list(a)[start:stop:step])
+
def test_setslice(self):
a = array.array(self.typecode, self.example)
a[:1] = a
@@ -596,12 +620,34 @@
a = array.array(self.typecode, self.example)
self.assertRaises(TypeError, a.__setslice__, 0, 0, None)
+ self.assertRaises(TypeError, a.__setitem__, slice(0, 0), None)
self.assertRaises(TypeError, a.__setitem__, slice(0, 1), None)
b = array.array(self.badtypecode())
self.assertRaises(TypeError, a.__setslice__, 0, 0, b)
+ self.assertRaises(TypeError, a.__setitem__, slice(0, 0), b)
self.assertRaises(TypeError, a.__setitem__, slice(0, 1), b)
+ def test_extended_set_del_slice(self):
+ indices = (0, None, 1, 3, 19, 100, -1, -2, -31, -100)
+ for start in indices:
+ for stop in indices:
+ # Everything except the initial 0 (invalid step)
+ for step in indices[1:]:
+ a = array.array(self.typecode, self.example)
+ L = list(a)
+ # Make sure we have a slice of exactly the right length,
+ # but with (hopefully) different data.
+ data = L[start:stop:step]
+ data.reverse()
+ L[start:stop:step] = data
+ a[start:stop:step] = array.array(self.typecode, data)
+ self.assertEquals(a, array.array(self.typecode, L))
+
+ del L[start:stop:step]
+ del a[start:stop:step]
+ self.assertEquals(a, array.array(self.typecode, L))
+
def test_index(self):
example = 2*self.example
a = array.array(self.typecode, example)
@@ -721,7 +767,8 @@
def test_buffer(self):
a = array.array(self.typecode, self.example)
- b = buffer(a)
+ with test_support._check_py3k_warnings():
+ b = buffer(a)
self.assertEqual(b[0], a.tostring()[0])
def test_weakref(self):
@@ -769,7 +816,6 @@
return array.array.__new__(cls, 'c', s)
def __init__(self, s, color='blue'):
- array.array.__init__(self, 'c', s)
self.color = color
def strip(self):
@@ -857,6 +903,9 @@
a = array.array(self.typecode, range(10))
del a[::1000]
self.assertEqual(a, array.array(self.typecode, [1,2,3,4,5,6,7,8,9]))
+ # test issue7788
+ a = array.array(self.typecode, range(10))
+ del a[9::1<<333]
def test_assignment(self):
a = array.array(self.typecode, range(10))
@@ -1023,6 +1072,24 @@
class DoubleTest(FPTest):
typecode = 'd'
minitemsize = 8
+
+ def test_alloc_overflow(self):
+ from sys import maxsize
+ a = array.array('d', [-1]*65536)
+ try:
+ a *= maxsize//65536 + 1
+ except MemoryError:
+ pass
+ else:
+ self.fail("Array of size > maxsize created - MemoryError expected")
+ b = array.array('d', [ 2.71828183, 3.14159265, -1])
+ try:
+ b * (maxsize//3 + 1)
+ except MemoryError:
+ pass
+ else:
+ self.fail("Array of size > maxsize created - MemoryError expected")
+
tests.append(DoubleTest)
def test_main(verbose=None):
diff --git a/Lib/test/test_code.py b/Lib/test/test_code.py
--- a/Lib/test/test_code.py
+++ b/Lib/test/test_code.py
@@ -89,3 +89,7 @@
from test.test_support import run_doctest
from test import test_code
run_doctest(test_code, verbose)
+
+
+if __name__ == '__main__':
+ test_main()
diff --git a/Lib/test/test_codeccallbacks.py b/Lib/test/test_codeccallbacks.py
--- a/Lib/test/test_codeccallbacks.py
+++ b/Lib/test/test_codeccallbacks.py
@@ -109,7 +109,7 @@
# useful that the error handler is not called for every single
# unencodable character, but for a complete sequence of
# unencodable characters, otherwise we would output many
- # unneccessary escape sequences.
+ # unnecessary escape sequences.
def uninamereplace(exc):
if not isinstance(exc, UnicodeEncodeError):
@@ -153,28 +153,30 @@
sout += "\\U%08x" % sys.maxunicode
self.assertEqual(sin.encode("iso-8859-15", "backslashreplace"), sout)
- def test_decoderelaxedutf8(self):
- # This is the test for a decoding callback handler,
- # that relaxes the UTF-8 minimal encoding restriction.
- # A null byte that is encoded as "\xc0\x80" will be
- # decoded as a null byte. All other illegal sequences
- # will be handled strictly.
+ def test_decoding_callbacks(self):
+ # This is a test for a decoding callback handler
+ # that allows the decoding of the invalid sequence
+ # "\xc0\x80" and returns "\x00" instead of raising an error.
+ # All other illegal sequences will be handled strictly.
def relaxedutf8(exc):
if not isinstance(exc, UnicodeDecodeError):
raise TypeError("don't know how to handle %r" % exc)
- if exc.object[exc.start:exc.end].startswith("\xc0\x80"):
+ if exc.object[exc.start:exc.start+2] == "\xc0\x80":
return (u"\x00", exc.start+2) # retry after two bytes
else:
raise exc
- codecs.register_error(
- "test.relaxedutf8", relaxedutf8)
+ codecs.register_error("test.relaxedutf8", relaxedutf8)
+ # all the "\xc0\x80" will be decoded to "\x00"
sin = "a\x00b\xc0\x80c\xc3\xbc\xc0\x80\xc0\x80"
sout = u"a\x00b\x00c\xfc\x00\x00"
self.assertEqual(sin.decode("utf-8", "test.relaxedutf8"), sout)
+
+ # "\xc0\x81" is not valid and a UnicodeDecodeError will be raised
sin = "\xc0\x80\xc0\x81"
- self.assertRaises(UnicodeError, sin.decode, "utf-8", "test.relaxedutf8")
+ self.assertRaises(UnicodeDecodeError, sin.decode,
+ "utf-8", "test.relaxedutf8")
def test_charmapencode(self):
# For charmap encodings the replacement string will be
@@ -285,7 +287,8 @@
def test_longstrings(self):
# test long strings to check for memory overflow problems
- errors = [ "strict", "ignore", "replace", "xmlcharrefreplace", "backslashreplace"]
+ errors = [ "strict", "ignore", "replace", "xmlcharrefreplace",
+ "backslashreplace"]
# register the handlers under different names,
# to prevent the codec from recognizing the name
for err in errors:
@@ -293,7 +296,8 @@
l = 1000
errors += [ "test." + err for err in errors ]
for uni in [ s*l for s in (u"x", u"\u3042", u"a\xe4") ]:
- for enc in ("ascii", "latin-1", "iso-8859-1", "iso-8859-15", "utf-8", "utf-7", "utf-16"):
+ for enc in ("ascii", "latin-1", "iso-8859-1", "iso-8859-15",
+ "utf-8", "utf-7", "utf-16", "utf-32"):
for err in errors:
try:
uni.encode(enc, err)
diff --git a/Lib/test/test_compile.py b/Lib/test/test_compile.py
--- a/Lib/test/test_compile.py
+++ b/Lib/test/test_compile.py
@@ -1,7 +1,8 @@
import unittest
-import warnings
import sys
+import _ast
from test import test_support
+import textwrap
class TestSpecifics(unittest.TestCase):
@@ -137,6 +138,9 @@
def test_complex_args(self):
+ with test_support._check_py3k_warnings(
+ ("tuple parameter unpacking has been removed", SyntaxWarning)):
+ exec textwrap.dedent('''
def comp_args((a, b)):
return a,b
self.assertEqual(comp_args((1, 2)), (1, 2))
@@ -154,6 +158,7 @@
return a, b, c
self.assertEqual(comp_args(1, (2, 3)), (1, 2, 3))
self.assertEqual(comp_args(), (2, 3, 4))
+ ''')
def test_argument_order(self):
try:
@@ -190,7 +195,9 @@
def test_literals_with_leading_zeroes(self):
for arg in ["077787", "0xj", "0x.", "0e", "090000000000000",
- "080000000000000", "000000000000009", "000000000000008"]:
+ "080000000000000", "000000000000009", "000000000000008",
+ "0b42", "0BADCAFE", "0o123456789", "0b1.1", "0o4.2",
+ "0b101j2", "0o153j2", "0b100e1", "0o777e1", "0o8", "0o78"]:
self.assertRaises(SyntaxError, eval, arg)
self.assertEqual(eval("0777"), 511)
@@ -218,6 +225,10 @@
self.assertEqual(eval("000000000000007"), 7)
self.assertEqual(eval("000000000000008."), 8.)
self.assertEqual(eval("000000000000009."), 9.)
+ self.assertEqual(eval("0b101010"), 42)
+ self.assertEqual(eval("-0b000000000010"), -2)
+ self.assertEqual(eval("0o777"), 511)
+ self.assertEqual(eval("-0o0000010"), -8)
self.assertEqual(eval("020000000000.0"), 20000000000.0)
self.assertEqual(eval("037777777777e0"), 37777777777.0)
self.assertEqual(eval("01000000000000000000000.0"),
@@ -417,9 +428,58 @@
del d[..., ...]
self.assertEqual((Ellipsis, Ellipsis) in d, False)
- def test_nested_classes(self):
- # Verify that it does not leak
- compile("class A:\n class B: pass", 'tmp', 'exec')
+ def test_mangling(self):
+ class A:
+ def f():
+ __mangled = 1
+ __not_mangled__ = 2
+ import __mangled_mod
+ import __package__.module
+
+ self.assert_("_A__mangled" in A.f.func_code.co_varnames)
+ self.assert_("__not_mangled__" in A.f.func_code.co_varnames)
+ self.assert_("_A__mangled_mod" in A.f.func_code.co_varnames)
+ self.assert_("__package__" in A.f.func_code.co_varnames)
+
+ def test_compile_ast(self):
+ fname = __file__
+ if fname.lower().endswith(('pyc', 'pyo')):
+ fname = fname[:-1]
+ with open(fname, 'r') as f:
+ fcontents = f.read()
+ sample_code = [
+ ['<assign>', 'x = 5'],
+ ['<print1>', 'print 1'],
+ ['<printv>', 'print v'],
+ ['<printTrue>', 'print True'],
+ ['<printList>', 'print []'],
+ ['<ifblock>', """if True:\n pass\n"""],
+ ['<forblock>', """for n in [1, 2, 3]:\n print n\n"""],
+ ['<deffunc>', """def foo():\n pass\nfoo()\n"""],
+ [fname, fcontents],
+ ]
+
+ for fname, code in sample_code:
+ co1 = compile(code, '%s1' % fname, 'exec')
+ ast = compile(code, '%s2' % fname, 'exec', _ast.PyCF_ONLY_AST)
+ self.assert_(type(ast) == _ast.Module)
+ co2 = compile(ast, '%s3' % fname, 'exec')
+ self.assertEqual(co1, co2)
+ # the code object's filename comes from the second compilation step
+ self.assertEqual(co2.co_filename, '%s3' % fname)
+
+ # raise exception when node type doesn't match with compile mode
+ co1 = compile('print 1', '<string>', 'exec', _ast.PyCF_ONLY_AST)
+ self.assertRaises(TypeError, compile, co1, '<ast>', 'eval')
+
+ # raise exception when node type is no start node
+ self.assertRaises(TypeError, compile, _ast.If(), '<ast>', 'exec')
+
+ # raise exception when node has invalid children
+ ast = _ast.Module()
+ ast.body = [_ast.BoolOp()]
+ self.assertRaises(TypeError, compile, ast, '<ast>', 'exec')
+
def test_main():
test_support.run_unittest(TestSpecifics)
diff --git a/Lib/test/test_copy.py b/Lib/test/test_copy.py
--- a/Lib/test/test_copy.py
+++ b/Lib/test/test_copy.py
@@ -1,6 +1,5 @@
"""Unit tests for the copy module."""
-import sys
import copy
import copy_reg
@@ -439,6 +438,7 @@
return (C, (), self.__dict__)
def __cmp__(self, other):
return cmp(self.__dict__, other.__dict__)
+ __hash__ = None # Silence Py3k warning
x = C()
x.foo = [42]
y = copy.copy(x)
@@ -455,6 +455,7 @@
self.__dict__.update(state)
def __cmp__(self, other):
return cmp(self.__dict__, other.__dict__)
+ __hash__ = None # Silence Py3k warning
x = C()
x.foo = [42]
y = copy.copy(x)
@@ -481,6 +482,7 @@
def __cmp__(self, other):
return (cmp(list(self), list(other)) or
cmp(self.__dict__, other.__dict__))
+ __hash__ = None # Silence Py3k warning
x = C([[1, 2], 3])
y = copy.copy(x)
self.assertEqual(x, y)
@@ -498,6 +500,7 @@
def __cmp__(self, other):
return (cmp(dict(self), list(dict)) or
cmp(self.__dict__, other.__dict__))
+ __hash__ = None # Silence Py3k warning
x = C([("foo", [1, 2]), ("bar", 3)])
y = copy.copy(x)
self.assertEqual(x, y)
diff --git a/Lib/test/test_descrtut.py b/Lib/test/test_descrtut.py
--- a/Lib/test/test_descrtut.py
+++ b/Lib/test/test_descrtut.py
@@ -54,7 +54,7 @@
{1: 3.25}
>>> print a[1] # show the new item
3.25
- >>> print a[0] # a non-existant item
+ >>> print a[0] # a non-existent item
0.0
>>> a.merge({1:100, 2:200}) # use a dict method
>>> print sortdict(a) # show the result
@@ -66,7 +66,7 @@
statement or the built-in function eval():
>>> def sorted(seq):
- ... seq.sort()
+ ... seq.sort(key=str)
... return seq
>>> print sorted(a.keys())
[1, 2]
@@ -183,6 +183,7 @@
'__delslice__',
'__doc__',
'__eq__',
+ '__format__',
'__ge__',
'__getattribute__',
'__getitem__',
@@ -207,7 +208,9 @@
'__setattr__',
'__setitem__',
'__setslice__',
+ '__sizeof__',
'__str__',
+ '__subclasshook__',
'append',
'count',
'extend',
diff --git a/Lib/test/test_dumbdbm.py b/Lib/test/test_dumbdbm.py
--- a/Lib/test/test_dumbdbm.py
+++ b/Lib/test/test_dumbdbm.py
@@ -38,6 +38,30 @@
self.read_helper(f)
f.close()
+ def test_dumbdbm_creation_mode(self):
+ # On platforms without chmod, don't do anything.
+ if not (hasattr(os, 'chmod') and hasattr(os, 'umask')):
+ return
+
+ try:
+ old_umask = os.umask(0002)
+ f = dumbdbm.open(_fname, 'c', 0637)
+ f.close()
+ finally:
+ os.umask(old_umask)
+
+ expected_mode = 0635
+ if os.name != 'posix':
+ # Windows only supports setting the read-only attribute.
+ # This shouldn't fail, but doesn't work like Unix either.
+ expected_mode = 0666
+
+ import stat
+ st = os.stat(_fname + '.dat')
+ self.assertEqual(stat.S_IMODE(st.st_mode), expected_mode)
+ st = os.stat(_fname + '.dir')
+ self.assertEqual(stat.S_IMODE(st.st_mode), expected_mode)
+
def test_close_twice(self):
f = dumbdbm.open(_fname)
f['a'] = 'b'
diff --git a/Lib/test/test_genexps.py b/Lib/test/test_genexps.py
--- a/Lib/test/test_genexps.py
+++ b/Lib/test/test_genexps.py
@@ -98,7 +98,7 @@
Verify that parenthesis are required when used as a keyword argument value
>>> dict(a = (i for i in xrange(10))) #doctest: +ELLIPSIS
- {'a': <generator object at ...>}
+ {'a': <generator object <genexpr> at ...>}
Verify early binding for the outermost for-expression
diff --git a/Lib/test/test_hashlib.py b/Lib/test/test_hashlib.py
--- a/Lib/test/test_hashlib.py
+++ b/Lib/test/test_hashlib.py
@@ -1,15 +1,15 @@
# Test hashlib module
#
-# $Id: test_hashlib.py 39316 2005-08-21 18:45:59Z greg $
+# $Id: test_hashlib.py 79216 2010-03-21 19:16:28Z georg.brandl $
#
-# Copyright (C) 2005 Gregory P. Smith (greg at electricrain.com)
+# Copyright (C) 2005-2010 Gregory P. Smith (greg at krypto.org)
# Licensed to PSF under a Contributor Agreement.
#
import hashlib
import unittest
from test import test_support
-
+from test.test_support import _4G, precisionbigmemtest
def hexstr(s):
import string
@@ -55,7 +55,6 @@
m2.update(aas + bees + cees)
self.assertEqual(m1.digest(), m2.digest())
-
def check(self, name, data, digest):
# test the direct constructors
computed = getattr(hashlib, name)(data).hexdigest()
@@ -75,6 +74,21 @@
self.check('md5', 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789',
'd174ab98d277d9f5a5611c2c9f419d9f')
+ @precisionbigmemtest(size=_4G + 5, memuse=1)
+ def test_case_md5_huge(self, size):
+ if size == _4G + 5:
+ try:
+ self.check('md5', 'A'*size, 'c9af2dff37468ce5dfee8f2cfc0a9c6d')
+ except OverflowError:
+ pass # 32-bit arch
+
+ @precisionbigmemtest(size=_4G - 1, memuse=1)
+ def test_case_md5_uintmax(self, size):
+ if size == _4G - 1:
+ try:
+ self.check('md5', 'A'*size, '28138d306ff1b8281f1a9067e1a1a2b3')
+ except OverflowError:
+ pass # 32-bit arch
# use the three examples from Federal Information Processing Standards
# Publication 180-1, Secure Hash Standard, 1995 April 17
diff --git a/Lib/test/test_hmac.py b/Lib/test/test_hmac.py
--- a/Lib/test/test_hmac.py
+++ b/Lib/test/test_hmac.py
@@ -1,7 +1,7 @@
import hmac
-import sha
import hashlib
import unittest
+import warnings
from test import test_support
class TestVectorsTestCase(unittest.TestCase):
@@ -44,7 +44,7 @@
def test_sha_vectors(self):
def shatest(key, data, digest):
- h = hmac.HMAC(key, data, digestmod=sha)
+ h = hmac.HMAC(key, data, digestmod=hashlib.sha1)
self.assertEqual(h.hexdigest().upper(), digest.upper())
shatest(chr(0x0b) * 20,
@@ -200,6 +200,35 @@
def test_sha512_rfc4231(self):
self._rfc4231_test_cases(hashlib.sha512)
+ def test_legacy_block_size_warnings(self):
+ class MockCrazyHash(object):
+ """Ain't no block_size attribute here."""
+ def __init__(self, *args):
+ self._x = hashlib.sha1(*args)
+ self.digest_size = self._x.digest_size
+ def update(self, v):
+ self._x.update(v)
+ def digest(self):
+ return self._x.digest()
+
+ with warnings.catch_warnings():
+ warnings.simplefilter('error', RuntimeWarning)
+ try:
+ hmac.HMAC('a', 'b', digestmod=MockCrazyHash)
+ except RuntimeWarning:
+ pass
+ else:
+ self.fail('Expected warning about missing block_size')
+
+ MockCrazyHash.block_size = 1
+ try:
+ hmac.HMAC('a', 'b', digestmod=MockCrazyHash)
+ except RuntimeWarning:
+ pass
+ else:
+ self.fail('Expected warning about small block_size')
+
+
class ConstructorTestCase(unittest.TestCase):
@@ -220,18 +249,16 @@
def test_withmodule(self):
# Constructor call with text and digest module.
- import sha
try:
- h = hmac.HMAC("key", "", sha)
+ h = hmac.HMAC("key", "", hashlib.sha1)
except:
- self.fail("Constructor call with sha module raised exception.")
+ self.fail("Constructor call with hashlib.sha1 raised exception.")
class SanityTestCase(unittest.TestCase):
def test_default_is_md5(self):
# Testing if HMAC defaults to MD5 algorithm.
# NOTE: this whitebox test depends on the hmac class internals
- import hashlib
h = hmac.HMAC("key")
self.failUnless(h.digest_cons == hashlib.md5)
diff --git a/Lib/test/test_iter.py b/Lib/test/test_iter.py
--- a/Lib/test/test_iter.py
+++ b/Lib/test/test_iter.py
@@ -1,7 +1,8 @@
# Test iterators.
import unittest
-from test.test_support import run_unittest, TESTFN, unlink, have_unicode
+from test.test_support import run_unittest, TESTFN, unlink, have_unicode, \
+ _check_py3k_warnings
# Test result of triple loop (too big to inline)
TRIPLETS = [(0, 0, 0), (0, 0, 1), (0, 0, 2),
@@ -389,21 +390,24 @@
# Test map()'s use of iterators.
def test_builtin_map(self):
- self.assertEqual(map(None, SequenceClass(5)), range(5))
self.assertEqual(map(lambda x: x+1, SequenceClass(5)), range(1, 6))
d = {"one": 1, "two": 2, "three": 3}
- self.assertEqual(map(None, d), d.keys())
self.assertEqual(map(lambda k, d=d: (k, d[k]), d), d.items())
dkeys = d.keys()
expected = [(i < len(d) and dkeys[i] or None,
i,
i < len(d) and dkeys[i] or None)
for i in range(5)]
- self.assertEqual(map(None, d,
- SequenceClass(5),
- iter(d.iterkeys())),
- expected)
+
+ # Deprecated map(None, ...)
+ with _check_py3k_warnings():
+ self.assertEqual(map(None, SequenceClass(5)), range(5))
+ self.assertEqual(map(None, d), d.keys())
+ self.assertEqual(map(None, d,
+ SequenceClass(5),
+ iter(d.iterkeys())),
+ expected)
f = open(TESTFN, "w")
try:
@@ -499,7 +503,11 @@
self.assertEqual(zip(x, y), expected)
# Test reduces()'s use of iterators.
- def test_builtin_reduce(self):
+ def test_deprecated_builtin_reduce(self):
+ with _check_py3k_warnings():
+ self._test_builtin_reduce()
+
+ def _test_builtin_reduce(self):
from operator import add
self.assertEqual(reduce(add, SequenceClass(5)), 10)
self.assertEqual(reduce(add, SequenceClass(5), 42), 52)
diff --git a/Lib/test/test_logging.py b/Lib/test/test_logging.py
--- a/Lib/test/test_logging.py
+++ b/Lib/test/test_logging.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python
#
-# Copyright 2001-2004 by Vinay Sajip. All Rights Reserved.
+# Copyright 2001-2010 by Vinay Sajip. All Rights Reserved.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose and without fee is hereby granted,
@@ -15,201 +15,288 @@
# ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER
# IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
-#
-# This file is part of the Python logging distribution. See
-# http://www.red-dove.com/python_logging.html
-#
+
"""Test harness for the logging module. Run all tests.
-Copyright (C) 2001-2002 Vinay Sajip. All Rights Reserved.
+Copyright (C) 2001-2010 Vinay Sajip. All Rights Reserved.
"""
+import logging
+import logging.handlers
+import logging.config
+
+import codecs
+import copy
+import cPickle
+import cStringIO
+import gc
+import os
+import re
import select
-import os, sys, string, struct, types, cPickle, cStringIO
-import socket, tempfile, threading, time
-import logging, logging.handlers, logging.config
-from test.test_support import run_with_locale
+import socket
+from SocketServer import ThreadingTCPServer, StreamRequestHandler
+import string
+import struct
+import sys
+import tempfile
+from test.test_support import captured_stdout, run_with_locale, run_unittest
+import textwrap
+import threading
+import time
+import types
+import unittest
+import weakref
-BANNER = "-- %-10s %-6s ---------------------------------------------------\n"
-FINISH_UP = "Finish up, it's closing time. Messages should bear numbers 0 through 24."
+class BaseTest(unittest.TestCase):
-#----------------------------------------------------------------------------
-# Log receiver
-#----------------------------------------------------------------------------
+ """Base class for logging tests."""
-TIMEOUT = 10
+ log_format = "%(name)s -> %(levelname)s: %(message)s"
+ expected_log_pat = r"^([\w.]+) -> ([\w]+): ([\d]+)$"
+ message_num = 0
-from SocketServer import ThreadingTCPServer, StreamRequestHandler
+ def setUp(self):
+ """Setup the default logging stream to an internal StringIO instance,
+ so that we can examine log output as we want."""
+ logger_dict = logging.getLogger().manager.loggerDict
+ logging._acquireLock()
+ try:
+ self.saved_handlers = logging._handlers.copy()
+ self.saved_handler_list = logging._handlerList[:]
+ self.saved_loggers = logger_dict.copy()
+ self.saved_level_names = logging._levelNames.copy()
+ finally:
+ logging._releaseLock()
-class LogRecordStreamHandler(StreamRequestHandler):
- """
- Handler for a streaming logging request. It basically logs the record
- using whatever logging policy is configured locally.
- """
+ # Set two unused loggers: one non-ASCII and one Unicode.
+ # This is to test correct operation when sorting existing
+ # loggers in the configuration code. See issues 8201, 9310.
+ logging.getLogger("\xab\xd7\xbb")
+ logging.getLogger(u"\u013f\u00d6\u0047")
- def handle(self):
- """
- Handle multiple requests - each expected to be a 4-byte length,
- followed by the LogRecord in pickle format. Logs the record
- according to whatever policy is configured locally.
- """
- while 1:
- try:
- chunk = self.connection.recv(4)
- if len(chunk) < 4:
- break
- slen = struct.unpack(">L", chunk)[0]
- chunk = self.connection.recv(slen)
- while len(chunk) < slen:
- chunk = chunk + self.connection.recv(slen - len(chunk))
- obj = self.unPickle(chunk)
- record = logging.makeLogRecord(obj)
- self.handleLogRecord(record)
- except:
- raise
+ self.root_logger = logging.getLogger("")
+ self.original_logging_level = self.root_logger.getEffectiveLevel()
- def unPickle(self, data):
- return cPickle.loads(data)
+ self.stream = cStringIO.StringIO()
+ self.root_logger.setLevel(logging.DEBUG)
+ self.root_hdlr = logging.StreamHandler(self.stream)
+ self.root_formatter = logging.Formatter(self.log_format)
+ self.root_hdlr.setFormatter(self.root_formatter)
+ self.root_logger.addHandler(self.root_hdlr)
- def handleLogRecord(self, record):
- logname = "logrecv.tcp." + record.name
- #If the end-of-messages sentinel is seen, tell the server to terminate
- if record.msg == FINISH_UP:
- self.server.abort = 1
- record.msg = record.msg + " (via " + logname + ")"
- logger = logging.getLogger(logname)
- logger.handle(record)
+ def tearDown(self):
+ """Remove our logging stream, and restore the original logging
+ level."""
+ self.stream.close()
+ self.root_logger.removeHandler(self.root_hdlr)
+ self.root_logger.setLevel(self.original_logging_level)
+ logging._acquireLock()
+ try:
+ logging._levelNames.clear()
+ logging._levelNames.update(self.saved_level_names)
+ logging._handlers.clear()
+ logging._handlers.update(self.saved_handlers)
+ logging._handlerList[:] = self.saved_handler_list
+ loggerDict = logging.getLogger().manager.loggerDict
+ loggerDict.clear()
+ loggerDict.update(self.saved_loggers)
+ finally:
+ logging._releaseLock()
-# The server sets socketDataProcessed when it's done.
-socketDataProcessed = threading.Event()
+ def assert_log_lines(self, expected_values, stream=None):
+ """Match the collected log lines against the regular expression
+ self.expected_log_pat, and compare the extracted group values to
+ the expected_values list of tuples."""
+ stream = stream or self.stream
+ pat = re.compile(self.expected_log_pat)
+ try:
+ stream.reset()
+ actual_lines = stream.readlines()
+ except AttributeError:
+ # StringIO.StringIO lacks a reset() method.
+ actual_lines = stream.getvalue().splitlines()
+ self.assertEquals(len(actual_lines), len(expected_values))
+ for actual, expected in zip(actual_lines, expected_values):
+ match = pat.search(actual)
+ if not match:
+ self.fail("Log line does not match expected pattern:\n" +
+ actual)
+ self.assertEquals(tuple(match.groups()), expected)
+ s = stream.read()
+ if s:
+ self.fail("Remaining output at end of log stream:\n" + s)
-class LogRecordSocketReceiver(ThreadingTCPServer):
- """
- A simple-minded TCP socket-based logging receiver suitable for test
- purposes.
- """
+ def next_message(self):
+ """Generate a message consisting solely of an auto-incrementing
+ integer."""
+ self.message_num += 1
+ return "%d" % self.message_num
- allow_reuse_address = 1
- def __init__(self, host='localhost',
- port=logging.handlers.DEFAULT_TCP_LOGGING_PORT,
- handler=LogRecordStreamHandler):
- ThreadingTCPServer.__init__(self, (host, port), handler)
- self.abort = 0
- self.timeout = 1
+class BuiltinLevelsTest(BaseTest):
+ """Test builtin levels and their inheritance."""
- def serve_until_stopped(self):
- if sys.platform.startswith('java'):
- # XXX: There's a problem using cpython_compatibile_select
- # here: it seems to be due to the fact that
- # cpython_compatible_select switches blocking mode on while
- # a separate thread is reading from the same socket, causing
- # a read of 0 in LogRecordStreamHandler.handle (which
- # deadlocks this test)
- self.socket.setblocking(0)
- while not self.abort:
- rd, wr, ex = select.select([self.socket.fileno()], [], [],
- self.timeout)
- if rd:
- self.handle_request()
- #notify the main thread that we're about to exit
- socketDataProcessed.set()
- # close the listen socket
- self.server_close()
+ def test_flat(self):
+ #Logging levels in a flat logger namespace.
+ m = self.next_message
- def process_request(self, request, client_address):
- #import threading
- t = threading.Thread(target = self.finish_request,
- args = (request, client_address))
- t.start()
+ ERR = logging.getLogger("ERR")
+ ERR.setLevel(logging.ERROR)
+ INF = logging.getLogger("INF")
+ INF.setLevel(logging.INFO)
+ DEB = logging.getLogger("DEB")
+ DEB.setLevel(logging.DEBUG)
-def runTCP(tcpserver):
- tcpserver.serve_until_stopped()
+ # These should log.
+ ERR.log(logging.CRITICAL, m())
+ ERR.error(m())
-#----------------------------------------------------------------------------
-# Test 0
-#----------------------------------------------------------------------------
+ INF.log(logging.CRITICAL, m())
+ INF.error(m())
+ INF.warn(m())
+ INF.info(m())
-msgcount = 0
+ DEB.log(logging.CRITICAL, m())
+ DEB.error(m())
+ DEB.warn (m())
+ DEB.info (m())
+ DEB.debug(m())
-def nextmessage():
- global msgcount
- rv = "Message %d" % msgcount
- msgcount = msgcount + 1
- return rv
+ # These should not log.
+ ERR.warn(m())
+ ERR.info(m())
+ ERR.debug(m())
-def test0():
- ERR = logging.getLogger("ERR")
- ERR.setLevel(logging.ERROR)
- INF = logging.getLogger("INF")
- INF.setLevel(logging.INFO)
- INF_ERR = logging.getLogger("INF.ERR")
- INF_ERR.setLevel(logging.ERROR)
- DEB = logging.getLogger("DEB")
- DEB.setLevel(logging.DEBUG)
+ INF.debug(m())
- INF_UNDEF = logging.getLogger("INF.UNDEF")
- INF_ERR_UNDEF = logging.getLogger("INF.ERR.UNDEF")
- UNDEF = logging.getLogger("UNDEF")
+ self.assert_log_lines([
+ ('ERR', 'CRITICAL', '1'),
+ ('ERR', 'ERROR', '2'),
+ ('INF', 'CRITICAL', '3'),
+ ('INF', 'ERROR', '4'),
+ ('INF', 'WARNING', '5'),
+ ('INF', 'INFO', '6'),
+ ('DEB', 'CRITICAL', '7'),
+ ('DEB', 'ERROR', '8'),
+ ('DEB', 'WARNING', '9'),
+ ('DEB', 'INFO', '10'),
+ ('DEB', 'DEBUG', '11'),
+ ])
- GRANDCHILD = logging.getLogger("INF.BADPARENT.UNDEF")
- CHILD = logging.getLogger("INF.BADPARENT")
+ def test_nested_explicit(self):
+ # Logging levels in a nested namespace, all explicitly set.
+ m = self.next_message
- #These should log
- ERR.log(logging.FATAL, nextmessage())
- ERR.error(nextmessage())
+ INF = logging.getLogger("INF")
+ INF.setLevel(logging.INFO)
+ INF_ERR = logging.getLogger("INF.ERR")
+ INF_ERR.setLevel(logging.ERROR)
- INF.log(logging.FATAL, nextmessage())
- INF.error(nextmessage())
- INF.warn(nextmessage())
- INF.info(nextmessage())
+ # These should log.
+ INF_ERR.log(logging.CRITICAL, m())
+ INF_ERR.error(m())
- INF_UNDEF.log(logging.FATAL, nextmessage())
- INF_UNDEF.error(nextmessage())
- INF_UNDEF.warn (nextmessage())
- INF_UNDEF.info (nextmessage())
+ # These should not log.
+ INF_ERR.warn(m())
+ INF_ERR.info(m())
+ INF_ERR.debug(m())
- INF_ERR.log(logging.FATAL, nextmessage())
- INF_ERR.error(nextmessage())
+ self.assert_log_lines([
+ ('INF.ERR', 'CRITICAL', '1'),
+ ('INF.ERR', 'ERROR', '2'),
+ ])
- INF_ERR_UNDEF.log(logging.FATAL, nextmessage())
- INF_ERR_UNDEF.error(nextmessage())
+ def test_nested_inherited(self):
+ #Logging levels in a nested namespace, inherited from parent loggers.
+ m = self.next_message
- DEB.log(logging.FATAL, nextmessage())
- DEB.error(nextmessage())
- DEB.warn (nextmessage())
- DEB.info (nextmessage())
- DEB.debug(nextmessage())
+ INF = logging.getLogger("INF")
+ INF.setLevel(logging.INFO)
+ INF_ERR = logging.getLogger("INF.ERR")
+ INF_ERR.setLevel(logging.ERROR)
+ INF_UNDEF = logging.getLogger("INF.UNDEF")
+ INF_ERR_UNDEF = logging.getLogger("INF.ERR.UNDEF")
+ UNDEF = logging.getLogger("UNDEF")
- UNDEF.log(logging.FATAL, nextmessage())
- UNDEF.error(nextmessage())
- UNDEF.warn (nextmessage())
- UNDEF.info (nextmessage())
+ # These should log.
+ INF_UNDEF.log(logging.CRITICAL, m())
+ INF_UNDEF.error(m())
+ INF_UNDEF.warn(m())
+ INF_UNDEF.info(m())
+ INF_ERR_UNDEF.log(logging.CRITICAL, m())
+ INF_ERR_UNDEF.error(m())
- GRANDCHILD.log(logging.FATAL, nextmessage())
- CHILD.log(logging.FATAL, nextmessage())
+ # These should not log.
+ INF_UNDEF.debug(m())
+ INF_ERR_UNDEF.warn(m())
+ INF_ERR_UNDEF.info(m())
+ INF_ERR_UNDEF.debug(m())
- #These should not log
- ERR.warn(nextmessage())
- ERR.info(nextmessage())
- ERR.debug(nextmessage())
+ self.assert_log_lines([
+ ('INF.UNDEF', 'CRITICAL', '1'),
+ ('INF.UNDEF', 'ERROR', '2'),
+ ('INF.UNDEF', 'WARNING', '3'),
+ ('INF.UNDEF', 'INFO', '4'),
+ ('INF.ERR.UNDEF', 'CRITICAL', '5'),
+ ('INF.ERR.UNDEF', 'ERROR', '6'),
+ ])
- INF.debug(nextmessage())
- INF_UNDEF.debug(nextmessage())
+ def test_nested_with_virtual_parent(self):
+ # Logging levels when some parent does not exist yet.
+ m = self.next_message
- INF_ERR.warn(nextmessage())
- INF_ERR.info(nextmessage())
- INF_ERR.debug(nextmessage())
- INF_ERR_UNDEF.warn(nextmessage())
- INF_ERR_UNDEF.info(nextmessage())
- INF_ERR_UNDEF.debug(nextmessage())
+ INF = logging.getLogger("INF")
+ GRANDCHILD = logging.getLogger("INF.BADPARENT.UNDEF")
+ CHILD = logging.getLogger("INF.BADPARENT")
+ INF.setLevel(logging.INFO)
- INF.info(FINISH_UP)
+ # These should log.
+ GRANDCHILD.log(logging.FATAL, m())
+ GRANDCHILD.info(m())
+ CHILD.log(logging.FATAL, m())
+ CHILD.info(m())
-#----------------------------------------------------------------------------
-# Test 1
-#----------------------------------------------------------------------------
+ # These should not log.
+ GRANDCHILD.debug(m())
+ CHILD.debug(m())
+
+ self.assert_log_lines([
+ ('INF.BADPARENT.UNDEF', 'CRITICAL', '1'),
+ ('INF.BADPARENT.UNDEF', 'INFO', '2'),
+ ('INF.BADPARENT', 'CRITICAL', '3'),
+ ('INF.BADPARENT', 'INFO', '4'),
+ ])
+
+
+class BasicFilterTest(BaseTest):
+
+ """Test the bundled Filter class."""
+
+ def test_filter(self):
+ # Only messages satisfying the specified criteria pass through the
+ # filter.
+ filter_ = logging.Filter("spam.eggs")
+ handler = self.root_logger.handlers[0]
+ try:
+ handler.addFilter(filter_)
+ spam = logging.getLogger("spam")
+ spam_eggs = logging.getLogger("spam.eggs")
+ spam_eggs_fish = logging.getLogger("spam.eggs.fish")
+ spam_bakedbeans = logging.getLogger("spam.bakedbeans")
+
+ spam.info(self.next_message())
+ spam_eggs.info(self.next_message()) # Good.
+ spam_eggs_fish.info(self.next_message()) # Good.
+ spam_bakedbeans.info(self.next_message())
+
+ self.assert_log_lines([
+ ('spam.eggs', 'INFO', '2'),
+ ('spam.eggs.fish', 'INFO', '3'),
+ ])
+ finally:
+ handler.removeFilter(filter_)
+
#
# First, we define our levels. There can be as many as you want - the only
@@ -219,16 +306,16 @@
# mapping dictionary to convert between your application levels and the
# logging system.
#
-SILENT = 10
-TACITURN = 9
-TERSE = 8
-EFFUSIVE = 7
-SOCIABLE = 6
-VERBOSE = 5
-TALKATIVE = 4
-GARRULOUS = 3
-CHATTERBOX = 2
-BORING = 1
+SILENT = 120
+TACITURN = 119
+TERSE = 118
+EFFUSIVE = 117
+SOCIABLE = 116
+VERBOSE = 115
+TALKATIVE = 114
+GARRULOUS = 113
+CHATTERBOX = 112
+BORING = 111
LEVEL_RANGE = range(BORING, SILENT + 1)
@@ -249,445 +336,602 @@
BORING : 'Boring',
}
-#
-# Now, to demonstrate filtering: suppose for some perverse reason we only
-# want to print out all except GARRULOUS messages. Let's create a filter for
-# this purpose...
-#
-class SpecificLevelFilter(logging.Filter):
- def __init__(self, lvl):
- self.level = lvl
+class GarrulousFilter(logging.Filter):
+
+ """A filter which blocks garrulous messages."""
def filter(self, record):
- return self.level != record.levelno
+ return record.levelno != GARRULOUS
-class GarrulousFilter(SpecificLevelFilter):
- def __init__(self):
- SpecificLevelFilter.__init__(self, GARRULOUS)
+class VerySpecificFilter(logging.Filter):
-#
-# Now, let's demonstrate filtering at the logger. This time, use a filter
-# which excludes SOCIABLE and TACITURN messages. Note that GARRULOUS events
-# are still excluded.
-#
-class VerySpecificFilter(logging.Filter):
+ """A filter which blocks sociable and taciturn messages."""
+
def filter(self, record):
return record.levelno not in [SOCIABLE, TACITURN]
-def message(s):
- sys.stdout.write("%s\n" % s)
-SHOULD1 = "This should only be seen at the '%s' logging level (or lower)"
+class CustomLevelsAndFiltersTest(BaseTest):
-def test1():
-#
-# Now, tell the logging system to associate names with our levels.
-#
- for lvl in my_logging_levels.keys():
- logging.addLevelName(lvl, my_logging_levels[lvl])
+ """Test various filtering possibilities with custom logging levels."""
-#
-# Now, define a test function which logs an event at each of our levels.
-#
+ # Skip the logger name group.
+ expected_log_pat = r"^[\w.]+ -> ([\w]+): ([\d]+)$"
- def doLog(log):
+ def setUp(self):
+ BaseTest.setUp(self)
+ for k, v in my_logging_levels.items():
+ logging.addLevelName(k, v)
+
+ def log_at_all_levels(self, logger):
for lvl in LEVEL_RANGE:
- log.log(lvl, SHOULD1, logging.getLevelName(lvl))
+ logger.log(lvl, self.next_message())
- log = logging.getLogger("")
- hdlr = log.handlers[0]
-#
-# Set the logging level to each different value and call the utility
-# function to log events.
-# In the output, you should see that each time round the loop, the number of
-# logging events which are actually output decreases.
-#
- for lvl in LEVEL_RANGE:
- message("-- setting logging level to '%s' -----" %
- logging.getLevelName(lvl))
- log.setLevel(lvl)
- doLog(log)
- #
- # Now, we demonstrate level filtering at the handler level. Tell the
- # handler defined above to filter at level 'SOCIABLE', and repeat the
- # above loop. Compare the output from the two runs.
- #
- hdlr.setLevel(SOCIABLE)
- message("-- Filtering at handler level to SOCIABLE --")
- for lvl in LEVEL_RANGE:
- message("-- setting logging level to '%s' -----" %
- logging.getLevelName(lvl))
- log.setLevel(lvl)
- doLog(log)
+ def test_logger_filter(self):
+ # Filter at logger level.
+ self.root_logger.setLevel(VERBOSE)
+ # Levels >= 'Verbose' are good.
+ self.log_at_all_levels(self.root_logger)
+ self.assert_log_lines([
+ ('Verbose', '5'),
+ ('Sociable', '6'),
+ ('Effusive', '7'),
+ ('Terse', '8'),
+ ('Taciturn', '9'),
+ ('Silent', '10'),
+ ])
- hdlr.setLevel(0) #turn off level filtering at the handler
+ def test_handler_filter(self):
+ # Filter at handler level.
+ self.root_logger.handlers[0].setLevel(SOCIABLE)
+ try:
+ # Levels >= 'Sociable' are good.
+ self.log_at_all_levels(self.root_logger)
+ self.assert_log_lines([
+ ('Sociable', '6'),
+ ('Effusive', '7'),
+ ('Terse', '8'),
+ ('Taciturn', '9'),
+ ('Silent', '10'),
+ ])
+ finally:
+ self.root_logger.handlers[0].setLevel(logging.NOTSET)
- garr = GarrulousFilter()
- hdlr.addFilter(garr)
- message("-- Filtering using GARRULOUS filter --")
- for lvl in LEVEL_RANGE:
- message("-- setting logging level to '%s' -----" %
- logging.getLevelName(lvl))
- log.setLevel(lvl)
- doLog(log)
- spec = VerySpecificFilter()
- log.addFilter(spec)
- message("-- Filtering using specific filter for SOCIABLE, TACITURN --")
- for lvl in LEVEL_RANGE:
- message("-- setting logging level to '%s' -----" %
- logging.getLevelName(lvl))
- log.setLevel(lvl)
- doLog(log)
+ def test_specific_filters(self):
+ # Set a specific filter object on the handler, and then add another
+ # filter object on the logger itself.
+ handler = self.root_logger.handlers[0]
+ specific_filter = None
+ garr = GarrulousFilter()
+ handler.addFilter(garr)
+ try:
+ self.log_at_all_levels(self.root_logger)
+ first_lines = [
+ # Notice how 'Garrulous' is missing
+ ('Boring', '1'),
+ ('Chatterbox', '2'),
+ ('Talkative', '4'),
+ ('Verbose', '5'),
+ ('Sociable', '6'),
+ ('Effusive', '7'),
+ ('Terse', '8'),
+ ('Taciturn', '9'),
+ ('Silent', '10'),
+ ]
+ self.assert_log_lines(first_lines)
- log.removeFilter(spec)
- hdlr.removeFilter(garr)
- #Undo the one level which clashes...for regression tests
- logging.addLevelName(logging.DEBUG, "DEBUG")
+ specific_filter = VerySpecificFilter()
+ self.root_logger.addFilter(specific_filter)
+ self.log_at_all_levels(self.root_logger)
+ self.assert_log_lines(first_lines + [
+ # Not only 'Garrulous' is still missing, but also 'Sociable'
+ # and 'Taciturn'
+ ('Boring', '11'),
+ ('Chatterbox', '12'),
+ ('Talkative', '14'),
+ ('Verbose', '15'),
+ ('Effusive', '17'),
+ ('Terse', '18'),
+ ('Silent', '20'),
+ ])
+ finally:
+ if specific_filter:
+ self.root_logger.removeFilter(specific_filter)
+ handler.removeFilter(garr)
-#----------------------------------------------------------------------------
-# Test 2
-#----------------------------------------------------------------------------
-MSG = "-- logging %d at INFO, messages should be seen every 10 events --"
-def test2():
- logger = logging.getLogger("")
- sh = logger.handlers[0]
- sh.close()
- logger.removeHandler(sh)
- mh = logging.handlers.MemoryHandler(10,logging.WARNING, sh)
- logger.setLevel(logging.DEBUG)
- logger.addHandler(mh)
- message("-- logging at DEBUG, nothing should be seen yet --")
- logger.debug("Debug message")
- message("-- logging at INFO, nothing should be seen yet --")
- logger.info("Info message")
- message("-- logging at WARNING, 3 messages should be seen --")
- logger.warn("Warn message")
- for i in xrange(102):
- message(MSG % i)
- logger.info("Info index = %d", i)
- mh.close()
- logger.removeHandler(mh)
- logger.addHandler(sh)
+class MemoryHandlerTest(BaseTest):
-#----------------------------------------------------------------------------
-# Test 3
-#----------------------------------------------------------------------------
+ """Tests for the MemoryHandler."""
-FILTER = "a.b"
+ # Do not bother with a logger name group.
+ expected_log_pat = r"^[\w.]+ -> ([\w]+): ([\d]+)$"
-def doLog3():
- logging.getLogger("a").info("Info 1")
- logging.getLogger("a.b").info("Info 2")
- logging.getLogger("a.c").info("Info 3")
- logging.getLogger("a.b.c").info("Info 4")
- logging.getLogger("a.b.c.d").info("Info 5")
- logging.getLogger("a.bb.c").info("Info 6")
- logging.getLogger("b").info("Info 7")
- logging.getLogger("b.a").info("Info 8")
- logging.getLogger("c.a.b").info("Info 9")
- logging.getLogger("a.bb").info("Info 10")
+ def setUp(self):
+ BaseTest.setUp(self)
+ self.mem_hdlr = logging.handlers.MemoryHandler(10, logging.WARNING,
+ self.root_hdlr)
+ self.mem_logger = logging.getLogger('mem')
+ self.mem_logger.propagate = 0
+ self.mem_logger.addHandler(self.mem_hdlr)
-def test3():
- root = logging.getLogger()
- root.setLevel(logging.DEBUG)
- hand = root.handlers[0]
- message("Unfiltered...")
- doLog3()
- message("Filtered with '%s'..." % FILTER)
- filt = logging.Filter(FILTER)
- hand.addFilter(filt)
- doLog3()
- hand.removeFilter(filt)
+ def tearDown(self):
+ self.mem_hdlr.close()
+ BaseTest.tearDown(self)
-#----------------------------------------------------------------------------
-# Test 4
-#----------------------------------------------------------------------------
+ def test_flush(self):
+ # The memory handler flushes to its target handler based on specific
+ # criteria (message count and message level).
+ self.mem_logger.debug(self.next_message())
+ self.assert_log_lines([])
+ self.mem_logger.info(self.next_message())
+ self.assert_log_lines([])
+ # This will flush because the level is >= logging.WARNING
+ self.mem_logger.warn(self.next_message())
+ lines = [
+ ('DEBUG', '1'),
+ ('INFO', '2'),
+ ('WARNING', '3'),
+ ]
+ self.assert_log_lines(lines)
+ for n in (4, 14):
+ for i in range(9):
+ self.mem_logger.debug(self.next_message())
+ self.assert_log_lines(lines)
+ # This will flush because it's the 10th message since the last
+ # flush.
+ self.mem_logger.debug(self.next_message())
+ lines = lines + [('DEBUG', str(i)) for i in range(n, n + 10)]
+ self.assert_log_lines(lines)
-# config0 is a standard configuration.
-config0 = """
-[loggers]
-keys=root
+ self.mem_logger.debug(self.next_message())
+ self.assert_log_lines(lines)
-[handlers]
-keys=hand1
-[formatters]
-keys=form1
+class ExceptionFormatter(logging.Formatter):
+ """A special exception formatter."""
+ def formatException(self, ei):
+ return "Got a [%s]" % ei[0].__name__
-[logger_root]
-level=NOTSET
-handlers=hand1
-[handler_hand1]
-class=StreamHandler
-level=NOTSET
-formatter=form1
-args=(sys.stdout,)
+class ConfigFileTest(BaseTest):
-[formatter_form1]
-format=%(levelname)s:%(name)s:%(message)s
-datefmt=
-"""
+ """Reading logging config from a .ini-style config file."""
-# config1 adds a little to the standard configuration.
-config1 = """
-[loggers]
-keys=root,parser
+ expected_log_pat = r"^([\w]+) \+\+ ([\w]+)$"
-[handlers]
-keys=hand1
+ # config0 is a standard configuration.
+ config0 = """
+ [loggers]
+ keys=root
-[formatters]
-keys=form1
+ [handlers]
+ keys=hand1
-[logger_root]
-level=NOTSET
-handlers=hand1
+ [formatters]
+ keys=form1
-[logger_parser]
-level=DEBUG
-handlers=hand1
-propagate=1
-qualname=compiler.parser
+ [logger_root]
+ level=WARNING
+ handlers=hand1
-[handler_hand1]
-class=StreamHandler
-level=NOTSET
-formatter=form1
-args=(sys.stdout,)
+ [handler_hand1]
+ class=StreamHandler
+ level=NOTSET
+ formatter=form1
+ args=(sys.stdout,)
-[formatter_form1]
-format=%(levelname)s:%(name)s:%(message)s
-datefmt=
-"""
+ [formatter_form1]
+ format=%(levelname)s ++ %(message)s
+ datefmt=
+ """
-# config2 has a subtle configuration error that should be reported
-config2 = string.replace(config1, "sys.stdout", "sys.stbout")
+ # config1 adds a little to the standard configuration.
+ config1 = """
+ [loggers]
+ keys=root,parser
-# config3 has a less subtle configuration error
-config3 = string.replace(
- config1, "formatter=form1", "formatter=misspelled_name")
+ [handlers]
+ keys=hand1
-def test4():
- for i in range(4):
- conf = globals()['config%d' % i]
- sys.stdout.write('config%d: ' % i)
- loggerDict = logging.getLogger().manager.loggerDict
- logging._acquireLock()
- try:
- saved_handlers = logging._handlers.copy()
- saved_handler_list = logging._handlerList[:]
- saved_loggers = loggerDict.copy()
- finally:
- logging._releaseLock()
+ [formatters]
+ keys=form1
+
+ [logger_root]
+ level=WARNING
+ handlers=
+
+ [logger_parser]
+ level=DEBUG
+ handlers=hand1
+ propagate=1
+ qualname=compiler.parser
+
+ [handler_hand1]
+ class=StreamHandler
+ level=NOTSET
+ formatter=form1
+ args=(sys.stdout,)
+
+ [formatter_form1]
+ format=%(levelname)s ++ %(message)s
+ datefmt=
+ """
+
+ # config2 has a subtle configuration error that should be reported
+ config2 = config1.replace("sys.stdout", "sys.stbout")
+
+ # config3 has a less subtle configuration error
+ config3 = config1.replace("formatter=form1", "formatter=misspelled_name")
+
+ # config4 specifies a custom formatter class to be loaded
+ config4 = """
+ [loggers]
+ keys=root
+
+ [handlers]
+ keys=hand1
+
+ [formatters]
+ keys=form1
+
+ [logger_root]
+ level=NOTSET
+ handlers=hand1
+
+ [handler_hand1]
+ class=StreamHandler
+ level=NOTSET
+ formatter=form1
+ args=(sys.stdout,)
+
+ [formatter_form1]
+ class=""" + __name__ + """.ExceptionFormatter
+ format=%(levelname)s:%(name)s:%(message)s
+ datefmt=
+ """
+
+ # config5 specifies a custom handler class to be loaded
+ config5 = config1.replace('class=StreamHandler', 'class=logging.StreamHandler')
+
+ # config6 uses ', ' delimiters in the handlers and formatters sections
+ config6 = """
+ [loggers]
+ keys=root,parser
+
+ [handlers]
+ keys=hand1, hand2
+
+ [formatters]
+ keys=form1, form2
+
+ [logger_root]
+ level=WARNING
+ handlers=
+
+ [logger_parser]
+ level=DEBUG
+ handlers=hand1
+ propagate=1
+ qualname=compiler.parser
+
+ [handler_hand1]
+ class=StreamHandler
+ level=NOTSET
+ formatter=form1
+ args=(sys.stdout,)
+
+ [handler_hand2]
+ class=StreamHandler
+ level=NOTSET
+ formatter=form1
+ args=(sys.stderr,)
+
+ [formatter_form1]
+ format=%(levelname)s ++ %(message)s
+ datefmt=
+
+ [formatter_form2]
+ format=%(message)s
+ datefmt=
+ """
+
+ def apply_config(self, conf):
try:
fn = tempfile.mktemp(".ini")
f = open(fn, "w")
- f.write(conf)
+ f.write(textwrap.dedent(conf))
f.close()
+ logging.config.fileConfig(fn)
+ finally:
+ os.remove(fn)
+
+ def test_config0_ok(self):
+ # A simple config file which overrides the default settings.
+ with captured_stdout() as output:
+ self.apply_config(self.config0)
+ logger = logging.getLogger()
+ # Won't output anything
+ logger.info(self.next_message())
+ # Outputs a message
+ logger.error(self.next_message())
+ self.assert_log_lines([
+ ('ERROR', '2'),
+ ], stream=output)
+ # Original logger output is empty.
+ self.assert_log_lines([])
+
+ def test_config1_ok(self, config=config1):
+ # A config file defining a sub-parser as well.
+ with captured_stdout() as output:
+ self.apply_config(config)
+ logger = logging.getLogger("compiler.parser")
+ # Both will output a message
+ logger.info(self.next_message())
+ logger.error(self.next_message())
+ self.assert_log_lines([
+ ('INFO', '1'),
+ ('ERROR', '2'),
+ ], stream=output)
+ # Original logger output is empty.
+ self.assert_log_lines([])
+
+ def test_config2_failure(self):
+ # A simple config file which overrides the default settings.
+ self.assertRaises(StandardError, self.apply_config, self.config2)
+
+ def test_config3_failure(self):
+ # A simple config file which overrides the default settings.
+ self.assertRaises(StandardError, self.apply_config, self.config3)
+
+ def test_config4_ok(self):
+ # A config file specifying a custom formatter class.
+ with captured_stdout() as output:
+ self.apply_config(self.config4)
+ logger = logging.getLogger()
try:
- logging.config.fileConfig(fn)
- #call again to make sure cleanup is correct
- logging.config.fileConfig(fn)
- except:
- t = sys.exc_info()[0]
- message(str(t))
- else:
- message('ok.')
- os.remove(fn)
+ raise RuntimeError()
+ except RuntimeError:
+ logging.exception("just testing")
+ sys.stdout.seek(0)
+ self.assertEquals(output.getvalue(),
+ "ERROR:root:just testing\nGot a [RuntimeError]\n")
+ # Original logger output is empty
+ self.assert_log_lines([])
+
+ def test_config5_ok(self):
+ self.test_config1_ok(config=self.config5)
+
+ def test_config6_ok(self):
+ self.test_config1_ok(config=self.config6)
+
+class LogRecordStreamHandler(StreamRequestHandler):
+
+ """Handler for a streaming logging request. It saves the log message in the
+ TCP server's 'log_output' attribute."""
+
+ TCP_LOG_END = "!!!END!!!"
+
+ def handle(self):
+ """Handle multiple requests - each expected to be of 4-byte length,
+ followed by the LogRecord in pickle format. Logs the record
+ according to whatever policy is configured locally."""
+ while True:
+ chunk = self.connection.recv(4)
+ if len(chunk) < 4:
+ break
+ slen = struct.unpack(">L", chunk)[0]
+ chunk = self.connection.recv(slen)
+ while len(chunk) < slen:
+ chunk = chunk + self.connection.recv(slen - len(chunk))
+ obj = self.unpickle(chunk)
+ record = logging.makeLogRecord(obj)
+ self.handle_log_record(record)
+
+ def unpickle(self, data):
+ return cPickle.loads(data)
+
+ def handle_log_record(self, record):
+ # If the end-of-messages sentinel is seen, tell the server to
+ # terminate.
+ if self.TCP_LOG_END in record.msg:
+ self.server.abort = 1
+ return
+ self.server.log_output += record.msg + "\n"
+
+
+class LogRecordSocketReceiver(ThreadingTCPServer):
+
+ """A simple-minded TCP socket-based logging receiver suitable for test
+ purposes."""
+
+ allow_reuse_address = 1
+ log_output = ""
+
+ def __init__(self, host='localhost',
+ port=logging.handlers.DEFAULT_TCP_LOGGING_PORT,
+ handler=LogRecordStreamHandler):
+ ThreadingTCPServer.__init__(self, (host, port), handler)
+ self.abort = False
+ self.timeout = 0.1
+ self.finished = threading.Event()
+
+ def serve_until_stopped(self):
+ if sys.platform.startswith('java'):
+ # XXX: There's a problem using cpython_compatibile_select
+ # here: it seems to be due to the fact that
+ # cpython_compatible_select switches blocking mode on while
+ # a separate thread is reading from the same socket, causing
+ # a read of 0 in LogRecordStreamHandler.handle (which
+ # deadlocks this test)
+ self.socket.setblocking(0)
+ while not self.abort:
+ rd, wr, ex = select.select([self.socket.fileno()], [], [],
+ self.timeout)
+ if rd:
+ self.handle_request()
+ # Notify the main thread that we're about to exit
+ self.finished.set()
+ # close the listen socket
+ self.server_close()
+
+
+class SocketHandlerTest(BaseTest):
+
+ """Test for SocketHandler objects."""
+
+ def setUp(self):
+ """Set up a TCP server to receive log messages, and a SocketHandler
+ pointing to that server's address and port."""
+ BaseTest.setUp(self)
+ self.tcpserver = LogRecordSocketReceiver(port=0)
+ self.port = self.tcpserver.socket.getsockname()[1]
+ self.threads = [
+ threading.Thread(target=self.tcpserver.serve_until_stopped)]
+ for thread in self.threads:
+ thread.start()
+
+ self.sock_hdlr = logging.handlers.SocketHandler('localhost', self.port)
+ self.sock_hdlr.setFormatter(self.root_formatter)
+ self.root_logger.removeHandler(self.root_logger.handlers[0])
+ self.root_logger.addHandler(self.sock_hdlr)
+
+ def tearDown(self):
+ """Shutdown the TCP server."""
+ try:
+ self.tcpserver.abort = True
+ del self.tcpserver
+ self.root_logger.removeHandler(self.sock_hdlr)
+ self.sock_hdlr.close()
+ for thread in self.threads:
+ thread.join(2.0)
finally:
- logging._acquireLock()
+ BaseTest.tearDown(self)
+
+ def get_output(self):
+ """Get the log output as received by the TCP server."""
+ # Signal the TCP receiver and wait for it to terminate.
+ self.root_logger.critical(LogRecordStreamHandler.TCP_LOG_END)
+ self.tcpserver.finished.wait(2.0)
+ return self.tcpserver.log_output
+
+ def test_output(self):
+ # The log message sent to the SocketHandler is properly received.
+ logger = logging.getLogger("tcp")
+ logger.error("spam")
+ logger.debug("eggs")
+ self.assertEquals(self.get_output(), "spam\neggs\n")
+
+
+class MemoryTest(BaseTest):
+
+ """Test memory persistence of logger objects."""
+
+ def setUp(self):
+ """Create a dict to remember potentially destroyed objects."""
+ BaseTest.setUp(self)
+ self._survivors = {}
+
+ def _watch_for_survival(self, *args):
+ """Watch the given objects for survival, by creating weakrefs to
+ them."""
+ for obj in args:
+ key = id(obj), repr(obj)
+ self._survivors[key] = weakref.ref(obj)
+
+ def _assert_survival(self):
+ """Assert that all objects watched for survival have survived."""
+ # Trigger cycle breaking.
+ gc.collect()
+ dead = []
+ for (id_, repr_), ref in self._survivors.items():
+ if ref() is None:
+ dead.append(repr_)
+ if dead:
+ self.fail("%d objects should have survived "
+ "but have been destroyed: %s" % (len(dead), ", ".join(dead)))
+
+ def test_persistent_loggers(self):
+ # Logger objects are persistent and retain their configuration, even
+ # if visible references are destroyed.
+ self.root_logger.setLevel(logging.INFO)
+ foo = logging.getLogger("foo")
+ self._watch_for_survival(foo)
+ foo.setLevel(logging.DEBUG)
+ self.root_logger.debug(self.next_message())
+ foo.debug(self.next_message())
+ self.assert_log_lines([
+ ('foo', 'DEBUG', '2'),
+ ])
+ del foo
+ # foo has survived.
+ self._assert_survival()
+ # foo has retained its settings.
+ bar = logging.getLogger("foo")
+ bar.debug(self.next_message())
+ self.assert_log_lines([
+ ('foo', 'DEBUG', '2'),
+ ('foo', 'DEBUG', '3'),
+ ])
+
+
+class EncodingTest(BaseTest):
+ def test_encoding_plain_file(self):
+ # In Python 2.x, a plain file object is treated as having no encoding.
+ log = logging.getLogger("test")
+ fn = tempfile.mktemp(".log")
+ # the non-ascii data we write to the log.
+ data = "foo\x80"
+ try:
+ handler = logging.FileHandler(fn)
+ log.addHandler(handler)
try:
- logging._handlers.clear()
- logging._handlers.update(saved_handlers)
- logging._handlerList[:] = saved_handler_list
- loggerDict = logging.getLogger().manager.loggerDict
- loggerDict.clear()
- loggerDict.update(saved_loggers)
+ # write non-ascii data to the log.
+ log.warning(data)
finally:
- logging._releaseLock()
+ log.removeHandler(handler)
+ handler.close()
+ # check we wrote exactly those bytes, ignoring trailing \n etc
+ f = open(fn)
+ try:
+ self.failUnlessEqual(f.read().rstrip(), data)
+ finally:
+ f.close()
+ finally:
+ if os.path.isfile(fn):
+ os.remove(fn)
-#----------------------------------------------------------------------------
-# Test 5
-#----------------------------------------------------------------------------
+ def test_encoding_cyrillic_unicode(self):
+ log = logging.getLogger("test")
+ #Get a message in Unicode: Do svidanya in Cyrillic (meaning goodbye)
+ message = u'\u0434\u043e \u0441\u0432\u0438\u0434\u0430\u043d\u0438\u044f'
+ #Ensure it's written in a Cyrillic encoding
+ writer_class = codecs.getwriter('cp1251')
+ writer_class.encoding = 'cp1251'
+ stream = cStringIO.StringIO()
+ writer = writer_class(stream, 'strict')
+ handler = logging.StreamHandler(writer)
+ log.addHandler(handler)
+ try:
+ log.warning(message)
+ finally:
+ log.removeHandler(handler)
+ handler.close()
+ # check we wrote exactly those bytes, ignoring trailing \n etc
+ s = stream.getvalue()
+ #Compare against what the data should be when encoded in CP-1251
+ self.assertEqual(s, '\xe4\xee \xf1\xe2\xe8\xe4\xe0\xed\xe8\xff\n')
-test5_config = """
-[loggers]
-keys=root
-
-[handlers]
-keys=hand1
-
-[formatters]
-keys=form1
-
-[logger_root]
-level=NOTSET
-handlers=hand1
-
-[handler_hand1]
-class=StreamHandler
-level=NOTSET
-formatter=form1
-args=(sys.stdout,)
-
-[formatter_form1]
-class=test.test_logging.FriendlyFormatter
-format=%(levelname)s:%(name)s:%(message)s
-datefmt=
-"""
-
-class FriendlyFormatter (logging.Formatter):
- def formatException(self, ei):
- return "%s... Don't panic!" % str(ei[0])
-
-
-def test5():
- loggerDict = logging.getLogger().manager.loggerDict
- logging._acquireLock()
- try:
- saved_handlers = logging._handlers.copy()
- saved_handler_list = logging._handlerList[:]
- saved_loggers = loggerDict.copy()
- finally:
- logging._releaseLock()
- try:
- fn = tempfile.mktemp(".ini")
- f = open(fn, "w")
- f.write(test5_config)
- f.close()
- logging.config.fileConfig(fn)
- try:
- raise KeyError
- except KeyError:
- logging.exception("just testing")
- os.remove(fn)
- hdlr = logging.getLogger().handlers[0]
- logging.getLogger().handlers.remove(hdlr)
- finally:
- logging._acquireLock()
- try:
- logging._handlers.clear()
- logging._handlers.update(saved_handlers)
- logging._handlerList[:] = saved_handler_list
- loggerDict = logging.getLogger().manager.loggerDict
- loggerDict.clear()
- loggerDict.update(saved_loggers)
- finally:
- logging._releaseLock()
-
-
-#----------------------------------------------------------------------------
-# Test Harness
-#----------------------------------------------------------------------------
-def banner(nm, typ):
- sep = BANNER % (nm, typ)
- sys.stdout.write(sep)
- sys.stdout.flush()
-
-def test_main_inner():
- rootLogger = logging.getLogger("")
- rootLogger.setLevel(logging.DEBUG)
- hdlr = logging.StreamHandler(sys.stdout)
- fmt = logging.Formatter(logging.BASIC_FORMAT)
- hdlr.setFormatter(fmt)
- rootLogger.addHandler(hdlr)
-
- # Find an unused port number
- port = logging.handlers.DEFAULT_TCP_LOGGING_PORT
- while port < logging.handlers.DEFAULT_TCP_LOGGING_PORT+100:
- try:
- tcpserver = LogRecordSocketReceiver(port=port)
- except socket.error:
- port += 1
- else:
- break
- else:
- raise ImportError, "Could not find unused port"
-
-
- #Set up a handler such that all events are sent via a socket to the log
- #receiver (logrecv).
- #The handler will only be added to the rootLogger for some of the tests
- shdlr = logging.handlers.SocketHandler('localhost', port)
-
- #Configure the logger for logrecv so events do not propagate beyond it.
- #The sockLogger output is buffered in memory until the end of the test,
- #and printed at the end.
- sockOut = cStringIO.StringIO()
- sockLogger = logging.getLogger("logrecv")
- sockLogger.setLevel(logging.DEBUG)
- sockhdlr = logging.StreamHandler(sockOut)
- sockhdlr.setFormatter(logging.Formatter(
- "%(name)s -> %(levelname)s: %(message)s"))
- sockLogger.addHandler(sockhdlr)
- sockLogger.propagate = 0
-
- #Set up servers
- threads = []
- #sys.stdout.write("About to start TCP server...\n")
- threads.append(threading.Thread(target=runTCP, args=(tcpserver,)))
-
- for thread in threads:
- thread.start()
- try:
- banner("log_test0", "begin")
-
- rootLogger.addHandler(shdlr)
- test0()
- # XXX(nnorwitz): Try to fix timing related test failures.
- # This sleep gives us some extra time to read messages.
- # The test generally only fails on Solaris without this sleep.
- time.sleep(2.0)
- shdlr.close()
- rootLogger.removeHandler(shdlr)
-
- banner("log_test0", "end")
-
- for t in range(1,6):
- banner("log_test%d" % t, "begin")
- globals()['test%d' % t]()
- banner("log_test%d" % t, "end")
-
- finally:
- #wait for TCP receiver to terminate
- socketDataProcessed.wait()
- # ensure the server dies
- tcpserver.abort = 1
- for thread in threads:
- thread.join(2.0)
- banner("logrecv output", "begin")
- sys.stdout.write(sockOut.getvalue())
- sockOut.close()
- sockLogger.removeHandler(sockhdlr)
- sockhdlr.close()
- banner("logrecv output", "end")
- sys.stdout.flush()
- try:
- hdlr.close()
- except:
- pass
- rootLogger.removeHandler(hdlr)
# Set the locale to the platform-dependent default. I have no idea
# why the test does this, but in any case we save the current locale
# first and restore it at the end.
@run_with_locale('LC_ALL', '')
def test_main():
- # Save and restore the original root logger level across the tests.
- # Otherwise, e.g., if any test using cookielib runs after test_logging,
- # cookielib's debug-level logger tries to log messages, leading to
- # confusing:
- # No handlers could be found for logger "cookielib"
- # output while the tests are running.
- root_logger = logging.getLogger("")
- original_logging_level = root_logger.getEffectiveLevel()
- try:
- test_main_inner()
- finally:
- root_logger.setLevel(original_logging_level)
+ run_unittest(BuiltinLevelsTest, BasicFilterTest,
+ CustomLevelsAndFiltersTest, MemoryHandlerTest,
+ ConfigFileTest, SocketHandlerTest, MemoryTest,
+ EncodingTest)
if __name__ == "__main__":
- sys.stdout.write("test_logging\n")
test_main()
diff --git a/Lib/test/test_operator.py b/Lib/test/test_operator.py
--- a/Lib/test/test_operator.py
+++ b/Lib/test/test_operator.py
@@ -57,6 +57,7 @@
class C(object):
def __eq__(self, other):
raise SyntaxError
+ __hash__ = None # Silence Py3k warning
self.failUnlessRaises(TypeError, operator.eq)
self.failUnlessRaises(SyntaxError, operator.eq, C(), C())
self.failIf(operator.eq(1, 0))
@@ -193,7 +194,9 @@
class C:
pass
def check(self, o, v):
- self.assert_(operator.isCallable(o) == callable(o) == v)
+ self.assertEqual(operator.isCallable(o), v)
+ with test_support._check_py3k_warnings():
+ self.assertEqual(callable(o), v)
check(self, 4, 0)
check(self, operator.isCallable, 1)
check(self, C, 1)
@@ -305,12 +308,12 @@
self.assertRaises(ValueError, operator.rshift, 2, -1)
def test_contains(self):
- self.failUnlessRaises(TypeError, operator.contains)
- self.failUnlessRaises(TypeError, operator.contains, None, None)
- self.failUnless(operator.contains(range(4), 2))
- self.failIf(operator.contains(range(4), 5))
- self.failUnless(operator.sequenceIncludes(range(4), 2))
- self.failIf(operator.sequenceIncludes(range(4), 5))
+ self.assertRaises(TypeError, operator.contains)
+ self.assertRaises(TypeError, operator.contains, None, None)
+ self.assertTrue(operator.contains(range(4), 2))
+ self.assertFalse(operator.contains(range(4), 5))
+ self.assertTrue(operator.sequenceIncludes(range(4), 2))
+ self.assertFalse(operator.sequenceIncludes(range(4), 5))
def test_setitem(self):
a = range(3)
@@ -386,9 +389,29 @@
self.assertRaises(TypeError, operator.attrgetter('x', (), 'y'), record)
class C(object):
- def __getattr(self, name):
+ def __getattr__(self, name):
raise SyntaxError
- self.failUnlessRaises(AttributeError, operator.attrgetter('foo'), C())
+ self.failUnlessRaises(SyntaxError, operator.attrgetter('foo'), C())
+
+ # recursive gets
+ a = A()
+ a.name = 'arthur'
+ a.child = A()
+ a.child.name = 'thomas'
+ f = operator.attrgetter('child.name')
+ self.assertEqual(f(a), 'thomas')
+ self.assertRaises(AttributeError, f, a.child)
+ f = operator.attrgetter('name', 'child.name')
+ self.assertEqual(f(a), ('arthur', 'thomas'))
+ f = operator.attrgetter('name', 'child.name', 'child.child.name')
+ self.assertRaises(AttributeError, f, a)
+
+ a.child.child = A()
+ a.child.child.name = 'johnson'
+ f = operator.attrgetter('child.child.name')
+ self.assertEqual(f(a), 'johnson')
+ f = operator.attrgetter('name', 'child.name', 'child.child.name')
+ self.assertEqual(f(a), ('arthur', 'thomas', 'johnson'))
def test_itemgetter(self):
a = 'ABCDE'
@@ -398,9 +421,9 @@
self.assertRaises(IndexError, f, a)
class C(object):
- def __getitem(self, name):
+ def __getitem__(self, name):
raise SyntaxError
- self.failUnlessRaises(TypeError, operator.itemgetter(42), C())
+ self.failUnlessRaises(SyntaxError, operator.itemgetter(42), C())
f = operator.itemgetter('name')
self.assertRaises(TypeError, f, a)
@@ -424,6 +447,24 @@
self.assertEqual(operator.itemgetter(2,10,5)(data), ('2', '10', '5'))
self.assertRaises(TypeError, operator.itemgetter(2, 'x', 5), data)
+ def test_methodcaller(self):
+ self.assertRaises(TypeError, operator.methodcaller)
+ class A:
+ def foo(self, *args, **kwds):
+ return args[0] + args[1]
+ def bar(self, f=42):
+ return f
+ a = A()
+ f = operator.methodcaller('foo')
+ self.assertRaises(IndexError, f, a)
+ f = operator.methodcaller('foo', 1, 2)
+ self.assertEquals(f(a), 3)
+ f = operator.methodcaller('bar')
+ self.assertEquals(f(a), 42)
+ self.assertRaises(TypeError, f, a, a)
+ f = operator.methodcaller('bar', f=5)
+ self.assertEquals(f(a), 5)
+
def test_inplace(self):
class C(object):
def __iadd__ (self, other): return "iadd"
diff --git a/Lib/test/test_os.py b/Lib/test/test_os.py
--- a/Lib/test/test_os.py
+++ b/Lib/test/test_os.py
@@ -23,6 +23,30 @@
os.close(f)
self.assert_(os.access(test_support.TESTFN, os.W_OK))
+ def test_closerange(self):
+ first = os.open(test_support.TESTFN, os.O_CREAT|os.O_RDWR)
+ # We must allocate two consecutive file descriptors, otherwise
+ # it will mess up other file descriptors (perhaps even the three
+ # standard ones).
+ second = os.dup(first)
+ try:
+ retries = 0
+ while second != first + 1:
+ os.close(first)
+ retries += 1
+ if retries > 10:
+ # XXX test skipped
+ print >> sys.stderr, (
+ "couldn't allocate two consecutive fds, "
+ "skipping test_closerange")
+ return
+ first, second = second, os.dup(second)
+ finally:
+ os.close(second)
+ # close a fd that is open, and one that isn't
+ os.closerange(first, first + 2)
+ self.assertRaises(OSError, os.write, first, "a")
+
def test_rename(self):
path = unicode(test_support.TESTFN)
if not test_support.is_jython:
@@ -223,7 +247,6 @@
if not hasattr(os, "statvfs"):
return
- import statvfs
try:
result = os.statvfs(self.fname)
except OSError, e:
@@ -233,16 +256,13 @@
return
# Make sure direct access works
- self.assertEquals(result.f_bfree, result[statvfs.F_BFREE])
+ self.assertEquals(result.f_bfree, result[3])
- # Make sure all the attributes are there
- members = dir(result)
- for name in dir(statvfs):
- if name[:2] == 'F_':
- attr = name.lower()
- self.assertEquals(getattr(result, attr),
- result[getattr(statvfs, name)])
- self.assert_(attr in members)
+ # Make sure all the attributes are there.
+ members = ('bsize', 'frsize', 'blocks', 'bfree', 'bavail', 'files',
+ 'ffree', 'favail', 'flag', 'namemax')
+ for value, member in enumerate(members):
+ self.assertEquals(getattr(result, 'f_' + member), result[value])
# Make sure that assignment really fails
try:
@@ -270,6 +290,15 @@
except TypeError:
pass
+ def test_utime_dir(self):
+ delta = 1000000
+ st = os.stat(test_support.TESTFN)
+ # round to int, because some systems may support sub-second
+ # time stamps in stat, but not in utime.
+ os.utime(test_support.TESTFN, (st.st_atime, int(st.st_mtime-delta)))
+ st2 = os.stat(test_support.TESTFN)
+ self.assertEquals(st2.st_mtime, int(st.st_mtime-delta))
+
# Restrict test to Win32, since there is no guarantee other
# systems support centiseconds
if sys.platform == 'win32':
@@ -292,7 +321,7 @@
try:
os.stat(r"c:\pagefile.sys")
except WindowsError, e:
- if e == 2: # file does not exist; cannot run test
+ if e.errno == 2: # file does not exist; cannot run test
return
self.fail("Could not stat pagefile.sys")
@@ -328,75 +357,104 @@
from os.path import join
# Build:
- # TESTFN/ a file kid and two directory kids
+ # TESTFN/
+ # TEST1/ a file kid and two directory kids
# tmp1
# SUB1/ a file kid and a directory kid
- # tmp2
- # SUB11/ no kids
- # SUB2/ just a file kid
- # tmp3
- sub1_path = join(test_support.TESTFN, "SUB1")
+ # tmp2
+ # SUB11/ no kids
+ # SUB2/ a file kid and a dirsymlink kid
+ # tmp3
+ # link/ a symlink to TESTFN.2
+ # TEST2/
+ # tmp4 a lone file
+ walk_path = join(test_support.TESTFN, "TEST1")
+ sub1_path = join(walk_path, "SUB1")
sub11_path = join(sub1_path, "SUB11")
- sub2_path = join(test_support.TESTFN, "SUB2")
- tmp1_path = join(test_support.TESTFN, "tmp1")
+ sub2_path = join(walk_path, "SUB2")
+ tmp1_path = join(walk_path, "tmp1")
tmp2_path = join(sub1_path, "tmp2")
tmp3_path = join(sub2_path, "tmp3")
+ link_path = join(sub2_path, "link")
+ t2_path = join(test_support.TESTFN, "TEST2")
+ tmp4_path = join(test_support.TESTFN, "TEST2", "tmp4")
# Create stuff.
os.makedirs(sub11_path)
os.makedirs(sub2_path)
- for path in tmp1_path, tmp2_path, tmp3_path:
+ os.makedirs(t2_path)
+ for path in tmp1_path, tmp2_path, tmp3_path, tmp4_path:
f = file(path, "w")
f.write("I'm " + path + " and proud of it. Blame test_os.\n")
f.close()
+ if hasattr(os, "symlink"):
+ os.symlink(os.path.abspath(t2_path), link_path)
+ sub2_tree = (sub2_path, ["link"], ["tmp3"])
+ else:
+ sub2_tree = (sub2_path, [], ["tmp3"])
# Walk top-down.
- all = list(os.walk(test_support.TESTFN))
+ all = list(os.walk(walk_path))
self.assertEqual(len(all), 4)
# We can't know which order SUB1 and SUB2 will appear in.
# Not flipped: TESTFN, SUB1, SUB11, SUB2
# flipped: TESTFN, SUB2, SUB1, SUB11
flipped = all[0][1][0] != "SUB1"
all[0][1].sort()
- self.assertEqual(all[0], (test_support.TESTFN, ["SUB1", "SUB2"], ["tmp1"]))
+ self.assertEqual(all[0], (walk_path, ["SUB1", "SUB2"], ["tmp1"]))
self.assertEqual(all[1 + flipped], (sub1_path, ["SUB11"], ["tmp2"]))
self.assertEqual(all[2 + flipped], (sub11_path, [], []))
- self.assertEqual(all[3 - 2 * flipped], (sub2_path, [], ["tmp3"]))
+ self.assertEqual(all[3 - 2 * flipped], sub2_tree)
# Prune the search.
all = []
- for root, dirs, files in os.walk(test_support.TESTFN):
+ for root, dirs, files in os.walk(walk_path):
all.append((root, dirs, files))
# Don't descend into SUB1.
if 'SUB1' in dirs:
# Note that this also mutates the dirs we appended to all!
dirs.remove('SUB1')
self.assertEqual(len(all), 2)
- self.assertEqual(all[0], (test_support.TESTFN, ["SUB2"], ["tmp1"]))
- self.assertEqual(all[1], (sub2_path, [], ["tmp3"]))
+ self.assertEqual(all[0], (walk_path, ["SUB2"], ["tmp1"]))
+ self.assertEqual(all[1], sub2_tree)
# Walk bottom-up.
- all = list(os.walk(test_support.TESTFN, topdown=False))
+ all = list(os.walk(walk_path, topdown=False))
self.assertEqual(len(all), 4)
# We can't know which order SUB1 and SUB2 will appear in.
# Not flipped: SUB11, SUB1, SUB2, TESTFN
# flipped: SUB2, SUB11, SUB1, TESTFN
flipped = all[3][1][0] != "SUB1"
all[3][1].sort()
- self.assertEqual(all[3], (test_support.TESTFN, ["SUB1", "SUB2"], ["tmp1"]))
+ self.assertEqual(all[3], (walk_path, ["SUB1", "SUB2"], ["tmp1"]))
self.assertEqual(all[flipped], (sub11_path, [], []))
self.assertEqual(all[flipped + 1], (sub1_path, ["SUB11"], ["tmp2"]))
- self.assertEqual(all[2 - 2 * flipped], (sub2_path, [], ["tmp3"]))
+ self.assertEqual(all[2 - 2 * flipped], sub2_tree)
+ if hasattr(os, "symlink"):
+ # Walk, following symlinks.
+ for root, dirs, files in os.walk(walk_path, followlinks=True):
+ if root == link_path:
+ self.assertEqual(dirs, [])
+ self.assertEqual(files, ["tmp4"])
+ break
+ else:
+ self.fail("Didn't follow symlink with followlinks=True")
+
+ def tearDown(self):
# Tear everything down. This is a decent use for bottom-up on
# Windows, which doesn't have a recursive delete command. The
# (not so) subtlety is that rmdir will fail unless the dir's
# kids are removed first, so bottom up is essential.
for root, dirs, files in os.walk(test_support.TESTFN, topdown=False):
for name in files:
- os.remove(join(root, name))
+ os.remove(os.path.join(root, name))
for name in dirs:
- os.rmdir(join(root, name))
+ dirname = os.path.join(root, name)
+ if not os.path.islink(dirname):
+ os.rmdir(dirname)
+ else:
+ os.remove(dirname)
os.rmdir(test_support.TESTFN)
class MakedirTests (unittest.TestCase):
@@ -444,10 +502,15 @@
class URandomTests (unittest.TestCase):
def test_urandom(self):
try:
- self.assertEqual(len(os.urandom(1)), 1)
- self.assertEqual(len(os.urandom(10)), 10)
- self.assertEqual(len(os.urandom(100)), 100)
- self.assertEqual(len(os.urandom(1000)), 1000)
+ with test_support.check_warnings():
+ self.assertEqual(len(os.urandom(1)), 1)
+ self.assertEqual(len(os.urandom(10)), 10)
+ self.assertEqual(len(os.urandom(100)), 100)
+ self.assertEqual(len(os.urandom(1000)), 1000)
+ # see http://bugs.python.org/issue3708
+ self.assertEqual(len(os.urandom(0.9)), 0)
+ self.assertEqual(len(os.urandom(1.1)), 1)
+ self.assertEqual(len(os.urandom(2.0)), 2)
except NotImplementedError:
pass
@@ -473,10 +536,143 @@
def test_chmod(self):
self.assertRaises(WindowsError, os.utime, test_support.TESTFN, 0)
+class TestInvalidFD(unittest.TestCase):
+ singles = ["fchdir", "fdopen", "dup", "fdatasync", "fstat",
+ "fstatvfs", "fsync", "tcgetpgrp", "ttyname"]
+ #singles.append("close")
+ #We omit close because it doesn'r raise an exception on some platforms
+ def get_single(f):
+ def helper(self):
+ if hasattr(os, f):
+ self.check(getattr(os, f))
+ return helper
+ for f in singles:
+ locals()["test_"+f] = get_single(f)
+
+ def check(self, f, *args):
+ self.assertRaises(OSError, f, test_support.make_bad_fd(), *args)
+
+ def test_isatty(self):
+ if hasattr(os, "isatty"):
+ self.assertEqual(os.isatty(test_support.make_bad_fd()), False)
+
+ def test_closerange(self):
+ if hasattr(os, "closerange"):
+ fd = test_support.make_bad_fd()
+ # Make sure none of the descriptors we are about to close are
+ # currently valid (issue 6542).
+ for i in range(10):
+ try: os.fstat(fd+i)
+ except OSError:
+ pass
+ else:
+ break
+ if i < 2:
+ # Unable to acquire a range of invalid file descriptors,
+ # so skip the test (in 2.6+ this is a unittest.SkipTest).
+ return
+ self.assertEqual(os.closerange(fd, fd + i-1), None)
+
+ def test_dup2(self):
+ if hasattr(os, "dup2"):
+ self.check(os.dup2, 20)
+
+ def test_fchmod(self):
+ if hasattr(os, "fchmod"):
+ self.check(os.fchmod, 0)
+
+ def test_fchown(self):
+ if hasattr(os, "fchown"):
+ self.check(os.fchown, -1, -1)
+
+ def test_fpathconf(self):
+ if hasattr(os, "fpathconf"):
+ self.check(os.fpathconf, "PC_NAME_MAX")
+
+ #this is a weird one, it raises IOError unlike the others
+ def test_ftruncate(self):
+ if hasattr(os, "ftruncate"):
+ self.assertRaises(IOError, os.ftruncate, test_support.make_bad_fd(),
+ 0)
+
+ def test_lseek(self):
+ if hasattr(os, "lseek"):
+ self.check(os.lseek, 0, 0)
+
+ def test_read(self):
+ if hasattr(os, "read"):
+ self.check(os.read, 1)
+
+ def test_tcsetpgrpt(self):
+ if hasattr(os, "tcsetpgrp"):
+ self.check(os.tcsetpgrp, 0)
+
+ def test_write(self):
+ if hasattr(os, "write"):
+ self.check(os.write, " ")
+
if sys.platform != 'win32':
class Win32ErrorTests(unittest.TestCase):
pass
+ class PosixUidGidTests(unittest.TestCase):
+ if hasattr(os, 'setuid'):
+ def test_setuid(self):
+ if os.getuid() != 0:
+ self.assertRaises(os.error, os.setuid, 0)
+ self.assertRaises(OverflowError, os.setuid, 1<<32)
+
+ if hasattr(os, 'setgid'):
+ def test_setgid(self):
+ if os.getuid() != 0:
+ self.assertRaises(os.error, os.setgid, 0)
+ self.assertRaises(OverflowError, os.setgid, 1<<32)
+
+ if hasattr(os, 'seteuid'):
+ def test_seteuid(self):
+ if os.getuid() != 0:
+ self.assertRaises(os.error, os.seteuid, 0)
+ self.assertRaises(OverflowError, os.seteuid, 1<<32)
+
+ if hasattr(os, 'setegid'):
+ def test_setegid(self):
+ if os.getuid() != 0:
+ self.assertRaises(os.error, os.setegid, 0)
+ self.assertRaises(OverflowError, os.setegid, 1<<32)
+
+ if hasattr(os, 'setreuid'):
+ def test_setreuid(self):
+ if os.getuid() != 0:
+ self.assertRaises(os.error, os.setreuid, 0, 0)
+ self.assertRaises(OverflowError, os.setreuid, 1<<32, 0)
+ self.assertRaises(OverflowError, os.setreuid, 0, 1<<32)
+
+ def test_setreuid_neg1(self):
+ # Needs to accept -1. We run this in a subprocess to avoid
+ # altering the test runner's process state (issue8045).
+ import subprocess
+ subprocess.check_call([
+ sys.executable, '-c',
+ 'import os,sys;os.setreuid(-1,-1);sys.exit(0)'])
+
+ if hasattr(os, 'setregid'):
+ def test_setregid(self):
+ if os.getuid() != 0:
+ self.assertRaises(os.error, os.setregid, 0, 0)
+ self.assertRaises(OverflowError, os.setregid, 1<<32, 0)
+ self.assertRaises(OverflowError, os.setregid, 0, 1<<32)
+
+ def test_setregid_neg1(self):
+ # Needs to accept -1. We run this in a subprocess to avoid
+ # altering the test runner's process state (issue8045).
+ import subprocess
+ subprocess.check_call([
+ sys.executable, '-c',
+ 'import os,sys;os.setregid(-1,-1);sys.exit(0)'])
+else:
+ class PosixUidGidTests(unittest.TestCase):
+ pass
+
def test_main():
test_support.run_unittest(
FileTests,
@@ -487,7 +683,9 @@
MakedirTests,
DevNullTests,
URandomTests,
- Win32ErrorTests
+ Win32ErrorTests,
+ TestInvalidFD,
+ PosixUidGidTests
)
if __name__ == "__main__":
diff --git a/Lib/test/test_pkgimport.py b/Lib/test/test_pkgimport.py
--- a/Lib/test/test_pkgimport.py
+++ b/Lib/test/test_pkgimport.py
@@ -6,14 +6,14 @@
def __init__(self, *args, **kw):
self.package_name = 'PACKAGE_'
- while sys.modules.has_key(self.package_name):
+ while self.package_name in sys.modules:
self.package_name += random.choose(string.letters)
self.module_name = self.package_name + '.foo'
unittest.TestCase.__init__(self, *args, **kw)
def remove_modules(self):
for module_name in (self.package_name, self.module_name):
- if sys.modules.has_key(module_name):
+ if module_name in sys.modules:
del sys.modules[module_name]
def setUp(self):
@@ -59,8 +59,8 @@
try: __import__(self.module_name)
except SyntaxError: pass
else: raise RuntimeError, 'Failed to induce SyntaxError'
- self.assert_(not sys.modules.has_key(self.module_name) and
- not hasattr(sys.modules[self.package_name], 'foo'))
+ self.assertTrue(self.module_name not in sys.modules)
+ self.assertFalse(hasattr(sys.modules[self.package_name], 'foo'))
# ...make up a variable name that isn't bound in __builtins__
import __builtin__
diff --git a/Lib/test/test_pprint.py b/Lib/test/test_pprint.py
--- a/Lib/test/test_pprint.py
+++ b/Lib/test/test_pprint.py
@@ -1,6 +1,7 @@
import pprint
import test.test_support
import unittest
+import test.test_set
try:
uni = unicode
@@ -39,20 +40,19 @@
def test_basic(self):
# Verify .isrecursive() and .isreadable() w/o recursion
- verify = self.assert_
pp = pprint.PrettyPrinter()
for safe in (2, 2.0, 2j, "abc", [3], (2,2), {3: 3}, uni("yaddayadda"),
self.a, self.b):
# module-level convenience functions
- verify(not pprint.isrecursive(safe),
- "expected not isrecursive for %r" % (safe,))
- verify(pprint.isreadable(safe),
- "expected isreadable for %r" % (safe,))
+ self.assertFalse(pprint.isrecursive(safe),
+ "expected not isrecursive for %r" % (safe,))
+ self.assertTrue(pprint.isreadable(safe),
+ "expected isreadable for %r" % (safe,))
# PrettyPrinter methods
- verify(not pp.isrecursive(safe),
- "expected not isrecursive for %r" % (safe,))
- verify(pp.isreadable(safe),
- "expected isreadable for %r" % (safe,))
+ self.assertFalse(pp.isrecursive(safe),
+ "expected not isrecursive for %r" % (safe,))
+ self.assertTrue(pp.isreadable(safe),
+ "expected isreadable for %r" % (safe,))
def test_knotted(self):
# Verify .isrecursive() and .isreadable() w/ recursion
@@ -62,14 +62,13 @@
self.d = {}
self.d[0] = self.d[1] = self.d[2] = self.d
- verify = self.assert_
pp = pprint.PrettyPrinter()
for icky in self.a, self.b, self.d, (self.d, self.d):
- verify(pprint.isrecursive(icky), "expected isrecursive")
- verify(not pprint.isreadable(icky), "expected not isreadable")
- verify(pp.isrecursive(icky), "expected isrecursive")
- verify(not pp.isreadable(icky), "expected not isreadable")
+ self.assertTrue(pprint.isrecursive(icky), "expected isrecursive")
+ self.assertFalse(pprint.isreadable(icky), "expected not isreadable")
+ self.assertTrue(pp.isrecursive(icky), "expected isrecursive")
+ self.assertFalse(pp.isreadable(icky), "expected not isreadable")
# Break the cycles.
self.d.clear()
@@ -78,31 +77,30 @@
for safe in self.a, self.b, self.d, (self.d, self.d):
# module-level convenience functions
- verify(not pprint.isrecursive(safe),
- "expected not isrecursive for %r" % (safe,))
- verify(pprint.isreadable(safe),
- "expected isreadable for %r" % (safe,))
+ self.assertFalse(pprint.isrecursive(safe),
+ "expected not isrecursive for %r" % (safe,))
+ self.assertTrue(pprint.isreadable(safe),
+ "expected isreadable for %r" % (safe,))
# PrettyPrinter methods
- verify(not pp.isrecursive(safe),
- "expected not isrecursive for %r" % (safe,))
- verify(pp.isreadable(safe),
- "expected isreadable for %r" % (safe,))
+ self.assertFalse(pp.isrecursive(safe),
+ "expected not isrecursive for %r" % (safe,))
+ self.assertTrue(pp.isreadable(safe),
+ "expected isreadable for %r" % (safe,))
def test_unreadable(self):
# Not recursive but not readable anyway
- verify = self.assert_
pp = pprint.PrettyPrinter()
for unreadable in type(3), pprint, pprint.isrecursive:
# module-level convenience functions
- verify(not pprint.isrecursive(unreadable),
- "expected not isrecursive for %r" % (unreadable,))
- verify(not pprint.isreadable(unreadable),
- "expected not isreadable for %r" % (unreadable,))
+ self.assertFalse(pprint.isrecursive(unreadable),
+ "expected not isrecursive for %r" % (unreadable,))
+ self.assertFalse(pprint.isreadable(unreadable),
+ "expected not isreadable for %r" % (unreadable,))
# PrettyPrinter methods
- verify(not pp.isrecursive(unreadable),
- "expected not isrecursive for %r" % (unreadable,))
- verify(not pp.isreadable(unreadable),
- "expected not isreadable for %r" % (unreadable,))
+ self.assertFalse(pp.isrecursive(unreadable),
+ "expected not isrecursive for %r" % (unreadable,))
+ self.assertFalse(pp.isreadable(unreadable),
+ "expected not isreadable for %r" % (unreadable,))
def test_same_as_repr(self):
# Simple objects, small containers and classes that overwrite __repr__
@@ -113,12 +111,11 @@
# it sorted a dict display if and only if the display required
# multiple lines. For that reason, dicts with more than one element
# aren't tested here.
- verify = self.assert_
for simple in (0, 0L, 0+0j, 0.0, "", uni(""),
(), tuple2(), tuple3(),
[], list2(), list3(),
{}, dict2(), dict3(),
- verify, pprint,
+ self.assertTrue, pprint,
-6, -6L, -6-6j, -1.5, "x", uni("x"), (3,), [3], {3: 6},
(1,2), [3,4], {5: 6},
tuple2((1,2)), tuple3((1,2)), tuple3(range(100)),
@@ -130,8 +127,9 @@
for function in "pformat", "saferepr":
f = getattr(pprint, function)
got = f(simple)
- verify(native == got, "expected %s got %s from pprint.%s" %
- (native, got, function))
+ self.assertEqual(native, got,
+ "expected %s got %s from pprint.%s" %
+ (native, got, function))
def test_basic_line_wrap(self):
# verify basic line-wrapping operation
@@ -169,6 +167,17 @@
for type in [list, list2]:
self.assertEqual(pprint.pformat(type(o), indent=4), exp)
+ def test_nested_indentations(self):
+ o1 = list(range(10))
+ o2 = dict(first=1, second=2, third=3)
+ o = [o1, o2]
+ expected = """\
+[ [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
+ { 'first': 1,
+ 'second': 2,
+ 'third': 3}]"""
+ self.assertEqual(pprint.pformat(o, indent=4, width=42), expected)
+
def test_sorted_dict(self):
# Starting in Python 2.5, pprint sorts dict displays by key regardless
# of how small the dictionary may be.
@@ -195,6 +204,212 @@
others.should.not.be: like.this}"""
self.assertEqual(DottedPrettyPrinter().pformat(o), exp)
+ def test_set_reprs(self):
+ self.assertEqual(pprint.pformat(set()), 'set()')
+ self.assertEqual(pprint.pformat(set(range(3))), 'set([0, 1, 2])')
+ self.assertEqual(pprint.pformat(frozenset()), 'frozenset()')
+ self.assertEqual(pprint.pformat(frozenset(range(3))), 'frozenset([0, 1, 2])')
+ cube_repr_tgt = """\
+{frozenset([]): frozenset([frozenset([2]), frozenset([0]), frozenset([1])]),
+ frozenset([0]): frozenset([frozenset(),
+ frozenset([0, 2]),
+ frozenset([0, 1])]),
+ frozenset([1]): frozenset([frozenset(),
+ frozenset([1, 2]),
+ frozenset([0, 1])]),
+ frozenset([2]): frozenset([frozenset(),
+ frozenset([1, 2]),
+ frozenset([0, 2])]),
+ frozenset([1, 2]): frozenset([frozenset([2]),
+ frozenset([1]),
+ frozenset([0, 1, 2])]),
+ frozenset([0, 2]): frozenset([frozenset([2]),
+ frozenset([0]),
+ frozenset([0, 1, 2])]),
+ frozenset([0, 1]): frozenset([frozenset([0]),
+ frozenset([1]),
+ frozenset([0, 1, 2])]),
+ frozenset([0, 1, 2]): frozenset([frozenset([1, 2]),
+ frozenset([0, 2]),
+ frozenset([0, 1])])}"""
+ cube = test.test_set.cube(3)
+ self.assertEqual(pprint.pformat(cube), cube_repr_tgt)
+ cubo_repr_tgt = """\
+{frozenset([frozenset([0, 2]), frozenset([0])]): frozenset([frozenset([frozenset([0,
+ 2]),
+ frozenset([0,
+ 1,
+ 2])]),
+ frozenset([frozenset([0]),
+ frozenset([0,
+ 1])]),
+ frozenset([frozenset(),
+ frozenset([0])]),
+ frozenset([frozenset([2]),
+ frozenset([0,
+ 2])])]),
+ frozenset([frozenset([0, 1]), frozenset([1])]): frozenset([frozenset([frozenset([0,
+ 1]),
+ frozenset([0,
+ 1,
+ 2])]),
+ frozenset([frozenset([0]),
+ frozenset([0,
+ 1])]),
+ frozenset([frozenset([1]),
+ frozenset([1,
+ 2])]),
+ frozenset([frozenset(),
+ frozenset([1])])]),
+ frozenset([frozenset([1, 2]), frozenset([1])]): frozenset([frozenset([frozenset([1,
+ 2]),
+ frozenset([0,
+ 1,
+ 2])]),
+ frozenset([frozenset([2]),
+ frozenset([1,
+ 2])]),
+ frozenset([frozenset(),
+ frozenset([1])]),
+ frozenset([frozenset([1]),
+ frozenset([0,
+ 1])])]),
+ frozenset([frozenset([1, 2]), frozenset([2])]): frozenset([frozenset([frozenset([1,
+ 2]),
+ frozenset([0,
+ 1,
+ 2])]),
+ frozenset([frozenset([1]),
+ frozenset([1,
+ 2])]),
+ frozenset([frozenset([2]),
+ frozenset([0,
+ 2])]),
+ frozenset([frozenset(),
+ frozenset([2])])]),
+ frozenset([frozenset([]), frozenset([0])]): frozenset([frozenset([frozenset([0]),
+ frozenset([0,
+ 1])]),
+ frozenset([frozenset([0]),
+ frozenset([0,
+ 2])]),
+ frozenset([frozenset(),
+ frozenset([1])]),
+ frozenset([frozenset(),
+ frozenset([2])])]),
+ frozenset([frozenset([]), frozenset([1])]): frozenset([frozenset([frozenset(),
+ frozenset([0])]),
+ frozenset([frozenset([1]),
+ frozenset([1,
+ 2])]),
+ frozenset([frozenset(),
+ frozenset([2])]),
+ frozenset([frozenset([1]),
+ frozenset([0,
+ 1])])]),
+ frozenset([frozenset([2]), frozenset([])]): frozenset([frozenset([frozenset([2]),
+ frozenset([1,
+ 2])]),
+ frozenset([frozenset(),
+ frozenset([0])]),
+ frozenset([frozenset(),
+ frozenset([1])]),
+ frozenset([frozenset([2]),
+ frozenset([0,
+ 2])])]),
+ frozenset([frozenset([0, 1, 2]), frozenset([0, 1])]): frozenset([frozenset([frozenset([1,
+ 2]),
+ frozenset([0,
+ 1,
+ 2])]),
+ frozenset([frozenset([0,
+ 2]),
+ frozenset([0,
+ 1,
+ 2])]),
+ frozenset([frozenset([0]),
+ frozenset([0,
+ 1])]),
+ frozenset([frozenset([1]),
+ frozenset([0,
+ 1])])]),
+ frozenset([frozenset([0]), frozenset([0, 1])]): frozenset([frozenset([frozenset(),
+ frozenset([0])]),
+ frozenset([frozenset([0,
+ 1]),
+ frozenset([0,
+ 1,
+ 2])]),
+ frozenset([frozenset([0]),
+ frozenset([0,
+ 2])]),
+ frozenset([frozenset([1]),
+ frozenset([0,
+ 1])])]),
+ frozenset([frozenset([2]), frozenset([0, 2])]): frozenset([frozenset([frozenset([0,
+ 2]),
+ frozenset([0,
+ 1,
+ 2])]),
+ frozenset([frozenset([2]),
+ frozenset([1,
+ 2])]),
+ frozenset([frozenset([0]),
+ frozenset([0,
+ 2])]),
+ frozenset([frozenset(),
+ frozenset([2])])]),
+ frozenset([frozenset([0, 1, 2]), frozenset([0, 2])]): frozenset([frozenset([frozenset([1,
+ 2]),
+ frozenset([0,
+ 1,
+ 2])]),
+ frozenset([frozenset([0,
+ 1]),
+ frozenset([0,
+ 1,
+ 2])]),
+ frozenset([frozenset([0]),
+ frozenset([0,
+ 2])]),
+ frozenset([frozenset([2]),
+ frozenset([0,
+ 2])])]),
+ frozenset([frozenset([1, 2]), frozenset([0, 1, 2])]): frozenset([frozenset([frozenset([0,
+ 2]),
+ frozenset([0,
+ 1,
+ 2])]),
+ frozenset([frozenset([0,
+ 1]),
+ frozenset([0,
+ 1,
+ 2])]),
+ frozenset([frozenset([2]),
+ frozenset([1,
+ 2])]),
+ frozenset([frozenset([1]),
+ frozenset([1,
+ 2])])])}"""
+
+ cubo = test.test_set.linegraph(cube)
+ self.assertEqual(pprint.pformat(cubo), cubo_repr_tgt)
+
+ def test_depth(self):
+ nested_tuple = (1, (2, (3, (4, (5, 6)))))
+ nested_dict = {1: {2: {3: {4: {5: {6: 6}}}}}}
+ nested_list = [1, [2, [3, [4, [5, [6, []]]]]]]
+ self.assertEqual(pprint.pformat(nested_tuple), repr(nested_tuple))
+ self.assertEqual(pprint.pformat(nested_dict), repr(nested_dict))
+ self.assertEqual(pprint.pformat(nested_list), repr(nested_list))
+
+ lv1_tuple = '(1, (...))'
+ lv1_dict = '{1: {...}}'
+ lv1_list = '[1, [...]]'
+ self.assertEqual(pprint.pformat(nested_tuple, depth=1), lv1_tuple)
+ self.assertEqual(pprint.pformat(nested_dict, depth=1), lv1_dict)
+ self.assertEqual(pprint.pformat(nested_list, depth=1), lv1_list)
+
class DottedPrettyPrinter(pprint.PrettyPrinter):
diff --git a/Lib/test/test_profilehooks.py b/Lib/test/test_profilehooks.py
--- a/Lib/test/test_profilehooks.py
+++ b/Lib/test/test_profilehooks.py
@@ -10,6 +10,22 @@
from test import test_support
+class TestGetProfile(unittest.TestCase):
+ def setUp(self):
+ sys.setprofile(None)
+
+ def tearDown(self):
+ sys.setprofile(None)
+
+ def test_empty(self):
+ assert sys.getprofile() == None
+
+ def test_setget(self):
+ def fn(*args):
+ pass
+
+ sys.setprofile(fn)
+ assert sys.getprofile() == fn
class HookWatcher:
def __init__(self):
@@ -100,7 +116,7 @@
def test_exception(self):
def f(p):
- 1/0
+ 1./0
f_ident = ident(f)
self.check_events(f, [(1, 'call', f_ident),
(1, 'return', f_ident),
@@ -108,7 +124,7 @@
def test_caught_exception(self):
def f(p):
- try: 1/0
+ try: 1./0
except: pass
f_ident = ident(f)
self.check_events(f, [(1, 'call', f_ident),
@@ -117,7 +133,7 @@
def test_caught_nested_exception(self):
def f(p):
- try: 1/0
+ try: 1./0
except: pass
f_ident = ident(f)
self.check_events(f, [(1, 'call', f_ident),
@@ -126,7 +142,7 @@
def test_nested_exception(self):
def f(p):
- 1/0
+ 1./0
f_ident = ident(f)
self.check_events(f, [(1, 'call', f_ident),
# This isn't what I expected:
@@ -137,7 +153,7 @@
def test_exception_in_except_clause(self):
def f(p):
- 1/0
+ 1./0
def g(p):
try:
f(p)
@@ -156,7 +172,7 @@
def test_exception_propogation(self):
def f(p):
- 1/0
+ 1./0
def g(p):
try: f(p)
finally: p.add_event("falling through")
@@ -171,8 +187,8 @@
def test_raise_twice(self):
def f(p):
- try: 1/0
- except: 1/0
+ try: 1./0
+ except: 1./0
f_ident = ident(f)
self.check_events(f, [(1, 'call', f_ident),
(1, 'return', f_ident),
@@ -180,7 +196,7 @@
def test_raise_reraise(self):
def f(p):
- try: 1/0
+ try: 1./0
except: raise
f_ident = ident(f)
self.check_events(f, [(1, 'call', f_ident),
@@ -197,7 +213,7 @@
def test_distant_exception(self):
def f():
- 1/0
+ 1./0
def g():
f()
def h():
@@ -282,7 +298,7 @@
def test_basic_exception(self):
def f(p):
- 1/0
+ 1./0
f_ident = ident(f)
self.check_events(f, [(1, 'call', f_ident),
(1, 'return', f_ident),
@@ -290,7 +306,7 @@
def test_caught_exception(self):
def f(p):
- try: 1/0
+ try: 1./0
except: pass
f_ident = ident(f)
self.check_events(f, [(1, 'call', f_ident),
@@ -299,7 +315,7 @@
def test_distant_exception(self):
def f():
- 1/0
+ 1./0
def g():
f()
def h():
@@ -379,6 +395,7 @@
del ProfileSimulatorTestCase.test_distant_exception
test_support.run_unittest(
+ TestGetProfile,
ProfileHookTestCase,
ProfileSimulatorTestCase
)
diff --git a/Lib/test/test_random.py b/Lib/test/test_random.py
--- a/Lib/test/test_random.py
+++ b/Lib/test/test_random.py
@@ -5,7 +5,8 @@
import time
import pickle
import warnings
-from math import log, exp, sqrt, pi
+from math import log, exp, sqrt, pi, fsum as msum
+from functools import reduce
from test import test_support
class TestBasicOps(unittest.TestCase):
@@ -52,10 +53,11 @@
state3 = self.gen.getstate() # s/b distinct from state2
self.assertNotEqual(state2, state3)
- self.assertRaises(TypeError, self.gen.jumpahead) # needs an arg
- self.assertRaises(TypeError, self.gen.jumpahead, "ick") # wrong type
- self.assertRaises(TypeError, self.gen.jumpahead, 2.3) # wrong type
- self.assertRaises(TypeError, self.gen.jumpahead, 2, 3) # too many
+ with test_support._check_py3k_warnings(quiet=True):
+ self.assertRaises(TypeError, self.gen.jumpahead) # needs an arg
+ self.assertRaises(TypeError, self.gen.jumpahead, "ick") # wrong type
+ self.assertRaises(TypeError, self.gen.jumpahead, 2.3) # wrong type
+ self.assertRaises(TypeError, self.gen.jumpahead, 2, 3) # too many
def test_sample(self):
# For the entire allowable range of 0 <= k <= N, validate that
@@ -140,6 +142,19 @@
restoredseq = [newgen.random() for i in xrange(10)]
self.assertEqual(origseq, restoredseq)
+ def test_bug_1727780(self):
+ # verify that version-2-pickles can be loaded
+ # fine, whether they are created on 32-bit or 64-bit
+ # platforms, and that version-3-pickles load fine.
+ files = [("randv2_32.pck", 780),
+ ("randv2_64.pck", 866),
+ ("randv3.pck", 343)]
+ for file, value in files:
+ f = open(test_support.findfile(file),"rb")
+ r = pickle.load(f)
+ f.close()
+ self.assertEqual(r.randrange(1000), value)
+
class WichmannHill_TestBasicOps(TestBasicOps):
gen = random.WichmannHill()
@@ -178,10 +193,9 @@
def test_bigrand(self):
# Verify warnings are raised when randrange is too large for random()
- oldfilters = warnings.filters[:]
- warnings.filterwarnings("error", "Underlying random")
- self.assertRaises(UserWarning, self.gen.randrange, 2**60)
- warnings.filters[:] = oldfilters
+ with warnings.catch_warnings():
+ warnings.filterwarnings("error", "Underlying random")
+ self.assertRaises(UserWarning, self.gen.randrange, 2**60)
class SystemRandom_TestBasicOps(TestBasicOps):
gen = random.SystemRandom()
@@ -453,11 +467,9 @@
def gamma(z, cof=_gammacoeff, g=7):
z -= 1.0
- sum = cof[0]
- for i in xrange(1,len(cof)):
- sum += cof[i] / (z+i)
+ s = msum([cof[0]] + [cof[i] / (z+i) for i in range(1,len(cof))])
z += 0.5
- return (z+g)**z / exp(z+g) * sqrt(2*pi) * sum
+ return (z+g)**z / exp(z+g) * sqrt(2.0*pi) * s
class TestDistributions(unittest.TestCase):
def test_zeroinputs(self):
@@ -476,6 +488,7 @@
g.random = x[:].pop; g.gammavariate(1.0, 1.0)
g.random = x[:].pop; g.gammavariate(200.0, 1.0)
g.random = x[:].pop; g.betavariate(3.0, 3.0)
+ g.random = x[:].pop; g.triangular(0.0, 1.0, 1.0/3.0)
def test_avg_std(self):
# Use integration to test distribution average and standard deviation.
@@ -485,6 +498,7 @@
x = [i/float(N) for i in xrange(1,N)]
for variate, args, mu, sigmasqrd in [
(g.uniform, (1.0,10.0), (10.0+1.0)/2, (10.0-1.0)**2/12),
+ (g.triangular, (0.0, 1.0, 1.0/3.0), 4.0/9.0, 7.0/9.0/18.0),
(g.expovariate, (1.5,), 1/1.5, 1/1.5**2),
(g.paretovariate, (5.0,), 5.0/(5.0-1),
5.0/((5.0-1)**2*(5.0-2))),
diff --git a/Lib/test/test_repr.py b/Lib/test/test_repr.py
--- a/Lib/test/test_repr.py
+++ b/Lib/test/test_repr.py
@@ -8,7 +8,7 @@
import shutil
import unittest
-from test.test_support import run_unittest
+from test.test_support import run_unittest, _check_py3k_warnings
from repr import repr as r # Don't shadow builtin repr
from repr import Repr
@@ -149,7 +149,6 @@
'<built-in method split of str object at 0x'))
def test_xrange(self):
- import warnings
eq = self.assertEquals
eq(repr(xrange(1)), 'xrange(1)')
eq(repr(xrange(1, 2)), 'xrange(1, 2)')
@@ -175,7 +174,8 @@
def test_buffer(self):
# XXX doesn't test buffers with no b_base or read-write buffers (see
# bufferobject.c). The test is fairly incomplete too. Sigh.
- x = buffer('foo')
+ with _check_py3k_warnings():
+ x = buffer('foo')
self.failUnless(repr(x).startswith('<read-only buffer for 0x'))
def test_cell(self):
@@ -212,10 +212,6 @@
fp.write(text)
fp.close()
-def zap(actions, dirname, names):
- for name in names:
- actions.append(os.path.join(dirname, name))
-
class LongReprTest(unittest.TestCase):
def setUp(self):
longname = 'areallylongpackageandmodulenametotestreprtruncation'
@@ -234,7 +230,9 @@
def tearDown(self):
actions = []
- os.path.walk(self.pkgname, zap, actions)
+ for dirpath, dirnames, filenames in os.walk(self.pkgname):
+ for name in dirnames + filenames:
+ actions.append(os.path.join(dirpath, name))
actions.append(self.pkgname)
actions.sort()
actions.reverse()
diff --git a/Lib/test/test_shutil.py b/Lib/test/test_shutil.py
--- a/Lib/test/test_shutil.py
+++ b/Lib/test/test_shutil.py
@@ -66,17 +66,6 @@
self.assertRaises(OSError, shutil.rmtree, path)
os.remove(path)
- def test_dont_move_dir_in_itself(self):
- src_dir = tempfile.mkdtemp()
- try:
- dst = os.path.join(src_dir, 'foo')
- self.assertRaises(shutil.Error, shutil.move, src_dir, dst)
- finally:
- try:
- os.rmdir(src_dir)
- except:
- pass
-
def test_copytree_simple(self):
def write_data(path, data):
f = open(path, "w")
@@ -116,13 +105,91 @@
):
if os.path.exists(path):
os.remove(path)
- for path in (
- os.path.join(src_dir, 'test_dir'),
- os.path.join(dst_dir, 'test_dir'),
+ for path in (src_dir,
+ os.path.dirname(dst_dir)
):
if os.path.exists(path):
- os.removedirs(path)
+ shutil.rmtree(path)
+ def test_copytree_with_exclude(self):
+
+ def write_data(path, data):
+ f = open(path, "w")
+ f.write(data)
+ f.close()
+
+ def read_data(path):
+ f = open(path)
+ data = f.read()
+ f.close()
+ return data
+
+ # creating data
+ join = os.path.join
+ exists = os.path.exists
+ src_dir = tempfile.mkdtemp()
+ try:
+ dst_dir = join(tempfile.mkdtemp(), 'destination')
+ write_data(join(src_dir, 'test.txt'), '123')
+ write_data(join(src_dir, 'test.tmp'), '123')
+ os.mkdir(join(src_dir, 'test_dir'))
+ write_data(join(src_dir, 'test_dir', 'test.txt'), '456')
+ os.mkdir(join(src_dir, 'test_dir2'))
+ write_data(join(src_dir, 'test_dir2', 'test.txt'), '456')
+ os.mkdir(join(src_dir, 'test_dir2', 'subdir'))
+ os.mkdir(join(src_dir, 'test_dir2', 'subdir2'))
+ write_data(join(src_dir, 'test_dir2', 'subdir', 'test.txt'), '456')
+ write_data(join(src_dir, 'test_dir2', 'subdir2', 'test.py'), '456')
+
+ # testing glob-like patterns
+ try:
+ patterns = shutil.ignore_patterns('*.tmp', 'test_dir2')
+ shutil.copytree(src_dir, dst_dir, ignore=patterns)
+ # checking the result: some elements should not be copied
+ self.assertTrue(exists(join(dst_dir, 'test.txt')))
+ self.assertTrue(not exists(join(dst_dir, 'test.tmp')))
+ self.assertTrue(not exists(join(dst_dir, 'test_dir2')))
+ finally:
+ if os.path.exists(dst_dir):
+ shutil.rmtree(dst_dir)
+ try:
+ patterns = shutil.ignore_patterns('*.tmp', 'subdir*')
+ shutil.copytree(src_dir, dst_dir, ignore=patterns)
+ # checking the result: some elements should not be copied
+ self.assertTrue(not exists(join(dst_dir, 'test.tmp')))
+ self.assertTrue(not exists(join(dst_dir, 'test_dir2', 'subdir2')))
+ self.assertTrue(not exists(join(dst_dir, 'test_dir2', 'subdir')))
+ finally:
+ if os.path.exists(dst_dir):
+ shutil.rmtree(dst_dir)
+
+ # testing callable-style
+ try:
+ def _filter(src, names):
+ res = []
+ for name in names:
+ path = os.path.join(src, name)
+
+ if (os.path.isdir(path) and
+ path.split()[-1] == 'subdir'):
+ res.append(name)
+ elif os.path.splitext(path)[-1] in ('.py'):
+ res.append(name)
+ return res
+
+ shutil.copytree(src_dir, dst_dir, ignore=_filter)
+
+ # checking the result: some elements should not be copied
+ self.assertTrue(not exists(join(dst_dir, 'test_dir2', 'subdir2',
+ 'test.py')))
+ self.assertTrue(not exists(join(dst_dir, 'test_dir2', 'subdir')))
+
+ finally:
+ if os.path.exists(dst_dir):
+ shutil.rmtree(dst_dir)
+ finally:
+ shutil.rmtree(src_dir)
+ shutil.rmtree(os.path.dirname(dst_dir))
if hasattr(os, "symlink"):
def test_dont_copy_file_onto_link_to_itself(self):
@@ -153,8 +220,263 @@
except OSError:
pass
+ def test_rmtree_on_symlink(self):
+ # bug 1669.
+ os.mkdir(TESTFN)
+ try:
+ src = os.path.join(TESTFN, 'cheese')
+ dst = os.path.join(TESTFN, 'shop')
+ os.mkdir(src)
+ os.symlink(src, dst)
+ self.assertRaises(OSError, shutil.rmtree, dst)
+ finally:
+ shutil.rmtree(TESTFN, ignore_errors=True)
+
+
+class TestMove(unittest.TestCase):
+
+ def setUp(self):
+ filename = "foo"
+ self.src_dir = tempfile.mkdtemp()
+ self.dst_dir = tempfile.mkdtemp()
+ self.src_file = os.path.join(self.src_dir, filename)
+ self.dst_file = os.path.join(self.dst_dir, filename)
+ # Try to create a dir in the current directory, hoping that it is
+ # not located on the same filesystem as the system tmp dir.
+ try:
+ self.dir_other_fs = tempfile.mkdtemp(
+ dir=os.path.dirname(__file__))
+ self.file_other_fs = os.path.join(self.dir_other_fs,
+ filename)
+ except OSError:
+ self.dir_other_fs = None
+ with open(self.src_file, "wb") as f:
+ f.write("spam")
+
+ def tearDown(self):
+ for d in (self.src_dir, self.dst_dir, self.dir_other_fs):
+ try:
+ if d:
+ shutil.rmtree(d)
+ except:
+ pass
+
+ def _check_move_file(self, src, dst, real_dst):
+ contents = open(src, "rb").read()
+ shutil.move(src, dst)
+ self.assertEqual(contents, open(real_dst, "rb").read())
+ self.assertFalse(os.path.exists(src))
+
+ def _check_move_dir(self, src, dst, real_dst):
+ contents = sorted(os.listdir(src))
+ shutil.move(src, dst)
+ self.assertEqual(contents, sorted(os.listdir(real_dst)))
+ self.assertFalse(os.path.exists(src))
+
+ def test_move_file(self):
+ # Move a file to another location on the same filesystem.
+ self._check_move_file(self.src_file, self.dst_file, self.dst_file)
+
+ def test_move_file_to_dir(self):
+ # Move a file inside an existing dir on the same filesystem.
+ self._check_move_file(self.src_file, self.dst_dir, self.dst_file)
+
+ def test_move_file_other_fs(self):
+ # Move a file to an existing dir on another filesystem.
+ if not self.dir_other_fs:
+ # skip
+ return
+ self._check_move_file(self.src_file, self.file_other_fs,
+ self.file_other_fs)
+
+ def test_move_file_to_dir_other_fs(self):
+ # Move a file to another location on another filesystem.
+ if not self.dir_other_fs:
+ # skip
+ return
+ self._check_move_file(self.src_file, self.dir_other_fs,
+ self.file_other_fs)
+
+ def test_move_dir(self):
+ # Move a dir to another location on the same filesystem.
+ dst_dir = tempfile.mktemp()
+ try:
+ self._check_move_dir(self.src_dir, dst_dir, dst_dir)
+ finally:
+ try:
+ shutil.rmtree(dst_dir)
+ except:
+ pass
+
+ def test_move_dir_other_fs(self):
+ # Move a dir to another location on another filesystem.
+ if not self.dir_other_fs:
+ # skip
+ return
+ dst_dir = tempfile.mktemp(dir=self.dir_other_fs)
+ try:
+ self._check_move_dir(self.src_dir, dst_dir, dst_dir)
+ finally:
+ try:
+ shutil.rmtree(dst_dir)
+ except:
+ pass
+
+ def test_move_dir_to_dir(self):
+ # Move a dir inside an existing dir on the same filesystem.
+ self._check_move_dir(self.src_dir, self.dst_dir,
+ os.path.join(self.dst_dir, os.path.basename(self.src_dir)))
+
+ def test_move_dir_to_dir_other_fs(self):
+ # Move a dir inside an existing dir on another filesystem.
+ if not self.dir_other_fs:
+ # skip
+ return
+ self._check_move_dir(self.src_dir, self.dir_other_fs,
+ os.path.join(self.dir_other_fs, os.path.basename(self.src_dir)))
+
+ def test_existing_file_inside_dest_dir(self):
+ # A file with the same name inside the destination dir already exists.
+ with open(self.dst_file, "wb"):
+ pass
+ self.assertRaises(shutil.Error, shutil.move, self.src_file, self.dst_dir)
+
+ def test_dont_move_dir_in_itself(self):
+ # Moving a dir inside itself raises an Error.
+ dst = os.path.join(self.src_dir, "bar")
+ self.assertRaises(shutil.Error, shutil.move, self.src_dir, dst)
+
+ def test_destinsrc_false_negative(self):
+ os.mkdir(TESTFN)
+ try:
+ for src, dst in [('srcdir', 'srcdir/dest')]:
+ src = os.path.join(TESTFN, src)
+ dst = os.path.join(TESTFN, dst)
+ self.assert_(shutil.destinsrc(src, dst),
+ msg='destinsrc() wrongly concluded that '
+ 'dst (%s) is not in src (%s)' % (dst, src))
+ finally:
+ shutil.rmtree(TESTFN, ignore_errors=True)
+
+ def test_destinsrc_false_positive(self):
+ os.mkdir(TESTFN)
+ try:
+ for src, dst in [('srcdir', 'src/dest'), ('srcdir', 'srcdir.new')]:
+ src = os.path.join(TESTFN, src)
+ dst = os.path.join(TESTFN, dst)
+ self.failIf(shutil.destinsrc(src, dst),
+ msg='destinsrc() wrongly concluded that '
+ 'dst (%s) is in src (%s)' % (dst, src))
+ finally:
+ shutil.rmtree(TESTFN, ignore_errors=True)
+
+
+class TestCopyFile(unittest.TestCase):
+
+ _delete = False
+
+ class Faux(object):
+ _entered = False
+ _exited_with = None
+ _raised = False
+ def __init__(self, raise_in_exit=False, suppress_at_exit=True):
+ self._raise_in_exit = raise_in_exit
+ self._suppress_at_exit = suppress_at_exit
+ def read(self, *args):
+ return ''
+ def __enter__(self):
+ self._entered = True
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self._exited_with = exc_type, exc_val, exc_tb
+ if self._raise_in_exit:
+ self._raised = True
+ raise IOError("Cannot close")
+ return self._suppress_at_exit
+
+ def tearDown(self):
+ if self._delete:
+ del shutil.open
+
+ def _set_shutil_open(self, func):
+ shutil.open = func
+ self._delete = True
+
+ def test_w_source_open_fails(self):
+ def _open(filename, mode='r'):
+ if filename == 'srcfile':
+ raise IOError('Cannot open "srcfile"')
+ assert 0 # shouldn't reach here.
+
+ self._set_shutil_open(_open)
+
+ self.assertRaises(IOError, shutil.copyfile, 'srcfile', 'destfile')
+
+ def test_w_dest_open_fails(self):
+
+ srcfile = self.Faux()
+
+ def _open(filename, mode='r'):
+ if filename == 'srcfile':
+ return srcfile
+ if filename == 'destfile':
+ raise IOError('Cannot open "destfile"')
+ assert 0 # shouldn't reach here.
+
+ self._set_shutil_open(_open)
+
+ shutil.copyfile('srcfile', 'destfile')
+ self.assertTrue(srcfile._entered)
+ self.assertTrue(srcfile._exited_with[0] is IOError)
+ self.assertEqual(srcfile._exited_with[1].args,
+ ('Cannot open "destfile"',))
+
+ def test_w_dest_close_fails(self):
+
+ srcfile = self.Faux()
+ destfile = self.Faux(True)
+
+ def _open(filename, mode='r'):
+ if filename == 'srcfile':
+ return srcfile
+ if filename == 'destfile':
+ return destfile
+ assert 0 # shouldn't reach here.
+
+ self._set_shutil_open(_open)
+
+ shutil.copyfile('srcfile', 'destfile')
+ self.assertTrue(srcfile._entered)
+ self.assertTrue(destfile._entered)
+ self.assertTrue(destfile._raised)
+ self.assertTrue(srcfile._exited_with[0] is IOError)
+ self.assertEqual(srcfile._exited_with[1].args,
+ ('Cannot close',))
+
+ def test_w_source_close_fails(self):
+
+ srcfile = self.Faux(True)
+ destfile = self.Faux()
+
+ def _open(filename, mode='r'):
+ if filename == 'srcfile':
+ return srcfile
+ if filename == 'destfile':
+ return destfile
+ assert 0 # shouldn't reach here.
+
+ self._set_shutil_open(_open)
+
+ self.assertRaises(IOError,
+ shutil.copyfile, 'srcfile', 'destfile')
+ self.assertTrue(srcfile._entered)
+ self.assertTrue(destfile._entered)
+ self.assertFalse(destfile._raised)
+ self.assertTrue(srcfile._exited_with[0] is None)
+ self.assertTrue(srcfile._raised)
+
+
def test_main():
- test_support.run_unittest(TestShutil)
+ test_support.run_unittest(TestShutil, TestMove, TestCopyFile)
if __name__ == '__main__':
test_main()
diff --git a/Lib/test/test_tempfile.py b/Lib/test/test_tempfile.py
--- a/Lib/test/test_tempfile.py
+++ b/Lib/test/test_tempfile.py
@@ -1,6 +1,5 @@
# From Python 2.5.1
# tempfile.py unit tests.
-from __future__ import with_statement
import tempfile
import os
import sys
@@ -82,7 +81,8 @@
"gettempprefix" : 1,
"gettempdir" : 1,
"tempdir" : 1,
- "template" : 1
+ "template" : 1,
+ "SpooledTemporaryFile" : 1
}
unexp = []
@@ -128,7 +128,7 @@
if i == 20:
break
except:
- failOnException("iteration")
+ self.failOnException("iteration")
test_classes.append(test__RandomNameSequence)
@@ -150,13 +150,11 @@
# _candidate_tempdir_list contains the expected directories
# Make sure the interesting environment variables are all set.
- added = []
- try:
+ with test_support.EnvironmentVarGuard() as env:
for envname in 'TMPDIR', 'TEMP', 'TMP':
dirname = os.getenv(envname)
if not dirname:
- os.environ[envname] = os.path.abspath(envname)
- added.append(envname)
+ env.set(envname, os.path.abspath(envname))
cand = tempfile._candidate_tempdir_list()
@@ -174,9 +172,6 @@
# Not practical to try to verify the presence of OS-specific
# paths in this list.
- finally:
- for p in added:
- del os.environ[p]
test_classes.append(test__candidate_tempdir_list)
@@ -581,11 +576,12 @@
class test_NamedTemporaryFile(TC):
"""Test NamedTemporaryFile()."""
- def do_create(self, dir=None, pre="", suf=""):
+ def do_create(self, dir=None, pre="", suf="", delete=True):
if dir is None:
dir = tempfile.gettempdir()
try:
- file = tempfile.NamedTemporaryFile(dir=dir, prefix=pre, suffix=suf)
+ file = tempfile.NamedTemporaryFile(dir=dir, prefix=pre, suffix=suf,
+ delete=delete)
except:
self.failOnException("NamedTemporaryFile")
@@ -619,6 +615,22 @@
finally:
os.rmdir(dir)
+ def test_dis_del_on_close(self):
+ # Tests that delete-on-close can be disabled
+ dir = tempfile.mkdtemp()
+ tmp = None
+ try:
+ f = tempfile.NamedTemporaryFile(dir=dir, delete=False)
+ tmp = f.name
+ f.write('blat')
+ f.close()
+ self.failUnless(os.path.exists(f.name),
+ "NamedTemporaryFile %s missing after close" % f.name)
+ finally:
+ if tmp is not None:
+ os.unlink(tmp)
+ os.rmdir(dir)
+
def test_multiple_close(self):
# A NamedTemporaryFile can be closed many times without error
f = tempfile.NamedTemporaryFile()
@@ -644,6 +656,160 @@
test_classes.append(test_NamedTemporaryFile)
+class test_SpooledTemporaryFile(TC):
+ """Test SpooledTemporaryFile()."""
+
+ def do_create(self, max_size=0, dir=None, pre="", suf=""):
+ if dir is None:
+ dir = tempfile.gettempdir()
+ try:
+ file = tempfile.SpooledTemporaryFile(max_size=max_size, dir=dir, prefix=pre, suffix=suf)
+ except:
+ self.failOnException("SpooledTemporaryFile")
+
+ return file
+
+
+ def test_basic(self):
+ # SpooledTemporaryFile can create files
+ f = self.do_create()
+ self.failIf(f._rolled)
+ f = self.do_create(max_size=100, pre="a", suf=".txt")
+ self.failIf(f._rolled)
+
+ def test_del_on_close(self):
+ # A SpooledTemporaryFile is deleted when closed
+ dir = tempfile.mkdtemp()
+ try:
+ f = tempfile.SpooledTemporaryFile(max_size=10, dir=dir)
+ self.failIf(f._rolled)
+ f.write('blat ' * 5)
+ self.failUnless(f._rolled)
+ filename = f.name
+ f.close()
+ self.failIf(os.path.exists(filename),
+ "SpooledTemporaryFile %s exists after close" % filename)
+ finally:
+ os.rmdir(dir)
+
+ def test_rewrite_small(self):
+ # A SpooledTemporaryFile can be written to multiple within the max_size
+ f = self.do_create(max_size=30)
+ self.failIf(f._rolled)
+ for i in range(5):
+ f.seek(0, 0)
+ f.write('x' * 20)
+ self.failIf(f._rolled)
+
+ def test_write_sequential(self):
+ # A SpooledTemporaryFile should hold exactly max_size bytes, and roll
+ # over afterward
+ f = self.do_create(max_size=30)
+ self.failIf(f._rolled)
+ f.write('x' * 20)
+ self.failIf(f._rolled)
+ f.write('x' * 10)
+ self.failIf(f._rolled)
+ f.write('x')
+ self.failUnless(f._rolled)
+
+ def test_sparse(self):
+ # A SpooledTemporaryFile that is written late in the file will extend
+ # when that occurs
+ f = self.do_create(max_size=30)
+ self.failIf(f._rolled)
+ f.seek(100, 0)
+ self.failIf(f._rolled)
+ f.write('x')
+ self.failUnless(f._rolled)
+
+ def test_fileno(self):
+ # A SpooledTemporaryFile should roll over to a real file on fileno()
+ f = self.do_create(max_size=30)
+ self.failIf(f._rolled)
+ self.failUnless(f.fileno() > 0)
+ self.failUnless(f._rolled)
+
+ def test_multiple_close_before_rollover(self):
+ # A SpooledTemporaryFile can be closed many times without error
+ f = tempfile.SpooledTemporaryFile()
+ f.write('abc\n')
+ self.failIf(f._rolled)
+ f.close()
+ try:
+ f.close()
+ f.close()
+ except:
+ self.failOnException("close")
+
+ def test_multiple_close_after_rollover(self):
+ # A SpooledTemporaryFile can be closed many times without error
+ f = tempfile.SpooledTemporaryFile(max_size=1)
+ f.write('abc\n')
+ self.failUnless(f._rolled)
+ f.close()
+ try:
+ f.close()
+ f.close()
+ except:
+ self.failOnException("close")
+
+ def test_bound_methods(self):
+ # It should be OK to steal a bound method from a SpooledTemporaryFile
+ # and use it independently; when the file rolls over, those bound
+ # methods should continue to function
+ f = self.do_create(max_size=30)
+ read = f.read
+ write = f.write
+ seek = f.seek
+
+ write("a" * 35)
+ write("b" * 35)
+ seek(0, 0)
+ self.failUnless(read(70) == 'a'*35 + 'b'*35)
+
+ def test_context_manager_before_rollover(self):
+ # A SpooledTemporaryFile can be used as a context manager
+ with tempfile.SpooledTemporaryFile(max_size=1) as f:
+ self.failIf(f._rolled)
+ self.failIf(f.closed)
+ self.failUnless(f.closed)
+ def use_closed():
+ with f:
+ pass
+ self.failUnlessRaises(ValueError, use_closed)
+
+ def test_context_manager_during_rollover(self):
+ # A SpooledTemporaryFile can be used as a context manager
+ with tempfile.SpooledTemporaryFile(max_size=1) as f:
+ self.failIf(f._rolled)
+ f.write('abc\n')
+ f.flush()
+ self.failUnless(f._rolled)
+ self.failIf(f.closed)
+ self.failUnless(f.closed)
+ def use_closed():
+ with f:
+ pass
+ self.failUnlessRaises(ValueError, use_closed)
+
+ def test_context_manager_after_rollover(self):
+ # A SpooledTemporaryFile can be used as a context manager
+ f = tempfile.SpooledTemporaryFile(max_size=1)
+ f.write('abc\n')
+ f.flush()
+ self.failUnless(f._rolled)
+ with f:
+ self.failIf(f.closed)
+ self.failUnless(f.closed)
+ def use_closed():
+ with f:
+ pass
+ self.failUnlessRaises(ValueError, use_closed)
+
+
+test_classes.append(test_SpooledTemporaryFile)
+
class test_TemporaryFile(TC):
"""Test TemporaryFile()."""
diff --git a/Lib/test/test_time.py b/Lib/test/test_time.py
--- a/Lib/test/test_time.py
+++ b/Lib/test/test_time.py
@@ -110,15 +110,19 @@
self.assertEquals(expected, result)
def test_strptime(self):
+ # Should be able to go round-trip from strftime to strptime without
+ # throwing an exception.
tt = time.gmtime(self.t)
for directive in ('a', 'A', 'b', 'B', 'c', 'd', 'H', 'I',
'j', 'm', 'M', 'p', 'S',
'U', 'w', 'W', 'x', 'X', 'y', 'Y', 'Z', '%'):
- format = ' %' + directive
+ format = '%' + directive
+ strf_output = time.strftime(format, tt)
try:
- time.strptime(time.strftime(format, tt), format)
+ time.strptime(strf_output, format)
except ValueError:
- self.fail('conversion specifier: %r failed.' % format)
+ self.fail("conversion specifier %r failed with '%s' input." %
+ (format, strf_output))
def test_strptime_empty(self):
try:
@@ -129,6 +133,16 @@
def test_asctime(self):
time.asctime(time.gmtime(self.t))
self.assertRaises(TypeError, time.asctime, 0)
+ self.assertRaises(TypeError, time.asctime, ())
+ # XXX: Posix compiant asctime should refuse to convert
+ # year > 9999, but Linux implementation does not.
+ # self.assertRaises(ValueError, time.asctime,
+ # (12345, 1, 0, 0, 0, 0, 0, 0, 0))
+ # XXX: For now, just make sure we don't have a crash:
+ try:
+ time.asctime((12345, 1, 0, 0, 0, 0, 0, 0, 0))
+ except ValueError:
+ pass
def test_tzset(self):
if not hasattr(time, "tzset"):
diff --git a/Lib/test/test_trace.py b/Lib/test/test_trace.py
--- a/Lib/test/test_trace.py
+++ b/Lib/test/test_trace.py
@@ -22,6 +22,7 @@
import unittest
import sys
import difflib
+import gc
# A very basic example. If this fails, we're in deep trouble.
def basic():
@@ -262,6 +263,17 @@
return self.trace
class TraceTestCase(unittest.TestCase):
+
+ # Disable gc collection when tracing, otherwise the
+ # deallocators may be traced as well.
+ def setUp(self):
+ self.using_gc = gc.isenabled()
+ gc.disable()
+
+ def tearDown(self):
+ if self.using_gc:
+ gc.enable()
+
def compare_events(self, line_offset, events, expected_events):
events = [(l - line_offset, e) for (l, e) in events]
if events != expected_events:
@@ -288,6 +300,20 @@
self.compare_events(func.func_code.co_firstlineno,
tracer.events, func.events)
+ def set_and_retrieve_none(self):
+ sys.settrace(None)
+ assert sys.gettrace() is None
+
+ def set_and_retrieve_func(self):
+ def fn(*args):
+ pass
+
+ sys.settrace(fn)
+ try:
+ assert sys.gettrace() is fn
+ finally:
+ sys.settrace(None)
+
def test_01_basic(self):
self.run_test(basic)
def test_02_arigo(self):
@@ -324,7 +350,7 @@
sys.settrace(tracer.traceWithGenexp)
generator_example()
sys.settrace(None)
- self.compare_events(generator_example.func_code.co_firstlineno,
+ self.compare_events(generator_example.__code__.co_firstlineno,
tracer.events, generator_example.events)
def test_14_onliner_if(self):
@@ -393,7 +419,7 @@
we're testing, so that the 'exception' trace event fires."""
if self.raiseOnEvent == 'exception':
x = 0
- y = 1/x
+ y = 1 // x
else:
return 1
@@ -732,6 +758,23 @@
def test_19_no_jump_without_trace_function(self):
no_jump_without_trace_function()
+ def test_20_large_function(self):
+ d = {}
+ exec("""def f(output): # line 0
+ x = 0 # line 1
+ y = 1 # line 2
+ ''' # line 3
+ %s # lines 4-1004
+ ''' # line 1005
+ x += 1 # line 1006
+ output.append(x) # line 1007
+ return""" % ('\n' * 1000,), d)
+ f = d['f']
+
+ f.jump = (2, 1007)
+ f.output = [0]
+ self.run_test(f)
+
def test_main():
tests = [TraceTestCase,
RaisingTraceFuncTestCase]
diff --git a/Lib/test/test_univnewlines.py b/Lib/test/test_univnewlines.py
--- a/Lib/test/test_univnewlines.py
+++ b/Lib/test/test_univnewlines.py
@@ -6,7 +6,7 @@
from test import test_support
if not hasattr(sys.stdin, 'newlines'):
- raise unittest.SkipTest, \
+ raise test_support.TestSkipped, \
"This Python does not have universal newline support"
FATX = 'x' * (2**14)
@@ -38,8 +38,9 @@
WRITEMODE = 'wb'
def setUp(self):
- with open(test_support.TESTFN, self.WRITEMODE) as fp:
- fp.write(self.DATA)
+ fp = open(test_support.TESTFN, self.WRITEMODE)
+ fp.write(self.DATA)
+ fp.close()
def tearDown(self):
try:
@@ -48,40 +49,41 @@
pass
def test_read(self):
- with open(test_support.TESTFN, self.READMODE) as fp:
- data = fp.read()
+ fp = open(test_support.TESTFN, self.READMODE)
+ data = fp.read()
self.assertEqual(data, DATA_LF)
self.assertEqual(repr(fp.newlines), repr(self.NEWLINE))
def test_readlines(self):
- with open(test_support.TESTFN, self.READMODE) as fp:
- data = fp.readlines()
+ fp = open(test_support.TESTFN, self.READMODE)
+ data = fp.readlines()
self.assertEqual(data, DATA_SPLIT)
self.assertEqual(repr(fp.newlines), repr(self.NEWLINE))
def test_readline(self):
- with open(test_support.TESTFN, self.READMODE) as fp:
- data = []
+ fp = open(test_support.TESTFN, self.READMODE)
+ data = []
+ d = fp.readline()
+ while d:
+ data.append(d)
d = fp.readline()
- while d:
- data.append(d)
- d = fp.readline()
self.assertEqual(data, DATA_SPLIT)
self.assertEqual(repr(fp.newlines), repr(self.NEWLINE))
def test_seek(self):
- with open(test_support.TESTFN, self.READMODE) as fp:
- fp.readline()
- pos = fp.tell()
- data = fp.readlines()
- self.assertEqual(data, DATA_SPLIT[1:])
- fp.seek(pos)
- data = fp.readlines()
+ fp = open(test_support.TESTFN, self.READMODE)
+ fp.readline()
+ pos = fp.tell()
+ data = fp.readlines()
+ self.assertEqual(data, DATA_SPLIT[1:])
+ fp.seek(pos)
+ data = fp.readlines()
self.assertEqual(data, DATA_SPLIT[1:])
def test_execfile(self):
namespace = {}
- execfile(test_support.TESTFN, namespace)
+ with test_support._check_py3k_warnings():
+ execfile(test_support.TESTFN, namespace)
func = namespace['line3']
self.assertEqual(func.func_code.co_firstlineno, 3)
self.assertEqual(namespace['line4'], FATX)
@@ -106,10 +108,10 @@
DATA = DATA_CRLF
def test_tell(self):
- with open(test_support.TESTFN, self.READMODE) as fp:
- self.assertEqual(repr(fp.newlines), repr(None))
- data = fp.readline()
- pos = fp.tell()
+ fp = open(test_support.TESTFN, self.READMODE)
+ self.assertEqual(repr(fp.newlines), repr(None))
+ data = fp.readline()
+ pos = fp.tell()
self.assertEqual(repr(fp.newlines), repr(self.NEWLINE))
class TestMixedNewlines(TestGenericUnivNewlines):
diff --git a/Lib/test/test_urllib2.py b/Lib/test/test_urllib2.py
--- a/Lib/test/test_urllib2.py
+++ b/Lib/test/test_urllib2.py
@@ -1,7 +1,8 @@
import unittest
from test import test_support
-import os, socket
+import os
+import socket
import StringIO
import urllib2
@@ -20,8 +21,7 @@
# XXX Name hacking to get this to work on Windows.
fname = os.path.abspath(urllib2.__file__).replace('\\', '/')
- if fname[1:2] == ":":
- fname = fname[2:]
+
# And more hacking to get it to work on MacOS. This assumes
# urllib.pathname2url works, unfortunately...
if os.name == 'mac':
@@ -31,7 +31,11 @@
fname = os.expand(fname)
fname = fname.translate(string.maketrans("/.", "./"))
- file_url = "file://%s" % fname
+ if os.name == 'nt':
+ file_url = "file:///%s" % fname
+ else:
+ file_url = "file://%s" % fname
+
f = urllib2.urlopen(file_url)
buf = f.read()
@@ -223,8 +227,8 @@
class MockOpener:
addheaders = []
- def open(self, req, data=None):
- self.req, self.data = req, data
+ def open(self, req, data=None,timeout=socket._GLOBAL_DEFAULT_TIMEOUT):
+ self.req, self.data, self.timeout = req, data, timeout
def error(self, proto, *args):
self.proto, self.args = proto, args
@@ -260,6 +264,51 @@
def __call__(self, *args):
return self.handle(self.meth_name, self.action, *args)
+class MockHTTPResponse:
+ def __init__(self, fp, msg, status, reason):
+ self.fp = fp
+ self.msg = msg
+ self.status = status
+ self.reason = reason
+ def read(self):
+ return ''
+
+class MockHTTPClass:
+ def __init__(self):
+ self.req_headers = []
+ self.data = None
+ self.raise_on_endheaders = False
+ self._tunnel_headers = {}
+
+ def __call__(self, host, timeout=socket._GLOBAL_DEFAULT_TIMEOUT):
+ self.host = host
+ self.timeout = timeout
+ return self
+
+ def set_debuglevel(self, level):
+ self.level = level
+
+ def _set_tunnel(self, host, port=None, headers=None):
+ self._tunnel_host = host
+ self._tunnel_port = port
+ if headers:
+ self._tunnel_headers = headers
+ else:
+ self._tunnel_headers.clear()
+ def request(self, method, url, body=None, headers=None):
+ self.method = method
+ self.selector = url
+ if headers is not None:
+ self.req_headers += headers.items()
+ self.req_headers.sort()
+ if body:
+ self.data = body
+ if self.raise_on_endheaders:
+ import socket
+ raise socket.error()
+ def getresponse(self):
+ return MockHTTPResponse(MockFile(), {}, 200, "OK")
+
class MockHandler:
# useful for testing handler machinery
# see add_ordered_mock_handlers() docstring
@@ -367,6 +416,17 @@
msg = mimetools.Message(StringIO("\r\n\r\n"))
return MockResponse(200, "OK", msg, "", req.get_full_url())
+class MockHTTPSHandler(urllib2.AbstractHTTPHandler):
+ # Useful for testing the Proxy-Authorization request by verifying the
+ # properties of httpcon
+
+ def __init__(self):
+ urllib2.AbstractHTTPHandler.__init__(self)
+ self.httpconn = MockHTTPClass()
+
+ def https_open(self, req):
+ return self.do_open(self.httpconn, req)
+
class MockPasswordManager:
def add_password(self, realm, uri, user, password):
self.realm = realm
@@ -552,14 +612,15 @@
class NullFTPHandler(urllib2.FTPHandler):
def __init__(self, data): self.data = data
- def connect_ftp(self, user, passwd, host, port, dirs):
+ def connect_ftp(self, user, passwd, host, port, dirs,
+ timeout=socket._GLOBAL_DEFAULT_TIMEOUT):
self.user, self.passwd = user, passwd
self.host, self.port = host, port
self.dirs = dirs
self.ftpwrapper = MockFTPWrapper(self.data)
return self.ftpwrapper
- import ftplib, socket
+ import ftplib
data = "rheum rhaponicum"
h = NullFTPHandler(data)
o = h.parent = MockOpener()
@@ -575,7 +636,9 @@
"localhost", ftplib.FTP_PORT, "A",
[], "baz.gif", None), # XXX really this should guess image/gif
]:
- r = h.ftp_open(Request(url))
+ req = Request(url)
+ req.timeout = None
+ r = h.ftp_open(req)
# ftp authentication not yet implemented by FTPHandler
self.assert_(h.user == h.passwd == "")
self.assertEqual(h.host, socket.gethostbyname(host))
@@ -588,7 +651,7 @@
self.assertEqual(int(headers["Content-length"]), len(data))
def test_file(self):
- import time, rfc822, socket
+ import rfc822, socket
h = urllib2.FileHandler()
o = h.parent = MockOpener()
@@ -619,7 +682,7 @@
try:
data = r.read()
headers = r.info()
- newurl = r.geturl()
+ respurl = r.geturl()
finally:
r.close()
stats = os.stat(TESTFN)
@@ -630,14 +693,15 @@
self.assertEqual(headers["Content-type"], "text/plain")
self.assertEqual(headers["Content-length"], "13")
self.assertEqual(headers["Last-modified"], modified)
+ self.assertEqual(respurl, url)
for url in [
"file://localhost:80%s" % urlpath,
-# XXXX bug: these fail with socket.gaierror, should be URLError
-## "file://%s:80%s/%s" % (socket.gethostbyname('localhost'),
-## os.getcwd(), TESTFN),
-## "file://somerandomhost.ontheinternet.com%s/%s" %
-## (os.getcwd(), TESTFN),
+ "file:///file_does_not_exist.txt",
+ "file://%s:80%s/%s" % (socket.gethostbyname('localhost'),
+ os.getcwd(), TESTFN),
+ "file://somerandomhost.ontheinternet.com%s/%s" %
+ (os.getcwd(), TESTFN),
]:
try:
f = open(TESTFN, "wb")
@@ -665,6 +729,8 @@
("file://ftp.example.com///foo.txt", False),
# XXXX bug: fails with OSError, should be URLError
("file://ftp.example.com/foo.txt", False),
+ ("file://somehost//foo/something.txt", True),
+ ("file://localhost//foo/something.txt", False),
]:
req = Request(url)
try:
@@ -675,38 +741,9 @@
else:
self.assert_(o.req is req)
self.assertEqual(req.type, "ftp")
+ self.assertEqual(req.type is "ftp", ftp)
def test_http(self):
- class MockHTTPResponse:
- def __init__(self, fp, msg, status, reason):
- self.fp = fp
- self.msg = msg
- self.status = status
- self.reason = reason
- def read(self):
- return ''
- class MockHTTPClass:
- def __init__(self):
- self.req_headers = []
- self.data = None
- self.raise_on_endheaders = False
- def __call__(self, host):
- self.host = host
- return self
- def set_debuglevel(self, level):
- self.level = level
- def request(self, method, url, body=None, headers={}):
- self.method = method
- self.selector = url
- self.req_headers += headers.items()
- self.req_headers.sort()
- if body:
- self.data = body
- if self.raise_on_endheaders:
- import socket
- raise socket.error()
- def getresponse(self):
- return MockHTTPResponse(MockFile(), {}, 200, "OK")
h = urllib2.AbstractHTTPHandler()
o = h.parent = MockOpener()
@@ -714,6 +751,7 @@
url = "http://example.com/"
for method, data in [("GET", None), ("POST", "blah")]:
req = Request(url, data, {"Foo": "bar"})
+ req.timeout = None
req.add_unredirected_header("Spam", "eggs")
http = MockHTTPClass()
r = h.do_open(http, req)
@@ -767,22 +805,56 @@
self.assertEqual(req.unredirected_hdrs["Host"], "baz")
self.assertEqual(req.unredirected_hdrs["Spam"], "foo")
+ def test_http_doubleslash(self):
+ # Checks that the presence of an unnecessary double slash in a url doesn't break anything
+ # Previously, a double slash directly after the host could cause incorrect parsing of the url
+ h = urllib2.AbstractHTTPHandler()
+ o = h.parent = MockOpener()
+
+ data = ""
+ ds_urls = [
+ "http://example.com/foo/bar/baz.html",
+ "http://example.com//foo/bar/baz.html",
+ "http://example.com/foo//bar/baz.html",
+ "http://example.com/foo/bar//baz.html",
+ ]
+
+ for ds_url in ds_urls:
+ ds_req = Request(ds_url, data)
+
+ # Check whether host is determined correctly if there is no proxy
+ np_ds_req = h.do_request_(ds_req)
+ self.assertEqual(np_ds_req.unredirected_hdrs["Host"],"example.com")
+
+ # Check whether host is determined correctly if there is a proxy
+ ds_req.set_proxy("someproxy:3128",None)
+ p_ds_req = h.do_request_(ds_req)
+ self.assertEqual(p_ds_req.unredirected_hdrs["Host"],"example.com")
+
def test_errors(self):
h = urllib2.HTTPErrorProcessor()
o = h.parent = MockOpener()
url = "http://example.com/"
req = Request(url)
- # 200 OK is passed through
+ # all 2xx are passed through
r = MockResponse(200, "OK", {}, "", url)
newr = h.http_response(req, r)
self.assert_(r is newr)
self.assert_(not hasattr(o, "proto")) # o.error not called
+ r = MockResponse(202, "Accepted", {}, "", url)
+ newr = h.http_response(req, r)
+ self.assert_(r is newr)
+ self.assert_(not hasattr(o, "proto")) # o.error not called
+ r = MockResponse(206, "Partial content", {}, "", url)
+ newr = h.http_response(req, r)
+ self.assert_(r is newr)
+ self.assert_(not hasattr(o, "proto")) # o.error not called
# anything else calls o.error (and MockOpener returns None, here)
- r = MockResponse(201, "Created", {}, "", url)
+ r = MockResponse(502, "Bad gateway", {}, "", url)
self.assert_(h.http_response(req, r) is None)
self.assertEqual(o.proto, "http") # o.error called
- self.assertEqual(o.args, (req, r, 201, "Created", {}))
+ self.assertEqual(o.args, (req, r, 502, "Bad gateway", {}))
def test_cookies(self):
cj = MockCookieJar()
@@ -811,6 +883,9 @@
method = getattr(h, "http_error_%s" % code)
req = Request(from_url, data)
req.add_header("Nonsense", "viking=withhold")
+ req.timeout = socket._GLOBAL_DEFAULT_TIMEOUT
+ if data is not None:
+ req.add_header("Content-Length", str(len(data)))
req.add_unredirected_header("Spam", "spam")
try:
method(req, MockFile(), code, "Blah",
@@ -823,6 +898,13 @@
self.assertEqual(o.req.get_method(), "GET")
except AttributeError:
self.assert_(not o.req.has_data())
+
+ # now it's a GET, there should not be headers regarding content
+ # (possibly dragged from before being a POST)
+ headers = [x.lower() for x in o.req.headers]
+ self.assertTrue("content-length" not in headers)
+ self.assertTrue("content-type" not in headers)
+
self.assertEqual(o.req.headers["Nonsense"],
"viking=withhold")
self.assert_("Spam" not in o.req.headers)
@@ -830,6 +912,7 @@
# loop detection
req = Request(from_url)
+ req.timeout = socket._GLOBAL_DEFAULT_TIMEOUT
def redirect(h, req, url=to_url):
h.http_error_302(req, MockFile(), 302, "Blah",
MockHeaders({"location": url}))
@@ -839,6 +922,7 @@
# detect infinite loop redirect of a URL to itself
req = Request(from_url, origin_req_host="example.com")
count = 0
+ req.timeout = socket._GLOBAL_DEFAULT_TIMEOUT
try:
while 1:
redirect(h, req, "http://example.com/")
@@ -850,6 +934,7 @@
# detect endless non-repeating chain of redirects
req = Request(from_url, origin_req_host="example.com")
count = 0
+ req.timeout = socket._GLOBAL_DEFAULT_TIMEOUT
try:
while 1:
redirect(h, req, "http://example.com/%d" % count)
@@ -858,6 +943,28 @@
self.assertEqual(count,
urllib2.HTTPRedirectHandler.max_redirections)
+ def test_invalid_redirect(self):
+ from_url = "http://example.com/a.html"
+ valid_schemes = ['http', 'https', 'ftp']
+ invalid_schemes = ['file', 'imap', 'ldap']
+ schemeless_url = "example.com/b.html"
+ h = urllib2.HTTPRedirectHandler()
+ o = h.parent = MockOpener()
+ req = Request(from_url)
+ req.timeout = socket._GLOBAL_DEFAULT_TIMEOUT
+
+ for scheme in invalid_schemes:
+ invalid_url = scheme + '://' + schemeless_url
+ self.assertRaises(urllib2.HTTPError, h.http_error_302,
+ req, MockFile(), 302, "Security Loophole",
+ MockHeaders({"location": invalid_url}))
+
+ for scheme in valid_schemes:
+ valid_url = scheme + '://' + schemeless_url
+ h.http_error_302(req, MockFile(), 302, "That's fine",
+ MockHeaders({"location": valid_url}))
+ self.assertEqual(o.req.get_full_url(), valid_url)
+
def test_cookie_redirect(self):
# cookies shouldn't leak into redirected requests
from cookielib import CookieJar
@@ -891,13 +998,68 @@
self.assertEqual([(handlers[0], "http_open")],
[tup[0:2] for tup in o.calls])
- def test_basic_auth(self):
+ def test_proxy_no_proxy(self):
+ os.environ['no_proxy'] = 'python.org'
+ o = OpenerDirector()
+ ph = urllib2.ProxyHandler(dict(http="proxy.example.com"))
+ o.add_handler(ph)
+ req = Request("http://www.perl.org/")
+ self.assertEqual(req.get_host(), "www.perl.org")
+ r = o.open(req)
+ self.assertEqual(req.get_host(), "proxy.example.com")
+ req = Request("http://www.python.org")
+ self.assertEqual(req.get_host(), "www.python.org")
+ r = o.open(req)
+ self.assertEqual(req.get_host(), "www.python.org")
+ del os.environ['no_proxy']
+
+
+ def test_proxy_https(self):
+ o = OpenerDirector()
+ ph = urllib2.ProxyHandler(dict(https='proxy.example.com:3128'))
+ o.add_handler(ph)
+ meth_spec = [
+ [("https_open","return response")]
+ ]
+ handlers = add_ordered_mock_handlers(o, meth_spec)
+ req = Request("https://www.example.com/")
+ self.assertEqual(req.get_host(), "www.example.com")
+ r = o.open(req)
+ self.assertEqual(req.get_host(), "proxy.example.com:3128")
+ self.assertEqual([(handlers[0], "https_open")],
+ [tup[0:2] for tup in o.calls])
+
+ def test_proxy_https_proxy_authorization(self):
+ o = OpenerDirector()
+ ph = urllib2.ProxyHandler(dict(https='proxy.example.com:3128'))
+ o.add_handler(ph)
+ https_handler = MockHTTPSHandler()
+ o.add_handler(https_handler)
+ req = Request("https://www.example.com/")
+ req.add_header("Proxy-Authorization","FooBar")
+ req.add_header("User-Agent","Grail")
+ self.assertEqual(req.get_host(), "www.example.com")
+ self.assertTrue(req._tunnel_host is None)
+ r = o.open(req)
+ # Verify Proxy-Authorization gets tunneled to request.
+ # httpsconn req_headers do not have the Proxy-Authorization header but
+ # the req will have.
+ self.assertFalse(("Proxy-Authorization","FooBar") in
+ https_handler.httpconn.req_headers)
+ self.assertTrue(("User-Agent","Grail") in
+ https_handler.httpconn.req_headers)
+ self.assertFalse(req._tunnel_host is None)
+ self.assertEqual(req.get_host(), "proxy.example.com:3128")
+ self.assertEqual(req.get_header("Proxy-authorization"),"FooBar")
+
+ def test_basic_auth(self, quote_char='"'):
opener = OpenerDirector()
password_manager = MockPasswordManager()
auth_handler = urllib2.HTTPBasicAuthHandler(password_manager)
realm = "ACME Widget Store"
http_handler = MockHTTPHandler(
- 401, 'WWW-Authenticate: Basic realm="%s"\r\n\r\n' % realm)
+ 401, 'WWW-Authenticate: Basic realm=%s%s%s\r\n\r\n' %
+ (quote_char, realm, quote_char) )
opener.add_handler(auth_handler)
opener.add_handler(http_handler)
self._test_basic_auth(opener, auth_handler, "Authorization",
@@ -906,6 +1068,9 @@
"http://acme.example.com/protected",
)
+ def test_basic_auth_with_single_quoted_realm(self):
+ self.test_basic_auth(quote_char="'")
+
def test_proxy_basic_auth(self):
opener = OpenerDirector()
ph = urllib2.ProxyHandler(dict(http="proxy.example.com:3128"))
@@ -973,7 +1138,7 @@
def _test_basic_auth(self, opener, auth_handler, auth_header,
realm, http_handler, password_manager,
request_url, protected_url):
- import base64, httplib
+ import base64
user, password = "wile", "coyote"
# .add_password() fed through to password manager
@@ -996,7 +1161,8 @@
auth_hdr_value = 'Basic '+base64.encodestring(userpass).strip()
self.assertEqual(http_handler.requests[1].get_header(auth_header),
auth_hdr_value)
-
+ self.assertEqual(http_handler.requests[1].unredirected_hdrs[auth_header],
+ auth_hdr_value)
# if the password manager can't find a password, the handler won't
# handle the HTTP auth error
password_manager.user = password_manager.password = None
@@ -1005,7 +1171,6 @@
self.assertEqual(len(http_handler.requests), 1)
self.assertFalse(http_handler.requests[0].has_header(auth_header))
-
class MiscTests(unittest.TestCase):
def test_build_opener(self):
@@ -1052,6 +1217,51 @@
else:
self.assert_(False)
+class RequestTests(unittest.TestCase):
+
+ def setUp(self):
+ self.get = urllib2.Request("http://www.python.org/~jeremy/")
+ self.post = urllib2.Request("http://www.python.org/~jeremy/",
+ "data",
+ headers={"X-Test": "test"})
+
+ def test_method(self):
+ self.assertEqual("POST", self.post.get_method())
+ self.assertEqual("GET", self.get.get_method())
+
+ def test_add_data(self):
+ self.assert_(not self.get.has_data())
+ self.assertEqual("GET", self.get.get_method())
+ self.get.add_data("spam")
+ self.assert_(self.get.has_data())
+ self.assertEqual("POST", self.get.get_method())
+
+ def test_get_full_url(self):
+ self.assertEqual("http://www.python.org/~jeremy/",
+ self.get.get_full_url())
+
+ def test_selector(self):
+ self.assertEqual("/~jeremy/", self.get.get_selector())
+ req = urllib2.Request("http://www.python.org/")
+ self.assertEqual("/", req.get_selector())
+
+ def test_get_type(self):
+ self.assertEqual("http", self.get.get_type())
+
+ def test_get_host(self):
+ self.assertEqual("www.python.org", self.get.get_host())
+
+ def test_get_host_unquote(self):
+ req = urllib2.Request("http://www.%70ython.org/")
+ self.assertEqual("www.python.org", req.get_host())
+
+ def test_proxy(self):
+ self.assert_(not self.get.has_proxy())
+ self.get.set_proxy("www.perl.org", "http")
+ self.assert_(self.get.has_proxy())
+ self.assertEqual("www.python.org", self.get.get_origin_req_host())
+ self.assertEqual("www.perl.org", self.get.get_host())
+
def test_main(verbose=None):
from test import test_urllib2
@@ -1060,7 +1270,8 @@
tests = (TrivialTests,
OpenerDirectorTests,
HandlerTests,
- MiscTests)
+ MiscTests,
+ RequestTests)
test_support.run_unittest(*tests)
if __name__ == "__main__":
diff --git a/Lib/test/test_weakref.py b/Lib/test/test_weakref.py
--- a/Lib/test/test_weakref.py
+++ b/Lib/test/test_weakref.py
@@ -3,6 +3,7 @@
import unittest
import UserList
import weakref
+import operator
from test import test_support
@@ -67,10 +68,10 @@
# Live reference:
o = C()
wr = weakref.ref(o)
- `wr`
+ repr(wr)
# Dead reference:
del o
- `wr`
+ repr(wr)
def test_basic_callback(self):
self.check_basic_callback(C)
@@ -191,7 +192,8 @@
p.append(12)
self.assertEqual(len(L), 1)
self.failUnless(p, "proxy for non-empty UserList should be true")
- p[:] = [2, 3]
+ with test_support._check_py3k_warnings():
+ p[:] = [2, 3]
self.assertEqual(len(L), 2)
self.assertEqual(len(p), 2)
self.failUnless(3 in p,
@@ -205,16 +207,48 @@
## self.assertEqual(repr(L2), repr(p2))
L3 = UserList.UserList(range(10))
p3 = weakref.proxy(L3)
- self.assertEqual(L3[:], p3[:])
- self.assertEqual(L3[5:], p3[5:])
- self.assertEqual(L3[:5], p3[:5])
- self.assertEqual(L3[2:5], p3[2:5])
+ with test_support._check_py3k_warnings():
+ self.assertEqual(L3[:], p3[:])
+ self.assertEqual(L3[5:], p3[5:])
+ self.assertEqual(L3[:5], p3[:5])
+ self.assertEqual(L3[2:5], p3[2:5])
+
+ def test_proxy_unicode(self):
+ # See bug 5037
+ class C(object):
+ def __str__(self):
+ return "string"
+ def __unicode__(self):
+ return u"unicode"
+ instance = C()
+ self.assertTrue("__unicode__" in dir(weakref.proxy(instance)))
+ self.assertEqual(unicode(weakref.proxy(instance)), u"unicode")
+
+ def test_proxy_index(self):
+ class C:
+ def __index__(self):
+ return 10
+ o = C()
+ p = weakref.proxy(o)
+ self.assertEqual(operator.index(p), 10)
+
+ def test_proxy_div(self):
+ class C:
+ def __floordiv__(self, other):
+ return 42
+ def __ifloordiv__(self, other):
+ return 21
+ o = C()
+ p = weakref.proxy(o)
+ self.assertEqual(p // 5, 42)
+ p //= 5
+ self.assertEqual(p, 21)
# The PyWeakref_* C API is documented as allowing either NULL or
# None as the value for the callback, where either means "no
# callback". The "no callback" ref and proxy objects are supposed
# to be shared so long as they exist by all callers so long as
- # they are active. In Python 2.3.3 and earlier, this guaranttee
+ # they are active. In Python 2.3.3 and earlier, this guarantee
# was not honored, and was broken in different ways for
# PyWeakref_NewRef() and PyWeakref_NewProxy(). (Two tests.)
@@ -676,8 +710,16 @@
w = Target()
+ def test_init(self):
+ # Issue 3634
+ # <weakref to class>.__init__() doesn't check errors correctly
+ r = weakref.ref(Exception)
+ self.assertRaises(TypeError, r.__init__, 0, 0, 0, 0, 0)
+ # No exception should be raised here
+ gc.collect()
-class SubclassableWeakrefTestCase(unittest.TestCase):
+
+class SubclassableWeakrefTestCase(TestBase):
def test_subclass_refs(self):
class MyRef(weakref.ref):
@@ -741,6 +783,44 @@
self.assertEqual(r.meth(), "abcdef")
self.failIf(hasattr(r, "__dict__"))
+ def test_subclass_refs_with_cycle(self):
+ # Bug #3110
+ # An instance of a weakref subclass can have attributes.
+ # If such a weakref holds the only strong reference to the object,
+ # deleting the weakref will delete the object. In this case,
+ # the callback must not be called, because the ref object is
+ # being deleted.
+ class MyRef(weakref.ref):
+ pass
+
+ # Use a local callback, for "regrtest -R::"
+ # to detect refcounting problems
+ def callback(w):
+ self.cbcalled += 1
+
+ o = C()
+ r1 = MyRef(o, callback)
+ r1.o = o
+ del o
+
+ del r1 # Used to crash here
+
+ self.assertEqual(self.cbcalled, 0)
+
+ # Same test, with two weakrefs to the same object
+ # (since code paths are different)
+ o = C()
+ r1 = MyRef(o, callback)
+ r2 = MyRef(o, callback)
+ r1.r = r2
+ r2.o = o
+ del o
+ del r2
+
+ del r1 # Used to crash here
+
+ self.assertEqual(self.cbcalled, 0)
+
class Object:
def __init__(self, arg):
@@ -789,7 +869,7 @@
def test_weak_keys(self):
#
# This exercises d.copy(), d.items(), d[] = v, d[], del d[],
- # len(d), d.has_key().
+ # len(d), in d.
#
dict, objects = self.make_weak_keyed_dict()
for o in objects:
@@ -813,8 +893,8 @@
"deleting the keys did not clear the dictionary")
o = Object(42)
dict[o] = "What is the meaning of the universe?"
- self.assert_(dict.has_key(o))
- self.assert_(not dict.has_key(34))
+ self.assertTrue(o in dict)
+ self.assertTrue(34 not in dict)
def test_weak_keyed_iters(self):
dict, objects = self.make_weak_keyed_dict()
@@ -826,8 +906,7 @@
objects2 = list(objects)
for wr in refs:
ob = wr()
- self.assert_(dict.has_key(ob))
- self.assert_(ob in dict)
+ self.assertTrue(ob in dict)
self.assertEqual(ob.arg, dict[ob])
objects2.remove(ob)
self.assertEqual(len(objects2), 0)
@@ -837,8 +916,7 @@
self.assertEqual(len(list(dict.iterkeyrefs())), len(objects))
for wr in dict.iterkeyrefs():
ob = wr()
- self.assert_(dict.has_key(ob))
- self.assert_(ob in dict)
+ self.assertTrue(ob in dict)
self.assertEqual(ob.arg, dict[ob])
objects2.remove(ob)
self.assertEqual(len(objects2), 0)
@@ -951,16 +1029,16 @@
" -- value parameters must be distinct objects")
weakdict = klass()
o = weakdict.setdefault(key, value1)
- self.assert_(o is value1)
- self.assert_(weakdict.has_key(key))
- self.assert_(weakdict.get(key) is value1)
- self.assert_(weakdict[key] is value1)
+ self.assertTrue(o is value1)
+ self.assertTrue(key in weakdict)
+ self.assertTrue(weakdict.get(key) is value1)
+ self.assertTrue(weakdict[key] is value1)
o = weakdict.setdefault(key, value2)
- self.assert_(o is value1)
- self.assert_(weakdict.has_key(key))
- self.assert_(weakdict.get(key) is value1)
- self.assert_(weakdict[key] is value1)
+ self.assertTrue(o is value1)
+ self.assertTrue(key in weakdict)
+ self.assertTrue(weakdict.get(key) is value1)
+ self.assertTrue(weakdict[key] is value1)
def test_weak_valued_dict_setdefault(self):
self.check_setdefault(weakref.WeakValueDictionary,
@@ -972,24 +1050,24 @@
def check_update(self, klass, dict):
#
- # This exercises d.update(), len(d), d.keys(), d.has_key(),
+ # This exercises d.update(), len(d), d.keys(), in d,
# d.get(), d[].
#
weakdict = klass()
weakdict.update(dict)
- self.assert_(len(weakdict) == len(dict))
+ self.assertEqual(len(weakdict), len(dict))
for k in weakdict.keys():
- self.assert_(dict.has_key(k),
+ self.assertTrue(k in dict,
"mysterious new key appeared in weak dict")
v = dict.get(k)
- self.assert_(v is weakdict[k])
- self.assert_(v is weakdict.get(k))
+ self.assertTrue(v is weakdict[k])
+ self.assertTrue(v is weakdict.get(k))
for k in dict.keys():
- self.assert_(weakdict.has_key(k),
+ self.assertTrue(k in weakdict,
"original key disappeared in weak dict")
v = dict[k]
- self.assert_(v is weakdict[k])
- self.assert_(v is weakdict.get(k))
+ self.assertTrue(v is weakdict[k])
+ self.assertTrue(v is weakdict.get(k))
def test_weak_valued_dict_update(self):
self.check_update(weakref.WeakValueDictionary,
@@ -1096,7 +1174,7 @@
def _reference(self):
return self.__ref.copy()
-libreftest = """ Doctest for examples in the library reference: libweakref.tex
+libreftest = """ Doctest for examples in the library reference: weakref.rst
>>> import weakref
>>> class Dict(dict):
@@ -1199,6 +1277,7 @@
MappingTestCase,
WeakValueDictionaryTestCase,
WeakKeyDictionaryTestCase,
+ SubclassableWeakrefTestCase,
)
test_support.run_doctest(sys.modules[__name__])
diff --git a/Lib/test/test_xml_etree.py b/Lib/test/test_xml_etree.py
--- a/Lib/test/test_xml_etree.py
+++ b/Lib/test/test_xml_etree.py
@@ -2,7 +2,8 @@
# all included components work as they should. For a more extensive
# test suite, see the selftest script in the ElementTree distribution.
-import doctest, sys
+import doctest
+import sys
from test import test_support
@@ -36,7 +37,7 @@
"""
def check_method(method):
- if not callable(method):
+ if not hasattr(method, '__call__'):
print method, "not callable"
def serialize(ET, elem, encoding=None):
diff --git a/Lib/test/test_xml_etree_c.py b/Lib/test/test_xml_etree_c.py
--- a/Lib/test/test_xml_etree_c.py
+++ b/Lib/test/test_xml_etree_c.py
@@ -1,6 +1,7 @@
# xml.etree test for cElementTree
-import doctest, sys
+import doctest
+import sys
from test import test_support
@@ -34,7 +35,7 @@
"""
def check_method(method):
- if not callable(method):
+ if not hasattr(method, '__call__'):
print method, "not callable"
def serialize(ET, elem, encoding=None):
diff --git a/Lib/test/test_zlib.py b/Lib/test/test_zlib.py
--- a/Lib/test/test_zlib.py
+++ b/Lib/test/test_zlib.py
@@ -1,23 +1,9 @@
import unittest
from test import test_support
import zlib
+import binascii
import random
-
-# print test_support.TESTFN
-
-def getbuf():
- # This was in the original. Avoid non-repeatable sources.
- # Left here (unused) in case something wants to be done with it.
- import imp
- try:
- t = imp.find_module('test_zlib')
- file = t[0]
- except ImportError:
- file = open(__file__)
- buf = file.read() * 8
- file.close()
- return buf
-
+from test.test_support import precisionbigmemtest, _1G
class ChecksumTestCase(unittest.TestCase):
@@ -59,33 +45,91 @@
self.assertEqual(zlib.crc32("penguin"), zlib.crc32("penguin", 0))
self.assertEqual(zlib.adler32("penguin"),zlib.adler32("penguin",1))
+ def test_abcdefghijklmnop(self):
+ """test issue1202 compliance: signed crc32, adler32 in 2.x"""
+ foo = 'abcdefghijklmnop'
+ # explicitly test signed behavior
+ self.assertEqual(zlib.crc32(foo), -1808088941)
+ self.assertEqual(zlib.crc32('spam'), 1138425661)
+ self.assertEqual(zlib.adler32(foo+foo), -721416943)
+ self.assertEqual(zlib.adler32('spam'), 72286642)
+ def test_same_as_binascii_crc32(self):
+ foo = 'abcdefghijklmnop'
+ self.assertEqual(binascii.crc32(foo), zlib.crc32(foo))
+ self.assertEqual(binascii.crc32('spam'), zlib.crc32('spam'))
+
+ def test_negative_crc_iv_input(self):
+ # The range of valid input values for the crc state should be
+ # -2**31 through 2**32-1 to allow inputs artifically constrained
+ # to a signed 32-bit integer.
+ self.assertEqual(zlib.crc32('ham', -1), zlib.crc32('ham', 0xffffffffL))
+ self.assertEqual(zlib.crc32('spam', -3141593),
+ zlib.crc32('spam', 0xffd01027L))
+ self.assertEqual(zlib.crc32('spam', -(2**31)),
+ zlib.crc32('spam', (2**31)))
+
+ def test_decompress_badinput(self):
+ self.assertRaises(zlib.error, zlib.decompress, 'foo')
class ExceptionTestCase(unittest.TestCase):
# make sure we generate some expected errors
- def test_bigbits(self):
- # specifying total bits too large causes an error
- self.assertRaises(zlib.error,
- zlib.compress, 'ERROR', zlib.MAX_WBITS + 1)
+ def test_badlevel(self):
+ # specifying compression level out of range causes an error
+ # (but -1 is Z_DEFAULT_COMPRESSION and apparently the zlib
+ # accepts 0 too)
+ self.assertRaises(zlib.error, zlib.compress, 'ERROR', 10)
def test_badcompressobj(self):
# verify failure on building compress object with bad params
self.assertRaises(ValueError, zlib.compressobj, 1, zlib.DEFLATED, 0)
+ # specifying total bits too large causes an error
+ self.assertRaises(ValueError,
+ zlib.compressobj, 1, zlib.DEFLATED, zlib.MAX_WBITS + 1)
def test_baddecompressobj(self):
# verify failure on building decompress object with bad params
- self.assertRaises(ValueError, zlib.decompressobj, 0)
+ self.assertRaises(ValueError, zlib.decompressobj, -1)
def test_decompressobj_badflush(self):
# verify failure on calling decompressobj.flush with bad params
self.assertRaises(ValueError, zlib.decompressobj().flush, 0)
self.assertRaises(ValueError, zlib.decompressobj().flush, -1)
- def test_decompress_badinput(self):
- self.assertRaises(zlib.error, zlib.decompress, 'foo')
+class BaseCompressTestCase(object):
+ def check_big_compress_buffer(self, size, compress_func):
+ _1M = 1024 * 1024
+ fmt = "%%0%dx" % (2 * _1M)
+ # Generate 10MB worth of random, and expand it by repeating it.
+ # The assumption is that zlib's memory is not big enough to exploit
+ # such spread out redundancy.
+ data = ''.join([binascii.a2b_hex(fmt % random.getrandbits(8 * _1M))
+ for i in range(10)])
+ data = data * (size // len(data) + 1)
+ try:
+ compress_func(data)
+ finally:
+ # Release memory
+ data = None
-class CompressTestCase(unittest.TestCase):
+ def check_big_decompress_buffer(self, size, decompress_func):
+ data = 'x' * size
+ try:
+ compressed = zlib.compress(data, 1)
+ finally:
+ # Release memory
+ data = None
+ data = decompress_func(compressed)
+ # Sanity check
+ try:
+ self.assertEqual(len(data), size)
+ self.assertEqual(len(data.strip('x')), 0)
+ finally:
+ data = None
+
+
+class CompressTestCase(BaseCompressTestCase, unittest.TestCase):
# Test compression in one go (whole message compression)
def test_speech(self):
x = zlib.compress(HAMLET_SCENE)
@@ -97,10 +141,31 @@
x = zlib.compress(data)
self.assertEqual(zlib.decompress(x), data)
+ def test_incomplete_stream(self):
+ # An useful error message is given
+ x = zlib.compress(HAMLET_SCENE)
+ try:
+ zlib.decompress(x[:-1])
+ except zlib.error as e:
+ self.assertTrue(
+ "Error -5 while decompressing data: incomplete or truncated stream"
+ in str(e), str(e))
+ else:
+ self.fail("zlib.error not raised")
+ # Memory use of the following functions takes into account overallocation
+ @precisionbigmemtest(size=_1G + 1024 * 1024, memuse=3)
+ def test_big_compress_buffer(self, size):
+ compress = lambda s: zlib.compress(s, 1)
+ self.check_big_compress_buffer(size, compress)
-class CompressObjectTestCase(unittest.TestCase):
+ @precisionbigmemtest(size=_1G + 1024 * 1024, memuse=2)
+ def test_big_decompress_buffer(self, size):
+ self.check_big_decompress_buffer(size, zlib.decompress)
+
+
+class CompressObjectTestCase(BaseCompressTestCase, unittest.TestCase):
# Test compression object
def test_pair(self):
# straightforward compress/decompress objects
@@ -314,6 +379,19 @@
dco = zlib.decompressobj()
self.assertEqual(dco.flush(), "") # Returns nothing
+ def test_decompress_incomplete_stream(self):
+ # This is 'foo', deflated
+ x = 'x\x9cK\xcb\xcf\x07\x00\x02\x82\x01E'
+ # For the record
+ self.assertEqual(zlib.decompress(x), 'foo')
+ self.assertRaises(zlib.error, zlib.decompress, x[:-5])
+ # Omitting the stream end works with decompressor objects
+ # (see issue #8672).
+ dco = zlib.decompressobj()
+ y = dco.decompress(x[:-5])
+ y += dco.flush()
+ self.assertEqual(y, 'foo')
+
if hasattr(zlib.compressobj(), "copy"):
def test_compresscopy(self):
# Test copying a compression object
@@ -374,6 +452,21 @@
d.flush()
self.assertRaises(ValueError, d.copy)
+ # Memory use of the following functions takes into account overallocation
+
+ @precisionbigmemtest(size=_1G + 1024 * 1024, memuse=3)
+ def test_big_compress_buffer(self, size):
+ c = zlib.compressobj(1)
+ compress = lambda s: c.compress(s) + c.flush()
+ self.check_big_compress_buffer(size, compress)
+
+ @precisionbigmemtest(size=_1G + 1024 * 1024, memuse=2)
+ def test_big_decompress_buffer(self, size):
+ d = zlib.decompressobj()
+ decompress = lambda s: d.decompress(s) + d.flush()
+ self.check_big_decompress_buffer(size, decompress)
+
+
def genblock(seed, length, step=1024, generator=random):
"""length-byte stream of random data from a seed (in step-byte blocks)."""
if seed is not None:
@@ -473,21 +566,3 @@
if __name__ == "__main__":
test_main()
-
-def test(tests=''):
- if not tests: tests = 'o'
- testcases = []
- if 'k' in tests: testcases.append(ChecksumTestCase)
- if 'x' in tests: testcases.append(ExceptionTestCase)
- if 'c' in tests: testcases.append(CompressTestCase)
- if 'o' in tests: testcases.append(CompressObjectTestCase)
- test_support.run_unittest(*testcases)
-
-if False:
- import sys
- sys.path.insert(1, '/Py23Src/python/dist/src/Lib/test')
- import test_zlib as tz
- ts, ut = tz.test_support, tz.unittest
- su = ut.TestSuite()
- su.addTest(ut.makeSuite(tz.CompressTestCase))
- ts.run_suite(su)
diff --git a/Lib/timeit.py b/Lib/timeit.py
--- a/Lib/timeit.py
+++ b/Lib/timeit.py
@@ -9,7 +9,7 @@
Library usage: see the Timer class.
Command line usage:
- python timeit.py [-n N] [-r N] [-s S] [-t] [-c] [-h] [statement]
+ python timeit.py [-n N] [-r N] [-s S] [-t] [-c] [-h] [--] [statement]
Options:
-n/--number N: how many times to execute 'statement' (default: see below)
@@ -19,6 +19,7 @@
-c/--clock: use time.clock() (default on Windows)
-v/--verbose: print raw timing results; repeat for more digits precision
-h/--help: print this usage message and exit
+ --: separate options from statement, use when statement starts with -
statement: statement to be timed (default 'pass')
A multi-line statement may be given by specifying each line as a
@@ -90,6 +91,17 @@
"""Helper to reindent a multi-line statement."""
return src.replace("\n", "\n" + " "*indent)
+def _template_func(setup, func):
+ """Create a timer function. Used if the "statement" is a callable."""
+ def inner(_it, _timer, _func=func):
+ setup()
+ _t0 = _timer()
+ for _i in _it:
+ _func()
+ _t1 = _timer()
+ return _t1 - _t0
+ return inner
+
class Timer:
"""Class for timing execution speed of small code snippets.
@@ -109,14 +121,32 @@
def __init__(self, stmt="pass", setup="pass", timer=default_timer):
"""Constructor. See class doc string."""
self.timer = timer
- stmt = reindent(stmt, 8)
- setup = reindent(setup, 4)
- src = template % {'stmt': stmt, 'setup': setup}
- self.src = src # Save for traceback display
- code = compile(src, dummy_src_name, "exec")
ns = {}
- exec code in globals(), ns
- self.inner = ns["inner"]
+ if isinstance(stmt, basestring):
+ stmt = reindent(stmt, 8)
+ if isinstance(setup, basestring):
+ setup = reindent(setup, 4)
+ src = template % {'stmt': stmt, 'setup': setup}
+ elif callable(setup):
+ src = template % {'stmt': stmt, 'setup': '_setup()'}
+ ns['_setup'] = setup
+ else:
+ raise ValueError("setup is neither a string nor callable")
+ self.src = src # Save for traceback display
+ code = compile(src, dummy_src_name, "exec")
+ exec code in globals(), ns
+ self.inner = ns["inner"]
+ elif callable(stmt):
+ self.src = None
+ if isinstance(setup, basestring):
+ _setup = setup
+ def setup():
+ exec _setup in globals(), ns
+ elif not callable(setup):
+ raise ValueError("setup is neither a string nor callable")
+ self.inner = _template_func(setup, stmt)
+ else:
+ raise ValueError("stmt is neither a string nor callable")
def print_exc(self, file=None):
"""Helper to print a traceback from the timed code.
@@ -136,10 +166,13 @@
sent; it defaults to sys.stderr.
"""
import linecache, traceback
- linecache.cache[dummy_src_name] = (len(self.src),
- None,
- self.src.split("\n"),
- dummy_src_name)
+ if self.src is not None:
+ linecache.cache[dummy_src_name] = (len(self.src),
+ None,
+ self.src.split("\n"),
+ dummy_src_name)
+ # else the source is already stored somewhere else
+
traceback.print_exc(file=file)
def timeit(self, number=default_number):
@@ -192,6 +225,16 @@
r.append(t)
return r
+def timeit(stmt="pass", setup="pass", timer=default_timer,
+ number=default_number):
+ """Convenience function to create Timer object and call timeit method."""
+ return Timer(stmt, setup, timer).timeit(number)
+
+def repeat(stmt="pass", setup="pass", timer=default_timer,
+ repeat=default_repeat, number=default_number):
+ """Convenience function to create Timer object and call repeat method."""
+ return Timer(stmt, setup, timer).repeat(repeat, number)
+
def main(args=None):
"""Main program, used when run as a script.
diff --git a/Lib/types.py b/Lib/types.py
--- a/Lib/types.py
+++ b/Lib/types.py
@@ -49,10 +49,9 @@
# Execution in restricted environment
pass
-def g():
+def _g():
yield 1
-GeneratorType = type(g())
-del g
+GeneratorType = type(_g())
class _C:
def _m(self): pass
@@ -90,4 +89,8 @@
DictProxyType = type(TypeType.__dict__)
NotImplementedType = type(NotImplemented)
-del sys, _f, _C, _x # Not for export
+# For Jython, the following two types are identical
+GetSetDescriptorType = type(FunctionType.func_code)
+MemberDescriptorType = type(FunctionType.func_globals)
+
+del sys, _f, _g, _C, _x # Not for export
diff --git a/Lib/weakref.py b/Lib/weakref.py
--- a/Lib/weakref.py
+++ b/Lib/weakref.py
@@ -2,7 +2,7 @@
This module is an implementation of PEP 205:
-http://python.sourceforge.net/peps/pep-0205.html
+http://www.python.org/dev/peps/pep-0205/
"""
# Naming convention: Variables named "wr" are weak reference objects;
@@ -26,7 +26,7 @@
ProxyTypes = (ProxyType, CallableProxyType)
__all__ = ["ref", "proxy", "getweakrefcount", "getweakrefs",
- "WeakKeyDictionary", "ReferenceType", "ProxyType",
+ "WeakKeyDictionary", "ReferenceError", "ReferenceType", "ProxyType",
"CallableProxyType", "ProxyTypes", "WeakValueDictionary"]
diff --git a/Lib/zipfile.py b/Lib/zipfile.py
--- a/Lib/zipfile.py
+++ b/Lib/zipfile.py
@@ -1,13 +1,15 @@
"""
Read and write ZIP files.
"""
-import struct, os, time, sys
-import binascii, cStringIO
+import struct, os, time, sys, shutil
+import binascii, cStringIO, stat
try:
import zlib # We may need its compression method
+ crc32 = zlib.crc32
except ImportError:
zlib = None
+ crc32 = binascii.crc32
__all__ = ["BadZipfile", "error", "ZIP_STORED", "ZIP_DEFLATED", "is_zipfile",
"ZipInfo", "ZipFile", "PyZipFile", "LargeZipFile" ]
@@ -26,32 +28,52 @@
error = BadZipfile # The exception raised by this module
-ZIP64_LIMIT= (1 << 31) - 1
+ZIP64_LIMIT = (1 << 31) - 1
+ZIP_FILECOUNT_LIMIT = 1 << 16
+ZIP_MAX_COMMENT = (1 << 16) - 1
# constants for Zip file compression methods
ZIP_STORED = 0
ZIP_DEFLATED = 8
# Other ZIP compression methods not supported
-# Here are some struct module formats for reading headers
-structEndArchive = "<4s4H2LH" # 9 items, end of archive, 22 bytes
-stringEndArchive = "PK\005\006" # magic number for end of archive record
-structCentralDir = "<4s4B4HlLL5HLL"# 19 items, central directory, 46 bytes
-stringCentralDir = "PK\001\002" # magic number for central directory
-structFileHeader = "<4s2B4HlLL2H" # 12 items, file header record, 30 bytes
-stringFileHeader = "PK\003\004" # magic number for file header
-structEndArchive64Locator = "<4slql" # 4 items, locate Zip64 header, 20 bytes
-stringEndArchive64Locator = "PK\x06\x07" # magic token for locator header
-structEndArchive64 = "<4sqhhllqqqq" # 10 items, end of archive (Zip64), 56 bytes
-stringEndArchive64 = "PK\x06\x06" # magic token for Zip64 header
+# Below are some formats and associated data for reading/writing headers using
+# the struct module. The names and structures of headers/records are those used
+# in the PKWARE description of the ZIP file format:
+# http://www.pkware.com/documents/casestudies/APPNOTE.TXT
+# (URL valid as of January 2008)
+# The "end of central directory" structure, magic number, size, and indices
+# (section V.I in the format document)
+structEndArchive = "<4s4H2LH"
+stringEndArchive = "PK\005\006"
+sizeEndCentDir = struct.calcsize(structEndArchive)
+
+_ECD_SIGNATURE = 0
+_ECD_DISK_NUMBER = 1
+_ECD_DISK_START = 2
+_ECD_ENTRIES_THIS_DISK = 3
+_ECD_ENTRIES_TOTAL = 4
+_ECD_SIZE = 5
+_ECD_OFFSET = 6
+_ECD_COMMENT_SIZE = 7
+# These last two indices are not part of the structure as defined in the
+# spec, but they are used internally by this module as a convenience
+_ECD_COMMENT = 8
+_ECD_LOCATION = 9
+
+# The "central directory" structure, magic number, size, and indices
+# of entries in the structure (section V.F in the format document)
+structCentralDir = "<4s4B4HL2L5H2L"
+stringCentralDir = "PK\001\002"
+sizeCentralDir = struct.calcsize(structCentralDir)
# indexes of entries in the central directory structure
_CD_SIGNATURE = 0
_CD_CREATE_VERSION = 1
_CD_CREATE_SYSTEM = 2
_CD_EXTRACT_VERSION = 3
-_CD_EXTRACT_SYSTEM = 4 # is this meaningful?
+_CD_EXTRACT_SYSTEM = 4
_CD_FLAG_BITS = 5
_CD_COMPRESS_TYPE = 6
_CD_TIME = 7
@@ -67,10 +89,15 @@
_CD_EXTERNAL_FILE_ATTRIBUTES = 17
_CD_LOCAL_HEADER_OFFSET = 18
-# indexes of entries in the local file header structure
+# The "local file header" structure, magic number, size, and indices
+# (section V.A in the format document)
+structFileHeader = "<4s2B4HL2L2H"
+stringFileHeader = "PK\003\004"
+sizeFileHeader = struct.calcsize(structFileHeader)
+
_FH_SIGNATURE = 0
_FH_EXTRACT_VERSION = 1
-_FH_EXTRACT_SYSTEM = 2 # is this meaningful?
+_FH_EXTRACT_SYSTEM = 2
_FH_GENERAL_PURPOSE_FLAG_BITS = 3
_FH_COMPRESSION_METHOD = 4
_FH_LAST_MOD_TIME = 5
@@ -81,6 +108,28 @@
_FH_FILENAME_LENGTH = 10
_FH_EXTRA_FIELD_LENGTH = 11
+# The "Zip64 end of central directory locator" structure, magic number, and size
+structEndArchive64Locator = "<4sLQL"
+stringEndArchive64Locator = "PK\x06\x07"
+sizeEndCentDir64Locator = struct.calcsize(structEndArchive64Locator)
+
+# The "Zip64 end of central directory" record, magic number, size, and indices
+# (section V.G in the format document)
+structEndArchive64 = "<4sQ2H2L4Q"
+stringEndArchive64 = "PK\x06\x06"
+sizeEndCentDir64 = struct.calcsize(structEndArchive64)
+
+_CD64_SIGNATURE = 0
+_CD64_DIRECTORY_RECSIZE = 1
+_CD64_CREATE_VERSION = 2
+_CD64_EXTRACT_VERSION = 3
+_CD64_DISK_NUMBER = 4
+_CD64_DISK_NUMBER_START = 5
+_CD64_NUMBER_ENTRIES_THIS_DISK = 6
+_CD64_NUMBER_ENTRIES_TOTAL = 7
+_CD64_DIRECTORY_SIZE = 8
+_CD64_OFFSET_START_CENTDIR = 9
+
def is_zipfile(filename):
"""Quickly see if file is a ZIP file by checking the magic number."""
try:
@@ -97,9 +146,8 @@
"""
Read the ZIP64 end-of-archive records and use that to update endrec
"""
- locatorSize = struct.calcsize(structEndArchive64Locator)
- fpin.seek(offset - locatorSize, 2)
- data = fpin.read(locatorSize)
+ fpin.seek(offset - sizeEndCentDir64Locator, 2)
+ data = fpin.read(sizeEndCentDir64Locator)
sig, diskno, reloff, disks = struct.unpack(structEndArchive64Locator, data)
if sig != stringEndArchive64Locator:
return endrec
@@ -108,9 +156,8 @@
raise BadZipfile("zipfiles that span multiple disks are not supported")
# Assume no 'zip64 extensible data'
- endArchiveSize = struct.calcsize(structEndArchive64)
- fpin.seek(offset - locatorSize - endArchiveSize, 2)
- data = fpin.read(endArchiveSize)
+ fpin.seek(offset - sizeEndCentDir64Locator - sizeEndCentDir64, 2)
+ data = fpin.read(sizeEndCentDir64)
sig, sz, create_version, read_version, disk_num, disk_dir, \
dircount, dircount2, dirsize, diroffset = \
struct.unpack(structEndArchive64, data)
@@ -118,12 +165,13 @@
return endrec
# Update the original endrec using data from the ZIP64 record
- endrec[1] = disk_num
- endrec[2] = disk_dir
- endrec[3] = dircount
- endrec[4] = dircount2
- endrec[5] = dirsize
- endrec[6] = diroffset
+ endrec[_ECD_SIGNATURE] = sig
+ endrec[_ECD_DISK_NUMBER] = disk_num
+ endrec[_ECD_DISK_START] = disk_dir
+ endrec[_ECD_ENTRIES_THIS_DISK] = dircount
+ endrec[_ECD_ENTRIES_TOTAL] = dircount2
+ endrec[_ECD_SIZE] = dirsize
+ endrec[_ECD_OFFSET] = diroffset
return endrec
@@ -132,38 +180,57 @@
The data is a list of the nine items in the ZIP "End of central dir"
record followed by a tenth item, the file seek offset of this record."""
- fpin.seek(-22, 2) # Assume no archive comment.
- filesize = fpin.tell() + 22 # Get file size
+
+ # Determine file size
+ fpin.seek(0, 2)
+ filesize = fpin.tell()
+
+ # Check to see if this is ZIP file with no archive comment (the
+ # "end of central directory" structure should be the last item in the
+ # file if this is the case).
+ try:
+ fpin.seek(-sizeEndCentDir, 2)
+ except IOError:
+ return None
data = fpin.read()
if data[0:4] == stringEndArchive and data[-2:] == "\000\000":
+ # the signature is correct and there's no comment, unpack structure
endrec = struct.unpack(structEndArchive, data)
- endrec = list(endrec)
- endrec.append("") # Append the archive comment
- endrec.append(filesize - 22) # Append the record start offset
- if endrec[-4] == -1 or endrec[-4] == 0xffffffff:
- return _EndRecData64(fpin, -22, endrec)
- return endrec
- # Search the last END_BLOCK bytes of the file for the record signature.
- # The comment is appended to the ZIP file and has a 16 bit length.
- # So the comment may be up to 64K long. We limit the search for the
- # signature to a few Kbytes at the end of the file for efficiency.
- # also, the signature must not appear in the comment.
- END_BLOCK = min(filesize, 1024 * 4)
- fpin.seek(filesize - END_BLOCK, 0)
+ endrec=list(endrec)
+
+ # Append a blank comment and record start offset
+ endrec.append("")
+ endrec.append(filesize - sizeEndCentDir)
+
+ # Try to read the "Zip64 end of central directory" structure
+ return _EndRecData64(fpin, -sizeEndCentDir, endrec)
+
+ # Either this is not a ZIP file, or it is a ZIP file with an archive
+ # comment. Search the end of the file for the "end of central directory"
+ # record signature. The comment is the last item in the ZIP file and may be
+ # up to 64K long. It is assumed that the "end of central directory" magic
+ # number does not appear in the comment.
+ maxCommentStart = max(filesize - (1 << 16) - sizeEndCentDir, 0)
+ fpin.seek(maxCommentStart, 0)
data = fpin.read()
start = data.rfind(stringEndArchive)
- if start >= 0: # Correct signature string was found
- endrec = struct.unpack(structEndArchive, data[start:start+22])
- endrec = list(endrec)
- comment = data[start+22:]
- if endrec[7] == len(comment): # Comment length checks out
+ if start >= 0:
+ # found the magic number; attempt to unpack and interpret
+ recData = data[start:start+sizeEndCentDir]
+ endrec = list(struct.unpack(structEndArchive, recData))
+ comment = data[start+sizeEndCentDir:]
+ # check that comment length is correct
+ if endrec[_ECD_COMMENT_SIZE] == len(comment):
# Append the archive comment and start offset
endrec.append(comment)
- endrec.append(filesize - END_BLOCK + start)
- if endrec[-4] == -1 or endrec[-4] == 0xffffffff:
- return _EndRecData64(fpin, - END_BLOCK + start, endrec)
- return endrec
- return # Error, return None
+ endrec.append(maxCommentStart + start)
+
+ # Try to read the "Zip64 end of central directory" structure
+ return _EndRecData64(fpin, maxCommentStart + start - filesize,
+ endrec)
+
+ # Unable to find a valid end of central directory structure
+ return
class ZipInfo (object):
@@ -188,6 +255,7 @@
'CRC',
'compress_size',
'file_size',
+ '_raw_time',
)
def __init__(self, filename="NoName", date_time=(1980,1,1,0,0,0)):
@@ -246,34 +314,50 @@
if file_size > ZIP64_LIMIT or compress_size > ZIP64_LIMIT:
# File is larger than what fits into a 4 byte integer,
# fall back to the ZIP64 extension
- fmt = '<hhqq'
+ fmt = '<HHQQ'
extra = extra + struct.pack(fmt,
1, struct.calcsize(fmt)-4, file_size, compress_size)
- file_size = 0xffffffff # -1
- compress_size = 0xffffffff # -1
+ file_size = 0xffffffff
+ compress_size = 0xffffffff
self.extract_version = max(45, self.extract_version)
self.create_version = max(45, self.extract_version)
+ filename, flag_bits = self._encodeFilenameFlags()
header = struct.pack(structFileHeader, stringFileHeader,
- self.extract_version, self.reserved, self.flag_bits,
+ self.extract_version, self.reserved, flag_bits,
self.compress_type, dostime, dosdate, CRC,
compress_size, file_size,
- len(self.filename), len(extra))
- return header + self.filename + extra
+ len(filename), len(extra))
+ return header + filename + extra
+
+ def _encodeFilenameFlags(self):
+ if isinstance(self.filename, unicode):
+ try:
+ return self.filename.encode('ascii'), self.flag_bits
+ except UnicodeEncodeError:
+ return self.filename.encode('utf-8'), self.flag_bits | 0x800
+ else:
+ return self.filename, self.flag_bits
+
+ def _decodeFilename(self):
+ if self.flag_bits & 0x800:
+ return self.filename.decode('utf-8')
+ else:
+ return self.filename
def _decodeExtra(self):
# Try to decode the extra field.
extra = self.extra
unpack = struct.unpack
while extra:
- tp, ln = unpack('<hh', extra[:4])
+ tp, ln = unpack('<HH', extra[:4])
if tp == 1:
if ln >= 24:
- counts = unpack('<qqq', extra[4:28])
+ counts = unpack('<QQQ', extra[4:28])
elif ln == 16:
- counts = unpack('<qq', extra[4:20])
+ counts = unpack('<QQ', extra[4:20])
elif ln == 8:
- counts = unpack('<q', extra[4:12])
+ counts = unpack('<Q', extra[4:12])
elif ln == 0:
counts = ()
else:
@@ -282,15 +366,15 @@
idx = 0
# ZIP64 extension (large files and/or large archives)
- if self.file_size == -1 or self.file_size == 0xFFFFFFFFL:
+ if self.file_size in (0xffffffffffffffffL, 0xffffffffL):
self.file_size = counts[idx]
idx += 1
- if self.compress_size == -1 or self.compress_size == 0xFFFFFFFFL:
+ if self.compress_size == 0xFFFFFFFFL:
self.compress_size = counts[idx]
idx += 1
- if self.header_offset == -1 or self.header_offset == 0xffffffffL:
+ if self.header_offset == 0xffffffffL:
old = self.header_offset
self.header_offset = counts[idx]
idx+=1
@@ -298,10 +382,263 @@
extra = extra[ln+4:]
+class _ZipDecrypter:
+ """Class to handle decryption of files stored within a ZIP archive.
+
+ ZIP supports a password-based form of encryption. Even though known
+ plaintext attacks have been found against it, it is still useful
+ to be able to get data out of such a file.
+
+ Usage:
+ zd = _ZipDecrypter(mypwd)
+ plain_char = zd(cypher_char)
+ plain_text = map(zd, cypher_text)
+ """
+
+ def _GenerateCRCTable():
+ """Generate a CRC-32 table.
+
+ ZIP encryption uses the CRC32 one-byte primitive for scrambling some
+ internal keys. We noticed that a direct implementation is faster than
+ relying on binascii.crc32().
+ """
+ poly = 0xedb88320
+ table = [0] * 256
+ for i in range(256):
+ crc = i
+ for j in range(8):
+ if crc & 1:
+ crc = ((crc >> 1) & 0x7FFFFFFF) ^ poly
+ else:
+ crc = ((crc >> 1) & 0x7FFFFFFF)
+ table[i] = crc
+ return table
+ crctable = _GenerateCRCTable()
+
+ def _crc32(self, ch, crc):
+ """Compute the CRC32 primitive on one byte."""
+ return ((crc >> 8) & 0xffffff) ^ self.crctable[(crc ^ ord(ch)) & 0xff]
+
+ def __init__(self, pwd):
+ self.key0 = 305419896
+ self.key1 = 591751049
+ self.key2 = 878082192
+ for p in pwd:
+ self._UpdateKeys(p)
+
+ def _UpdateKeys(self, c):
+ self.key0 = self._crc32(c, self.key0)
+ self.key1 = (self.key1 + (self.key0 & 255)) & 4294967295
+ self.key1 = (self.key1 * 134775813 + 1) & 4294967295
+ self.key2 = self._crc32(chr((self.key1 >> 24) & 255), self.key2)
+
+ def __call__(self, c):
+ """Decrypt a single character."""
+ c = ord(c)
+ k = self.key2 | 2
+ c = c ^ (((k * (k^1)) >> 8) & 255)
+ c = chr(c)
+ self._UpdateKeys(c)
+ return c
+
+class ZipExtFile:
+ """File-like object for reading an archive member.
+ Is returned by ZipFile.open().
+ """
+
+ def __init__(self, fileobj, zipinfo, decrypt=None):
+ self.fileobj = fileobj
+ self.decrypter = decrypt
+ self.bytes_read = 0L
+ self.rawbuffer = ''
+ self.readbuffer = ''
+ self.linebuffer = ''
+ self.eof = False
+ self.univ_newlines = False
+ self.nlSeps = ("\n", )
+ self.lastdiscard = ''
+
+ self.compress_type = zipinfo.compress_type
+ self.compress_size = zipinfo.compress_size
+
+ self.closed = False
+ self.mode = "r"
+ self.name = zipinfo.filename
+
+ # read from compressed files in 64k blocks
+ self.compreadsize = 64*1024
+ if self.compress_type == ZIP_DEFLATED:
+ self.dc = zlib.decompressobj(-15)
+
+ def set_univ_newlines(self, univ_newlines):
+ self.univ_newlines = univ_newlines
+
+ # pick line separator char(s) based on universal newlines flag
+ self.nlSeps = ("\n", )
+ if self.univ_newlines:
+ self.nlSeps = ("\r\n", "\r", "\n")
+
+ def __iter__(self):
+ return self
+
+ def next(self):
+ nextline = self.readline()
+ if not nextline:
+ raise StopIteration()
+
+ return nextline
+
+ def close(self):
+ self.closed = True
+
+ def _checkfornewline(self):
+ nl, nllen = -1, -1
+ if self.linebuffer:
+ # ugly check for cases where half of an \r\n pair was
+ # read on the last pass, and the \r was discarded. In this
+ # case we just throw away the \n at the start of the buffer.
+ if (self.lastdiscard, self.linebuffer[0]) == ('\r','\n'):
+ self.linebuffer = self.linebuffer[1:]
+
+ for sep in self.nlSeps:
+ nl = self.linebuffer.find(sep)
+ if nl >= 0:
+ nllen = len(sep)
+ return nl, nllen
+
+ return nl, nllen
+
+ def readline(self, size = -1):
+ """Read a line with approx. size. If size is negative,
+ read a whole line.
+ """
+ if size < 0:
+ size = sys.maxint
+ elif size == 0:
+ return ''
+
+ # check for a newline already in buffer
+ nl, nllen = self._checkfornewline()
+
+ if nl >= 0:
+ # the next line was already in the buffer
+ nl = min(nl, size)
+ else:
+ # no line break in buffer - try to read more
+ size -= len(self.linebuffer)
+ while nl < 0 and size > 0:
+ buf = self.read(min(size, 100))
+ if not buf:
+ break
+ self.linebuffer += buf
+ size -= len(buf)
+
+ # check for a newline in buffer
+ nl, nllen = self._checkfornewline()
+
+ # we either ran out of bytes in the file, or
+ # met the specified size limit without finding a newline,
+ # so return current buffer
+ if nl < 0:
+ s = self.linebuffer
+ self.linebuffer = ''
+ return s
+
+ buf = self.linebuffer[:nl]
+ self.lastdiscard = self.linebuffer[nl:nl + nllen]
+ self.linebuffer = self.linebuffer[nl + nllen:]
+
+ # line is always returned with \n as newline char (except possibly
+ # for a final incomplete line in the file, which is handled above).
+ return buf + "\n"
+
+ def readlines(self, sizehint = -1):
+ """Return a list with all (following) lines. The sizehint parameter
+ is ignored in this implementation.
+ """
+ result = []
+ while True:
+ line = self.readline()
+ if not line: break
+ result.append(line)
+ return result
+
+ def read(self, size = None):
+ # act like file() obj and return empty string if size is 0
+ if size == 0:
+ return ''
+
+ # determine read size
+ bytesToRead = self.compress_size - self.bytes_read
+
+ # adjust read size for encrypted files since the first 12 bytes
+ # are for the encryption/password information
+ if self.decrypter is not None:
+ bytesToRead -= 12
+
+ if size is not None and size >= 0:
+ if self.compress_type == ZIP_STORED:
+ lr = len(self.readbuffer)
+ bytesToRead = min(bytesToRead, size - lr)
+ elif self.compress_type == ZIP_DEFLATED:
+ if len(self.readbuffer) > size:
+ # the user has requested fewer bytes than we've already
+ # pulled through the decompressor; don't read any more
+ bytesToRead = 0
+ else:
+ # user will use up the buffer, so read some more
+ lr = len(self.rawbuffer)
+ bytesToRead = min(bytesToRead, self.compreadsize - lr)
+
+ # avoid reading past end of file contents
+ if bytesToRead + self.bytes_read > self.compress_size:
+ bytesToRead = self.compress_size - self.bytes_read
+
+ # try to read from file (if necessary)
+ if bytesToRead > 0:
+ bytes = self.fileobj.read(bytesToRead)
+ self.bytes_read += len(bytes)
+ self.rawbuffer += bytes
+
+ # handle contents of raw buffer
+ if self.rawbuffer:
+ newdata = self.rawbuffer
+ self.rawbuffer = ''
+
+ # decrypt new data if we were given an object to handle that
+ if newdata and self.decrypter is not None:
+ newdata = ''.join(map(self.decrypter, newdata))
+
+ # decompress newly read data if necessary
+ if newdata and self.compress_type == ZIP_DEFLATED:
+ newdata = self.dc.decompress(newdata)
+ self.rawbuffer = self.dc.unconsumed_tail
+ if self.eof and len(self.rawbuffer) == 0:
+ # we're out of raw bytes (both from the file and
+ # the local buffer); flush just to make sure the
+ # decompressor is done
+ newdata += self.dc.flush()
+ # prevent decompressor from being used again
+ self.dc = None
+
+ self.readbuffer += newdata
+
+
+ # return what the user asked for
+ if size is None or len(self.readbuffer) <= size:
+ bytes = self.readbuffer
+ self.readbuffer = ''
+ else:
+ bytes = self.readbuffer[:size]
+ self.readbuffer = self.readbuffer[size:]
+
+ return bytes
+
+
class ZipFile:
""" Class with methods to open, read, write, close, list zip files.
- z = ZipFile(file, mode="r", compression=ZIP_STORED, allowZip64=True)
+ z = ZipFile(file, mode="r", compression=ZIP_STORED, allowZip64=False)
file: Either the path to the file, or a file-like object.
If it is a path, the file will be opened and closed by ZipFile.
@@ -317,8 +654,9 @@
def __init__(self, file, mode="r", compression=ZIP_STORED, allowZip64=False):
"""Open the ZIP file with mode read "r", write "w" or append "a"."""
- self._allowZip64 = allowZip64
- self._didModify = False
+ if mode not in ("r", "w", "a"):
+ raise RuntimeError('ZipFile() requires mode "r", "w", or "a"')
+
if compression == ZIP_STORED:
pass
elif compression == ZIP_DEFLATED:
@@ -327,18 +665,30 @@
"Compression requires the (missing) zlib module"
else:
raise RuntimeError, "That compression method is not supported"
+
+ self._allowZip64 = allowZip64
+ self._didModify = False
self.debug = 0 # Level of printing: 0 through 3
self.NameToInfo = {} # Find file info given name
self.filelist = [] # List of ZipInfo instances for archive
self.compression = compression # Method of compression
self.mode = key = mode.replace('b', '')[0]
+ self.pwd = None
+ self.comment = ''
# Check if we were passed a file-like object
if isinstance(file, basestring):
self._filePassed = 0
self.filename = file
modeDict = {'r' : 'rb', 'w': 'wb', 'a' : 'r+b'}
- self.fp = open(file, modeDict[mode])
+ try:
+ self.fp = open(file, modeDict[mode])
+ except IOError:
+ if mode == 'a':
+ mode = key = 'w'
+ self.fp = open(file, modeDict[mode])
+ else:
+ raise
else:
self._filePassed = 1
self.fp = file
@@ -380,18 +730,19 @@
raise BadZipfile, "File is not a zip file"
if self.debug > 1:
print endrec
- size_cd = endrec[5] # bytes in central directory
- offset_cd = endrec[6] # offset of central directory
- self.comment = endrec[8] # archive comment
- # endrec[9] is the offset of the "End of Central Dir" record
- if endrec[9] > ZIP64_LIMIT:
- x = endrec[9] - size_cd - 56 - 20
- else:
- x = endrec[9] - size_cd
+ size_cd = endrec[_ECD_SIZE] # bytes in central directory
+ offset_cd = endrec[_ECD_OFFSET] # offset of central directory
+ self.comment = endrec[_ECD_COMMENT] # archive comment
+
# "concat" is zero, unless zip was concatenated to another file
- concat = x - offset_cd
+ concat = endrec[_ECD_LOCATION] - size_cd - offset_cd
+ if endrec[_ECD_SIGNATURE] == stringEndArchive64:
+ # If Zip64 extension structures are present, account for them
+ concat -= (sizeEndCentDir64 + sizeEndCentDir64Locator)
+
if self.debug > 2:
- print "given, inferred, offset", offset_cd, x, concat
+ inferred = concat + offset_cd
+ print "given, inferred, offset", offset_cd, inferred, concat
# self.start_dir: Position of start of central directory
self.start_dir = offset_cd + concat
fp.seek(self.start_dir, 0)
@@ -399,8 +750,7 @@
fp = cStringIO.StringIO(data)
total = 0
while total < size_cd:
- centdir = fp.read(46)
- total = total + 46
+ centdir = fp.read(sizeCentralDir)
if centdir[0:4] != stringCentralDir:
raise BadZipfile, "Bad magic number for central directory"
centdir = struct.unpack(structCentralDir, centdir)
@@ -411,22 +761,27 @@
x = ZipInfo(filename)
x.extra = fp.read(centdir[_CD_EXTRA_FIELD_LENGTH])
x.comment = fp.read(centdir[_CD_COMMENT_LENGTH])
- total = (total + centdir[_CD_FILENAME_LENGTH]
- + centdir[_CD_EXTRA_FIELD_LENGTH]
- + centdir[_CD_COMMENT_LENGTH])
x.header_offset = centdir[_CD_LOCAL_HEADER_OFFSET]
(x.create_version, x.create_system, x.extract_version, x.reserved,
x.flag_bits, x.compress_type, t, d,
x.CRC, x.compress_size, x.file_size) = centdir[1:12]
x.volume, x.internal_attr, x.external_attr = centdir[15:18]
# Convert date/time code to (year, month, day, hour, min, sec)
+ x._raw_time = t
x.date_time = ( (d>>9)+1980, (d>>5)&0xF, d&0x1F,
t>>11, (t>>5)&0x3F, (t&0x1F) * 2 )
x._decodeExtra()
x.header_offset = x.header_offset + concat
+ x.filename = x._decodeFilename()
self.filelist.append(x)
self.NameToInfo[x.filename] = x
+
+ # update total bytes read from central directory
+ total = (total + sizeCentralDir + centdir[_CD_FILENAME_LENGTH]
+ + centdir[_CD_EXTRA_FIELD_LENGTH]
+ + centdir[_CD_COMMENT_LENGTH])
+
if self.debug > 2:
print "total", total
@@ -452,67 +807,174 @@
def testzip(self):
"""Read all the files and check the CRC."""
+ chunk_size = 2 ** 20
for zinfo in self.filelist:
try:
- self.read(zinfo.filename) # Check CRC-32
+ # Read by chunks, to avoid an OverflowError or a
+ # MemoryError with very large embedded files.
+ f = self.open(zinfo.filename, "r")
+ while f.read(chunk_size): # Check CRC-32
+ pass
except BadZipfile:
return zinfo.filename
-
def getinfo(self, name):
"""Return the instance of ZipInfo given 'name'."""
- return self.NameToInfo[name]
+ info = self.NameToInfo.get(name)
+ if info is None:
+ raise KeyError(
+ 'There is no item named %r in the archive' % name)
- def read(self, name):
+ return info
+
+ def setpassword(self, pwd):
+ """Set default password for encrypted files."""
+ self.pwd = pwd
+
+ def read(self, name, pwd=None):
"""Return file bytes (as a string) for name."""
- if self.mode not in ("r", "a"):
- raise RuntimeError, 'read() requires mode "r" or "a"'
+ return self.open(name, "r", pwd).read()
+
+ def open(self, name, mode="r", pwd=None):
+ """Return file-like object for 'name'."""
+ if mode not in ("r", "U", "rU"):
+ raise RuntimeError, 'open() requires mode "r", "U", or "rU"'
if not self.fp:
raise RuntimeError, \
"Attempt to read ZIP archive that was already closed"
- zinfo = self.getinfo(name)
- filepos = self.fp.tell()
- self.fp.seek(zinfo.header_offset, 0)
+ # Only open a new file for instances where we were not
+ # given a file object in the constructor
+ if self._filePassed:
+ zef_file = self.fp
+ else:
+ zef_file = open(self.filename, 'rb')
+
+ # Make sure we have an info object
+ if isinstance(name, ZipInfo):
+ # 'name' is already an info object
+ zinfo = name
+ else:
+ # Get info object for name
+ zinfo = self.getinfo(name)
+
+ zef_file.seek(zinfo.header_offset, 0)
# Skip the file header:
- fheader = self.fp.read(30)
+ fheader = zef_file.read(sizeFileHeader)
if fheader[0:4] != stringFileHeader:
raise BadZipfile, "Bad magic number for file header"
fheader = struct.unpack(structFileHeader, fheader)
- fname = self.fp.read(fheader[_FH_FILENAME_LENGTH])
+ fname = zef_file.read(fheader[_FH_FILENAME_LENGTH])
if fheader[_FH_EXTRA_FIELD_LENGTH]:
- self.fp.read(fheader[_FH_EXTRA_FIELD_LENGTH])
+ zef_file.read(fheader[_FH_EXTRA_FIELD_LENGTH])
if fname != zinfo.orig_filename:
raise BadZipfile, \
'File name in directory "%s" and header "%s" differ.' % (
zinfo.orig_filename, fname)
- bytes = self.fp.read(zinfo.compress_size)
- self.fp.seek(filepos, 0)
- if zinfo.compress_type == ZIP_STORED:
- pass
- elif zinfo.compress_type == ZIP_DEFLATED:
- if not zlib:
- raise RuntimeError, \
- "De-compression requires the (missing) zlib module"
- # zlib compress/decompress code by Jeremy Hylton of CNRI
- dc = zlib.decompressobj(-15)
- bytes = dc.decompress(bytes)
- # need to feed in unused pad byte so that zlib won't choke
- ex = dc.decompress('Z') + dc.flush()
- if ex:
- bytes = bytes + ex
+ # check for encrypted flag & handle password
+ is_encrypted = zinfo.flag_bits & 0x1
+ zd = None
+ if is_encrypted:
+ if not pwd:
+ pwd = self.pwd
+ if not pwd:
+ raise RuntimeError, "File %s is encrypted, " \
+ "password required for extraction" % name
+
+ zd = _ZipDecrypter(pwd)
+ # The first 12 bytes in the cypher stream is an encryption header
+ # used to strengthen the algorithm. The first 11 bytes are
+ # completely random, while the 12th contains the MSB of the CRC,
+ # or the MSB of the file time depending on the header type
+ # and is used to check the correctness of the password.
+ bytes = zef_file.read(12)
+ h = map(zd, bytes[0:12])
+ if zinfo.flag_bits & 0x8:
+ # compare against the file type from extended local headers
+ check_byte = (zinfo._raw_time >> 8) & 0xff
+ else:
+ # compare against the CRC otherwise
+ check_byte = (zinfo.CRC >> 24) & 0xff
+ if ord(h[11]) != check_byte:
+ raise RuntimeError("Bad password for file", name)
+
+ # build and return a ZipExtFile
+ if zd is None:
+ zef = ZipExtFile(zef_file, zinfo)
else:
- raise BadZipfile, \
- "Unsupported compression method %d for file %s" % \
- (zinfo.compress_type, name)
- crc = binascii.crc32(bytes)
- if crc != zinfo.CRC:
- raise BadZipfile, "Bad CRC-32 for file %s" % name
- return bytes
+ zef = ZipExtFile(zef_file, zinfo, zd)
+
+ # set universal newlines on ZipExtFile if necessary
+ if "U" in mode:
+ zef.set_univ_newlines(True)
+ return zef
+
+ def extract(self, member, path=None, pwd=None):
+ """Extract a member from the archive to the current working directory,
+ using its full name. Its file information is extracted as accurately
+ as possible. `member' may be a filename or a ZipInfo object. You can
+ specify a different directory using `path'.
+ """
+ if not isinstance(member, ZipInfo):
+ member = self.getinfo(member)
+
+ if path is None:
+ path = os.getcwd()
+
+ return self._extract_member(member, path, pwd)
+
+ def extractall(self, path=None, members=None, pwd=None):
+ """Extract all members from the archive to the current working
+ directory. `path' specifies a different directory to extract to.
+ `members' is optional and must be a subset of the list returned
+ by namelist().
+ """
+ if members is None:
+ members = self.namelist()
+
+ for zipinfo in members:
+ self.extract(zipinfo, path, pwd)
+
+ def _extract_member(self, member, targetpath, pwd):
+ """Extract the ZipInfo object 'member' to a physical
+ file on the path targetpath.
+ """
+ # build the destination pathname, replacing
+ # forward slashes to platform specific separators.
+ # Strip trailing path separator, unless it represents the root.
+ if (targetpath[-1:] in (os.path.sep, os.path.altsep)
+ and len(os.path.splitdrive(targetpath)[1]) > 1):
+ targetpath = targetpath[:-1]
+
+ # don't include leading "/" from file name if present
+ if member.filename[0] == '/':
+ targetpath = os.path.join(targetpath, member.filename[1:])
+ else:
+ targetpath = os.path.join(targetpath, member.filename)
+
+ targetpath = os.path.normpath(targetpath)
+
+ # Create all upper directories if necessary.
+ upperdirs = os.path.dirname(targetpath)
+ if upperdirs and not os.path.exists(upperdirs):
+ os.makedirs(upperdirs)
+
+ if member.filename[-1] == '/':
+ if not os.path.isdir(targetpath):
+ os.mkdir(targetpath)
+ return targetpath
+
+ source = self.open(member, pwd=pwd)
+ target = file(targetpath, "wb")
+ shutil.copyfileobj(source, target)
+ source.close()
+ target.close()
+
+ return targetpath
def _writecheck(self, zinfo):
"""Check for errors before writing a file to the archive."""
@@ -540,7 +1002,12 @@
def write(self, filename, arcname=None, compress_type=None):
"""Put the bytes from filename into the archive under the name
arcname."""
+ if not self.fp:
+ raise RuntimeError(
+ "Attempt to write to ZIP archive that was already closed")
+
st = os.stat(filename)
+ isdir = stat.S_ISDIR(st.st_mode)
mtime = time.localtime(st.st_mtime)
date_time = mtime[0:6]
# Create ZipInfo instance to store file information
@@ -549,6 +1016,8 @@
arcname = os.path.normpath(os.path.splitdrive(arcname)[1])
while arcname[0] in (os.sep, os.altsep):
arcname = arcname[1:]
+ if isdir:
+ arcname += '/'
zinfo = ZipInfo(arcname, date_time)
zinfo.external_attr = (st[0] & 0xFFFF) << 16L # Unix attributes
if compress_type is None:
@@ -562,6 +1031,16 @@
self._writecheck(zinfo)
self._didModify = True
+
+ if isdir:
+ zinfo.file_size = 0
+ zinfo.compress_size = 0
+ zinfo.CRC = 0
+ self.filelist.append(zinfo)
+ self.NameToInfo[zinfo.filename] = zinfo
+ self.fp.write(zinfo.FileHeader())
+ return
+
fp = open(filename, "rb")
# Must overwrite CRC and sizes with correct data later
zinfo.CRC = CRC = 0
@@ -578,7 +1057,7 @@
if not buf:
break
file_size = file_size + len(buf)
- CRC = binascii.crc32(buf, CRC)
+ CRC = crc32(buf, CRC) & 0xffffffff
if cmpr:
buf = cmpr.compress(buf)
compress_size = compress_size + len(buf)
@@ -596,7 +1075,7 @@
# Seek backwards and write CRC and file sizes
position = self.fp.tell() # Preserve current position in file
self.fp.seek(zinfo.header_offset + 14, 0)
- self.fp.write(struct.pack("<lLL", zinfo.CRC, zinfo.compress_size,
+ self.fp.write(struct.pack("<LLL", zinfo.CRC, zinfo.compress_size,
zinfo.file_size))
self.fp.seek(position, 0)
self.filelist.append(zinfo)
@@ -610,13 +1089,19 @@
zinfo = ZipInfo(filename=zinfo_or_arcname,
date_time=time.localtime(time.time())[:6])
zinfo.compress_type = self.compression
+ zinfo.external_attr = 0600 << 16
else:
zinfo = zinfo_or_arcname
+
+ if not self.fp:
+ raise RuntimeError(
+ "Attempt to write to ZIP archive that was already closed")
+
zinfo.file_size = len(bytes) # Uncompressed size
zinfo.header_offset = self.fp.tell() # Start of header bytes
self._writecheck(zinfo)
self._didModify = True
- zinfo.CRC = binascii.crc32(bytes) # CRC-32 checksum
+ zinfo.CRC = crc32(bytes) & 0xffffffff # CRC-32 checksum
if zinfo.compress_type == ZIP_DEFLATED:
co = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION,
zlib.DEFLATED, -15)
@@ -630,7 +1115,7 @@
self.fp.flush()
if zinfo.flag_bits & 0x08:
# Write CRC and file sizes after the file data
- self.fp.write(struct.pack("<lLL", zinfo.CRC, zinfo.compress_size,
+ self.fp.write(struct.pack("<LLL", zinfo.CRC, zinfo.compress_size,
zinfo.file_size))
self.filelist.append(zinfo)
self.NameToInfo[zinfo.filename] = zinfo
@@ -658,15 +1143,15 @@
or zinfo.compress_size > ZIP64_LIMIT:
extra.append(zinfo.file_size)
extra.append(zinfo.compress_size)
- file_size = 0xffffffff #-1
- compress_size = 0xffffffff #-1
+ file_size = 0xffffffff
+ compress_size = 0xffffffff
else:
file_size = zinfo.file_size
compress_size = zinfo.compress_size
if zinfo.header_offset > ZIP64_LIMIT:
extra.append(zinfo.header_offset)
- header_offset = -1 # struct "l" format: 32 one bits
+ header_offset = 0xffffffffL
else:
header_offset = zinfo.header_offset
@@ -674,7 +1159,7 @@
if extra:
# Append a ZIP64 field to the extra's
extra_data = struct.pack(
- '<hh' + 'q'*len(extra),
+ '<HH' + 'Q'*len(extra),
1, 8*len(extra), *extra) + extra_data
extract_version = max(45, zinfo.extract_version)
@@ -683,44 +1168,68 @@
extract_version = zinfo.extract_version
create_version = zinfo.create_version
- centdir = struct.pack(structCentralDir,
- stringCentralDir, create_version,
- zinfo.create_system, extract_version, zinfo.reserved,
- zinfo.flag_bits, zinfo.compress_type, dostime, dosdate,
- zinfo.CRC, compress_size, file_size,
- len(zinfo.filename), len(extra_data), len(zinfo.comment),
- 0, zinfo.internal_attr, zinfo.external_attr,
- header_offset)
+ try:
+ filename, flag_bits = zinfo._encodeFilenameFlags()
+ centdir = struct.pack(structCentralDir,
+ stringCentralDir, create_version,
+ zinfo.create_system, extract_version, zinfo.reserved,
+ flag_bits, zinfo.compress_type, dostime, dosdate,
+ zinfo.CRC, compress_size, file_size,
+ len(filename), len(extra_data), len(zinfo.comment),
+ 0, zinfo.internal_attr, zinfo.external_attr,
+ header_offset)
+ except DeprecationWarning:
+ print >>sys.stderr, (structCentralDir,
+ stringCentralDir, create_version,
+ zinfo.create_system, extract_version, zinfo.reserved,
+ zinfo.flag_bits, zinfo.compress_type, dostime, dosdate,
+ zinfo.CRC, compress_size, file_size,
+ len(zinfo.filename), len(extra_data), len(zinfo.comment),
+ 0, zinfo.internal_attr, zinfo.external_attr,
+ header_offset)
+ raise
self.fp.write(centdir)
- self.fp.write(zinfo.filename)
+ self.fp.write(filename)
self.fp.write(extra_data)
self.fp.write(zinfo.comment)
pos2 = self.fp.tell()
# Write end-of-zip-archive record
- if pos1 > ZIP64_LIMIT:
+ centDirCount = count
+ centDirSize = pos2 - pos1
+ centDirOffset = pos1
+ if (centDirCount >= ZIP_FILECOUNT_LIMIT or
+ centDirOffset > ZIP64_LIMIT or
+ centDirSize > ZIP64_LIMIT):
# Need to write the ZIP64 end-of-archive records
zip64endrec = struct.pack(
structEndArchive64, stringEndArchive64,
- 44, 45, 45, 0, 0, count, count, pos2 - pos1, pos1)
+ 44, 45, 45, 0, 0, centDirCount, centDirCount,
+ centDirSize, centDirOffset)
self.fp.write(zip64endrec)
zip64locrec = struct.pack(
structEndArchive64Locator,
stringEndArchive64Locator, 0, pos2, 1)
self.fp.write(zip64locrec)
+ centDirCount = min(centDirCount, 0xFFFF)
+ centDirSize = min(centDirSize, 0xFFFFFFFF)
+ centDirOffset = min(centDirOffset, 0xFFFFFFFF)
- # XXX Why is `pos3` computed next? It's never referenced.
- pos3 = self.fp.tell()
- endrec = struct.pack(structEndArchive, stringEndArchive,
- 0, 0, count, count, pos2 - pos1, -1, 0)
- self.fp.write(endrec)
+ # check for valid comment length
+ if len(self.comment) >= ZIP_MAX_COMMENT:
+ if self.debug > 0:
+ msg = 'Archive comment is too long; truncating to %d bytes' \
+ % ZIP_MAX_COMMENT
+ self.comment = self.comment[:ZIP_MAX_COMMENT]
- else:
- endrec = struct.pack(structEndArchive, stringEndArchive,
- 0, 0, count, count, pos2 - pos1, pos1, 0)
- self.fp.write(endrec)
+ endrec = struct.pack(structEndArchive, stringEndArchive,
+ 0, 0, centDirCount, centDirCount,
+ centDirSize, centDirOffset, len(self.comment))
+ self.fp.write(endrec)
+ self.fp.write(self.comment)
self.fp.flush()
+
if not self._filePassed:
self.fp.close()
self.fp = None
--
Repository URL: http://hg.python.org/jython
More information about the Jython-checkins
mailing list