[Scipy-svn] r7088 - in trunk/scipy/linalg: . tests
scipy-svn at scipy.org
scipy-svn at scipy.org
Wed Jan 26 18:06:55 EST 2011
Author: ptvirtan
Date: 2011-01-26 17:06:53 -0600 (Wed, 26 Jan 2011)
New Revision: 7088
Modified:
trunk/scipy/linalg/decomp_qr.py
trunk/scipy/linalg/generic_flapack.pyf
trunk/scipy/linalg/tests/test_decomp.py
Log:
ENH: improvements for linalg.rq
- API: now uses the same keywords as linalg.qr : implemented
mode='economic'
- type support: now supports non-square as well as complex arrays.
This way I could uncomment a lot of tests.
- performance: now uses LAPACK orgrq for computing the Q matrix
instead of plain matrix multiplications. On a 500 x 500 array in my
Core Duo, this makes it go from 17.6s s to 0.16s, i.e. gain of a
100x factor.
- Also, some docstring fixes in both linalg.qr and linalg.rq.
Thanks to Fabian Pedregosa for the patch.
Modified: trunk/scipy/linalg/decomp_qr.py
===================================================================
--- trunk/scipy/linalg/decomp_qr.py 2011-01-25 21:28:41 UTC (rev 7087)
+++ trunk/scipy/linalg/decomp_qr.py 2011-01-26 23:06:53 UTC (rev 7088)
@@ -1,9 +1,7 @@
"""QR decomposition functions."""
-from warnings import warn
-
import numpy
-from numpy import asarray_chkfinite, complexfloating
+from numpy import asarray_chkfinite
# Local imports
import special_matrices
@@ -52,17 +50,16 @@
Examples
--------
- >>> from scipy import random, linalg, dot
+ >>> from scipy import random, linalg, dot, allclose
>>> a = random.randn(9, 6)
>>> q, r = linalg.qr(a)
>>> allclose(a, dot(q, r))
True
>>> q.shape, r.shape
((9, 9), (9, 6))
-
>>> r2 = linalg.qr(a, mode='r')
>>> allclose(r, r2)
-
+ True
>>> q3, r3 = linalg.qr(a, mode='economic')
>>> q3.shape, r3.shape
((9, 6), (6, 6))
@@ -185,7 +182,7 @@
return Q, R
-def rq(a, overwrite_a=False, lwork=None):
+def rq(a, overwrite_a=False, lwork=None, mode='full'):
"""Compute RQ decomposition of a square real matrix.
Calculate the decomposition :lm:`A = R Q` where Q is unitary/orthogonal
@@ -194,12 +191,16 @@
Parameters
----------
a : array, shape (M, M)
- Square real matrix to be decomposed
+ Matrix to be decomposed
overwrite_a : boolean
Whether data in a is overwritten (may improve performance)
lwork : integer
Work array size, lwork >= a.shape[1]. If None or -1, an optimal size
is computed.
+ mode : {'full', 'r', 'economic'}
+ Determines what information is to be returned: either both Q and R
+ ('full', default), only R ('r') or both Q and R but computed in
+ economy-size ('economic', see Notes).
Returns
-------
@@ -208,17 +209,34 @@
Raises LinAlgError if decomposition fails
+ Examples
+ --------
+ >>> from scipy import linalg
+ >>> from numpy import random, dot, allclose
+ >>> a = random.randn(6, 9)
+ >>> r, q = linalg.rq(a)
+ >>> allclose(a, dot(r, q))
+ True
+ >>> r.shape, q.shape
+ ((6, 9), (9, 9))
+ >>> r2 = linalg.rq(a, mode='r')
+ >>> allclose(r, r2)
+ True
+ >>> r3, q3 = linalg.rq(a, mode='economic')
+ >>> r3.shape, q3.shape
+ ((6, 6), (6, 9))
+
"""
- # TODO: implement support for non-square and complex arrays
+ if not mode in ['full', 'r', 'economic']:
+ raise ValueError(\
+ "Mode argument should be one of ['full', 'r', 'economic']")
+
a1 = asarray_chkfinite(a)
if len(a1.shape) != 2:
raise ValueError('expected matrix')
- M,N = a1.shape
- if M != N:
- raise ValueError('expected square matrix')
- if issubclass(a1.dtype.type, complexfloating):
- raise ValueError('expected real (non-complex) matrix')
+ M, N = a1.shape
overwrite_a = overwrite_a or (_datanotshared(a1, a))
+
gerqf, = get_lapack_funcs(('gerqf',), (a1,))
if lwork is None or lwork == -1:
# get optimal work array
@@ -226,20 +244,40 @@
lwork = work[0].real.astype(numpy.int)
rq, tau, work, info = gerqf(a1, lwork=lwork, overwrite_a=overwrite_a)
if info < 0:
- raise ValueError('illegal value in %d-th argument of internal geqrf'
+ raise ValueError('illegal value in %d-th argument of internal gerqf'
% -info)
- gemm, = get_blas_funcs(('gemm',), (rq,))
- t = rq.dtype.char
- R = special_matrices.triu(rq)
- Q = numpy.identity(M, dtype=t)
- ident = numpy.identity(M, dtype=t)
- zeros = numpy.zeros
+ if not mode == 'economic' or N < M:
+ R = special_matrices.triu(rq, N-M)
+ else:
+ R = special_matrices.triu(rq[-M:, -M:])
- k = min(M, N)
- for i in range(k):
- v = zeros((M,), t)
- v[N-k+i] = 1
- v[0:N-k+i] = rq[M-k+i, 0:N-k+i]
- H = gemm(-tau[i], v, v, 1+0j, ident, trans_b=2)
- Q = gemm(1, Q, H)
+ if mode == 'r':
+ return R
+
+ if find_best_lapack_type((a1,))[0] in ('s', 'd'):
+ gor_un_grq, = get_lapack_funcs(('orgrq',), (rq,))
+ else:
+ gor_un_grq, = get_lapack_funcs(('ungrq',), (rq,))
+
+ if N < M:
+ # get optimal work array
+ Q, work, info = gor_un_grq(rq[-N:], tau, lwork=-1, overwrite_a=1)
+ lwork = work[0].real.astype(numpy.int)
+ Q, work, info = gor_un_grq(rq[-N:], tau, lwork=lwork, overwrite_a=1)
+ elif mode == 'economic':
+ # get optimal work array
+ Q, work, info = gor_un_grq(rq, tau, lwork=-1, overwrite_a=1)
+ lwork = work[0].real.astype(numpy.int)
+ Q, work, info = gor_un_grq(rq, tau, lwork=lwork, overwrite_a=1)
+ else:
+ rq1 = numpy.empty((N, N), dtype=rq.dtype)
+ rq1[-M:] = rq
+ # get optimal work array
+ Q, work, info = gor_un_grq(rq1, tau, lwork=-1, overwrite_a=1)
+ lwork = work[0].real.astype(numpy.int)
+ Q, work, info = gor_un_grq(rq1, tau, lwork=lwork, overwrite_a=1)
+
+ if info < 0:
+ raise ValueError("illegal value in %d-th argument of internal orgrq"
+ % -info)
return R, Q
Modified: trunk/scipy/linalg/generic_flapack.pyf
===================================================================
--- trunk/scipy/linalg/generic_flapack.pyf 2011-01-25 21:28:41 UTC (rev 7087)
+++ trunk/scipy/linalg/generic_flapack.pyf 2011-01-26 23:06:53 UTC (rev 7088)
@@ -573,7 +573,50 @@
<type_in> dimension(MAX(lwork,1)),intent(out),depend(lwork) :: work
integer intent(out) :: info
end subroutine <tchar=c,z>ungqr
-+
+
+ subroutine <tchar=s,d>orgrq(m,n,k,a,tau,work,lwork,info)
+
+ ! q,work,info = orgrq(a,lwork=3*n,overwrite_a=0)
+ ! Generates an M-by-N real matrix Q with orthonormal columns,
+ ! which is defined as the first N columns of a product of K elementary
+ ! reflectors of order M (e.g. output of gerqf)
+
+ callstatement (*f2py_func)(&m,&n,&k,a,&m,tau,work,&lwork,&info)
+ callprotoargument int*,int*,int*,<type_in_c>*,int*,<type_in_c>*,<type_in_c>*,int*,int*
+
+ integer intent(hide),depend(a):: m = shape(a,0)
+ integer intent(hide),depend(a):: n = shape(a,1)
+ integer intent(hide),depend(tau):: k = shape(tau,0)
+ <type_in> dimension(m,n),intent(in,out,copy,out=q) :: a
+ <type_in> dimension(k),intent(in) :: tau
+
+ integer optional,intent(in),depend(n),check(lwork>=n||lwork==-1) :: lwork=3*n
+ <type_in> dimension(MAX(lwork,1)),intent(out),depend(lwork) :: work
+ integer intent(out) :: info
+ end subroutine <tchar=s,d>orgrq
+
+ subroutine <tchar=c,z>ungrq(m,n,k,a,tau,work,lwork,info)
+
+ ! q,work,info = ungrq(a,lwork=3*n,overwrite_a=0)
+ ! Generates an M-by-N complex matrix Q with unitary columns,
+ ! which is defined as the first N columns of a product of K elementary
+ ! reflectors of order M (e.g. output of gerqf)
+
+ callstatement (*f2py_func)(&m,&n,&k,a,&m,tau,work,&lwork,&info)
+ callprotoargument int*,int*,int*,<type_in_c>*,int*,<type_in_c>*,<type_in_c>*,int*,int*
+
+ integer intent(hide),depend(a):: m = shape(a,0)
+ integer intent(hide),depend(a):: n = shape(a,1)
+ integer intent(hide),depend(tau):: k = shape(tau,0)
+ <type_in> dimension(m,n),intent(in,out,copy,out=q) :: a
+ <type_in> dimension(k),intent(in) :: tau
+
+ integer optional,intent(in),depend(n),check(lwork>=n||lwork==-1) :: lwork=3*n
+ <type_in> dimension(MAX(lwork,1)),intent(out),depend(lwork) :: work
+ integer intent(out) :: info
+ end subroutine <tchar=c,z>ungrq
+
+
subroutine <tchar=s,d>geev(compute_vl,compute_vr,n,a,wr,wi,vl,ldvl,vr,ldvr,work,lwork,info)
! wr,wi,vl,vr,info = geev(a,compute_vl=1,compute_vr=1,lwork=4*n,overwrite_a=0)
Modified: trunk/scipy/linalg/tests/test_decomp.py
===================================================================
--- trunk/scipy/linalg/tests/test_decomp.py 2011-01-25 21:28:41 UTC (rev 7087)
+++ trunk/scipy/linalg/tests/test_decomp.py 2011-01-26 23:06:53 UTC (rev 7088)
@@ -919,7 +919,7 @@
def test_simple(self):
a = [[8,2,3],[2,9,3],[5,3,6]]
r,q = rq(a)
- assert_array_almost_equal(dot(transpose(q),q),identity(3))
+ assert_array_almost_equal(dot(q, transpose(q)),identity(3))
assert_array_almost_equal(dot(r,q),a)
def test_random(self):
@@ -927,54 +927,63 @@
for k in range(2):
a = random([n,n])
r,q = rq(a)
- assert_array_almost_equal(dot(transpose(q),q),identity(n))
+ assert_array_almost_equal(dot(q, transpose(q)),identity(n))
assert_array_almost_equal(dot(r,q),a)
-# TODO: implement support for non-square and complex arrays
+ def test_simple_trap(self):
+ a = [[8,2,3],[2,9,3]]
+ r,q = rq(a)
+ assert_array_almost_equal(dot(transpose(q),q),identity(3))
+ assert_array_almost_equal(dot(r,q),a)
-## def test_simple_trap(self):
-## a = [[8,2,3],[2,9,3]]
-## r,q = rq(a)
-## assert_array_almost_equal(dot(transpose(q),q),identity(2))
-## assert_array_almost_equal(dot(r,q),a)
+ def test_simple_tall(self):
+ a = [[8,2],[2,9],[5,3]]
+ r,q = rq(a)
+ assert_array_almost_equal(dot(transpose(q),q),identity(2))
+ assert_array_almost_equal(dot(r,q),a)
-## def test_simple_tall(self):
-## a = [[8,2],[2,9],[5,3]]
-## r,q = rq(a)
-## assert_array_almost_equal(dot(transpose(q),q),identity(3))
-## assert_array_almost_equal(dot(r,q),a)
+ def test_simple_complex(self):
+ a = [[3,3+4j,5],[5,2,2+7j],[3,2,7]]
+ r,q = rq(a)
+ assert_array_almost_equal(dot(q, conj(transpose(q))),identity(3))
+ assert_array_almost_equal(dot(r,q),a)
-## def test_simple_complex(self):
-## a = [[3,3+4j,5],[5,2,2+7j],[3,2,7]]
-## r,q = rq(a)
-## assert_array_almost_equal(dot(conj(transpose(q)),q),identity(3))
-## assert_array_almost_equal(dot(r,q),a)
+ def test_random_tall(self):
+ m = 200
+ n = 100
+ for k in range(2):
+ a = random([m,n])
+ r,q = rq(a)
+ assert_array_almost_equal(dot(q, transpose(q)),identity(n))
+ assert_array_almost_equal(dot(r,q),a)
-## def test_random_tall(self):
-## m = 200
-## n = 100
-## for k in range(2):
-## a = random([m,n])
-## r,q = rq(a)
-## assert_array_almost_equal(dot(transpose(q),q),identity(m))
-## assert_array_almost_equal(dot(r,q),a)
+ def test_random_trap(self):
+ m = 100
+ n = 200
+ for k in range(2):
+ a = random([m,n])
+ r,q = rq(a)
+ assert_array_almost_equal(dot(q, transpose(q)),identity(n))
+ assert_array_almost_equal(dot(r,q),a)
-## def test_random_trap(self):
-## m = 100
-## n = 200
-## for k in range(2):
-## a = random([m,n])
-## r,q = rq(a)
-## assert_array_almost_equal(dot(transpose(q),q),identity(m))
-## assert_array_almost_equal(dot(r,q),a)
+ def test_random_trap_economic(self):
+ m = 100
+ n = 200
+ for k in range(2):
+ a = random([m,n])
+ r,q = rq(a, mode='economic')
+ assert_array_almost_equal(dot(q,transpose(q)),identity(m))
+ assert_array_almost_equal(dot(r,q),a)
+ assert_equal(q.shape, (m, n))
+ assert_equal(r.shape, (m, m))
-## def test_random_complex(self):
-## n = 20
-## for k in range(2):
-## a = random([n,n])+1j*random([n,n])
-## r,q = rq(a)
-## assert_array_almost_equal(dot(conj(transpose(q)),q),identity(n))
-## assert_array_almost_equal(dot(r,q),a)
+ def test_random_complex(self):
+ n = 20
+ for k in range(2):
+ a = random([n,n])+1j*random([n,n])
+ r,q = rq(a)
+ assert_array_almost_equal(dot(q, conj(transpose(q))),identity(n))
+ assert_array_almost_equal(dot(r,q),a)
transp = transpose
any = sometrue
More information about the Scipy-svn
mailing list