[Scipy-svn] r3890 - in trunk/scipy: sandbox/arpack/tests sparse splinalg splinalg/isolve splinalg/isolve/tests splinalg/tests

scipy-svn at scipy.org scipy-svn at scipy.org
Sat Feb 2 13:17:43 EST 2008


Author: wnbell
Date: 2008-02-02 12:17:25 -0600 (Sat, 02 Feb 2008)
New Revision: 3890

Modified:
   trunk/scipy/sandbox/arpack/tests/test_speigs.py
   trunk/scipy/sparse/compressed.py
   trunk/scipy/splinalg/interface.py
   trunk/scipy/splinalg/isolve/iterative.py
   trunk/scipy/splinalg/isolve/minres.py
   trunk/scipy/splinalg/isolve/tests/test_iterative.py
   trunk/scipy/splinalg/isolve/utils.py
   trunk/scipy/splinalg/tests/test_interface.py
Log:
iterative solvers now use LinearOperator
added M argument for preconditioners


Modified: trunk/scipy/sandbox/arpack/tests/test_speigs.py
===================================================================
--- trunk/scipy/sandbox/arpack/tests/test_speigs.py	2008-02-02 04:26:51 UTC (rev 3889)
+++ trunk/scipy/sandbox/arpack/tests/test_speigs.py	2008-02-02 18:17:25 UTC (rev 3890)
@@ -22,8 +22,8 @@
         vals = vals[uv_sortind]
         vecs = vecs[:,uv_sortind]
 
-        from scipy.splinalg.isolve.iterative import get_matvec
-        matvec = get_matvec(A)
+        from scipy.splinalg.interface import aslinearoperator
+        matvec = aslinearoperator(A).matvec
         #= lambda x: N.asarray(A*x)[0]
         nev=4
         eigvs = ARPACK_eigs(matvec, A.shape[0], nev=nev)

Modified: trunk/scipy/sparse/compressed.py
===================================================================
--- trunk/scipy/sparse/compressed.py	2008-02-02 04:26:51 UTC (rev 3889)
+++ trunk/scipy/sparse/compressed.py	2008-02-02 18:17:25 UTC (rev 3890)
@@ -365,10 +365,19 @@
             raise TypeError, "need a dense vector"
 
     def rmatvec(self, other, conjugate=True):
+        """Multiplies the vector 'other' by the sparse matrix, returning a
+        dense vector as a result.
+        
+        If 'conjugate' is True:
+            - returns A.transpose().conj() * other
+        Otherwise:
+            - returns A.transpose() * other.
+        
+        """
         if conjugate:
-            return transpose( self.transpose().conj().matvec(transpose(other)) )
+            return self.transpose().conj().matvec( other )
         else:
-            return transpose( self.transpose().matvec(transpose(other)) )
+            return self.transpose().matvec( other )
 
     def getdata(self, ind):
         return self.data[ind]
@@ -376,6 +385,7 @@
     def diagonal(self):
         """Returns the main diagonal of the matrix
         """
+        #TODO support k-th diagonal
         fn = getattr(sparsetools, self.format + "_diagonal")
         y = empty( min(self.shape), dtype=upcast(self.dtype) )
         fn(self.shape[0], self.shape[1], self.indptr, self.indices, self.data, y)

Modified: trunk/scipy/splinalg/interface.py
===================================================================
--- trunk/scipy/splinalg/interface.py	2008-02-02 04:26:51 UTC (rev 3889)
+++ trunk/scipy/splinalg/interface.py	2008-02-02 18:17:25 UTC (rev 3890)
@@ -1,4 +1,5 @@
-import numpy as np
+import numpy
+from numpy import matrix, ndarray, asarray, dot, atleast_2d
 from scipy.sparse.sputils import isshape
 from scipy.sparse import isspmatrix
 
