[Scipy-svn] r6271 - trunk/scipy/sparse/linalg/eigen/arpack
scipy-svn at scipy.org
scipy-svn at scipy.org
Fri Mar 26 01:35:16 EDT 2010
Author: cdavid
Date: 2010-03-26 00:35:16 -0500 (Fri, 26 Mar 2010)
New Revision: 6271
Modified:
trunk/scipy/sparse/linalg/eigen/arpack/arpack.py
Log:
REF: split symmetric/unsymmetric cases in two classes.
Modified: trunk/scipy/sparse/linalg/eigen/arpack/arpack.py
===================================================================
--- trunk/scipy/sparse/linalg/eigen/arpack/arpack.py 2010-03-26 05:35:07 UTC (rev 6270)
+++ trunk/scipy/sparse/linalg/eigen/arpack/arpack.py 2010-03-26 05:35:16 UTC (rev 6271)
@@ -52,7 +52,7 @@
_ndigits = {'f':5, 'd':12, 'F':5, 'D':12}
class _ArpackParams(object):
- def __init__(self, n, k, tp, matvec, mode="symmetric", sigma=None,
+ def __init__(self, n, k, tp, matvec, sigma=None,
ncv=None, v0=None, maxiter=None, which="LM", tol=0):
if k <= 0:
raise ValueError("k must be positive, k=%d" % k)
@@ -84,39 +84,7 @@
if ncv > n or ncv < k:
raise ValueError("ncv must be k<=ncv<=n, ncv=%s" % ncv)
- ltr = _type_conv[tp]
-
self.v = np.zeros((n, ncv), tp) # holds Ritz vectors
- self.rwork = None # Only used for unsymmetric, complex solver
-
- if mode == "unsymmetric":
- if not which in ["LM", "SM", "LR", "SR", "LI", "SI"]:
- raise ValueError("Parameter which must be one of %s" % ' '.join(whiches))
-
- self.workd = np.zeros(3 * n, tp)
- self.workl = np.zeros(3 * ncv * ncv + 6 * ncv, tp)
- self._arpack_solver = _arpack.__dict__[ltr + 'naupd']
- self.iterate = self._unsymmetric_solver
- self.extract = _arpack.__dict__[ltr + 'neupd']
-
- if tp in 'FD':
- self.rwork = np.zeros(ncv, tp.lower())
-
- self.ipntr = np.zeros(14, "int")
- elif mode == "symmetric":
- if not which in ['LM','SM','LA','SA','BE']:
- raise ValueError("which must be one of %s" % ' '.join(whiches))
-
- self.workd = np.zeros(3 * n, tp)
- self.workl = np.zeros(ncv * (ncv + 8), tp)
- self._arpack_solver = _arpack.__dict__[ltr + 'saupd']
- self.iterate = self._symmetric_solver
- self.extract = _arpack.__dict__[ltr + 'seupd']
-
- self.ipntr = np.zeros(11, "int")
- else:
- raise ValueError("Unrecognized mode %s" % mode)
-
self.iparam = np.zeros(11, "int")
# set solver mode and parameters
@@ -128,7 +96,6 @@
self.iparam[6] = mode1
self.n = n
- self.mode = mode
self.matvec = matvec
self.tol = tol
self.k = k
@@ -142,18 +109,30 @@
self.converged = False
self.ido = 0
- def _unsymmetric_solver(self):
- if self.tp in 'fd':
- self.ido, self.resid, self.v, self.iparam, self.ipntr, self.info = \
- self._arpack_solver(self.ido, self.bmat, self.which, self.k, self.tol,
- self.resid, self.v, self.iparam, self.ipntr,
- self.workd, self.workl, self.info)
- else:
- self.ido, self.resid, self.v, self.iparam, self.ipntr, self.info =\
- self._arpack_solver(self.ido, self.bmat, self.which, self.k, self.tol,
- self.resid, self.v, self.iparam, self.ipntr,
- self.workd, self.workl, self.rwork, self.info)
+class _SymmetricArpackParams(_ArpackParams):
+ def __init__(self, n, k, tp, matvec, sigma=None,
+ ncv=None, v0=None, maxiter=None, which="LM", tol=0):
+ if not which in ['LM', 'SM', 'LA', 'SA', 'BE']:
+ raise ValueError("which must be one of %s" % ' '.join(whiches))
+ _ArpackParams.__init__(self, n, k, tp, matvec, sigma,
+ ncv, v0, maxiter, which, tol)
+
+ self.workd = np.zeros(3 * n, self.tp)
+ self.workl = np.zeros(self.ncv * (self.ncv + 8), self.tp)
+
+ ltr = _type_conv[self.tp]
+ self._arpack_solver = _arpack.__dict__[ltr + 'saupd']
+ self.extract = _arpack.__dict__[ltr + 'seupd']
+
+ self.ipntr = np.zeros(11, "int")
+
+ def iterate(self):
+ self.ido, self.resid, self.v, self.iparam, self.ipntr, self.info = \
+ self._arpack_solver(self.ido, self.bmat, self.which, self.k, self.tol,
+ self.resid, self.v, self.iparam, self.ipntr,
+ self.workd, self.workl, self.info)
+
xslice = slice(self.ipntr[0]-1, self.ipntr[0]-1+self.n)
yslice = slice(self.ipntr[1]-1, self.ipntr[1]-1+self.n)
if self.ido == -1:
@@ -170,12 +149,44 @@
elif self.info == -1:
warnings.warn("Maximum number of iterations taken: %s" % self.iparam[2])
- def _symmetric_solver(self):
- self.ido, self.resid, self.v, self.iparam, self.ipntr, self.info = \
- self._arpack_solver(self.ido, self.bmat, self.which, self.k, self.tol,
- self.resid, self.v, self.iparam, self.ipntr,
- self.workd, self.workl, self.info)
+ if self.iparam[4] < self.k:
+ warnings.warn("Only %d/%d eigenvectors converged" % (self.iparam[4], self.k))
+class _UnsymmetricArpackParams(_ArpackParams):
+ def __init__(self, n, k, tp, matvec, sigma=None,
+ ncv=None, v0=None, maxiter=None, which="LM", tol=0):
+ if not which in ["LM", "SM", "LR", "SR", "LI", "SI"]:
+ raise ValueError("Parameter which must be one of %s" % ' '.join(whiches))
+
+ _ArpackParams.__init__(self, n, k, tp, matvec, sigma,
+ ncv, v0, maxiter, which, tol)
+
+ self.workd = np.zeros(3 * n, self.tp)
+ self.workl = np.zeros(3 * self.ncv * self.ncv + 6 * self.ncv, self.tp)
+
+ ltr = _type_conv[self.tp]
+ self._arpack_solver = _arpack.__dict__[ltr + 'naupd']
+ self.extract = _arpack.__dict__[ltr + 'neupd']
+
+ self.ipntr = np.zeros(14, "int")
+
+ if self.tp in 'FD':
+ self.rwork = np.zeros(self.ncv, self.tp.lower())
+ else:
+ self.rwork = None
+
+ def iterate(self):
+ if self.tp in 'fd':
+ self.ido, self.resid, self.v, self.iparam, self.ipntr, self.info = \
+ self._arpack_solver(self.ido, self.bmat, self.which, self.k, self.tol,
+ self.resid, self.v, self.iparam, self.ipntr,
+ self.workd, self.workl, self.info)
+ else:
+ self.ido, self.resid, self.v, self.iparam, self.ipntr, self.info =\
+ self._arpack_solver(self.ido, self.bmat, self.which, self.k, self.tol,
+ self.resid, self.v, self.iparam, self.ipntr,
+ self.workd, self.workl, self.rwork, self.info)
+
xslice = slice(self.ipntr[0]-1, self.ipntr[0]-1+self.n)
yslice = slice(self.ipntr[1]-1, self.ipntr[1]-1+self.n)
if self.ido == -1:
@@ -192,9 +203,6 @@
elif self.info == -1:
warnings.warn("Maximum number of iterations taken: %s" % self.iparam[2])
- if self.iparam[4] < self.k:
- warnings.warn("Only %d/%d eigenvectors converged" % (self.iparam[4], self.k))
-
def eigen(A, k=6, M=None, sigma=None, which='LM', v0=None,
ncv=None, maxiter=None, tol=0,
return_eigenvectors=True):
@@ -277,7 +285,7 @@
n = A.shape[0]
matvec = lambda x : A.matvec(x)
- params = _ArpackParams(n, k, A.dtype.char, matvec, "unsymmetric", sigma,
+ params = _UnsymmetricArpackParams(n, k, A.dtype.char, matvec, sigma,
ncv, v0, maxiter, which, tol)
if M is not None:
@@ -453,7 +461,7 @@
raise NotImplementedError("generalized eigenproblem not supported yet")
matvec = lambda x : A.matvec(x)
- params = _ArpackParams(n, k, A.dtype.char, matvec, "symmetric", sigma,
+ params = _SymmetricArpackParams(n, k, A.dtype.char, matvec, sigma,
ncv, v0, maxiter, which, tol)
while not params.converged:
More information about the Scipy-svn
mailing list