[Scipy-svn] r6271 - trunk/scipy/sparse/linalg/eigen/arpack

scipy-svn at scipy.org scipy-svn at scipy.org
Fri Mar 26 01:35:16 EDT 2010


Author: cdavid
Date: 2010-03-26 00:35:16 -0500 (Fri, 26 Mar 2010)
New Revision: 6271

Modified:
   trunk/scipy/sparse/linalg/eigen/arpack/arpack.py
Log:
REF: split symmetric/unsymmetric cases in two classes.

Modified: trunk/scipy/sparse/linalg/eigen/arpack/arpack.py
===================================================================
--- trunk/scipy/sparse/linalg/eigen/arpack/arpack.py	2010-03-26 05:35:07 UTC (rev 6270)
+++ trunk/scipy/sparse/linalg/eigen/arpack/arpack.py	2010-03-26 05:35:16 UTC (rev 6271)
@@ -52,7 +52,7 @@
 _ndigits = {'f':5, 'd':12, 'F':5, 'D':12}
 
 class _ArpackParams(object):
-    def __init__(self, n, k, tp, matvec, mode="symmetric", sigma=None,
+    def __init__(self, n, k, tp, matvec, sigma=None,
                  ncv=None, v0=None, maxiter=None, which="LM", tol=0):
         if k <= 0:
             raise ValueError("k must be positive, k=%d" % k)
@@ -84,39 +84,7 @@
         if ncv > n or ncv < k:
             raise ValueError("ncv must be k<=ncv<=n, ncv=%s" % ncv)
 
-        ltr = _type_conv[tp]
-
         self.v = np.zeros((n, ncv), tp) # holds Ritz vectors
-        self.rwork = None # Only used for unsymmetric, complex solver
-
-        if mode == "unsymmetric":
-            if not which in ["LM", "SM", "LR", "SR", "LI", "SI"]:
-                raise ValueError("Parameter which must be one of %s" % ' '.join(whiches))
-
-            self.workd = np.zeros(3 * n, tp)
-            self.workl = np.zeros(3 * ncv * ncv + 6 * ncv, tp)
-            self._arpack_solver = _arpack.__dict__[ltr + 'naupd']
-            self.iterate = self._unsymmetric_solver
-            self.extract = _arpack.__dict__[ltr + 'neupd']
-
-            if tp in 'FD':
-                self.rwork = np.zeros(ncv, tp.lower())
-
-            self.ipntr = np.zeros(14, "int")
-        elif mode == "symmetric":
-            if not which in ['LM','SM','LA','SA','BE']:
-                raise ValueError("which must be one of %s" % ' '.join(whiches))
-
-            self.workd = np.zeros(3 * n, tp)
-            self.workl = np.zeros(ncv * (ncv + 8), tp)
-            self._arpack_solver = _arpack.__dict__[ltr + 'saupd']
-            self.iterate = self._symmetric_solver
-            self.extract = _arpack.__dict__[ltr + 'seupd']
-
-            self.ipntr = np.zeros(11, "int")
-        else:
-            raise ValueError("Unrecognized mode %s" % mode)
-
         self.iparam = np.zeros(11, "int")
 
         # set solver mode and parameters
@@ -128,7 +96,6 @@
         self.iparam[6] = mode1
 
         self.n = n
-        self.mode = mode
         self.matvec = matvec
         self.tol = tol
         self.k = k
@@ -142,18 +109,30 @@
         self.converged = False
         self.ido = 0
 
-    def _unsymmetric_solver(self):
-        if self.tp in 'fd':
-            self.ido, self.resid, self.v, self.iparam, self.ipntr, self.info = \
-                self._arpack_solver(self.ido, self.bmat, self.which, self.k, self.tol,
-                        self.resid, self.v, self.iparam, self.ipntr,
-                        self.workd, self.workl, self.info)
-        else:
-            self.ido, self.resid, self.v, self.iparam, self.ipntr, self.info =\
-                self._arpack_solver(self.ido, self.bmat, self.which, self.k, self.tol,
-                        self.resid, self.v, self.iparam, self.ipntr,
-                        self.workd, self.workl, self.rwork, self.info)
+class _SymmetricArpackParams(_ArpackParams):
+    def __init__(self, n, k, tp, matvec, sigma=None,
+                 ncv=None, v0=None, maxiter=None, which="LM", tol=0):
+        if not which in ['LM', 'SM', 'LA', 'SA', 'BE']:
+            raise ValueError("which must be one of %s" % ' '.join(whiches))
 
+        _ArpackParams.__init__(self, n, k, tp, matvec, sigma,
+                 ncv, v0, maxiter, which, tol)
+
+        self.workd = np.zeros(3 * n, self.tp)
+        self.workl = np.zeros(self.ncv * (self.ncv + 8), self.tp)
+
+        ltr = _type_conv[self.tp]
+        self._arpack_solver = _arpack.__dict__[ltr + 'saupd']
+        self.extract = _arpack.__dict__[ltr + 'seupd']
+
+        self.ipntr = np.zeros(11, "int")
+
+    def iterate(self):
+        self.ido, self.resid, self.v, self.iparam, self.ipntr, self.info = \
+            self._arpack_solver(self.ido, self.bmat, self.which, self.k, self.tol,
+                    self.resid, self.v, self.iparam, self.ipntr,
+                    self.workd, self.workl, self.info)
+
         xslice = slice(self.ipntr[0]-1, self.ipntr[0]-1+self.n)
         yslice = slice(self.ipntr[1]-1, self.ipntr[1]-1+self.n)
         if self.ido == -1:
@@ -170,12 +149,44 @@
             elif self.info == -1:
                 warnings.warn("Maximum number of iterations taken: %s" % self.iparam[2])
 
-    def _symmetric_solver(self):
-        self.ido, self.resid, self.v, self.iparam, self.ipntr, self.info = \
-            self._arpack_solver(self.ido, self.bmat, self.which, self.k, self.tol,
-                    self.resid, self.v, self.iparam, self.ipntr,
-                    self.workd, self.workl, self.info)
+            if self.iparam[4] < self.k:
+                warnings.warn("Only %d/%d eigenvectors converged" % (self.iparam[4], self.k))
 
+class _UnsymmetricArpackParams(_ArpackParams):
+    def __init__(self, n, k, tp, matvec, sigma=None,
+                 ncv=None, v0=None, maxiter=None, which="LM", tol=0):
+        if not which in ["LM", "SM", "LR", "SR", "LI", "SI"]:
+            raise ValueError("Parameter which must be one of %s" % ' '.join(whiches))
+
+        _ArpackParams.__init__(self, n, k, tp, matvec, sigma,
+                 ncv, v0, maxiter, which, tol)
+
+        self.workd = np.zeros(3 * n, self.tp)
+        self.workl = np.zeros(3 * self.ncv * self.ncv + 6 * self.ncv, self.tp)
+
+        ltr = _type_conv[self.tp]
+        self._arpack_solver = _arpack.__dict__[ltr + 'naupd']
+        self.extract = _arpack.__dict__[ltr + 'neupd']
+
+        self.ipntr = np.zeros(14, "int")
+
+        if self.tp in 'FD':
+            self.rwork = np.zeros(self.ncv, self.tp.lower())
+        else:
+            self.rwork = None
+
+    def iterate(self):
+        if self.tp in 'fd':
+            self.ido, self.resid, self.v, self.iparam, self.ipntr, self.info = \
+                self._arpack_solver(self.ido, self.bmat, self.which, self.k, self.tol,
+                        self.resid, self.v, self.iparam, self.ipntr,
+                        self.workd, self.workl, self.info)
+        else:
+            self.ido, self.resid, self.v, self.iparam, self.ipntr, self.info =\
+                self._arpack_solver(self.ido, self.bmat, self.which, self.k, self.tol,
+                        self.resid, self.v, self.iparam, self.ipntr,
+                        self.workd, self.workl, self.rwork, self.info)
+
         xslice = slice(self.ipntr[0]-1, self.ipntr[0]-1+self.n)
         yslice = slice(self.ipntr[1]-1, self.ipntr[1]-1+self.n)
         if self.ido == -1:
@@ -192,9 +203,6 @@
             elif self.info == -1:
                 warnings.warn("Maximum number of iterations taken: %s" % self.iparam[2])
 
-            if self.iparam[4] < self.k:
-                warnings.warn("Only %d/%d eigenvectors converged" % (self.iparam[4], self.k))
-
 def eigen(A, k=6, M=None, sigma=None, which='LM', v0=None,
           ncv=None, maxiter=None, tol=0,
           return_eigenvectors=True):
@@ -277,7 +285,7 @@
     n = A.shape[0]
 
     matvec = lambda x : A.matvec(x)
-    params = _ArpackParams(n, k, A.dtype.char, matvec, "unsymmetric", sigma,
+    params = _UnsymmetricArpackParams(n, k, A.dtype.char, matvec, sigma,
                            ncv, v0, maxiter, which, tol)
 
     if M is not None:
@@ -453,7 +461,7 @@
         raise NotImplementedError("generalized eigenproblem not supported yet")
 
     matvec = lambda x : A.matvec(x)
-    params = _ArpackParams(n, k, A.dtype.char, matvec, "symmetric", sigma,
+    params = _SymmetricArpackParams(n, k, A.dtype.char, matvec, sigma,
                            ncv, v0, maxiter, which, tol)
 
     while not params.converged:




More information about the Scipy-svn mailing list