@@ -7,7 +8,42 @@
 class LinearOperator:
     def __init__( self, shape, matvec, rmatvec=None, dtype=None ):
         """Common interface for performing matrix vector products
+
+        Many iterative methods (e.g. cg, gmres) do not need to know the
+        individual entries of a matrix to solve a linear system A*x=b. 
+        Such solvers only require the computation of matrix vector 
+        products, A*v where v is a dense vector.  This class serves as
+        an abstract interface between iterative solvers and matrix-like
+        objects.
+
+        Required Parameters:
+            shape     : tuple of matrix dimensions (M,N)
+            matvec(x) : function that returns A * x
+
+        Optional Parameters:
+            rmatvec(x) : function that returns A^H * x where A^H represents 
+                         the Hermitian (conjugate) transpose of A
+            dtype      : data type of the matrix
+                        
+
+        See Also:
+            aslinearoperator() : Construct LinearOperators for SciPy classes
+
+        Example:
+
+        >>> from scipy.splinalg import LinearOperator
+        >>> from scipy import *
+        >>> def mv(x):
+        ...     return array([ 2*x[0], 3*x[1]])
+        ... 
+        >>> A = LinearOperator( (2,2), matvec=mv )
+        >>> A
+        <2x2 LinearOperator with unspecified dtype>
+        >>> A.matvec( ones(2) )
+        array([ 2.,  3.])
+        
         """
+
         shape = tuple(shape)
 
         if not isshape(shape):
@@ -24,7 +60,7 @@
             self.rmatvec = rmatvec
 
         if dtype is not None:
-            self.dtype = np.dtype(dtype)
+            self.dtype = numpy.dtype(dtype)
 
     def __repr__(self):
         M,N = self.shape
@@ -60,11 +96,16 @@
     if isinstance(A, LinearOperator):
         return A
 
-    elif isinstance(A, np.ndarray) or isinstance(A,np.matrix):
+    elif isinstance(A, ndarray) or isinstance(A, matrix):
+        if len(A.shape) > 2:
+            raise ValueError('array must have rank <= 2')
+
+        A = atleast_2d(asarray(A))
+
         def matvec(x):
-            return np.dot(np.asarray(A),x)
+            return dot(A,x)
         def rmatvec(x):
-            return np.dot(x,np.asarray(A))
+            return dot(A.conj().transpose(),x)
         return LinearOperator( A.shape, matvec, rmatvec=rmatvec, dtype=A.dtype )
 
     elif isspmatrix(A):

Modified: trunk/scipy/splinalg/isolve/iterative.py
===================================================================
--- trunk/scipy/splinalg/isolve/iterative.py	2008-02-02 04:26:51 UTC (rev 3889)
+++ trunk/scipy/splinalg/isolve/iterative.py	2008-02-02 18:17:25 UTC (rev 3890)
@@ -15,102 +15,12 @@
 import numpy as sb
 import copy
 
-try:
-    False, True
-except NameError:
-    False, True = 0, 1
+from scipy.splinalg.interface import LinearOperator
+from utils import make_system
 
 _type_conv = {'f':'s', 'd':'d', 'F':'c', 'D':'z'}
 
-_coerce_rules = {('f','f'):'f', ('f','d'):'d', ('f','F'):'F',
-                 ('f','D'):'D', ('d','f'):'d', ('d','d'):'d',
-                 ('d','F'):'D', ('d','D'):'D', ('F','f'):'F',
-                 ('F','d'):'D', ('F','F'):'F', ('F','D'):'D',
-                 ('D','f'):'D', ('D','d'):'D', ('D','F'):'D',
-                 ('D','D'):'D'}
-
-class get_matvec:
-    methname = 'matvec'
-    def __init__(self, obj, *args):
-        self.obj = obj
-        self.args = args
-        if isinstance(obj, sb.matrix):
-            self.callfunc = self.type1m
-            return
-        if isinstance(obj, sb.ndarray):
-            self.callfunc = self.type1
-            return
-        meth = getattr(obj,self.methname,None)
-        if not callable(meth):
-            raise ValueError, "Object must be an array "\
-                  "or have a callable %s attribute." % (self.methname,)
-
-        self.obj = meth
-        self.callfunc = self.type2
-
-    def __call__(self, x):
-        return self.callfunc(x)
-
-    def type1(self, x):
-        return sb.dot(self.obj, x)
-
-    def type1m(self, x):
-        return sb.dot(self.obj.A, x)
-
-    def type2(self, x):
-        return self.obj(x,*self.args)
-
-class get_rmatvec(get_matvec):
-    methname = 'rmatvec'
-    def type1(self, x):
-        return sb.dot(x, self.obj)
-    def type1m(self, x):
-        return sb.dot(x, self.obj.A)
-
-class get_psolve:
-    methname = 'psolve'
-    def __init__(self, obj, *args):
-        self.obj = obj
-        self.args = args
-        meth = getattr(obj,self.methname,None)
-        if meth is None:  # no preconditiong available
-            self.callfunc = self.type1
-            return
-
-        if not callable(meth):
-            raise ValueError, "Preconditioning method %s "\
-                  "must be callable." % (self.methname,)
-
-        self.obj = meth
-        self.callfunc = self.type2
-
-    def __call__(self, x):
-        return self.callfunc(x)
-
-    def type1(self, x):
-        return x
-
-    def type2(self, x):
-        return self.obj(x,*self.args)
-
-class get_rpsolve(get_psolve):
-    methname = 'rpsolve'
-
-class get_psolveq(get_psolve):
-
-    def __call__(self, x, which):
-        return self.callfunc(x, which)
-
-    def type1(self, x, which):
-        return x
-
-    def type2(self, x, which):
-        return self.obj(x,which,*self.args)
-
-class get_rpsolveq(get_psolveq):
-    methname = 'rpsolve'
-
-def bicg(A, b, x0=None, tol=1e-5, maxiter=None, xtype=None, callback=None):
+def bicg(A, b, x0=None, tol=1e-5, maxiter=None, xtype=None, M=None, callback=None):
     """Use BIConjugate Gradient iteration to solve A x = b
 
     Inputs:
@@ -145,43 +55,22 @@
                 iteration.  It is called as callback(xk), where xk is the
                 current parameter vector.
     """
