[Numpy-svn] r5458 - trunk/numpy/core/tests

numpy-svn at scipy.org numpy-svn at scipy.org
Sat Jul 19 17:58:21 EDT 2008


Author: ptvirtan
Date: 2008-07-19 16:58:14 -0500 (Sat, 19 Jul 2008)
New Revision: 5458

Modified:
   trunk/numpy/core/tests/test_umath.py
Log:
Add tests for complex functions: test against Python's cmath, check the branch cuts and C99 compliance at inf, nan special points.

Modified: trunk/numpy/core/tests/test_umath.py
===================================================================
--- trunk/numpy/core/tests/test_umath.py	2008-07-19 21:56:46 UTC (rev 5457)
+++ trunk/numpy/core/tests/test_umath.py	2008-07-19 21:58:14 UTC (rev 5458)
@@ -1,6 +1,8 @@
 from numpy.testing import *
 import numpy.core.umath as ncu
 import numpy as np
+import nose
+from numpy import inf, nan, pi
 
 class TestDivision(TestCase):
     def test_division_int(self):
@@ -35,7 +37,6 @@
         assert_almost_equal(x**14, [-76443+16124j, 23161315+58317492j,
                                     5583548873 +  2465133864j])
 
-
 class TestLog1p(TestCase):
     def test_log1p(self):
         assert_almost_equal(ncu.log1p(0.2), ncu.log(1.2))
@@ -179,10 +180,10 @@
         assert_equal(np.choose(c, (a, 1)), np.array([1,1]))
 
 
-class TestComplexFunctions(TestCase):
+class TestComplexFunctions(object):
     funcs = [np.arcsin , np.arccos , np.arctan, np.arcsinh, np.arccosh,
              np.arctanh, np.sin    , np.cos   , np.tan    , np.exp,
-             np.log    , np.sqrt   , np.log10]
+             np.log    , np.sqrt   , np.log10,  np.log1p]
 
     def test_it(self):
         for f in self.funcs:
@@ -204,7 +205,206 @@
             assert_almost_equal(fcf, fcd, decimal=6, err_msg='fch-fcd %s'%f)
             assert_almost_equal(fcl, fcd, decimal=15, err_msg='fch-fcl %s'%f)
 
+    def test_branch_cuts(self):
+        # check branch cuts and continuity on them
+        yield _check_branch_cut, np.log,   -0.5, 1j, 1, -1, True
+        yield _check_branch_cut, np.log10, -0.5, 1j, 1, -1, True
+        yield _check_branch_cut, np.log1p, -1.5, 1j, 1, -1, True
+        yield _check_branch_cut, np.sqrt,  -0.5, 1j, 1, -1
+        
+        yield _check_branch_cut, np.arcsin, [ -2, 2],   [1j, -1j], 1, -1
+        yield _check_branch_cut, np.arccos, [ -2, 2],   [1j, -1j], 1, -1
+        yield _check_branch_cut, np.arctan, [-2j, 2j],  [1,  -1 ], -1, 1
+        
+        yield _check_branch_cut, np.arcsinh, [-2j,  2j], [-1,   1], -1, 1
+        yield _check_branch_cut, np.arccosh, [ -1, 0.5], [1j,  1j], 1, -1
+        yield _check_branch_cut, np.arctanh, [ -2,   2], [1j, -1j], 1, -1
 
