[pypy-svn] r75980 - pypy/branch/fast-forward/pypy/module/math

benjamin at codespeak.net benjamin at codespeak.net
Wed Jul 7 17:13:18 CEST 2010


Author: benjamin
Date: Wed Jul  7 17:13:17 2010
New Revision: 75980

Modified:
   pypy/branch/fast-forward/pypy/module/math/__init__.py
   pypy/branch/fast-forward/pypy/module/math/interp_math.py
Log:
add gamma and lgamma

Modified: pypy/branch/fast-forward/pypy/module/math/__init__.py
==============================================================================
--- pypy/branch/fast-forward/pypy/module/math/__init__.py	(original)
+++ pypy/branch/fast-forward/pypy/module/math/__init__.py	Wed Jul  7 17:13:17 2010
@@ -47,5 +47,7 @@
        'expm1'          : 'interp_math.expm1',
        'erf'            : 'interp_math.erf',
        'erfc'           : 'interp_math.erfc',
+       'gamma'          : 'interp_math.gamma',
+       'lgamma'         : 'interp_math.lgamma',
 }
 

Modified: pypy/branch/fast-forward/pypy/module/math/interp_math.py
==============================================================================
--- pypy/branch/fast-forward/pypy/module/math/interp_math.py	(original)
+++ pypy/branch/fast-forward/pypy/module/math/interp_math.py	Wed Jul  7 17:13:17 2010
@@ -432,6 +432,16 @@
     return math1(space, _erfc, x)
 erfc.unwrap_spec = [ObjSpace, float]
 
+def gamma(space, x):
+    """Compute the gamma function for x."""
+    return math1(space, _gamma, x)
+gamma.unwrap_spec = [ObjSpace, float]
+
+def lgamma(space, x):
+    """Compute the natural logarithm of the gamma function for x."""
+    return math1(space, _lgamma, x)
+lgamma.unwrap_spec = [ObjSpace, float]
+
 # Implementation of the error function, the complimentary error function, the
 # gamma function, and the natural log of the gamma function.  This exist in
 # libm, but I hear those implementations are horrible.
@@ -488,3 +498,137 @@
     else:
         cf = _erfc_contfrac(absx)
         return cf if x > 0. else 2. - cf