-    b = sb.asarray(b)+0.0
+    A,M,x,b,postprocess = make_system(A,M,x0,b,xtype)
+
     n = len(b)
     if maxiter is None:
         maxiter = n*10
 
-    if x0 is None:
-        x = sb.zeros(n)
-    else:
-        x = copy.copy(x0)
+    matvec, rmatvec = A.matvec, A.rmatvec
+    psolve, rpsolve = M.matvec, M.rmatvec
+    ltr = _type_conv[x.dtype.char]
+    revcom   = getattr(_iterative, ltr + 'bicgrevcom')
+    stoptest = getattr(_iterative, ltr + 'stoptest2')
 
-    if xtype is None:
-        try:
-            atyp = A.dtype.char
-        except AttributeError:
-            atyp = None
-        if atyp is None:
-            atyp = A.matvec(x).dtype.char
-        typ = _coerce_rules[b.dtype.char,atyp]
-    elif xtype == 0:
-        typ = b.dtype.char
-    else:
-        typ = xtype
-        if typ not in 'fdFD':
-            raise ValueError, "xtype must be 'f', 'd', 'F', or 'D'"
-
-    x = sb.asarray(x,typ)
-    b = sb.asarray(b,typ)
-
-    matvec, psolve, rmatvec, rpsolve = (None,)*4
-    ltr = _type_conv[typ]
-    revcom = _iterative.__dict__[ltr+'bicgrevcom']
-    stoptest = _iterative.__dict__[ltr+'stoptest2']
-
     resid = tol
     ndx1 = 1
     ndx2 = -1
-    work = sb.zeros(6*n,typ)
+    work = sb.zeros(6*n,dtype=x.dtype)
     ijob = 1
     info = 0
     ftflag = True
@@ -198,26 +87,16 @@
         if (ijob == -1):
             break
         elif (ijob == 1):
-            if matvec is None:
-                matvec = get_matvec(A)
             work[slice2] *= sclr2
             work[slice2] += sclr1*matvec(work[slice1])
         elif (ijob == 2):
-            if rmatvec is None:
-                rmatvec = get_rmatvec(A)
             work[slice2] *= sclr2
             work[slice2] += sclr1*rmatvec(work[slice1])
         elif (ijob == 3):
-            if psolve is None:
-                psolve = get_psolve(A)
             work[slice1] = psolve(work[slice2])
         elif (ijob == 4):
-            if rpsolve is None:
-                rpsolve = get_rpsolve(A)
             work[slice1] = rpsolve(work[slice2])
         elif (ijob == 5):
-            if matvec is None:
-                matvec = get_matvec(A)
             work[slice2] *= sclr2
             work[slice2] += sclr1*matvec(x)
         elif (ijob == 6):
@@ -227,10 +106,10 @@
             bnrm2, resid, info = stoptest(work[slice1], b, bnrm2, tol, info)
         ijob = 2
 
-    return x, info
+    return postprocess(x), info
 
 
-def bicgstab(A, b, x0=None, tol=1e-5, maxiter=None, xtype=None, callback=None):
+def bicgstab(A, b, x0=None, tol=1e-5, maxiter=None, xtype=None, M=None, callback=None):
     """Use BIConjugate Gradient STABilized iteration to solve A x = b
 
     Inputs:
@@ -264,44 +143,22 @@
                 iteration.  It is called as callback(xk), where xk is the
                 current parameter vector.
     """