+        # check against bogus branch cuts: assert continuity between quadrants
+        yield _check_branch_cut, np.arcsin, [-2j, 2j], [ 1,  1], 1, 1
+        yield _check_branch_cut, np.arccos, [-2j, 2j], [ 1,  1], 1, 1
+        yield _check_branch_cut, np.arctan, [ -2,  2], [1j, 1j], 1, 1
+
+        yield _check_branch_cut, np.arcsinh, [ -2,  2, 0], [1j, 1j, 1 ], 1, 1
+        yield _check_branch_cut, np.arccosh, [-2j, 2j, 2], [1,  1,  1j], 1, 1
+        yield _check_branch_cut, np.arctanh, [-2j, 2j, 0], [1,  1,  1j], 1, 1
+
+    def test_branch_cuts_failing(self):
+        # XXX: signed zeros are not OK for sqrt or for the arc* functions
+        yield _check_branch_cut, np.sqrt,  -0.5, 1j, 1, -1, True
+        yield _check_branch_cut, np.arcsin, [ -2, 2],   [1j, -1j], 1, -1, True
+        yield _check_branch_cut, np.arccos, [ -2, 2],   [1j, -1j], 1, -1, True
+        yield _check_branch_cut, np.arctan, [-2j, 2j],  [1,  -1 ], -1, 1, True
+        yield _check_branch_cut, np.arcsinh, [-2j,  2j], [-1,   1], -1, 1, True
+        yield _check_branch_cut, np.arccosh, [ -1, 0.5], [1j,  1j], 1, -1, True
+        yield _check_branch_cut, np.arctanh, [ -2,   2], [1j, -1j], 1, -1, True
+    test_branch_cuts_failing = dec.skipknownfailure(test_branch_cuts_failing)
+        
+    def test_against_cmath(self):
+        import cmath, sys
+
+        # cmath.asinh is broken in some versions of Python, see
+        # http://bugs.python.org/issue1381
+        broken_cmath_asinh = False
+        if sys.version_info < (2,5,3):
+            broken_cmath_asinh = True
+        
+        points = [-2, 2j, 2, -2j, -1-1j, -1+1j, +1-1j, +1+1j]
+        name_map = {'arcsin': 'asin', 'arccos': 'acos', 'arctan': 'atan',
+                    'arcsinh': 'asinh', 'arccosh': 'acosh', 'arctanh': 'atanh'}
+        atol = 4*np.finfo(np.complex).eps
+        for func in self.funcs:
+            fname = func.__name__.split('.')[-1]
+            cname = name_map.get(fname, fname)
+            try: cfunc = getattr(cmath, cname)
+            except AttributeError: continue
+            for p in points:
+                a = complex(func(np.complex_(p)))
+                b = cfunc(p)
+                
+                if cname == 'asinh' and broken_cmath_asinh:
+                    continue 
+
+                assert abs(a - b) < atol, "%s %s: %s; cmath: %s"%(fname,p,a,b)
+
+class TestC99(object):
+    """Check special functions at special points against the C99 standard"""
+    # NB: inherits from object instead of TestCase since using test generators
+    
+    #
+    # Non-conforming results are with XXX added to the exception field.
+    #
+    
+    def test_clog(self):
+        for p, v, e in [
+            ((-0., 0.), (-inf, pi), 'divide'),
+            ((+0., 0.), (-inf, 0.), 'divide'),
+            ((1., inf), (inf, pi/2), ''),
+            ((1., nan), (nan, nan), ''),
+            ((-inf, 1.), (inf, pi), ''),
+            ((inf, 1.), (inf, 0.), ''),
+            ((-inf, inf), (inf, 3*pi/4), ''),
+            ((inf, inf), (inf, pi/4), ''),
+            ((inf, nan), (inf, nan), ''),
+            ((-inf, nan), (inf, nan), ''),
+            ((nan, 0.), (nan, nan), ''),
+            ((nan, 1.), (nan, nan), ''),
+            ((nan, inf), (inf, nan), ''),
+            ((+nan, nan), (nan, nan), ''),
+        ]:
+            yield self._check, np.log, p, v, e
+    
+    def test_csqrt(self):
+        for p, v, e in [
+            ((-0., 0.), (0.,0.),  'XXX'), # now (-0., 0.)
+            ((0., 0.), (0.,0.),  ''),
+            ((1., inf), (inf,inf), 'XXX invalid'), # now (inf, nan)
+            ((nan, inf), (inf,inf), 'XXX'), # now (nan, nan)
+            ((-inf, 1.), (0.,inf), ''),
+            ((inf, 1.), (inf,0.), ''),
+            ((-inf,nan), (nan, -inf), ''), # could also be +inf
+            ((inf, nan), (inf, nan),  ''),
+            ((nan, 1.), (nan, nan), ''),
+            ((nan, nan), (nan, nan), ''),
+        ]:
+            yield self._check, np.sqrt, p, v, e
+
+    def test_cacos(self):
+        for p, v, e in [
+            ((0., 0.), (pi/2, -0.), 'XXX'), # now (-0., 0.)
+            ((-0., 0.), (pi/2, -0.), ''),
+            ((0., nan), (pi/2, nan), 'XXX'), # now (nan, nan)
+            ((-0., nan), (pi/2, nan), 'XXX'), # now (nan, nan)
+            ((1., inf), (pi/2, -inf), 'XXX'), # now (nan, -inf)
+            ((1., nan), (nan, nan), ''),
+            ((-inf, 1.), (pi, -inf), 'XXX'), # now (nan, -inf)
+            ((inf, 1.), (0., -inf), 'XXX'), # now (nan, -inf)
+            ((-inf, inf), (3*pi/4, -inf), 'XXX'), # now (nan, nan)
+            ((inf, inf), (pi/4, -inf), 'XXX'), # now (nan, nan)
+            ((inf, nan), (nan, +-inf), 'XXX'), # now (nan, nan)
+            ((-inf, nan), (nan, +-inf), 'XXX'), # now: (nan, nan)
+            ((nan, 1.), (nan, nan), ''),
+            ((nan, inf), (nan, -inf), 'XXX'), # now: (nan, nan)
+            ((nan, nan), (nan, nan), ''),
+        ]:
+            yield self._check, np.arccos, p, v, e
+
+    def test_cacosh(self):
+        for p, v, e in [
+            ((0., 0), (0, pi/2), ''),
+            ((-0., 0), (0, pi/2), ''),
+            ((1., inf), (inf, pi/2), 'XXX'), # now: (nan, nan)
+            ((1., nan), (nan, nan), ''),
+            ((-inf, 1.), (inf, pi), 'XXX'), # now: (inf, nan)
+            ((inf, 1.), (inf, 0.), 'XXX'), # now: (inf, nan)
+            ((-inf, inf), (inf, 3*pi/4), 'XXX'), # now: (nan, nan)
+            ((inf, inf), (inf, pi/4), 'XXX'), # now: (nan, nan)
+            ((inf, nan), (inf, nan), 'XXX'), # now: (nan, nan)
+            ((-inf, nan), (inf, nan), 'XXX'), # now: (nan, nan)
+            ((nan, 1.), (nan, nan), ''),
+            ((nan, inf), (inf, nan), 'XXX'), # now: (nan, nan)
+            ((nan, nan), (nan, nan), '')
+        ]:
+            yield self._check, np.arccosh, p, v, e
+
+    def test_casinh(self):
+        for p, v, e in [
+            ((0., 0), (0, 0), ''),
+            ((1., inf), (inf, pi/2), 'XXX'), # now: (inf, nan)
+            ((1., nan), (nan, nan), ''),
+            ((inf, 1.), (inf, 0.), 'XXX'), # now: (inf, nan)
+            ((inf, inf), (inf, pi/4), 'XXX'), # now: (nan, nan)
+            ((inf, nan), (nan, nan), 'XXX'), # now: (nan, nan)
+            ((nan, 0.), (nan, 0.), 'XXX'), # now: (nan, nan)
+            ((nan, 1.), (nan, nan), ''),
+            ((nan, inf), (+-inf, nan), 'XXX'), # now: (nan, nan)
+            ((nan, nan), (nan, nan), ''),
+        ]:
+            yield self._check, np.arcsinh, p, v, e
+
+    def test_catanh(self):
+        for p, v, e in [
+            ((0., 0), (0, 0), ''),
+            ((0., nan), (0., nan), 'XXX'), # now: (nan, nan)
+            ((1., 0.), (inf, 0.), 'XXX divide'), # now: (nan, nan)
+            ((1., inf), (inf, 0.), 'XXX'), # now: (nan, nan)
+            ((1., nan), (nan, nan), ''),
+            ((inf, 1.), (0., pi/2), 'XXX'), # now: (nan, nan)
+            ((inf, inf), (0, pi/2), 'XXX'), # now: (nan, nan)
+            ((inf, nan), (0, nan), 'XXX'), # now: (nan, nan)
+            ((nan, 1.), (nan, nan), ''),
+            ((nan, inf), (+0, pi/2), 'XXX'), # now: (nan, nan)
+            ((nan, nan), (nan, nan), ''),
+        ]:
+            yield self._check, np.arctanh, p, v, e
+
+    def _check(self, func, point, value, exc=''):
+        if 'XXX' in exc:
+            raise nose.SkipTest
+        if isinstance(point, tuple): point = complex(*point)
+        if isinstance(value, tuple): value = complex(*value)
+        v = dict(divide='ignore', invalid='ignore',
+                 over='ignore', under='ignore')
+        old_err = np.seterr(**v)
+        try:
+            # check sign of zero, nan, etc.
+            got = complex(func(point))
+            got = "(%s, %s)" % (repr(got.real), repr(got.imag))
+            expected = "(%s, %s)" % (repr(value.real), repr(value.imag))
+            assert got == expected, (got, expected)
+            
+            # check exceptions
+            if exc in ('divide', 'invalid', 'over', 'under'):
+                v[exc] = 'raise'
+                np.seterr(**v)
+                assert_raises(FloatingPointError, func, point)
+            else:
+                for k in v.keys(): v[k] = 'raise'
+                np.seterr(**v)
+                func(point)
+        finally:
+            np.seterr(**old_err)
+
 class TestAttributes(TestCase):
     def test_attributes(self):
         add = ncu.add
