[Scipy-svn] r2230 - trunk/Lib/special
scipy-svn at scipy.org
scipy-svn at scipy.org
Sun Sep 24 04:24:25 EDT 2006
Author: rkern
Date: 2006-09-24 03:24:24 -0500 (Sun, 24 Sep 2006)
New Revision: 2230
Modified:
trunk/Lib/special/orthogonal.py
Log:
Use modern numpy idioms.
Modified: trunk/Lib/special/orthogonal.py
===================================================================
--- trunk/Lib/special/orthogonal.py 2006-09-24 08:08:41 UTC (rev 2229)
+++ trunk/Lib/special/orthogonal.py 2006-09-24 08:24:24 UTC (rev 2230)
@@ -1,5 +1,3 @@
-## Automatically adapted for scipy Oct 05, 2005 by convertcode.py
-
#!/usr/bin/env python
#
# Author: Travis Oliphant 2000
@@ -60,21 +58,25 @@
"""
from __future__ import nested_scopes
-from numpy import *
-from numpy.oldnumeric import take
+
+# Scipy imports.
+import numpy as np
+from numpy import all, any, exp, inf, pi, sqrt
+from scipy.linalg import eig
+
+# Local imports.
import _cephes as cephes
_gam = cephes.gamma
-from scipy.linalg import eig
def poch(z,m):
"""Pochhammer symbol (z)_m = (z)(z+1)....(z+m-1) = gamma(z+m)/gamma(z)"""
return _gam(z+m) / _gam(z)
-class orthopoly1d(poly1d):
+class orthopoly1d(np.poly1d):
def __init__(self, roots, weights=None, hn=1.0, kn=1.0, wfunc=None, limits=None, monic=0):
- poly1d.__init__(self, roots, r=1)
+ np.poly1d.__init__(self, roots, r=1)
equiv_weights = [weights[k] / wfunc(roots[k]) for k in range(len(roots))]
- self.__dict__['weights'] = array(zip(roots,weights,equiv_weights))
+ self.__dict__['weights'] = np.array(zip(roots,weights,equiv_weights))
self.__dict__['weight_func'] = wfunc
self.__dict__['limits'] = limits
mu = sqrt(hn)
@@ -88,7 +90,7 @@
def gen_roots_and_weights(n,an_func,sqrt_bn_func,mu):
"""[x,w] = gen_roots_and_weights(n,an_func,sqrt_bn_func,mu)
- Returns the roots (x) of an nth order orthogonal polynomail,
+ Returns the roots (x) of an nth order orthogonal polynomial,
and weights (w) to use in appropriate Gaussian quadrature with that
orthogonal polynomial.
@@ -99,14 +101,16 @@
sqrt_bn_func(n) should return sqrt(B_n)
mu ( = h_0 ) is the integral of the weight over the orthogonal interval
"""
- nn = arange(1.0,n)
+ nn = np.arange(1.0,n)
sqrt_bn = sqrt_bn_func(nn)
- an = an_func(concatenate(([0],nn)))
- [x,v] = eig((diag(an)+diag(sqrt_bn,1)+diag(sqrt_bn,-1)))
+ an = an_func(np.concatenate(([0], nn)))
+ x, v = eig((np.diagflat(an) +
+ np.diagflat(sqrt_bn,1) +
+ np.diagflat(sqrt_bn,-1)))
answer = []
- sortind = argsort(real(x))
- answer.append(take(x,sortind,axis=0))
- answer.append(take(mu*v[0]**2,sortind,axis=0))
+ sortind = x.real.argsort()
+ answer.append(x[sortind])
+ answer.append((mu*v[0]**2)[sortind])
return answer
# Jacobi Polynomials 1 P^(alpha,beta)_n(x)
@@ -118,17 +122,17 @@
function (1-x)**alpha (1+x)**beta with alpha,beta > -1.
"""
if any(alpha <= -1) or any(beta <= -1):
- raise ValueError, "alpha and beta must be greater than -1."
+ raise ValueError("alpha and beta must be greater than -1.")
assert(n>0), "n must be positive."
(p,q) = (alpha,beta)
# from recurrence relations
sbn_J = lambda k: 2.0/(2.0*k+p+q)*sqrt((k+p)*(k+q)/(2*k+q+p+1)) * \
- (where(k==1,1.0,sqrt(k*(k+p+q)/(2.0*k+p+q-1))))
+ (np.where(k==1,1.0,sqrt(k*(k+p+q)/(2.0*k+p+q-1))))
if any(p == q): # XXX any or all???
an_J = lambda k: 0.0*k
else:
- an_J = lambda k: where(k==0,(q-p)/(p+q+2.0),
+ an_J = lambda k: np.where(k==0,(q-p)/(p+q+2.0),
(q*q - p*p)/((2.0*k+p+q)*(2.0*k+p+q+2)))
g = cephes.gamma
mu0 = 2.0**(p+q+1)*g(p+1)*g(q+1)/(g(p+q+2))
@@ -145,7 +149,8 @@
"""
assert(n>=0), "n must be nonnegative"
wfunc = lambda x: (1-x)**alpha * (1+x)**beta
- if n==0: return orthopoly1d([],[],1.0,1.0,wfunc,(-1,1),monic)
+ if n==0:
+ return orthopoly1d([],[],1.0,1.0,wfunc,(-1,1),monic)
x,w,mu = j_roots(n,alpha,beta,mu=1)
ab1 = alpha+beta+1.0
hn = 2**ab1/(2*n+ab1)*_gam(n+alpha+1)
@@ -164,17 +169,17 @@
function (1-x)**(p-q) x**(q-1) with p-q > -1 and q > 0.
"""
# from recurrence relation
- if not ( any( (p1 - q1) > -1 ) and any( q1 > 0 ) ):
- raise ValueError, "(p - q) > -1 and q > 0 please."
+ if not ( any((p1 - q1) > -1) and any(q1 > 0) ):
+ raise ValueError("(p - q) > -1 and q > 0 please.")
if (n <= 0):
- raise ValueError, "n must be positive."
+ raise ValueError("n must be positive.")
p,q = p1,q1
- sbn_Js = lambda k: sqrt(where(k==1,q*(p-q+1.0)/(p+2.0), \
+ sbn_Js = lambda k: sqrt(np.where(k==1,q*(p-q+1.0)/(p+2.0), \
k*(k+q-1.0)*(k+p-1.0)*(k+p-q) \
/ ((2.0*k+p-2) * (2.0*k+p))))/(2*k+p-1.0)
- an_Js = lambda k: where(k==0,q/(p+1.0),(2.0*k*(k+p)+q*(p-1.0)) / ((2.0*k+p+1.0)*(2*k+p-1.0)))
+ an_Js = lambda k: np.where(k==0,q/(p+1.0),(2.0*k*(k+p)+q*(p-1.0)) / ((2.0*k+p+1.0)*(2*k+p-1.0)))
# could also use definition
# Gn(p,q,x) = constant_n * P^(p-q,q-1)_n(2x-1)
@@ -201,9 +206,10 @@
(1-x)**(p-q) (x)**(q-1) with p>q-1 and q > 0.
"""
if (n<0):
- raise ValueError, "n must be nonnegative"
+ raise ValueError("n must be nonnegative")
wfunc = lambda x: (1.0-x)**(p-q) * (x)**(q-1.)
- if n==0: return orthopoly1d([],[],1.0,1.0,wfunc,(-1,1),monic)
+ if n==0:
+ return orthopoly1d([],[],1.0,1.0,wfunc,(-1,1),monic)
n1 = n
x,w,mu0 = js_roots(n1,p,q,mu=1)
hn = _gam(n+1)*_gam(n+q)*_gam(n+p)*_gam(n+p-q+1)
@@ -222,7 +228,7 @@
[0,inf] with weighting function exp(-x) x**alpha with alpha > -1.
"""
if not all(alpha > -1):
- raise ValueError, "alpha > -1"
+ raise ValueError("alpha > -1")
assert(n>0), "n must be positive."
(p,q) = (alpha,0.0)
sbn_La = lambda k: -sqrt(k*(k + p)) # from recurrence relation
@@ -240,7 +246,7 @@
exp(-x) x**alpha with alpha > -1
"""
if any(alpha <= -1):
- raise ValueError, "alpha must be > -1"
+ raise ValueError("alpha must be > -1")
assert(n>=0), "n must be nonnegative"
if n==0: n1 = n+1
else: n1 = n
@@ -378,7 +384,7 @@
"""
assert(n>0), "n must be positive."
# from recurrence relation
- sbn_J = lambda k: where(k==1,sqrt(2)/2.0,0.5)
+ sbn_J = lambda k: np.where(k==1,sqrt(2)/2.0,0.5)
an_J = lambda k: 0.0*k
g = cephes.gamma
mu0 = pi
@@ -394,7 +400,8 @@
"""
assert(n>=0), "n must be nonnegative"
wfunc = lambda x: 1.0/sqrt(1-x*x)
- if n==0: return orthopoly1d([],[],pi,1.0,wfunc,(-1,1),monic)
+ if n==0:
+ return orthopoly1d([],[],pi,1.0,wfunc,(-1,1),monic)
n1 = n
x,w,mu = t_roots(n1,mu=1)
hn = pi/2
More information about the Scipy-svn
mailing list