[Scipy-svn] r7088 - in trunk/scipy/linalg: . tests

scipy-svn at scipy.org scipy-svn at scipy.org
Wed Jan 26 18:06:55 EST 2011


Author: ptvirtan
Date: 2011-01-26 17:06:53 -0600 (Wed, 26 Jan 2011)
New Revision: 7088

Modified:
   trunk/scipy/linalg/decomp_qr.py
   trunk/scipy/linalg/generic_flapack.pyf
   trunk/scipy/linalg/tests/test_decomp.py
Log:
ENH: improvements for linalg.rq

- API: now uses the same keywords as linalg.qr : implemented
  mode='economic'

- type support: now supports non-square as well as complex arrays.
  This way I could uncomment a lot of tests.

- performance: now uses LAPACK orgrq for computing the Q matrix
  instead of plain matrix multiplications. On a 500 x 500 array in my
  Core Duo, this makes it go from 17.6s s to 0.16s, i.e. gain of a
  100x factor.

- Also, some docstring fixes in both linalg.qr and linalg.rq.

Thanks to Fabian Pedregosa for the patch.

Modified: trunk/scipy/linalg/decomp_qr.py
===================================================================
--- trunk/scipy/linalg/decomp_qr.py	2011-01-25 21:28:41 UTC (rev 7087)
+++ trunk/scipy/linalg/decomp_qr.py	2011-01-26 23:06:53 UTC (rev 7088)
@@ -1,9 +1,7 @@
 """QR decomposition functions."""
 
-from warnings import warn
-
 import numpy
-from numpy import asarray_chkfinite, complexfloating
+from numpy import asarray_chkfinite
 
 # Local imports
 import special_matrices
@@ -52,17 +50,16 @@
 
     Examples
     --------
-    >>> from scipy import random, linalg, dot
+    >>> from scipy import random, linalg, dot, allclose
     >>> a = random.randn(9, 6)
     >>> q, r = linalg.qr(a)
     >>> allclose(a, dot(q, r))
     True
     >>> q.shape, r.shape
     ((9, 9), (9, 6))
-
     >>> r2 = linalg.qr(a, mode='r')
     >>> allclose(r, r2)
-
+    True
     >>> q3, r3 = linalg.qr(a, mode='economic')
     >>> q3.shape, r3.shape
     ((9, 6), (6, 6))
@@ -185,7 +182,7 @@
     return Q, R
 
 
-def rq(a, overwrite_a=False, lwork=None):
+def rq(a, overwrite_a=False, lwork=None, mode='full'):
     """Compute RQ decomposition of a square real matrix.
 
     Calculate the decomposition :lm:`A = R Q` where Q is unitary/orthogonal
@@ -194,12 +191,16 @@
     Parameters
     ----------
     a : array, shape (M, M)
-        Square real matrix to be decomposed
+        Matrix to be decomposed
     overwrite_a : boolean
         Whether data in a is overwritten (may improve performance)
     lwork : integer
         Work array size, lwork >= a.shape[1]. If None or -1, an optimal size
         is computed.
+    mode : {'full', 'r', 'economic'}
+        Determines what information is to be returned: either both Q and R
+        ('full', default), only R ('r') or both Q and R but computed in
+        economy-size ('economic', see Notes).
 
     Returns
     -------
@@ -208,17 +209,34 @@
 
     Raises LinAlgError if decomposition fails
 
+    Examples
+    --------
+    >>> from scipy import linalg
+    >>> from numpy import random, dot, allclose
+    >>> a = random.randn(6, 9)
+    >>> r, q = linalg.rq(a)
+    >>> allclose(a, dot(r, q))
+    True
+    >>> r.shape, q.shape
+    ((6, 9), (9, 9))
+    >>> r2 = linalg.rq(a, mode='r')
+    >>> allclose(r, r2)
+    True
+    >>> r3, q3 = linalg.rq(a, mode='economic')
+    >>> r3.shape, q3.shape
+    ((6, 6), (6, 9))
+
     """