@@ -216,6 +416,59 @@
         assert_equal(add.nout, 1)
         assert_equal(add.identity, 0)
 
+def _check_branch_cut(f, x0, dx, re_sign=1, im_sign=-1, sig_zero_ok=False,
+                      dtype=np.complex):
+    """
+    Check for a branch cut in a function.
 
+    Assert that `x0` lies on a branch cut of function `f` and `f` is
+    continuous from the direction `dx`.
+
+    Parameters
+    ----------
+    f : func
+        Function to check
+    x0 : array-like
+        Point on branch cut
+    dx : array-like
+        Direction to check continuity in
+    re_sign, im_sign : {1, -1}
+        Change of sign of the real or imaginary part expected
+    sig_zero_ok : bool
+        Whether to check if the branch cut respects signed zero (if applicable)
+    dtype : dtype
+        Dtype to check (should be complex)
+
+    """
+    x0 = np.atleast_1d(x0).astype(dtype)
+    dx = np.atleast_1d(dx).astype(dtype)
+    
+    scale = np.finfo(dtype).eps * 1e3
+    atol  = 1e-4
+    
+    y0 = f(x0)
+    yp = f(x0 + dx*scale*np.absolute(x0)/np.absolute(dx))
+    ym = f(x0 - dx*scale*np.absolute(x0)/np.absolute(dx))
+    
+    assert np.all(np.absolute(y0.real - yp.real) < atol), (y0, yp)
+    assert np.all(np.absolute(y0.imag - yp.imag) < atol), (y0, yp)
+    assert np.all(np.absolute(y0.real - ym.real*re_sign) < atol), (y0, ym)
+    assert np.all(np.absolute(y0.imag - ym.imag*im_sign) < atol), (y0, ym)
+    
+    if sig_zero_ok:
+        # check that signed zeros also work as a displacement
+        jr = (x0.real == 0) & (dx.real != 0)
+        ji = (x0.imag == 0) & (dx.imag != 0)
+        
+        x = -x0
+        x.real[jr] = 0.*dx.real
+        x.imag[ji] = 0.*dx.imag
+        x = -x
+        ym = f(x)
+        ym = ym[jr | ji]
+        y0 = y0[jr | ji]
+        assert np.all(np.absolute(y0.real - ym.real*re_sign) < atol), (y0, ym)
+        assert np.all(np.absolute(y0.imag - ym.imag*im_sign) < atol), (y0, ym)
+
 if __name__ == "__main__":
     run_module_suite()




More information about the Numpy-svn mailing list