-    b = sb.asarray(b)+0.0
+    A,M,x,b,postprocess = make_system(A,M,x0,b,xtype)
+
     n = len(b)
     if maxiter is None:
         maxiter = n*10
 
-    if x0 is None:
-        x = sb.zeros(n)
-    else:
-        x = copy.copy(x0)
+    matvec = A.matvec
+    psolve = M.matvec
+    ltr = _type_conv[x.dtype.char]
+    revcom   = getattr(_iterative, ltr + 'bicgstabrevcom')
+    stoptest = getattr(_iterative, ltr + 'stoptest2')
 
-
-    if xtype is None:
-        try:
-            atyp = A.dtype.char
-        except AttributeError:
-            atyp = None
-        if atyp is None:
-            atyp = A.matvec(x).dtype.char
-        typ = _coerce_rules[b.dtype.char,atyp]
-    elif xtype == 0:
-        typ = b.dtype.char
-    else:
-        typ = xtype
-        if typ not in 'fdFD':
-            raise ValueError, "xtype must be 'f', 'd', 'F', or 'D'"
-
-    x = sb.asarray(x,typ)
-    b = sb.asarray(b,typ)
-
-    matvec, psolve = (None,)*2
-    ltr = _type_conv[typ]
-    revcom = _iterative.__dict__[ltr+'bicgstabrevcom']
-    stoptest = _iterative.__dict__[ltr+'stoptest2']
-
     resid = tol
     ndx1 = 1
     ndx2 = -1
-    work = sb.zeros(7*n,typ)
+    work = sb.zeros(7*n,dtype=x.dtype)
     ijob = 1
     info = 0
     ftflag = True
@@ -338,10 +195,10 @@
             bnrm2, resid, info = stoptest(work[slice1], b, bnrm2, tol, info)
         ijob = 2
 
-    return x, info
+    return postprocess(x), info
 
 
-def cg(A, b, x0=None, tol=1e-5, maxiter=None, xtype=None, callback=None):
+def cg(A, b, x0=None, tol=1e-5, maxiter=None, xtype=None, M=None, callback=None):
     """Use Conjugate Gradient iteration to solve A x = b (A^H = A)
 
     Inputs:
@@ -376,44 +233,22 @@
                 iteration.  It is called as callback(xk), where xk is the
                 current parameter vector.
     """
-    b = sb.asarray(b)+0.0
+    A,M,x,b,postprocess = make_system(A,M,x0,b,xtype)
+
     n = len(b)
     if maxiter is None:
         maxiter = n*10
 
-    if x0 is None:
-        x = sb.zeros(n)
-    else:
-        x = copy.copy(x0)
+    matvec = A.matvec
+    psolve = M.matvec
+    ltr = _type_conv[x.dtype.char]
+    revcom   = getattr(_iterative, ltr + 'cgrevcom')
+    stoptest = getattr(_iterative, ltr + 'stoptest2')
 
-
-    if xtype is None:
-        try:
-            atyp = A.dtype.char
-        except AttributeError:
-            atyp = None
-        if atyp is None:
-            atyp = A.matvec(x).dtype.char
-        typ = _coerce_rules[b.dtype.char,atyp]
-    elif xtype == 0:
-        typ = b.dtype.char
-    else:
-        typ = xtype
-        if typ not in 'fdFD':
-            raise ValueError, "xtype must be 'f', 'd', 'F', or 'D'"
-
-    x = sb.asarray(x,typ)
-    b = sb.asarray(b,typ)
-
-    matvec, psolve = (None,)*2
-    ltr = _type_conv[typ]
-    revcom = _iterative.__dict__[ltr+'cgrevcom']
-    stoptest = _iterative.__dict__[ltr+'stoptest2']
-
     resid = tol
     ndx1 = 1
     ndx2 = -1
-    work = sb.zeros(4*n,typ)
+    work = sb.zeros(4*n,dtype=x.dtype)
     ijob = 1
     info = 0
     ftflag = True
@@ -430,17 +265,11 @@
         if (ijob == -1):
             break
         elif (ijob == 1):
-            if matvec is None:
-                matvec = get_matvec(A)
             work[slice2] *= sclr2
             work[slice2] += sclr1*matvec(work[slice1])
         elif (ijob == 2):
-            if psolve is None:
-                psolve = get_psolve(A)
             work[slice1] = psolve(work[slice2])
         elif (ijob == 3):
