[Scipy-svn] r2230 - trunk/Lib/special

scipy-svn at scipy.org scipy-svn at scipy.org
Sun Sep 24 04:24:25 EDT 2006


Author: rkern
Date: 2006-09-24 03:24:24 -0500 (Sun, 24 Sep 2006)
New Revision: 2230

Modified:
   trunk/Lib/special/orthogonal.py
Log:
Use modern numpy idioms.

Modified: trunk/Lib/special/orthogonal.py
===================================================================
--- trunk/Lib/special/orthogonal.py	2006-09-24 08:08:41 UTC (rev 2229)
+++ trunk/Lib/special/orthogonal.py	2006-09-24 08:24:24 UTC (rev 2230)
@@ -1,5 +1,3 @@
-## Automatically adapted for scipy Oct 05, 2005 by convertcode.py
-
 #!/usr/bin/env python
 #
 # Author:  Travis Oliphant 2000
@@ -60,21 +58,25 @@
 """
 
 from __future__ import nested_scopes
-from numpy import *
-from numpy.oldnumeric import take
+
+# Scipy imports.
+import numpy as np
+from numpy import all, any, exp, inf, pi, sqrt
+from scipy.linalg import eig
+
+# Local imports.
 import _cephes as cephes
 _gam = cephes.gamma
-from scipy.linalg import eig
 
 def poch(z,m):
     """Pochhammer symbol (z)_m = (z)(z+1)....(z+m-1) = gamma(z+m)/gamma(z)"""
     return _gam(z+m) / _gam(z)
 
-class orthopoly1d(poly1d):
+class orthopoly1d(np.poly1d):
     def __init__(self, roots, weights=None, hn=1.0, kn=1.0, wfunc=None, limits=None, monic=0):
-        poly1d.__init__(self, roots, r=1)
+        np.poly1d.__init__(self, roots, r=1)
         equiv_weights = [weights[k] / wfunc(roots[k]) for k in range(len(roots))]
-        self.__dict__['weights'] = array(zip(roots,weights,equiv_weights))
+        self.__dict__['weights'] = np.array(zip(roots,weights,equiv_weights))
         self.__dict__['weight_func'] = wfunc
         self.__dict__['limits'] = limits
         mu = sqrt(hn)
@@ -88,7 +90,7 @@
 def gen_roots_and_weights(n,an_func,sqrt_bn_func,mu):
     """[x,w] = gen_roots_and_weights(n,an_func,sqrt_bn_func,mu)
 
-    Returns the roots (x) of an nth order orthogonal polynomail,
+    Returns the roots (x) of an nth order orthogonal polynomial,
     and weights (w) to use in appropriate Gaussian quadrature with that
     orthogonal polynomial.
 
@@ -99,14 +101,16 @@
     sqrt_bn_func(n)     should return sqrt(B_n)
     mu ( = h_0 )        is the integral of the weight over the orthogonal interval
     """
-    nn = arange(1.0,n)
+    nn = np.arange(1.0,n)
     sqrt_bn = sqrt_bn_func(nn)
-    an = an_func(concatenate(([0],nn)))
-    [x,v] = eig((diag(an)+diag(sqrt_bn,1)+diag(sqrt_bn,-1)))
+    an = an_func(np.concatenate(([0], nn)))
+    x, v = eig((np.diagflat(an) + 
+                np.diagflat(sqrt_bn,1) + 
+                np.diagflat(sqrt_bn,-1)))
     answer = []
-    sortind = argsort(real(x))
-    answer.append(take(x,sortind,axis=0))
-    answer.append(take(mu*v[0]**2,sortind,axis=0))
+    sortind = x.real.argsort()
+    answer.append(x[sortind])
+    answer.append((mu*v[0]**2)[sortind])
     return answer
 
 # Jacobi Polynomials 1               P^(alpha,beta)_n(x)
@@ -118,17 +122,17 @@
     function (1-x)**alpha (1+x)**beta with alpha,beta > -1.
     """
     if any(alpha <= -1) or any(beta <= -1):
-        raise ValueError, "alpha and beta must be greater than -1."
+        raise ValueError("alpha and beta must be greater than -1.")
     assert(n>0), "n must be positive."
 
     (p,q) = (alpha,beta)
     # from recurrence relations
     sbn_J = lambda k: 2.0/(2.0*k+p+q)*sqrt((k+p)*(k+q)/(2*k+q+p+1)) * \
-                (where(k==1,1.0,sqrt(k*(k+p+q)/(2.0*k+p+q-1))))
+                (np.where(k==1,1.0,sqrt(k*(k+p+q)/(2.0*k+p+q-1))))
     if any(p == q):  # XXX any or all???
         an_J = lambda k: 0.0*k
     else:
