[Scipy-svn] r3888 - in trunk/scipy/splinalg/isolve: . tests

scipy-svn at scipy.org scipy-svn at scipy.org
Fri Feb 1 18:29:41 EST 2008


Author: wnbell
Date: 2008-02-01 17:29:37 -0600 (Fri, 01 Feb 2008)
New Revision: 3888

Added:
   trunk/scipy/splinalg/isolve/utils.py
Modified:
   trunk/scipy/splinalg/isolve/__init__.py
   trunk/scipy/splinalg/isolve/minres.py
   trunk/scipy/splinalg/isolve/tests/test_iterative.py
Log:
updated MINRES code
abstracted iterative solver setup code


Modified: trunk/scipy/splinalg/isolve/__init__.py
===================================================================
--- trunk/scipy/splinalg/isolve/__init__.py	2008-02-01 19:49:30 UTC (rev 3887)
+++ trunk/scipy/splinalg/isolve/__init__.py	2008-02-01 23:29:37 UTC (rev 3888)
@@ -2,6 +2,7 @@
 
 #from info import __doc__
 from iterative import *
+from minres import minres
 
 __all__ = filter(lambda s:not s.startswith('_'),dir())
 from scipy.testing.pkgtester import Tester

Modified: trunk/scipy/splinalg/isolve/minres.py
===================================================================
--- trunk/scipy/splinalg/isolve/minres.py	2008-02-01 19:49:30 UTC (rev 3887)
+++ trunk/scipy/splinalg/isolve/minres.py	2008-02-01 23:29:37 UTC (rev 3888)
@@ -1,11 +1,10 @@
-from numpy import sqrt, inner, finfo, asarray, zeros
+from numpy import ndarray, matrix, sqrt, inner, finfo, asarray, zeros
 from numpy.linalg import norm
 
-def psolve(x): return x
-def check_sizes(A,x,b): pass
+from utils import make_system
 
 def minres(A, b, x0=None, shift=0.0, tol=1e-5, maxiter=None, xtype=None,
-        precond=None, callback=None, show=False, check=False):
+           M=None, callback=None, show=False, check=True):
     """Use the Minimum Residual Method (MINRES) to solve Ax=b 
     
     MINRES minimizes norm(A*x - b) for the symmetric matrix A.  Unlike
@@ -30,23 +29,19 @@
             http://www.stanford.edu/group/SOL/software/minres/matlab/
 
     """ 
+    A,M,x,b,postprocess = make_system(A,M,x0,b,xtype)
 
-    show  = True  #TODO remove
-    check = True  #TODO remove
+    matvec = A.matvec
+    psolve = M.matvec
 
     first = 'Enter minres.   '
     last  = 'Exit  minres.   '
 
-    assert(A.shape[0] == A.shape[1])
-    assert(A.shape[1] == len(b))
-
-    b = asarray(b).ravel()
     n = A.shape[0]
 
     if maxiter is None:
         maxiter = 5 * n
 
-    matvec = A.matvec
 
     msg   =[' beta2 = 0.  If M = I, b and x are eigenvectors    ',   # -1
             ' beta1 = 0.  The exact solution is  x = 0          ',   #  0
@@ -56,9 +51,9 @@
             ' x has converged to an eigenvector                 ',   #  4
             ' acond has exceeded 0.1/eps                        ',   #  5
             ' The iteration limit was reached                   ',   #  6
-            ' Aname  does not define a symmetric matrix         ',   #  7
-            ' Mname  does not define a symmetric matrix         ',   #  8
-            ' Mname  does not define a pos-def preconditioner   ']   #  9
+            ' A  does not define a symmetric matrix             ',   #  7
+            ' M  does not define a symmetric matrix             ',   #  8
+            ' M  does not define a pos-def preconditioner       ']   #  9
 
      
     if show:
@@ -90,7 +85,7 @@
     if beta1 < 0:
         raise ValueError('indefinite preconditioner')
     elif beta1 == 0:
-        return x
+        return (postprocess(x), 0)
     
     beta1 = sqrt( beta1 )
 
@@ -262,7 +257,7 @@
         if callback is not None:
             callback(x)
 
-        if istop > 0: break
+        if istop != 0: break #TODO check this
         
 
     if show:
@@ -273,7 +268,7 @@
         print last + ' Arnorm  =  %12.4e'                       %  (Arnorm,)
         print last + msg[istop+1]
 
-    return x
+    return (postprocess(x),0)
 
 
 if __name__ == '__main__':
@@ -283,7 +278,7 @@
     from scipy.splinalg import cg
     #from scipy.sandbox.multigrid import *
 
-    n = 100
+    n = 10
 
     residuals = []
 
@@ -292,7 +287,9 @@
 
     #A = poisson((10,),format='csr')
     A = spdiags( [arange(1,n+1,dtype=float)], [0], n, n, format='csr')