-            if matvec is None:
-                matvec = get_matvec(A)
             work[slice2] *= sclr2
             work[slice2] += sclr1*matvec(x)
         elif (ijob == 4):
@@ -450,10 +279,10 @@
             bnrm2, resid, info = stoptest(work[slice1], b, bnrm2, tol, info)
         ijob = 2
 
-    return x, info
+    return postprocess(x), info
 
 
-def cgs(A, b, x0=None, tol=1e-5, maxiter=None, xtype=None, callback=None):
+def cgs(A, b, x0=None, tol=1e-5, maxiter=None, xtype=None, M=None, callback=None):
     """Use Conjugate Gradient Squared iteration to solve A x = b
 
     Inputs:
@@ -488,43 +317,22 @@
                 iteration.  It is called as callback(xk), where xk is the
                 current parameter vector.
     """
-    b = sb.asarray(b) + 0.0
+    A,M,x,b,postprocess = make_system(A,M,x0,b,xtype)
+
     n = len(b)
     if maxiter is None:
         maxiter = n*10
+    
+    matvec = A.matvec
+    psolve = M.matvec
+    ltr = _type_conv[x.dtype.char]
+    revcom   = getattr(_iterative, ltr + 'cgsrevcom')
+    stoptest = getattr(_iterative, ltr + 'stoptest2')
 
-    if x0 is None:
-        x = sb.zeros(n)
-    else:
-        x = copy.copy(x0)
-
-    if xtype is None:
-        try:
-            atyp = A.dtype.char
-        except AttributeError:
-            atyp = None
-        if atyp is None:
-            atyp = A.matvec(x).dtype.char
-        typ = _coerce_rules[b.dtype.char,atyp]
-    elif xtype == 0:
-        typ = b.dtype.char
-    else:
-        typ = xtype
-        if typ not in 'fdFD':
-            raise ValueError, "xtype must be 'f', 'd', 'F', or 'D'"
-
-    x = sb.asarray(x,typ)
-    b = sb.asarray(b,typ)
-
-    matvec, psolve = (None,)*2
-    ltr = _type_conv[typ]
-    revcom = _iterative.__dict__[ltr+'cgsrevcom']
-    stoptest = _iterative.__dict__[ltr+'stoptest2']
-
     resid = tol
     ndx1 = 1
     ndx2 = -1
-    work = sb.zeros(7*n,typ)
+    work = sb.zeros(7*n,dtype=x.dtype)
     ijob = 1
     info = 0
     ftflag = True
@@ -541,17 +349,11 @@
         if (ijob == -1):
             break
         elif (ijob == 1):
-            if matvec is None:
-                matvec = get_matvec(A)
             work[slice2] *= sclr2
             work[slice2] += sclr1*matvec(work[slice1])
         elif (ijob == 2):
-            if psolve is None:
-                psolve = get_psolve(A)
             work[slice1] = psolve(work[slice2])
         elif (ijob == 3):
-            if matvec is None:
-                matvec = get_matvec(A)
             work[slice2] *= sclr2
             work[slice2] += sclr1*matvec(x)
         elif (ijob == 4):
@@ -561,10 +363,10 @@
             bnrm2, resid, info = stoptest(work[slice1], b, bnrm2, tol, info)
         ijob = 2
 
-    return x, info
+    return postprocess(x), info
 
 
-def gmres(A, b, x0=None, tol=1e-5, restrt=None, maxiter=None, xtype=None, callback=None):
+def gmres(A, b, x0=None, tol=1e-5, restrt=None, maxiter=None, xtype=None, M=None, callback=None):
     """Use Generalized Minimal RESidual iteration to solve A x = b
 
     Inputs:
@@ -601,44 +403,25 @@
                 iteration.  It is called as callback(xk), where xk is the
                 current parameter vector.
     """
-    b = sb.asarray(b)+0.0
+    A,M,x,b,postprocess = make_system(A,M,x0,b,xtype)
+
     n = len(b)
     if maxiter is None:
         maxiter = n*10
 
-    if x0 is None:
-        x = sb.zeros(n)
-    else:
-        x = copy.copy(x0)
+    matvec = A.matvec
+    psolve = M.matvec
+    ltr = _type_conv[x.dtype.char]
+    revcom   = getattr(_iterative, ltr + 'gmresrevcom')
+    stoptest = getattr(_iterative, ltr + 'stoptest2')
 
