[pypy-svn] pypy cmath: (lac, arigo)

arigo commits-noreply at bitbucket.org
Mon Jan 17 17:14:46 CET 2011


Author: Armin Rigo <arigo at tunes.org>
Branch: cmath
Changeset: r40779:b03a2f70599c
Date: 2011-01-17 17:14 +0100
http://bitbucket.org/pypy/pypy/changeset/b03a2f70599c/

Log:	(lac, arigo)

	atan(), atanh().

diff --git a/pypy/module/cmath/__init__.py b/pypy/module/cmath/__init__.py
--- a/pypy/module/cmath/__init__.py
+++ b/pypy/module/cmath/__init__.py
@@ -8,6 +8,8 @@
     'acosh': "Return the hyperbolic arc cosine of x.",
     'asin': "Return the arc sine of x.",
     'asinh': "Return the hyperbolic arc sine of x.",
+    'atan': "Return the arc tangent of x.",
+    'atanh': "Return the hyperbolic arc tangent of x.",
     }
 
 

diff --git a/pypy/module/cmath/test/test_cmath.py b/pypy/module/cmath/test/test_cmath.py
--- a/pypy/module/cmath/test/test_cmath.py
+++ b/pypy/module/cmath/test/test_cmath.py
@@ -140,12 +140,13 @@
             function = getattr(interp_cmath, 'c_' + fn)
         if 'divide-by-zero' in flags or 'invalid' in flags:
             try:
-                actual = function(arg)
+                actual = function(*arg)
             except ValueError:
                 continue
             else:
-                self.fail('ValueError not raised in test '
-                      '{}: {}(complex({!r}, {!r}))'.format(id, fn, ar, ai))
+                raise AssertionError('ValueError not raised in test '
+                                     '%s: %s(complex(%r, %r))' % (id, fn,
+                                                                  ar, ai))
 
         if 'overflow' in flags:
             try:

diff --git a/pypy/module/cmath/interp_cmath.py b/pypy/module/cmath/interp_cmath.py
--- a/pypy/module/cmath/interp_cmath.py
+++ b/pypy/module/cmath/interp_cmath.py
@@ -1,22 +1,31 @@
 import math
 from math import fabs
-from pypy.rlib.rarithmetic import copysign, asinh
+from pypy.rlib.rarithmetic import copysign, asinh, log1p
 from pypy.interpreter.gateway import ObjSpace, W_Root
 from pypy.module.cmath import Module, names_and_docstrings
 from pypy.module.cmath.constant import DBL_MIN, CM_SCALE_UP, CM_SCALE_DOWN
 from pypy.module.cmath.constant import CM_LARGE_DOUBLE, M_LN2
-from pypy.module.cmath.special_value import isfinite, special_type
+from pypy.module.cmath.constant import CM_SQRT_LARGE_DOUBLE, CM_SQRT_DBL_MIN
+from pypy.module.cmath.special_value import isfinite, special_type, INF
 from pypy.module.cmath.special_value import sqrt_special_values
 from pypy.module.cmath.special_value import acos_special_values
 from pypy.module.cmath.special_value import acosh_special_values
 from pypy.module.cmath.special_value import asinh_special_values
+from pypy.module.cmath.special_value import atanh_special_values
 
 
 def unaryfn(c_func):
     def wrapper(space, w_z):
         x = space.float_w(space.getattr(w_z, space.wrap('real')))
         y = space.float_w(space.getattr(w_z, space.wrap('imag')))
-        resx, resy = c_func(x, y)
+        try:
+            resx, resy = c_func(x, y)
+        except ValueError:
+            raise OperationError(space.w_ValueError,
+                                 space.wrap("math domain error"))
+        except OverflowError:
+            raise OperationError(space.w_OverflowError,
+                                 space.wrap("math range error"))
         return space.newcomplex(resx, resy)
     #
     name = c_func.func_name
@@ -27,6 +36,10 @@
     return c_func
 
 
+def c_neg(x, y):
+    return (-x, -y)
+
+
 @unaryfn
 def c_sqrt(x, y):
     # Method: use symmetries to reduce to the case when x = z.real and y