-    # TODO: implement support for non-square and complex arrays
+    if not mode in ['full', 'r', 'economic']:
+        raise ValueError(\
+                 "Mode argument should be one of ['full', 'r', 'economic']")
+
     a1 = asarray_chkfinite(a)
     if len(a1.shape) != 2:
         raise ValueError('expected matrix')
-    M,N = a1.shape
-    if M != N:
-        raise ValueError('expected square matrix')
-    if issubclass(a1.dtype.type, complexfloating):
-        raise ValueError('expected real (non-complex) matrix')
+    M, N = a1.shape
     overwrite_a = overwrite_a or (_datanotshared(a1, a))
+
     gerqf, = get_lapack_funcs(('gerqf',), (a1,))
     if lwork is None or lwork == -1:
         # get optimal work array
@@ -226,20 +244,40 @@
         lwork = work[0].real.astype(numpy.int)
     rq, tau, work, info = gerqf(a1, lwork=lwork, overwrite_a=overwrite_a)
     if info < 0:
-        raise ValueError('illegal value in %d-th argument of internal geqrf'
+        raise ValueError('illegal value in %d-th argument of internal gerqf'
                                                                     % -info)
-    gemm, = get_blas_funcs(('gemm',), (rq,))
-    t = rq.dtype.char
-    R = special_matrices.triu(rq)
-    Q = numpy.identity(M, dtype=t)
-    ident = numpy.identity(M, dtype=t)
-    zeros = numpy.zeros
+    if not mode == 'economic' or N < M:
+        R = special_matrices.triu(rq, N-M)
+    else:
+        R = special_matrices.triu(rq[-M:, -M:])
 
-    k = min(M, N)
-    for i in range(k):
-        v = zeros((M,), t)
-        v[N-k+i] = 1
-        v[0:N-k+i] = rq[M-k+i, 0:N-k+i]
-        H = gemm(-tau[i], v, v, 1+0j, ident, trans_b=2)
-        Q = gemm(1, Q, H)
+    if mode == 'r':
+        return R
+
+    if find_best_lapack_type((a1,))[0] in ('s', 'd'):
+        gor_un_grq, = get_lapack_funcs(('orgrq',), (rq,))
+    else:
+        gor_un_grq, = get_lapack_funcs(('ungrq',), (rq,))
+
+    if N < M:
+        # get optimal work array
+        Q, work, info = gor_un_grq(rq[-N:], tau, lwork=-1, overwrite_a=1)
+        lwork = work[0].real.astype(numpy.int)
+        Q, work, info = gor_un_grq(rq[-N:], tau, lwork=lwork, overwrite_a=1)
+    elif mode == 'economic':
+        # get optimal work array
+        Q, work, info = gor_un_grq(rq, tau, lwork=-1, overwrite_a=1)
+        lwork = work[0].real.astype(numpy.int)
+        Q, work, info = gor_un_grq(rq, tau, lwork=lwork, overwrite_a=1)
+    else:
+        rq1 = numpy.empty((N, N), dtype=rq.dtype)
+        rq1[-M:] = rq
+        # get optimal work array
+        Q, work, info = gor_un_grq(rq1, tau, lwork=-1, overwrite_a=1)
+        lwork = work[0].real.astype(numpy.int)
+        Q, work, info = gor_un_grq(rq1, tau, lwork=lwork, overwrite_a=1)
+
+    if info < 0:
+        raise ValueError("illegal value in %d-th argument of internal orgrq"
+                                                                    % -info)
     return R, Q

Modified: trunk/scipy/linalg/generic_flapack.pyf
===================================================================
--- trunk/scipy/linalg/generic_flapack.pyf	2011-01-25 21:28:41 UTC (rev 7087)
+++ trunk/scipy/linalg/generic_flapack.pyf	2011-01-26 23:06:53 UTC (rev 7088)
@@ -573,7 +573,50 @@
      <type_in> dimension(MAX(lwork,1)),intent(out),depend(lwork) :: work
      integer intent(out) :: info
    end subroutine <tchar=c,z>ungqr
-+
+
+   subroutine <tchar=s,d>orgrq(m,n,k,a,tau,work,lwork,info)
+
+   ! q,work,info = orgrq(a,lwork=3*n,overwrite_a=0)
+   ! Generates an M-by-N real matrix Q with orthonormal columns,
+   ! which is defined as the first N columns of a product of K elementary
+   ! reflectors of order M (e.g. output of gerqf)
+
+     callstatement (*f2py_func)(&m,&n,&k,a,&m,tau,work,&lwork,&info)
+     callprotoargument int*,int*,int*,<type_in_c>*,int*,<type_in_c>*,<type_in_c>*,int*,int*
+
+     integer intent(hide),depend(a):: m = shape(a,0)
+     integer intent(hide),depend(a):: n = shape(a,1)
+     integer intent(hide),depend(tau):: k = shape(tau,0)
+     <type_in> dimension(m,n),intent(in,out,copy,out=q) :: a
+     <type_in> dimension(k),intent(in) :: tau
+
+     integer optional,intent(in),depend(n),check(lwork>=n||lwork==-1) :: lwork=3*n
+     <type_in> dimension(MAX(lwork,1)),intent(out),depend(lwork) :: work
+     integer intent(out) :: info
+   end subroutine <tchar=s,d>orgrq
+
+   subroutine <tchar=c,z>ungrq(m,n,k,a,tau,work,lwork,info)
+
+   ! q,work,info = ungrq(a,lwork=3*n,overwrite_a=0)
+   ! Generates an M-by-N complex matrix Q with unitary columns,
+   ! which is defined as the first N columns of a product of K elementary
+   ! reflectors of order M (e.g. output of gerqf)
+
+     callstatement (*f2py_func)(&m,&n,&k,a,&m,tau,work,&lwork,&info)
+     callprotoargument int*,int*,int*,<type_in_c>*,int*,<type_in_c>*,<type_in_c>*,int*,int*
+
+     integer intent(hide),depend(a):: m = shape(a,0)
+     integer intent(hide),depend(a):: n = shape(a,1)
+     integer intent(hide),depend(tau):: k = shape(tau,0)
+     <type_in> dimension(m,n),intent(in,out,copy,out=q) :: a
+     <type_in> dimension(k),intent(in) :: tau
+
+     integer optional,intent(in),depend(n),check(lwork>=n||lwork==-1) :: lwork=3*n
+     <type_in> dimension(MAX(lwork,1)),intent(out),depend(lwork) :: work
+     integer intent(out) :: info
+   end subroutine <tchar=c,z>ungrq
+
+
    subroutine <tchar=s,d>geev(compute_vl,compute_vr,n,a,wr,wi,vl,ldvl,vr,ldvr,work,lwork,info)
 
      ! wr,wi,vl,vr,info = geev(a,compute_vl=1,compute_vr=1,lwork=4*n,overwrite_a=0)

Modified: trunk/scipy/linalg/tests/test_decomp.py
===================================================================
--- trunk/scipy/linalg/tests/test_decomp.py	2011-01-25 21:28:41 UTC (rev 7087)
+++ trunk/scipy/linalg/tests/test_decomp.py	2011-01-26 23:06:53 UTC (rev 7088)
@@ -919,7 +919,7 @@
     def test_simple(self):
         a = [[8,2,3],[2,9,3],[5,3,6]]
         r,q = rq(a)
-        assert_array_almost_equal(dot(transpose(q),q),identity(3))
+        assert_array_almost_equal(dot(q, transpose(q)),identity(3))
         assert_array_almost_equal(dot(r,q),a)
 
     def test_random(self):
@@ -927,54 +927,63 @@
         for k in range(2):
             a = random([n,n])
             r,q = rq(a)
-            assert_array_almost_equal(dot(transpose(q),q),identity(n))
+            assert_array_almost_equal(dot(q, transpose(q)),identity(n))
             assert_array_almost_equal(dot(r,q),a)
 
-# TODO: implement support for non-square and complex arrays
+    def test_simple_trap(self):
+        a = [[8,2,3],[2,9,3]]
+        r,q = rq(a)
+        assert_array_almost_equal(dot(transpose(q),q),identity(3))
+        assert_array_almost_equal(dot(r,q),a)
 
-##    def test_simple_trap(self):
-##        a = [[8,2,3],[2,9,3]]
-##        r,q = rq(a)
-##        assert_array_almost_equal(dot(transpose(q),q),identity(2))
-##        assert_array_almost_equal(dot(r,q),a)
+    def test_simple_tall(self):
+        a = [[8,2],[2,9],[5,3]]
+        r,q = rq(a)
+        assert_array_almost_equal(dot(transpose(q),q),identity(2))
+        assert_array_almost_equal(dot(r,q),a)
 
-##    def test_simple_tall(self):
-##        a = [[8,2],[2,9],[5,3]]
-##        r,q = rq(a)
-##        assert_array_almost_equal(dot(transpose(q),q),identity(3))
-##        assert_array_almost_equal(dot(r,q),a)
+    def test_simple_complex(self):
+        a = [[3,3+4j,5],[5,2,2+7j],[3,2,7]]
+        r,q = rq(a)
+        assert_array_almost_equal(dot(q, conj(transpose(q))),identity(3))
+        assert_array_almost_equal(dot(r,q),a)
 
-##    def test_simple_complex(self):
-##        a = [[3,3+4j,5],[5,2,2+7j],[3,2,7]]
-##        r,q = rq(a)
-##        assert_array_almost_equal(dot(conj(transpose(q)),q),identity(3))
-##        assert_array_almost_equal(dot(r,q),a)
+    def test_random_tall(self):
+        m = 200
+        n = 100
+        for k in range(2):
+            a = random([m,n])
+            r,q = rq(a)
+            assert_array_almost_equal(dot(q, transpose(q)),identity(n))
+            assert_array_almost_equal(dot(r,q),a)
 
-##    def test_random_tall(self):
-##        m = 200
-##        n = 100
-##        for k in range(2):
-##            a = random([m,n])
-##            r,q = rq(a)
-##            assert_array_almost_equal(dot(transpose(q),q),identity(m))
-##            assert_array_almost_equal(dot(r,q),a)
+    def test_random_trap(self):
+        m = 100
+        n = 200
+        for k in range(2):
+            a = random([m,n])
+            r,q = rq(a)
+            assert_array_almost_equal(dot(q, transpose(q)),identity(n))
+            assert_array_almost_equal(dot(r,q),a)
 
-##    def test_random_trap(self):
-##        m = 100
-##        n = 200
-##        for k in range(2):
-##            a = random([m,n])
-##            r,q = rq(a)
-##            assert_array_almost_equal(dot(transpose(q),q),identity(m))
-##            assert_array_almost_equal(dot(r,q),a)
+    def test_random_trap_economic(self):
+        m = 100
+        n = 200
+        for k in range(2):
+            a = random([m,n])
+            r,q = rq(a, mode='economic')
+            assert_array_almost_equal(dot(q,transpose(q)),identity(m))
+            assert_array_almost_equal(dot(r,q),a)
+            assert_equal(q.shape, (m, n))
+            assert_equal(r.shape, (m, m))
 
-##    def test_random_complex(self):
-##        n = 20
-##        for k in range(2):
-##            a = random([n,n])+1j*random([n,n])
-##            r,q = rq(a)
-##            assert_array_almost_equal(dot(conj(transpose(q)),q),identity(n))
-##            assert_array_almost_equal(dot(r,q),a)
+    def test_random_complex(self):
+        n = 20
+        for k in range(2):
+            a = random([n,n])+1j*random([n,n])
+            r,q = rq(a)
+            assert_array_almost_equal(dot(q, conj(transpose(q))),identity(n))
+            assert_array_almost_equal(dot(r,q),a)
 
 transp = transpose
 any = sometrue




More information about the Scipy-svn mailing list