-    if xtype is None:
-        try:
-            atyp = A.dtype.char
-        except AttributeError:
-            atyp = A.matvec(x).dtype.char
-        typ = _coerce_rules[b.dtype.char,atyp]
-    elif xtype == 0:
-        typ = b.dtype.char
-    else:
-        typ = xtype
-        if typ not in 'fdFD':
-            raise ValueError, "xtype must be 'f', 'd', 'F', or 'D'"
-
-    x = sb.asarray(x,typ)
-    b = sb.asarray(b,typ)
-
-    matvec, psolve = (None,)*2
-    ltr = _type_conv[typ]
-    revcom = _iterative.__dict__[ltr+'gmresrevcom']
-    stoptest = _iterative.__dict__[ltr+'stoptest2']
-
     if restrt is None:
         restrt = n
     resid = tol
     ndx1 = 1
     ndx2 = -1
-    work = sb.zeros((6+restrt)*n,typ)
-    work2 = sb.zeros((restrt+1)*(2*restrt+2),typ)
+    work  = sb.zeros((6+restrt)*n,dtype=x.dtype)
+    work2 = sb.zeros((restrt+1)*(2*restrt+2),dtype=x.dtype)
     ijob = 1
     info = 0
     ftflag = True
@@ -655,17 +438,11 @@
         if (ijob == -1):
             break
         elif (ijob == 1):
-            if matvec is None:
-                matvec = get_matvec(A)
             work[slice2] *= sclr2
             work[slice2] += sclr1*matvec(x)
         elif (ijob == 2):
-            if psolve is None:
-                psolve = get_psolve(A)
             work[slice1] = psolve(work[slice2])
         elif (ijob == 3):
-            if matvec is None:
-                matvec = get_matvec(A)
             work[slice2] *= sclr2
             work[slice2] += sclr1*matvec(work[slice1])
         elif (ijob == 4):
@@ -675,10 +452,10 @@
             bnrm2, resid, info = stoptest(work[slice1], b, bnrm2, tol, info)
         ijob = 2
 
-    return x, info
+    return postprocess(x), info
 
 
-def qmr(A, b, x0=None, tol=1e-5, maxiter=None, xtype=None, callback=None):
+def qmr(A, b, x0=None, tol=1e-5, maxiter=None, xtype=None, M1=None, M2=None, callback=None):
     """Use Quasi-Minimal Residual iteration to solve A x = b
 
     Inputs:
@@ -714,44 +491,39 @@
                 iteration.  It is called as callback(xk), where xk is the
                 current parameter vector.
     """
-    b = sb.asarray(b)+0.0
+    A_ = A
+    A,M,x,b,postprocess = make_system(A,None,x0,b,xtype)
+
+    if M1 is None and M2 is None:
+        if hasattr(A_,'psolve'):
+            def left_psolve(b):
+                return A_.psolve(b,'left')
+            def right_psolve(b):
+                return A_.psolve(b,'right')
+            def left_rpsolve(b):
+                return A_.rpsolve(b,'left')
+            def right_rpsolve(b):
+                return A_.rpsolve(b,'right')
+            M1 = LinearOperator(A.shape, matvec=left_psolve, rmatvec=left_rpsolve)
+            M2 = LinearOperator(A.shape, matvec=right_psolve, rmatvec=right_rpsolve)
+        else:
+            def id(b):
+                return b
+            M1 = LinearOperator(A.shape, matvec=id, rmatvec=id)
+            M2 = LinearOperator(A.shape, matvec=id, rmatvec=id)
+
     n = len(b)
     if maxiter is None:
         maxiter = n*10
 
-    if x0 is None:
-        x = sb.zeros(n)
-    else:
-        x = copy.copy(x0)
+    ltr = _type_conv[x.dtype.char]
+    revcom   = getattr(_iterative, ltr + 'qmrrevcom')
+    stoptest = getattr(_iterative, ltr + 'stoptest2')
 
-    if xtype is None:
-        try:
-            atyp = A.dtype.char
-        except AttributeError:
-            atyp = None
-        if atyp is None:
-            atyp = A.matvec(x).dtype.char
-        typ = _coerce_rules[b.dtype.char,atyp]
-    elif xtype == 0:
-        typ = b.dtype.char
-    else:
-        typ = xtype
-        if typ not in 'fdFD':
-            raise ValueError, "xtype must be 'f', 'd', 'F', or 'D'"
-
-    x = sb.asarray(x,typ)
-    b = sb.asarray(b,typ)
-
-
-    matvec, psolve, rmatvec, rpsolve = (None,)*4
-    ltr = _type_conv[typ]
-    revcom = _iterative.__dict__[ltr+'qmrrevcom']
-    stoptest = _iterative.__dict__[ltr+'stoptest2']
-
     resid = tol
     ndx1 = 1
     ndx2 = -1