-        an_J = lambda k: where(k==0,(q-p)/(p+q+2.0),
+        an_J = lambda k: np.where(k==0,(q-p)/(p+q+2.0),
                                (q*q - p*p)/((2.0*k+p+q)*(2.0*k+p+q+2)))
     g = cephes.gamma
     mu0 = 2.0**(p+q+1)*g(p+1)*g(q+1)/(g(p+q+2))
@@ -145,7 +149,8 @@
     """
     assert(n>=0), "n must be nonnegative"
     wfunc = lambda x: (1-x)**alpha * (1+x)**beta
-    if n==0: return orthopoly1d([],[],1.0,1.0,wfunc,(-1,1),monic)
+    if n==0:
+        return orthopoly1d([],[],1.0,1.0,wfunc,(-1,1),monic)
     x,w,mu = j_roots(n,alpha,beta,mu=1)
     ab1 = alpha+beta+1.0
     hn = 2**ab1/(2*n+ab1)*_gam(n+alpha+1)
@@ -164,17 +169,17 @@
     function (1-x)**(p-q) x**(q-1) with p-q > -1 and q > 0.
     """
     # from recurrence relation
-    if not ( any( (p1 - q1) > -1 ) and any( q1 > 0 ) ):
-        raise ValueError, "(p - q) > -1 and q > 0 please."
+    if not ( any((p1 - q1) > -1) and any(q1 > 0) ):
+        raise ValueError("(p - q) > -1 and q > 0 please.")
     if (n <= 0):
-        raise ValueError, "n must be positive."
+        raise ValueError("n must be positive.")
 
     p,q = p1,q1
 
-    sbn_Js = lambda k: sqrt(where(k==1,q*(p-q+1.0)/(p+2.0), \
+    sbn_Js = lambda k: sqrt(np.where(k==1,q*(p-q+1.0)/(p+2.0), \
                                   k*(k+q-1.0)*(k+p-1.0)*(k+p-q) \
                                   / ((2.0*k+p-2) * (2.0*k+p))))/(2*k+p-1.0)
-    an_Js = lambda k: where(k==0,q/(p+1.0),(2.0*k*(k+p)+q*(p-1.0)) / ((2.0*k+p+1.0)*(2*k+p-1.0)))
+    an_Js = lambda k: np.where(k==0,q/(p+1.0),(2.0*k*(k+p)+q*(p-1.0)) / ((2.0*k+p+1.0)*(2*k+p-1.0)))
 
     # could also use definition
     #  Gn(p,q,x) = constant_n * P^(p-q,q-1)_n(2x-1)
@@ -201,9 +206,10 @@
     (1-x)**(p-q) (x)**(q-1) with p>q-1 and q > 0.
     """
     if (n<0):
-        raise ValueError, "n must be nonnegative"
+        raise ValueError("n must be nonnegative")
     wfunc = lambda x: (1.0-x)**(p-q) * (x)**(q-1.)
-    if n==0: return orthopoly1d([],[],1.0,1.0,wfunc,(-1,1),monic)
+    if n==0:
+        return orthopoly1d([],[],1.0,1.0,wfunc,(-1,1),monic)
     n1 = n
     x,w,mu0 = js_roots(n1,p,q,mu=1)
     hn = _gam(n+1)*_gam(n+q)*_gam(n+p)*_gam(n+p-q+1)
@@ -222,7 +228,7 @@
     [0,inf] with weighting function exp(-x) x**alpha with alpha > -1.
     """
     if not all(alpha > -1):
-        raise ValueError, "alpha > -1"
+        raise ValueError("alpha > -1")
     assert(n>0), "n must be positive."
     (p,q) = (alpha,0.0)
     sbn_La = lambda k: -sqrt(k*(k + p))  # from recurrence relation
@@ -240,7 +246,7 @@
     exp(-x) x**alpha with alpha > -1
     """
     if any(alpha <= -1):
-        raise ValueError, "alpha must be > -1"
+        raise ValueError("alpha must be > -1")
     assert(n>=0), "n must be nonnegative"
     if n==0: n1 = n+1
     else: n1 = n
@@ -378,7 +384,7 @@
     """
     assert(n>0), "n must be positive."
     # from recurrence relation
-    sbn_J = lambda k: where(k==1,sqrt(2)/2.0,0.5)
+    sbn_J = lambda k: np.where(k==1,sqrt(2)/2.0,0.5)
     an_J = lambda k: 0.0*k
     g = cephes.gamma
     mu0 = pi
@@ -394,7 +400,8 @@
     """
     assert(n>=0), "n must be nonnegative"
     wfunc = lambda x: 1.0/sqrt(1-x*x)
-    if n==0: return orthopoly1d([],[],pi,1.0,wfunc,(-1,1),monic)
+    if n==0:
+        return orthopoly1d([],[],pi,1.0,wfunc,(-1,1),monic)
     n1 = n
     x,w,mu = t_roots(n1,mu=1)
     hn = pi/2




More information about the Scipy-svn mailing list