-    b = ones( A.shape[0] )
+    M = spdiags( [1.0/arange(1,n+1,dtype=float)], [0], n, n, format='csr')
+    A.psolve = M.matvec
+    b = 0*ones( A.shape[0] )
     x = minres(A,b,tol=1e-12,maxiter=None,callback=cb)
     #x = cg(A,b,x0=b,tol=1e-12,maxiter=None,callback=cb)[0]
 

Modified: trunk/scipy/splinalg/isolve/tests/test_iterative.py
===================================================================
--- trunk/scipy/splinalg/isolve/tests/test_iterative.py	2008-02-01 19:49:30 UTC (rev 3887)
+++ trunk/scipy/splinalg/isolve/tests/test_iterative.py	2008-02-01 23:29:37 UTC (rev 3888)
@@ -9,7 +9,7 @@
 from scipy.linalg import norm
 from scipy.sparse import spdiags
 
-from scipy.splinalg.isolve import cg, cgs, bicg, bicgstab, gmres, qmr
+from scipy.splinalg.isolve import cg, cgs, bicg, bicgstab, gmres, qmr, minres
 
 #def callback(x):
 #    global A, b
@@ -36,7 +36,7 @@
         self.solvers.append( (bicgstab, False, False) )
         self.solvers.append( (gmres,    False, False) )
         self.solvers.append( (qmr,      False, False) )
-        #self.solvers.append( (minres,   True,  False) )
+        self.solvers.append( (minres,   True,  False) )
         
         # list of tuples (A, symmetric, positive_definite )
         self.cases = []
@@ -91,7 +91,7 @@
                 if req_pos and not pos: continue
 
                 M,N = A.shape
-                D = spdiags( [1.0/A.diagonal()], [0], M, N)
+                D = spdiags( [abs(1.0/A.diagonal())], [0], M, N)
                 def precond(b,which=None):
                     return D*b
 

Added: trunk/scipy/splinalg/isolve/utils.py
===================================================================
--- trunk/scipy/splinalg/isolve/utils.py	2008-02-01 19:49:30 UTC (rev 3887)
+++ trunk/scipy/splinalg/isolve/utils.py	2008-02-01 23:29:37 UTC (rev 3888)
@@ -0,0 +1,79 @@
+from numpy import asanyarray, asmatrix, array, matrix, zeros
+
+from scipy.splinalg.interface import aslinearoperator, LinearOperator
+
+_coerce_rules = {('f','f'):'f', ('f','d'):'d', ('f','F'):'F',
+                 ('f','D'):'D', ('d','f'):'d', ('d','d'):'d',
+                 ('d','F'):'D', ('d','D'):'D', ('F','f'):'F',
+                 ('F','d'):'D', ('F','F'):'F', ('F','D'):'D',
+                 ('D','f'):'D', ('D','d'):'D', ('D','F'):'D',
+                 ('D','D'):'D'}
+
+def coerce(x,y):
+    if x not in 'fdFD':
+        x = 'd'
+    if y not in 'fdFD':
+        y = 'd'
+    return _coerce_rules[x,y]
+
+def id(x):
+    return x
+
+def make_system(A, M, x0, b, xtype=None):
+    A_ = A
+    A = aslinearoperator(A)
+
+    if A.shape[0] != A.shape[1]:
+        raise ValueError('expected square matrix (shape=%s)' % shape)
+
+    N = A.shape[0]
+    
+    b = asanyarray(b)
+
+    if not (b.shape == (N,1) or b.shape == (N,)):
+        raise ValueError('A and b have incompatible dimensions')
+
+    def postprocess(x):
+        if isinstance(b,matrix):
+            x = asmatrix(x)
+        return x.reshape(b.shape)
+
+
+    if xtype is None:
+        if hasattr(A,'dtype'):
+            xtype = A.dtype.char
+        else:
+            xtype = A.matvec(b).dtype.char
+        xtype = coerce(xtype, b.dtype.char)
+    elif xtype == 0:
+        xtype = b.dtype.char
+    else:
+        if xtype not in 'fdFD':
+            raise ValueError, "xtype must be 'f', 'd', 'F', or 'D'"
+
+    if x0 is None:
+        x = zeros(N, dtype=xtype)
+    else:
+        x = array(x0, dtype=xtype)
+        if not (x.shape == (N,1) or x.shape == (N,)):
+            raise ValueError('A and x have incompatible dimensions')
+        x = x.ravel()
+
+    # process preconditioner
+    if M is None:
+        if hasattr(A_,'psolve'):
+            psolve = A_.psolve
+        else:
+            psolve = id
+        if hasattr(A_,'rpsolve'):
+            rpsolve = A_.rpsolve
+        else:
+            rpsolve = id
+        M = LinearOperator(A.shape, matvec=psolve, rmatvec=rpsolve, dtype=A.dtype)
+    else:
+        if A.shape != M.shape:
+            raise ValueError('matrix and preconditioner have different shapes')
+
+    return A, M, x, b, postprocess
+
+




More information about the Scipy-svn mailing list