-    work = sb.zeros(11*n,typ)
+    work = sb.zeros(11*n,x.dtype)
     ijob = 1
     info = 0
     ftflag = True
@@ -768,36 +540,22 @@
         if (ijob == -1):
             break
         elif (ijob == 1):
-            if matvec is None:
-                matvec = get_matvec(A)
             work[slice2] *= sclr2
-            work[slice2] += sclr1*matvec(work[slice1])
+            work[slice2] += sclr1*A.matvec(work[slice1])
         elif (ijob == 2):
-            if rmatvec is None:
-                rmatvec = get_rmatvec(A)
             work[slice2] *= sclr2
-            work[slice2] += sclr1*rmatvec(work[slice1])
+            work[slice2] += sclr1*A.rmatvec(work[slice1])
         elif (ijob == 3):
-            if psolve is None:
-                psolve = get_psolveq(A)
-            work[slice1] = psolve(work[slice2],'left')
+            work[slice1] = M1.matvec(work[slice2])
         elif (ijob == 4):
-            if psolve is None:
-                psolve = get_psolveq(A)
-            work[slice1] = psolve(work[slice2],'right')
+            work[slice1] = M2.matvec(work[slice2])
         elif (ijob == 5):
-            if rpsolve is None:
-                rpsolve = get_rpsolveq(A)
-            work[slice1] = rpsolve(work[slice2],'left')
+            work[slice1] = M1.rmatvec(work[slice2])
         elif (ijob == 6):
-            if rpsolve is None:
-                rpsolve = get_rpsolveq(A)
-            work[slice1] = rpsolve(work[slice2],'right')
+            work[slice1] = M2.rmatvec(work[slice2])
         elif (ijob == 7):
-            if matvec is None:
-                matvec = get_matvec(A)
             work[slice2] *= sclr2
-            work[slice2] += sclr1*matvec(x)
+            work[slice2] += sclr1*A.matvec(x)
         elif (ijob == 8):
             if ftflag:
                 info = -1
@@ -805,4 +563,4 @@
             bnrm2, resid, info = stoptest(work[slice1], b, bnrm2, tol, info)
         ijob = 2
 
-    return x, info
+    return postprocess(x), info

Modified: trunk/scipy/splinalg/isolve/minres.py
===================================================================
--- trunk/scipy/splinalg/isolve/minres.py	2008-02-02 04:26:51 UTC (rev 3889)
+++ trunk/scipy/splinalg/isolve/minres.py	2008-02-02 18:17:25 UTC (rev 3890)
@@ -3,6 +3,8 @@
 
 from utils import make_system
 
