[Numpy-svn] r5458 - trunk/numpy/core/tests
numpy-svn at scipy.org
numpy-svn at scipy.org
Sat Jul 19 17:58:21 EDT 2008
Author: ptvirtan
Date: 2008-07-19 16:58:14 -0500 (Sat, 19 Jul 2008)
New Revision: 5458
Modified:
trunk/numpy/core/tests/test_umath.py
Log:
Add tests for complex functions: test against Python's cmath, check the branch cuts and C99 compliance at inf, nan special points.
Modified: trunk/numpy/core/tests/test_umath.py
===================================================================
--- trunk/numpy/core/tests/test_umath.py 2008-07-19 21:56:46 UTC (rev 5457)
+++ trunk/numpy/core/tests/test_umath.py 2008-07-19 21:58:14 UTC (rev 5458)
@@ -1,6 +1,8 @@
from numpy.testing import *
import numpy.core.umath as ncu
import numpy as np
+import nose
+from numpy import inf, nan, pi
class TestDivision(TestCase):
def test_division_int(self):
@@ -35,7 +37,6 @@
assert_almost_equal(x**14, [-76443+16124j, 23161315+58317492j,
5583548873 + 2465133864j])
-
class TestLog1p(TestCase):
def test_log1p(self):
assert_almost_equal(ncu.log1p(0.2), ncu.log(1.2))
@@ -179,10 +180,10 @@
assert_equal(np.choose(c, (a, 1)), np.array([1,1]))
-class TestComplexFunctions(TestCase):
+class TestComplexFunctions(object):
funcs = [np.arcsin , np.arccos , np.arctan, np.arcsinh, np.arccosh,
np.arctanh, np.sin , np.cos , np.tan , np.exp,
- np.log , np.sqrt , np.log10]
+ np.log , np.sqrt , np.log10, np.log1p]
def test_it(self):
for f in self.funcs:
@@ -204,7 +205,206 @@
assert_almost_equal(fcf, fcd, decimal=6, err_msg='fch-fcd %s'%f)
assert_almost_equal(fcl, fcd, decimal=15, err_msg='fch-fcl %s'%f)
+ def test_branch_cuts(self):
+ # check branch cuts and continuity on them
+ yield _check_branch_cut, np.log, -0.5, 1j, 1, -1, True
+ yield _check_branch_cut, np.log10, -0.5, 1j, 1, -1, True
+ yield _check_branch_cut, np.log1p, -1.5, 1j, 1, -1, True
+ yield _check_branch_cut, np.sqrt, -0.5, 1j, 1, -1
+
+ yield _check_branch_cut, np.arcsin, [ -2, 2], [1j, -1j], 1, -1
+ yield _check_branch_cut, np.arccos, [ -2, 2], [1j, -1j], 1, -1
+ yield _check_branch_cut, np.arctan, [-2j, 2j], [1, -1 ], -1, 1
+
+ yield _check_branch_cut, np.arcsinh, [-2j, 2j], [-1, 1], -1, 1
+ yield _check_branch_cut, np.arccosh, [ -1, 0.5], [1j, 1j], 1, -1
+ yield _check_branch_cut, np.arctanh, [ -2, 2], [1j, -1j], 1, -1
+ # check against bogus branch cuts: assert continuity between quadrants
+ yield _check_branch_cut, np.arcsin, [-2j, 2j], [ 1, 1], 1, 1
+ yield _check_branch_cut, np.arccos, [-2j, 2j], [ 1, 1], 1, 1
+ yield _check_branch_cut, np.arctan, [ -2, 2], [1j, 1j], 1, 1
+
+ yield _check_branch_cut, np.arcsinh, [ -2, 2, 0], [1j, 1j, 1 ], 1, 1
+ yield _check_branch_cut, np.arccosh, [-2j, 2j, 2], [1, 1, 1j], 1, 1
+ yield _check_branch_cut, np.arctanh, [-2j, 2j, 0], [1, 1, 1j], 1, 1
+
+ def test_branch_cuts_failing(self):
+ # XXX: signed zeros are not OK for sqrt or for the arc* functions
+ yield _check_branch_cut, np.sqrt, -0.5, 1j, 1, -1, True
+ yield _check_branch_cut, np.arcsin, [ -2, 2], [1j, -1j], 1, -1, True
+ yield _check_branch_cut, np.arccos, [ -2, 2], [1j, -1j], 1, -1, True
+ yield _check_branch_cut, np.arctan, [-2j, 2j], [1, -1 ], -1, 1, True
+ yield _check_branch_cut, np.arcsinh, [-2j, 2j], [-1, 1], -1, 1, True
+ yield _check_branch_cut, np.arccosh, [ -1, 0.5], [1j, 1j], 1, -1, True
+ yield _check_branch_cut, np.arctanh, [ -2, 2], [1j, -1j], 1, -1, True
+ test_branch_cuts_failing = dec.skipknownfailure(test_branch_cuts_failing)
+
+ def test_against_cmath(self):
+ import cmath, sys
+
+ # cmath.asinh is broken in some versions of Python, see
+ # http://bugs.python.org/issue1381
+ broken_cmath_asinh = False
+ if sys.version_info < (2,5,3):
+ broken_cmath_asinh = True
+
+ points = [-2, 2j, 2, -2j, -1-1j, -1+1j, +1-1j, +1+1j]
+ name_map = {'arcsin': 'asin', 'arccos': 'acos', 'arctan': 'atan',
+ 'arcsinh': 'asinh', 'arccosh': 'acosh', 'arctanh': 'atanh'}
+ atol = 4*np.finfo(np.complex).eps
+ for func in self.funcs:
+ fname = func.__name__.split('.')[-1]
+ cname = name_map.get(fname, fname)
+ try: cfunc = getattr(cmath, cname)
+ except AttributeError: continue
+ for p in points:
+ a = complex(func(np.complex_(p)))
+ b = cfunc(p)
+
+ if cname == 'asinh' and broken_cmath_asinh:
+ continue
+
+ assert abs(a - b) < atol, "%s %s: %s; cmath: %s"%(fname,p,a,b)
+
+class TestC99(object):
+ """Check special functions at special points against the C99 standard"""
+ # NB: inherits from object instead of TestCase since using test generators
+
+ #
+ # Non-conforming results are with XXX added to the exception field.
+ #
+
+ def test_clog(self):
+ for p, v, e in [
+ ((-0., 0.), (-inf, pi), 'divide'),
+ ((+0., 0.), (-inf, 0.), 'divide'),
+ ((1., inf), (inf, pi/2), ''),
+ ((1., nan), (nan, nan), ''),
+ ((-inf, 1.), (inf, pi), ''),
+ ((inf, 1.), (inf, 0.), ''),
+ ((-inf, inf), (inf, 3*pi/4), ''),
+ ((inf, inf), (inf, pi/4), ''),
+ ((inf, nan), (inf, nan), ''),
+ ((-inf, nan), (inf, nan), ''),
+ ((nan, 0.), (nan, nan), ''),
+ ((nan, 1.), (nan, nan), ''),
+ ((nan, inf), (inf, nan), ''),
+ ((+nan, nan), (nan, nan), ''),
+ ]:
+ yield self._check, np.log, p, v, e
+
+ def test_csqrt(self):
+ for p, v, e in [
+ ((-0., 0.), (0.,0.), 'XXX'), # now (-0., 0.)
+ ((0., 0.), (0.,0.), ''),
+ ((1., inf), (inf,inf), 'XXX invalid'), # now (inf, nan)
+ ((nan, inf), (inf,inf), 'XXX'), # now (nan, nan)
+ ((-inf, 1.), (0.,inf), ''),
+ ((inf, 1.), (inf,0.), ''),
+ ((-inf,nan), (nan, -inf), ''), # could also be +inf
+ ((inf, nan), (inf, nan), ''),
+ ((nan, 1.), (nan, nan), ''),
+ ((nan, nan), (nan, nan), ''),
+ ]:
+ yield self._check, np.sqrt, p, v, e
+
+ def test_cacos(self):
+ for p, v, e in [
+ ((0., 0.), (pi/2, -0.), 'XXX'), # now (-0., 0.)
+ ((-0., 0.), (pi/2, -0.), ''),
+ ((0., nan), (pi/2, nan), 'XXX'), # now (nan, nan)
+ ((-0., nan), (pi/2, nan), 'XXX'), # now (nan, nan)
+ ((1., inf), (pi/2, -inf), 'XXX'), # now (nan, -inf)
+ ((1., nan), (nan, nan), ''),
+ ((-inf, 1.), (pi, -inf), 'XXX'), # now (nan, -inf)
+ ((inf, 1.), (0., -inf), 'XXX'), # now (nan, -inf)
+ ((-inf, inf), (3*pi/4, -inf), 'XXX'), # now (nan, nan)
+ ((inf, inf), (pi/4, -inf), 'XXX'), # now (nan, nan)
+ ((inf, nan), (nan, +-inf), 'XXX'), # now (nan, nan)
+ ((-inf, nan), (nan, +-inf), 'XXX'), # now: (nan, nan)
+ ((nan, 1.), (nan, nan), ''),
+ ((nan, inf), (nan, -inf), 'XXX'), # now: (nan, nan)
+ ((nan, nan), (nan, nan), ''),
+ ]:
+ yield self._check, np.arccos, p, v, e
+
+ def test_cacosh(self):
+ for p, v, e in [
+ ((0., 0), (0, pi/2), ''),
+ ((-0., 0), (0, pi/2), ''),
+ ((1., inf), (inf, pi/2), 'XXX'), # now: (nan, nan)
+ ((1., nan), (nan, nan), ''),
+ ((-inf, 1.), (inf, pi), 'XXX'), # now: (inf, nan)
+ ((inf, 1.), (inf, 0.), 'XXX'), # now: (inf, nan)
+ ((-inf, inf), (inf, 3*pi/4), 'XXX'), # now: (nan, nan)
+ ((inf, inf), (inf, pi/4), 'XXX'), # now: (nan, nan)
+ ((inf, nan), (inf, nan), 'XXX'), # now: (nan, nan)
+ ((-inf, nan), (inf, nan), 'XXX'), # now: (nan, nan)
+ ((nan, 1.), (nan, nan), ''),
+ ((nan, inf), (inf, nan), 'XXX'), # now: (nan, nan)
+ ((nan, nan), (nan, nan), '')
+ ]:
+ yield self._check, np.arccosh, p, v, e
+
+ def test_casinh(self):
+ for p, v, e in [
+ ((0., 0), (0, 0), ''),
+ ((1., inf), (inf, pi/2), 'XXX'), # now: (inf, nan)
+ ((1., nan), (nan, nan), ''),
+ ((inf, 1.), (inf, 0.), 'XXX'), # now: (inf, nan)
+ ((inf, inf), (inf, pi/4), 'XXX'), # now: (nan, nan)
+ ((inf, nan), (nan, nan), 'XXX'), # now: (nan, nan)
+ ((nan, 0.), (nan, 0.), 'XXX'), # now: (nan, nan)
+ ((nan, 1.), (nan, nan), ''),
+ ((nan, inf), (+-inf, nan), 'XXX'), # now: (nan, nan)
+ ((nan, nan), (nan, nan), ''),
+ ]:
+ yield self._check, np.arcsinh, p, v, e
+
+ def test_catanh(self):
+ for p, v, e in [
+ ((0., 0), (0, 0), ''),
+ ((0., nan), (0., nan), 'XXX'), # now: (nan, nan)
+ ((1., 0.), (inf, 0.), 'XXX divide'), # now: (nan, nan)
+ ((1., inf), (inf, 0.), 'XXX'), # now: (nan, nan)
+ ((1., nan), (nan, nan), ''),
+ ((inf, 1.), (0., pi/2), 'XXX'), # now: (nan, nan)
+ ((inf, inf), (0, pi/2), 'XXX'), # now: (nan, nan)
+ ((inf, nan), (0, nan), 'XXX'), # now: (nan, nan)
+ ((nan, 1.), (nan, nan), ''),
+ ((nan, inf), (+0, pi/2), 'XXX'), # now: (nan, nan)
+ ((nan, nan), (nan, nan), ''),
+ ]:
+ yield self._check, np.arctanh, p, v, e
+
+ def _check(self, func, point, value, exc=''):
+ if 'XXX' in exc:
+ raise nose.SkipTest
+ if isinstance(point, tuple): point = complex(*point)
+ if isinstance(value, tuple): value = complex(*value)
+ v = dict(divide='ignore', invalid='ignore',
+ over='ignore', under='ignore')
+ old_err = np.seterr(**v)
+ try:
+ # check sign of zero, nan, etc.
+ got = complex(func(point))
+ got = "(%s, %s)" % (repr(got.real), repr(got.imag))
+ expected = "(%s, %s)" % (repr(value.real), repr(value.imag))
+ assert got == expected, (got, expected)
+
+ # check exceptions
+ if exc in ('divide', 'invalid', 'over', 'under'):
+ v[exc] = 'raise'
+ np.seterr(**v)
+ assert_raises(FloatingPointError, func, point)
+ else:
+ for k in v.keys(): v[k] = 'raise'
+ np.seterr(**v)
+ func(point)
+ finally:
+ np.seterr(**old_err)
+
class TestAttributes(TestCase):
def test_attributes(self):
add = ncu.add
@@ -216,6 +416,59 @@
assert_equal(add.nout, 1)
assert_equal(add.identity, 0)
+def _check_branch_cut(f, x0, dx, re_sign=1, im_sign=-1, sig_zero_ok=False,
+ dtype=np.complex):
+ """
+ Check for a branch cut in a function.
+ Assert that `x0` lies on a branch cut of function `f` and `f` is
+ continuous from the direction `dx`.
+
+ Parameters
+ ----------
+ f : func
+ Function to check
+ x0 : array-like
+ Point on branch cut
+ dx : array-like
+ Direction to check continuity in
+ re_sign, im_sign : {1, -1}
+ Change of sign of the real or imaginary part expected
+ sig_zero_ok : bool
+ Whether to check if the branch cut respects signed zero (if applicable)
+ dtype : dtype
+ Dtype to check (should be complex)
+
+ """
+ x0 = np.atleast_1d(x0).astype(dtype)
+ dx = np.atleast_1d(dx).astype(dtype)
+
+ scale = np.finfo(dtype).eps * 1e3
+ atol = 1e-4
+
+ y0 = f(x0)
+ yp = f(x0 + dx*scale*np.absolute(x0)/np.absolute(dx))
+ ym = f(x0 - dx*scale*np.absolute(x0)/np.absolute(dx))
+
+ assert np.all(np.absolute(y0.real - yp.real) < atol), (y0, yp)
+ assert np.all(np.absolute(y0.imag - yp.imag) < atol), (y0, yp)
+ assert np.all(np.absolute(y0.real - ym.real*re_sign) < atol), (y0, ym)
+ assert np.all(np.absolute(y0.imag - ym.imag*im_sign) < atol), (y0, ym)
+
+ if sig_zero_ok:
+ # check that signed zeros also work as a displacement
+ jr = (x0.real == 0) & (dx.real != 0)
+ ji = (x0.imag == 0) & (dx.imag != 0)
+
+ x = -x0
+ x.real[jr] = 0.*dx.real
+ x.imag[ji] = 0.*dx.imag
+ x = -x
+ ym = f(x)
+ ym = ym[jr | ji]
+ y0 = y0[jr | ji]
+ assert np.all(np.absolute(y0.real - ym.real*re_sign) < atol), (y0, ym)
+ assert np.all(np.absolute(y0.imag - ym.imag*im_sign) < atol), (y0, ym)
+
if __name__ == "__main__":
run_module_suite()
More information about the Numpy-svn
mailing list