@@ -125,6 +138,7 @@
 
 @unaryfn
 def c_asin(x, y):
+    # asin(z) = -i asinh(iz)
     sx, sy = c_asinh(-y, x)
     return (sy, -sx)
 
@@ -148,3 +162,46 @@
         real = asinh(s1x*s2y - s2x*s1y)
         imag = math.atan2(y, s1x*s2x - s1y*s2y)
     return (real, imag)
+
+
+ at unaryfn
+def c_atan(x, y):
+    # atan(z) = -i atanh(iz)
+    sx, sy = c_atanh(-y, x)
+    return (sy, -sx)
+
+
+ at unaryfn
+def c_atanh(x, y):
+    if not isfinite(x) or not isfinite(y):
+        return atanh_special_values[special_type(x)][special_type(y)]
+
+    # Reduce to case where x >= 0., using atanh(z) = -atanh(-z).
+    if x < 0.:
+        return c_neg(*c_atanh(*c_neg(x, y)))
+
+    ay = fabs(y)
+    if x > CM_SQRT_LARGE_DOUBLE or ay > CM_SQRT_LARGE_DOUBLE:
+        # if abs(z) is large then we use the approximation
+        # atanh(z) ~ 1/z +/- i*pi/2 (+/- depending on the sign
+        # of y
+        h = math.hypot(x/2., y/2.)   # safe from overflow
+        real = x/4./h/h
+        # the two negations in the next line cancel each other out
+        # except when working with unsigned zeros: they're there to
+        # ensure that the branch cut has the correct continuity on
+        # systems that don't support signed zeros
+        imag = -copysign(math.pi/2., -y)
+    elif x == 1. and ay < CM_SQRT_DBL_MIN:
+        # C99 standard says:  atanh(1+/-0.) should be inf +/- 0i
+        if ay == 0.:
+            real = INF
+            imag = y
+            raise ValueError("result is infinite")
+        else:
+            real = -math.log(math.sqrt(ay)/math.sqrt(math.hypot(ay, 2.)))
+            imag = copysign(math.atan2(2., -ay)/2, y)
+    else:
+        real = log1p(4.*x/((1-x)*(1-x) + ay*ay))/4.
+        imag = -math.atan2(-2.*y, (1-x)*(1+x) - ay*ay)/2.
+    return (real, imag)

diff --git a/pypy/module/cmath/special_value.py b/pypy/module/cmath/special_value.py
--- a/pypy/module/cmath/special_value.py
+++ b/pypy/module/cmath/special_value.py
@@ -86,6 +86,16 @@
     (INF,N),    (N,N),     (N,-0.),   (N,0.),   (N,N),    (INF,N),   (N,N),
     ])
 
+atanh_special_values = build_table([
+    (-0.,-P12),(-0.,-P12),(-0.,-P12),(-0.,P12),(-0.,P12),(-0.,P12),(-0.,N),
+    (-0.,-P12),(U,U),     (U,U),     (U,U),    (U,U),    (-0.,P12),(N,N),
+    (-0.,-P12),(U,U),     (-0.,-0.), (-0.,0.), (U,U),    (-0.,P12),(-0.,N),
+    (0.,-P12), (U,U),     (0.,-0.),  (0.,0.),  (U,U),    (0.,P12), (0.,N),
+    (0.,-P12), (U,U),     (U,U),     (U,U),    (U,U),    (0.,P12), (N,N),
+    (0.,-P12), (0.,-P12), (0.,-P12), (0.,P12), (0.,P12), (0.,P12), (0.,N),
+    (0.,-P12), (N,N),     (N,N),     (N,N),    (N,N),    (0.,P12), (N,N),
+    ])
+
 sqrt_special_values = build_table([
     (INF,-INF), (0.,-INF), (0.,-INF), (0.,INF), (0.,INF), (INF,INF), (N,INF),
     (INF,-INF), (U,U),     (U,U),     (U,U),    (U,U),    (INF,INF), (N,N),


More information about the Pypy-commit mailing list