+__all__ = ['minres']
+
 def minres(A, b, x0=None, shift=0.0, tol=1e-5, maxiter=None, xtype=None,
            M=None, callback=None, show=False, check=False):
     """Use the Minimum Residual Method (MINRES) to solve Ax=b 

Modified: trunk/scipy/splinalg/isolve/tests/test_iterative.py
===================================================================
--- trunk/scipy/splinalg/isolve/tests/test_iterative.py	2008-02-02 04:26:51 UTC (rev 3889)
+++ trunk/scipy/splinalg/isolve/tests/test_iterative.py	2008-02-02 18:17:25 UTC (rev 3890)
@@ -17,6 +17,10 @@
 #    #print "||A.x - b|| = " + str(norm(dot(A,x)-b))
 
 
+#TODO check that method preserve shape and type
+#TODO test complex matrices
+#TODO test both preconditioner methods
+
 data = ones((3,10))
 data[0,:] =  2
 data[1,:] = -1

Modified: trunk/scipy/splinalg/isolve/utils.py
===================================================================
--- trunk/scipy/splinalg/isolve/utils.py	2008-02-02 04:26:51 UTC (rev 3889)
+++ trunk/scipy/splinalg/isolve/utils.py	2008-02-02 18:17:25 UTC (rev 3890)
@@ -1,5 +1,7 @@
-from numpy import asanyarray, asmatrix, array, matrix, zeros
+from warnings import warn
 
+from numpy import asanyarray, asarray, asmatrix, array, matrix, zeros
+
 from scipy.splinalg.interface import aslinearoperator, LinearOperator
 
 _coerce_rules = {('f','f'):'f', ('f','d'):'d', ('f','F'):'F',
@@ -20,6 +22,31 @@
     return x
 
 def make_system(A, M, x0, b, xtype=None):
+    """Make a linear system Ax=b
+    
+    Parameters:
+        A - LinearOperator
+            - sparse or dense matrix (or any valid input to aslinearoperator)
+        M - LinearOperator or None
+            - preconditioner
+            - sparse or dense matrix (or any valid input to aslinearoperator)
+        x0 - array_like or None
+            - initial guess to iterative method
+        b  - array_like
+            - right hand side
+        xtype - None or one of 'fdFD'
+            - dtype of the x vector
+
+    Returns:
+        (A, M, x, b, postprocess) where:
+            - A is a LinearOperator 
+            - M is a LinearOperator
+            - x is the initial guess (rank 1 array)
+            - b is the rhs (rank 1 array)
+            - postprocess is a function that converts the solution vector
+              to the appropriate type and dimensions (e.g. (N,1) matrix)
+
+    """
     A_ = A
     A = aslinearoperator(A)
 
@@ -33,24 +60,32 @@
     if not (b.shape == (N,1) or b.shape == (N,)):
         raise ValueError('A and b have incompatible dimensions')
 
+    if b.dtype.char not in 'fdFD':
+        b = b.astype('d') # upcast non-FP types to double
+
     def postprocess(x):
         if isinstance(b,matrix):
             x = asmatrix(x)
         return x.reshape(b.shape)
 
-
     if xtype is None:
         if hasattr(A,'dtype'):
             xtype = A.dtype.char
         else:
             xtype = A.matvec(b).dtype.char
         xtype = coerce(xtype, b.dtype.char)
-    elif xtype == 0:
-        xtype = b.dtype.char
     else:
-        if xtype not in 'fdFD':
-            raise ValueError, "xtype must be 'f', 'd', 'F', or 'D'"
+        warn('Use of xtype argument is deprecated. '\
+                'Use LinearOperator( ... , dtype=xtype) instead.',\
+                DeprecationWarning)
+        if xtype == 0:
+            xtype = b.dtype.char
+        else:
+            if xtype not in 'fdFD':
+                raise ValueError, "xtype must be 'f', 'd', 'F', or 'D'"
 
+    b = asarray(b,dtype=xtype) #make b the same type as x
+
     if x0 is None:
         x = zeros(N, dtype=xtype)
     else:
@@ -71,6 +106,7 @@
             rpsolve = id
         M = LinearOperator(A.shape, matvec=psolve, rmatvec=rpsolve, dtype=A.dtype)
     else:
+        M = aslinearoperator(M)
         if A.shape != M.shape:
             raise ValueError('matrix and preconditioner have different shapes')
 

Modified: trunk/scipy/splinalg/tests/test_interface.py
===================================================================
--- trunk/scipy/splinalg/tests/test_interface.py	2008-02-02 04:26:51 UTC (rev 3889)
+++ trunk/scipy/splinalg/tests/test_interface.py	2008-02-02 18:17:25 UTC (rev 3890)
@@ -31,18 +31,9 @@
                 return y
 
             def rmatvec(self,x):
-                if len(x.shape) == 1:
-                    y = array([ 1*x[0] + 4*x[1],
-                                2*x[0] + 5*x[1],
-                                3*x[0] + 6*x[1]])
-                    return y
-                else:
-                    y = array([ 1*x[0,0] + 4*x[0,1],
-                                2*x[0,0] + 5*x[0,1],
-                                3*x[0,0] + 6*x[0,1]])
-                    return y.reshape(1,-1)
-
-                return y
+                return array([ 1*x[0] + 4*x[1],
+                               2*x[0] + 5*x[1],
+                               3*x[0] + 6*x[1]])
                
         cases.append( matlike() )
 
@@ -55,7 +46,7 @@
             assert_equal(A.matvec(array([[1],[2],[3]])),[[14],[32]])
 
             assert_equal(A.rmatvec(array([1,2])),  [9,12,15])
-            assert_equal(A.rmatvec(array([[1,2]])),[[9,12,15]])
+            assert_equal(A.rmatvec(array([[1],[2]])),[[9],[12],[15]])
 
             if hasattr(M,'dtype'):
                 assert_equal(A.dtype, M.dtype)




More information about the Scipy-svn mailing list