[Scipy-svn] r6270 - trunk/scipy/sparse/linalg/eigen/arpack

scipy-svn at scipy.org scipy-svn at scipy.org
Fri Mar 26 01:35:07 EDT 2010


Author: cdavid
Date: 2010-03-26 00:35:07 -0500 (Fri, 26 Mar 2010)
New Revision: 6270

Modified:
   trunk/scipy/sparse/linalg/eigen/arpack/arpack.py
Log:
REF: abstrace solver call for ARPACK.

Modified: trunk/scipy/sparse/linalg/eigen/arpack/arpack.py
===================================================================
--- trunk/scipy/sparse/linalg/eigen/arpack/arpack.py	2010-03-26 05:34:57 UTC (rev 6269)
+++ trunk/scipy/sparse/linalg/eigen/arpack/arpack.py	2010-03-26 05:35:07 UTC (rev 6270)
@@ -52,7 +52,7 @@
 _ndigits = {'f':5, 'd':12, 'F':5, 'D':12}
 
 class _ArpackParams(object):
-    def __init__(self, n, k, tp, mode="symmetric", sigma=None,
+    def __init__(self, n, k, tp, matvec, mode="symmetric", sigma=None,
                  ncv=None, v0=None, maxiter=None, which="LM", tol=0):
         if k <= 0:
             raise ValueError("k must be positive, k=%d" % k)
@@ -95,7 +95,8 @@
 
             self.workd = np.zeros(3 * n, tp)
             self.workl = np.zeros(3 * ncv * ncv + 6 * ncv, tp)
-            self.solver = _arpack.__dict__[ltr + 'naupd']
+            self._arpack_solver = _arpack.__dict__[ltr + 'naupd']
+            self.iterate = self._unsymmetric_solver
             self.extract = _arpack.__dict__[ltr + 'neupd']
 
             if tp in 'FD':
@@ -108,7 +109,8 @@
 
             self.workd = np.zeros(3 * n, tp)
             self.workl = np.zeros(ncv * (ncv + 8), tp)
-            self.solver = _arpack.__dict__[ltr + 'saupd']
+            self._arpack_solver = _arpack.__dict__[ltr + 'saupd']
+            self.iterate = self._symmetric_solver
             self.extract = _arpack.__dict__[ltr + 'seupd']
 
             self.ipntr = np.zeros(11, "int")
@@ -127,6 +129,7 @@
 
         self.n = n
         self.mode = mode
+        self.matvec = matvec
         self.tol = tol
         self.k = k
         self.maxiter = maxiter
@@ -136,6 +139,62 @@
         self.info = info
         self.bmat = 'I'
 
+        self.converged = False
+        self.ido = 0
+
+    def _unsymmetric_solver(self):
+        if self.tp in 'fd':
+            self.ido, self.resid, self.v, self.iparam, self.ipntr, self.info = \
+                self._arpack_solver(self.ido, self.bmat, self.which, self.k, self.tol,
+                        self.resid, self.v, self.iparam, self.ipntr,
+                        self.workd, self.workl, self.info)
+        else:
+            self.ido, self.resid, self.v, self.iparam, self.ipntr, self.info =\
+                self._arpack_solver(self.ido, self.bmat, self.which, self.k, self.tol,
+                        self.resid, self.v, self.iparam, self.ipntr,
+                        self.workd, self.workl, self.rwork, self.info)
+
+        xslice = slice(self.ipntr[0]-1, self.ipntr[0]-1+self.n)
+        yslice = slice(self.ipntr[1]-1, self.ipntr[1]-1+self.n)
+        if self.ido == -1:
+            # initialization
+            self.workd[yslice] = self.matvec(self.workd[xslice])
+        elif self.ido == 1:
+            # compute y=Ax
+            self.workd[yslice] = self.matvec(self.workd[xslice])
+        else:
+            self.converged = True
+
+            if self.info < -1 :
+                raise RuntimeError("Error info=%d in arpack" % self.info)
+            elif self.info == -1:
+                warnings.warn("Maximum number of iterations taken: %s" % self.iparam[2])
+
+    def _symmetric_solver(self):
+        self.ido, self.resid, self.v, self.iparam, self.ipntr, self.info = \
+            self._arpack_solver(self.ido, self.bmat, self.which, self.k, self.tol,
+                    self.resid, self.v, self.iparam, self.ipntr,
+                    self.workd, self.workl, self.info)
+
+        xslice = slice(self.ipntr[0]-1, self.ipntr[0]-1+self.n)
+        yslice = slice(self.ipntr[1]-1, self.ipntr[1]-1+self.n)
+        if self.ido == -1:
+            # initialization
+            self.workd[yslice] = self.matvec(self.workd[xslice])
+        elif self.ido == 1:
+            # compute y=Ax
+            self.workd[yslice] = self.matvec(self.workd[xslice])
+        else:
+            self.converged = True
+
+            if self.info < -1 :
+                raise RuntimeError("Error info=%d in arpack" % self.info)
+            elif self.info == -1:
+                warnings.warn("Maximum number of iterations taken: %s" % self.iparam[2])
+
+            if self.iparam[4] < self.k:
+                warnings.warn("Only %d/%d eigenvectors converged" % (self.iparam[4], self.k))
+
 def eigen(A, k=6, M=None, sigma=None, which='LM', v0=None,
           ncv=None, maxiter=None, tol=0,
           return_eigenvectors=True):
@@ -217,42 +276,16 @@
         raise ValueError('expected square matrix (shape=%s)' % A.shape)
     n = A.shape[0]
 
-    params = _ArpackParams(n, k, A.dtype.char, "unsymmetric", sigma,
+    matvec = lambda x : A.matvec(x)
+    params = _ArpackParams(n, k, A.dtype.char, matvec, "unsymmetric", sigma,
                            ncv, v0, maxiter, which, tol)
 
     if M is not None:
         raise NotImplementedError("generalized eigenproblem not supported yet")
 
-    ido = 0
+    while not params.converged:
+        params.iterate()
 
-    while True:
-        if params.tp in 'fd':
-            ido, params.resid, params.v, params.iparam, params.ipntr, params.info = \
-                params.solver(ido, params.bmat, params.which, params.k, params.tol,
-                        params.resid, params.v, params.iparam, params.ipntr,
-                        params.workd, params.workl, params.info)
-        else:
-            ido, params.resid, params.v, params.iparam, params.ipntr, params.info =\
-                params.solver(ido, params.bmat, params.which, params.k, params.tol,
-                        params.resid, params.v, params.iparam, params.ipntr,
-                        params.workd, params.workl, params.rwork, params.info)
-
-        xslice = slice(params.ipntr[0]-1, params.ipntr[0]-1+n)
-        yslice = slice(params.ipntr[1]-1, params.ipntr[1]-1+n)
-        if ido == -1:
-            # initialization
-            params.workd[yslice] = A.matvec(params.workd[xslice])
-        elif ido == 1:
-            # compute y=Ax
-            params.workd[yslice] = A.matvec(params.workd[xslice])
-        else:
-            break
-
-    if params.info < -1 :
-        raise RuntimeError("Error info=%d in arpack" % params.info)
-    elif params.info == -1:
-        warnings.warn("Maximum number of iterations taken: %s" % self.iparam[2])
-
     # now extract eigenvalues and (optionally) eigenvectors
     rvec = return_eigenvectors
     ierr = 0
@@ -419,35 +452,13 @@
     if M is not None:
         raise NotImplementedError("generalized eigenproblem not supported yet")
 
-    params = _ArpackParams(n, k, A.dtype.char, "symmetric", sigma,
+    matvec = lambda x : A.matvec(x)
+    params = _ArpackParams(n, k, A.dtype.char, matvec, "symmetric", sigma,
                            ncv, v0, maxiter, which, tol)
 
-    ido = 0
-    while True:
-        ido, params.resid, params.v, params.iparam, params.ipntr, params.info = \
-            params.solver(ido, params.bmat, params.which, params.k, params.tol,
-                    params.resid, params.v, params.iparam, params.ipntr,
-                    params.workd, params.workl, params.info)
+    while not params.converged:
+        params.iterate()
 
-        xslice = slice(params.ipntr[0]-1, params.ipntr[0]-1+n)
-        yslice = slice(params.ipntr[1]-1, params.ipntr[1]-1+n)
-        if ido == -1:
-            # initialization
-            params.workd[yslice] = A.matvec(params.workd[xslice])
-        elif ido == 1:
-            # compute y=Ax
-            params.workd[yslice] = A.matvec(params.workd[xslice])
-        else:
-            break
-
-    if params.info < -1 :
-        raise RuntimeError("Error info=%d in arpack" % params.info)
-    elif params.info == 1:
-        warnings.warn("Maximum number of iterations taken: %s" % params.iparam[2])
-
-    if params.iparam[4] < k:
-        warnings.warn("Only %d/%d eigenvectors converged" % (params.iparam[4], k))
-
     # now extract eigenvalues and (optionally) eigenvectors
     rvec = return_eigenvectors
     ierr = 0




More information about the Scipy-svn mailing list