+
+def _sinpi(x):
+    y = math.fmod(abs(x), 2.)
+    n = int(round(2. * y))
+    if n == 0:
+        r = math.sin(math.pi * y)
+    elif n == 1:
+        r = math.cos(math.pi * (y - .5))
+    elif n == 2:
+        r = math.sin(math.pi * (1. - y))
+    elif n == 3:
+        r = -math.cos(math.pi * (y - 1.5))
+    elif n == 4:
+        r = math.sin(math.pi * (y - 2.))
+    else:
+        raise AssertionError("should not reach")
+    return rarithmetic.copysign(1., x) * r
+
+_lanczos_g = 6.024680040776729583740234375
+_lanczos_g_minus_half = 5.524680040776729583740234375
+_lanczos_num_coeffs = [
+    23531376880.410759688572007674451636754734846804940,
+    42919803642.649098768957899047001988850926355848959,
+    35711959237.355668049440185451547166705960488635843,
+    17921034426.037209699919755754458931112671403265390,
+    6039542586.3520280050642916443072979210699388420708,
+    1439720407.3117216736632230727949123939715485786772,
+    248874557.86205415651146038641322942321632125127801,
+    31426415.585400194380614231628318205362874684987640,
+    2876370.6289353724412254090516208496135991145378768,
+    186056.26539522349504029498971604569928220784236328,
+    8071.6720023658162106380029022722506138218516325024,
+    210.82427775157934587250973392071336271166969580291,
+    2.5066282746310002701649081771338373386264310793408
+]
+_lanczos_den_coeffs = [
+    0.0, 39916800.0, 120543840.0, 150917976.0, 105258076.0, 45995730.0,
+    13339535.0, 2637558.0, 357423.0, 32670.0, 1925.0, 66.0, 1.0]
+_gamma_integrals = [
+    1.0, 1.0, 2.0, 6.0, 24.0, 120.0, 720.0, 5040.0, 40320.0, 362880.0,
+    3628800.0, 39916800.0, 479001600.0, 6227020800.0, 87178291200.0,
+    1307674368000.0, 20922789888000.0, 355687428096000.0,
+    6402373705728000.0, 121645100408832000.0, 2432902008176640000.0,
+    51090942171709440000.0, 1124000727777607680000.0]
+
+def _lanczos_sum(x):
+    num = 0.
+    den = 0.
+    assert x > 0.
+    if x < 5.:
+        for i in range(len(_lanczos_den_coeffs) - 1, -1, -1):
+            num = num * x + _lanczos_num_coeffs[i]
+            den = den * x + _lanczos_den_coeffs[i]
+    else:
+        for i in range(len(_lanczos_den_coeffs)):
+            num = num / x + _lanczos_num_coeffs[i]
+            den = den / x + _lanczos_den_coeffs[i]
+    return num / den
+
+def _gamma(x):
+    if rarithmetic.isnan(x) or (rarithmetic.isinf(x) and x > 0.):
+        return x
+    if rarithmetic.isinf(x):
+        raise ValueError("math domain error")
+    if x == 0.:
+        raise ValueError("math domain error")
+    if x == math.floor(x):
+        if x < 0.:
+            raise ValueError("math domain error")
+        if x < len(_gamma_integrals):
+            return _gamma_integrals[int(x) - 1]
+    absx = abs(x)
+    if absx < 1e-20:
+        r = 1. / x
+        if rarithmetic.isinf(r):
+            raise OverflowError("math range error")
+        return r
+    if absx > 200.:
+        if x < 0.:
+            return 0. / -_sinpi(x)
+        else:
+            raise OverflowError("math range error")
+    y = absx + _lanczos_g_minus_half
+    if absx > _lanczos_g_minus_half:
+        q = y - absx
+        z = q - _lanczos_g_minus_half
+    else:
+        q = y - _lanczos_g_minus_half
+        z = q - absx
+    z = z * _lanczos_g / y
+    if x < 0.:
+        r = -math.pi / _sinpi(absx) / absx * math.exp(y) / _lanczos_sum(absx)
+        r -= z * r
+        if absx < 140.:
+            r /= math.pow(y, absx - .5)
+        else:
+            sqrtpow = math.pow(y, absx / 2. - .25)
+            r /= sqrtpow
+            r /= sqrtpow
+    else:
+        r = _lanczos_sum(absx) / math.exp(y)
+        r += z * r
+        if absx < 140.:
+            r *= math.pow(y, absx - .5)
+        else:
+            sqrtpow = math.pow(y, absx / 2. - .25)
+            r *= sqrtpow
+            r *= sqrtpow
+    if rarithmetic.isinf(r):
+        raise OverflowError("math range error")
+    return r
+
+def _lgamma(x):
+    if rarithmetic.isnan(x):
+        return x
+    if rarithmetic.isinf(x):
+        return rarithmetic.INFINITY
+    if x == math.floor(x) and x <= 2.:
+        if x <= 0.:
+            raise ValueError("math range error")
+        return 0.
+    absx = abs(x)
+    if absx < 1e-20:
+        return -math.log(absx)
+    if x > 0.:
+        r = (math.log(_lanczos_sum(x)) - _lanczos_g + (x - .5) *
+             (math.log(x + _lanczos_g - .5) - 1))
+    else:
+        r = (math.log(math.pi) - math.log(abs(_sinpi(absx))) - math.log(absx) -
+             (math.log(_lanczos_sum(absx)) - _lanczos_g +
+              (absx - .5) * (math.log(absx + _lanczos_g - .5) - 1)))
+    if rarithmetic.isinf(r):
+        raise OverflowError("math domain error")
+    return r



More information about the Pypy-commit mailing list