[Scipy-svn] r7116 - in trunk/scipy/linalg: . tests
scipy-svn at scipy.org
scipy-svn at scipy.org
Mon Jan 31 16:52:58 EST 2011
Author: ptvirtan
Date: 2011-01-31 15:52:58 -0600 (Mon, 31 Jan 2011)
New Revision: 7116
Added:
trunk/scipy/linalg/_testutils.py
Modified:
trunk/scipy/linalg/basic.py
trunk/scipy/linalg/decomp.py
trunk/scipy/linalg/decomp_cholesky.py
trunk/scipy/linalg/decomp_lu.py
trunk/scipy/linalg/decomp_qr.py
trunk/scipy/linalg/decomp_schur.py
trunk/scipy/linalg/decomp_svd.py
trunk/scipy/linalg/misc.py
trunk/scipy/linalg/tests/test_basic.py
trunk/scipy/linalg/tests/test_decomp.py
trunk/scipy/linalg/tests/test_decomp_cholesky.py
Log:
BUG: linalg: more robust data ovewrite behavior
Some routines in linalg used an invalid way to determine if data can be
overwritten, which fails for non-ndarray objects providing an array
interface but no __array__ method.
Also add new tests checking the data overwrite behavior.
Added: trunk/scipy/linalg/_testutils.py
===================================================================
--- trunk/scipy/linalg/_testutils.py (rev 0)
+++ trunk/scipy/linalg/_testutils.py 2011-01-31 21:52:58 UTC (rev 7116)
@@ -0,0 +1,57 @@
+import numpy as np
+
+class _FakeMatrix(object):
+ def __init__(self, data):
+ self._data = data
+ self.__array_interface__ = data.__array_interface__
+
+class _FakeMatrix2(object):
+ def __init__(self, data):
+ self._data = data
+ def __array__(self):
+ return self._data
+
+def _get_array(shape, dtype):
+ """
+ Get a test array of given shape and data type.
+ Returned NxN matrices are posdef, and 2xN are banded-posdef.
+
+ """
+ if len(shape) == 2 and shape[0] == 2:
+ # yield a banded positive definite one
+ x = np.zeros(shape, dtype=dtype)
+ x[0,1:] = -1
+ x[1] = 2
+ return x
+ elif len(shape) == 2 and shape[0] == shape[1]:
+ # always yield a positive definite matrix
+ x = np.zeros(shape, dtype=dtype)
+ j = np.arange(shape[0])
+ x[j,j] = 2
+ x[j[:-1],j[:-1]+1] = -1
+ x[j[:-1]+1,j[:-1]] = -1
+ return x
+ else:
+ np.random.seed(1234)
+ return np.random.randn(*shape).astype(dtype)
+
+def _id(x):
+ return x
+
+def assert_no_overwrite(call, shapes, dtypes=None):
+ """
+ Test that a call does not overwrite its input arguments
+ """
+
+ if dtypes is None:
+ dtypes = [np.float32, np.float64, np.complex64, np.complex128]
+
+ for dtype in dtypes:
+ for order in ["C", "F"]:
+ for faker in [_id, _FakeMatrix, _FakeMatrix2]:
+ orig_inputs = [_get_array(s, dtype) for s in shapes]
+ inputs = [faker(x.copy(order)) for x in orig_inputs]
+ call(*inputs)
+ msg = "call modified inputs [%r, %r]" % (dtype, faker)
+ for a, b in zip(inputs, orig_inputs):
+ np.testing.assert_equal(a, b, err_msg=msg)
Modified: trunk/scipy/linalg/basic.py
===================================================================
--- trunk/scipy/linalg/basic.py 2011-01-31 21:52:21 UTC (rev 7115)
+++ trunk/scipy/linalg/basic.py 2011-01-31 21:52:58 UTC (rev 7116)
@@ -12,7 +12,7 @@
from flinalg import get_flinalg_funcs
from lapack import get_lapack_funcs
-from misc import LinAlgError
+from misc import LinAlgError, _datacopied
from scipy.linalg import calc_lwork
import decomp_svd
@@ -49,8 +49,8 @@
raise ValueError('expected square matrix')
if a1.shape[0] != b1.shape[0]:
raise ValueError('incompatible dimensions')
- overwrite_a = overwrite_a or (a1 is not a and not hasattr(a,'__array__'))
- overwrite_b = overwrite_b or (b1 is not b and not hasattr(b,'__array__'))
+ overwrite_a = overwrite_a or _datacopied(a1, a)
+ overwrite_b = overwrite_b or _datacopied(b1, b)
if debug:
print 'solve:overwrite_a=',overwrite_a
print 'solve:overwrite_b=',overwrite_b
@@ -117,7 +117,7 @@
raise ValueError('expected square matrix')
if a1.shape[0] != b1.shape[0]:
raise ValueError('incompatible dimensions')
- overwrite_b = overwrite_b or (b1 is not b and not hasattr(b,'__array__'))
+ overwrite_b = overwrite_b or _datacopied(b1, b)
if debug:
print 'solve:overwrite_b=',overwrite_b
trans = {'N': 0, 'T': 1, 'C': 2}.get(trans, trans)
@@ -174,7 +174,7 @@
raise ValueError("invalid values for the number of lower and upper diagonals:"
" l+u+1 (%d) does not equal ab.shape[0] (%d)" % (l+u+1, ab.shape[0]))
- overwrite_b = overwrite_b or (b1 is not b and not hasattr(b,'__array__'))
+ overwrite_b = overwrite_b or _datacopied(b1, b)
gbsv, = get_lapack_funcs(('gbsv',), (a1, b1))
a2 = zeros((2*l+u+1, a1.shape[1]), dtype=gbsv.dtype)
@@ -285,7 +285,7 @@
a1 = asarray_chkfinite(a)
if len(a1.shape) != 2 or a1.shape[0] != a1.shape[1]:
raise ValueError('expected square matrix')
- overwrite_a = overwrite_a or (a1 is not a and not hasattr(a,'__array__'))
+ overwrite_a = overwrite_a or _datacopied(a1, a)
#XXX: I found no advantage or disadvantage of using finv.
## finv, = get_flinalg_funcs(('inv',),(a1,))
## if finv is not None:
@@ -350,7 +350,7 @@
a1 = asarray_chkfinite(a)
if len(a1.shape) != 2 or a1.shape[0] != a1.shape[1]:
raise ValueError('expected square matrix')
- overwrite_a = overwrite_a or (a1 is not a and not hasattr(a,'__array__'))
+ overwrite_a = overwrite_a or _datacopied(a1, a)
fdet, = get_flinalg_funcs(('det',), (a1,))
a_det, info = fdet(a1, overwrite_a=overwrite_a)
if info < 0:
@@ -426,8 +426,8 @@
else:
b2[:m,0] = b1
b1 = b2
- overwrite_a = overwrite_a or (a1 is not a and not hasattr(a,'__array__'))
- overwrite_b = overwrite_b or (b1 is not b and not hasattr(b,'__array__'))
+ overwrite_a = overwrite_a or _datacopied(a1, a)
+ overwrite_b = overwrite_b or _datacopied(b1, b)
if gelss.module_name[:7] == 'flapack':
lwork = calc_lwork.gelss(gelss.prefix, m, n, nrhs)[1]
v, x, s, rank, info = gelss(a1, b1, cond=cond, lwork=lwork,
Modified: trunk/scipy/linalg/decomp.py
===================================================================
--- trunk/scipy/linalg/decomp.py 2011-01-31 21:52:21 UTC (rev 7115)
+++ trunk/scipy/linalg/decomp.py 2011-01-31 21:52:58 UTC (rev 7116)
@@ -21,7 +21,7 @@
# Local imports
from scipy.linalg import calc_lwork
-from misc import LinAlgError, _datanotshared
+from misc import LinAlgError, _datacopied
from lapack import get_lapack_funcs
from blas import get_blas_funcs
@@ -43,7 +43,7 @@
def _geneig(a1, b, left, right, overwrite_a, overwrite_b):
b1 = asarray(b)
- overwrite_b = overwrite_b or _datanotshared(b1, b)
+ overwrite_b = overwrite_b or _datacopied(b1, b)
if len(b1.shape) != 2 or b1.shape[0] != b1.shape[1]:
raise ValueError('expected square matrix')
ggev, = get_lapack_funcs(('ggev',), (a1, b1))
@@ -135,7 +135,7 @@
a1 = asarray_chkfinite(a)
if len(a1.shape) != 2 or a1.shape[0] != a1.shape[1]:
raise ValueError('expected square matrix')
- overwrite_a = overwrite_a or (_datanotshared(a1, a))
+ overwrite_a = overwrite_a or (_datacopied(a1, a))
if b is not None:
b = asarray_chkfinite(b)
if b.shape != a1.shape:
@@ -265,14 +265,14 @@
a1 = asarray_chkfinite(a)
if len(a1.shape) != 2 or a1.shape[0] != a1.shape[1]:
raise ValueError('expected square matrix')
- overwrite_a = overwrite_a or (_datanotshared(a1, a))
+ overwrite_a = overwrite_a or (_datacopied(a1, a))
if iscomplexobj(a1):
cplx = True
else:
cplx = False
if b is not None:
b1 = asarray_chkfinite(b)
- overwrite_b = overwrite_b or _datanotshared(b1, b)
+ overwrite_b = overwrite_b or _datacopied(b1, b)
if len(b1.shape) != 2 or b1.shape[0] != b1.shape[1]:
raise ValueError('expected square matrix')
@@ -455,7 +455,7 @@
"""
if eigvals_only or overwrite_a_band:
a1 = asarray_chkfinite(a_band)
- overwrite_a_band = overwrite_a_band or (_datanotshared(a1, a_band))
+ overwrite_a_band = overwrite_a_band or (_datacopied(a1, a_band))
else:
a1 = array(a_band)
if issubclass(a1.dtype.type, inexact) and not isfinite(a1).all():
@@ -734,9 +734,9 @@
a1 = asarray(a)
if len(a1.shape) != 2 or (a1.shape[0] != a1.shape[1]):
raise ValueError('expected square matrix')
- overwrite_a = overwrite_a or (_datanotshared(a1, a))
+ overwrite_a = overwrite_a or (_datacopied(a1, a))
gehrd,gebal = get_lapack_funcs(('gehrd','gebal'), (a1,))
- ba, lo, hi, pivscale, info = gebal(a, permute=1, overwrite_a=overwrite_a)
+ ba, lo, hi, pivscale, info = gebal(a1, permute=1, overwrite_a=overwrite_a)
if info < 0:
raise ValueError('illegal value in %d-th argument of internal gebal '
'(hessenberg)' % -info)
Modified: trunk/scipy/linalg/decomp_cholesky.py
===================================================================
--- trunk/scipy/linalg/decomp_cholesky.py 2011-01-31 21:52:21 UTC (rev 7115)
+++ trunk/scipy/linalg/decomp_cholesky.py 2011-01-31 21:52:58 UTC (rev 7116)
@@ -3,7 +3,7 @@
from numpy import asarray_chkfinite
# Local imports
-from misc import LinAlgError, _datanotshared
+from misc import LinAlgError, _datacopied
from lapack import get_lapack_funcs
__all__ = ['cholesky', 'cho_factor', 'cho_solve', 'cholesky_banded',
@@ -17,7 +17,7 @@
if len(a1.shape) != 2 or a1.shape[0] != a1.shape[1]:
raise ValueError('expected square matrix')
- overwrite_a = overwrite_a or _datanotshared(a1, a)
+ overwrite_a = overwrite_a or _datacopied(a1, a)
potrf, = get_lapack_funcs(('potrf',), (a1,))
c, info = potrf(a1, lower=lower, overwrite_a=overwrite_a, clean=clean)
if info > 0:
@@ -104,7 +104,7 @@
See also
--------
- cho_solve : Solve a linear set equations using the Cholesky factorization
+ cho_solve : Solve a linear set equations using the Cholesky factorization
of a matrix.
"""
@@ -140,7 +140,7 @@
if c.shape[1] != b1.shape[0]:
raise ValueError("incompatible dimensions.")
- overwrite_b = overwrite_b or (b1 is not b and not hasattr(b,'__array__'))
+ overwrite_b = overwrite_b or _datacopied(b1, b)
potrs, = get_lapack_funcs(('potrs',), (c, b1))
x, info = potrs(c, b1, lower=lower, overwrite_b=overwrite_b)
@@ -208,7 +208,7 @@
b : array
Right-hand side
overwrite_b : bool
- If True, the function will overwrite the values in `b`.
+ If True, the function will overwrite the values in `b`.
Returns
-------
@@ -221,7 +221,7 @@
Notes
-----
-
+
.. versionadded:: 0.8.0
"""
Modified: trunk/scipy/linalg/decomp_lu.py
===================================================================
--- trunk/scipy/linalg/decomp_lu.py 2011-01-31 21:52:21 UTC (rev 7115)
+++ trunk/scipy/linalg/decomp_lu.py 2011-01-31 21:52:58 UTC (rev 7116)
@@ -5,7 +5,7 @@
from numpy import asarray, asarray_chkfinite
# Local imports
-from misc import _datanotshared
+from misc import _datacopied
from lapack import get_lapack_funcs
from flinalg import get_flinalg_funcs
@@ -48,9 +48,9 @@
a1 = asarray(a)
if len(a1.shape) != 2 or (a1.shape[0] != a1.shape[1]):
raise ValueError('expected square matrix')
- overwrite_a = overwrite_a or (_datanotshared(a1, a))
+ overwrite_a = overwrite_a or (_datacopied(a1, a))
getrf, = get_lapack_funcs(('getrf',), (a1,))
- lu, piv, info = getrf(a, overwrite_a=overwrite_a)
+ lu, piv, info = getrf(a1, overwrite_a=overwrite_a)
if info < 0:
raise ValueError('illegal value in %d-th argument of '
'internal getrf (lu_factor)' % -info)
@@ -91,7 +91,7 @@
"""
b1 = asarray_chkfinite(b)
- overwrite_b = overwrite_b or (b1 is not b and not hasattr(b, '__array__'))
+ overwrite_b = overwrite_b or _datacopied(b1, b)
if lu.shape[0] != b1.shape[0]:
raise ValueError("incompatible dimensions.")
@@ -148,7 +148,7 @@
a1 = asarray_chkfinite(a)
if len(a1.shape) != 2:
raise ValueError('expected matrix')
- overwrite_a = overwrite_a or (_datanotshared(a1, a))
+ overwrite_a = overwrite_a or (_datacopied(a1, a))
flu, = get_flinalg_funcs(('lu',), (a1,))
p, l, u, info = flu(a1, permute_l=permute_l, overwrite_a=overwrite_a)
if info < 0:
Modified: trunk/scipy/linalg/decomp_qr.py
===================================================================
--- trunk/scipy/linalg/decomp_qr.py 2011-01-31 21:52:21 UTC (rev 7115)
+++ trunk/scipy/linalg/decomp_qr.py 2011-01-31 21:52:58 UTC (rev 7116)
@@ -7,7 +7,7 @@
import special_matrices
from blas import get_blas_funcs
from lapack import get_lapack_funcs, find_best_lapack_type
-from misc import _datanotshared
+from misc import _datacopied
def qr(a, overwrite_a=False, lwork=None, mode='full'):
@@ -77,7 +77,7 @@
if len(a1.shape) != 2:
raise ValueError("expected 2D array")
M, N = a1.shape
- overwrite_a = overwrite_a or (_datanotshared(a1, a))
+ overwrite_a = overwrite_a or (_datacopied(a1, a))
geqrf, = get_lapack_funcs(('geqrf',), (a1,))
if lwork is None or lwork == -1:
@@ -157,7 +157,7 @@
if len(a1.shape) != 2:
raise ValueError('expected matrix')
M,N = a1.shape
- overwrite_a = overwrite_a or (_datanotshared(a1, a))
+ overwrite_a = overwrite_a or (_datacopied(a1, a))
geqrf, = get_lapack_funcs(('geqrf',), (a1,))
if lwork is None or lwork == -1:
# get optimal work array
@@ -235,7 +235,7 @@
if len(a1.shape) != 2:
raise ValueError('expected matrix')
M, N = a1.shape
- overwrite_a = overwrite_a or (_datanotshared(a1, a))
+ overwrite_a = overwrite_a or (_datacopied(a1, a))
gerqf, = get_lapack_funcs(('gerqf',), (a1,))
if lwork is None or lwork == -1:
Modified: trunk/scipy/linalg/decomp_schur.py
===================================================================
--- trunk/scipy/linalg/decomp_schur.py 2011-01-31 21:52:21 UTC (rev 7115)
+++ trunk/scipy/linalg/decomp_schur.py 2011-01-31 21:52:58 UTC (rev 7116)
@@ -5,7 +5,7 @@
# Local imports.
import misc
-from misc import LinAlgError, _datanotshared
+from misc import LinAlgError, _datacopied
from lapack import get_lapack_funcs
from decomp import eigvals
@@ -63,13 +63,13 @@
else:
a1 = a1.astype('F')
typ = 'F'
- overwrite_a = overwrite_a or (_datanotshared(a1, a))
+ overwrite_a = overwrite_a or (_datacopied(a1, a))
gees, = get_lapack_funcs(('gees',), (a1,))
if lwork is None or lwork == -1:
# get optimal work array
- result = gees(lambda x: None, a, lwork=-1)
+ result = gees(lambda x: None, a1, lwork=-1)
lwork = result[-2][0].real.astype(numpy.int)
- result = gees(lambda x: None, a, lwork=lwork, overwrite_a=overwrite_a)
+ result = gees(lambda x: None, a1, lwork=lwork, overwrite_a=overwrite_a)
info = result[-1]
if info < 0:
raise ValueError('illegal value in %d-th argument of internal gees'
Modified: trunk/scipy/linalg/decomp_svd.py
===================================================================
--- trunk/scipy/linalg/decomp_svd.py 2011-01-31 21:52:21 UTC (rev 7115)
+++ trunk/scipy/linalg/decomp_svd.py 2011-01-31 21:52:58 UTC (rev 7116)
@@ -5,7 +5,7 @@
from scipy.linalg import calc_lwork
# Local imports.
-from misc import LinAlgError, _datanotshared
+from misc import LinAlgError, _datacopied
from lapack import get_lapack_funcs
@@ -73,7 +73,7 @@
if len(a1.shape) != 2:
raise ValueError('expected matrix')
m,n = a1.shape
- overwrite_a = overwrite_a or (_datanotshared(a1, a))
+ overwrite_a = overwrite_a or (_datacopied(a1, a))
gesdd, = get_lapack_funcs(('gesdd',), (a1,))
if gesdd.module_name[:7] == 'flapack':
lwork = calc_lwork.gesdd(gesdd.prefix, m, n, compute_uv)[1]
Modified: trunk/scipy/linalg/misc.py
===================================================================
--- trunk/scipy/linalg/misc.py 2011-01-31 21:52:21 UTC (rev 7115)
+++ trunk/scipy/linalg/misc.py 2011-01-31 21:52:58 UTC (rev 7116)
@@ -9,13 +9,14 @@
return np.linalg.norm(np.asarray_chkfinite(a), ord=ord)
norm.__doc__ = np.linalg.norm.__doc__
+def _datacopied(arr, original):
+ """
+ Strict check for `arr` not sharing any data with `original`,
+ under the assumption that arr = asarray(original)
-def _datanotshared(a1,a):
- if a1 is a:
+ """
+ if arr is original:
return False
- else:
- #try comparing data pointers
- try:
- return a1.__array_interface__['data'][0] != a.__array_interface__['data'][0]
- except:
- return True
\ No newline at end of file
+ if not isinstance(original, np.ndarray) and hasattr(original, '__array__'):
+ return False
+ return arr.base is None
Modified: trunk/scipy/linalg/tests/test_basic.py
===================================================================
--- trunk/scipy/linalg/tests/test_basic.py 2011-01-31 21:52:21 UTC (rev 7115)
+++ trunk/scipy/linalg/tests/test_basic.py 2011-01-31 21:52:58 UTC (rev 7116)
@@ -29,6 +29,7 @@
from scipy.linalg import solve, inv, det, lstsq, pinv, pinv2, norm,\
solve_banded, solveh_banded, solve_triangular
+from scipy.linalg._testutils import assert_no_overwrite
def random(size):
return rand(*size)
@@ -561,5 +562,26 @@
assert_equal(norm([1,0,3], 0), 2)
assert_equal(norm([1,2,3], 0), 3)
+class TestOverwrite(object):
+ def test_solve(self):
+ assert_no_overwrite(solve, [(3,3), (3,)])
+ def test_solve_triangular(self):
+ assert_no_overwrite(solve_triangular, [(3,3), (3,)])
+ def test_solve_banded(self):
+ assert_no_overwrite(lambda ab, b: solve_banded((2,1), ab, b),
+ [(4,6), (6,)])
+ def test_solveh_banded(self):
+ assert_no_overwrite(solveh_banded, [(2,6), (6,)])
+ def test_inv(self):
+ assert_no_overwrite(inv, [(3,3)])
+ def test_det(self):
+ assert_no_overwrite(det, [(3,3)])
+ def test_lstsq(self):
+ assert_no_overwrite(lstsq, [(3,2), (3,)])
+ def test_pinv(self):
+ assert_no_overwrite(pinv, [(3,3)])
+ def test_pinv2(self):
+ assert_no_overwrite(pinv2, [(3,3)])
+
if __name__ == "__main__":
run_module_suite()
Modified: trunk/scipy/linalg/tests/test_decomp.py
===================================================================
--- trunk/scipy/linalg/tests/test_decomp.py 2011-01-31 21:52:21 UTC (rev 7115)
+++ trunk/scipy/linalg/tests/test_decomp.py 2011-01-31 21:52:58 UTC (rev 7116)
@@ -20,7 +20,7 @@
from scipy.linalg import eig, eigvals, lu, svd, svdvals, cholesky, qr, \
schur, rsf2csf, lu_solve, lu_factor, solve, diagsvd, hessenberg, rq, \
- eig_banded, eigvals_banded, eigh
+ eig_banded, eigvals_banded, eigh, eigvalsh
from scipy.linalg.flapack import dgbtrf, dgbtrs, zgbtrf, zgbtrs, \
dsbev, dsbevd, dsbevx, zhbevd, zhbevx
@@ -32,6 +32,8 @@
from numpy.random import rand, normal, seed
+from scipy.linalg._testutils import assert_no_overwrite
+
# digit precision to use in asserts for different types
DIGITS = {'d':11, 'D':11, 'f':4, 'F':4}
@@ -1101,24 +1103,36 @@
-class TestDataNotShared(TestCase):
+class TestDatacopied(TestCase):
- def test_datanotshared(self):
- from scipy.linalg.decomp import _datanotshared
+ def test_datacopied(self):
+ from scipy.linalg.decomp import _datacopied
M = matrix([[0,1],[2,3]])
A = asarray(M)
L = M.tolist()
M2 = M.copy()
- assert_equal(_datanotshared(M,M),False)
- assert_equal(_datanotshared(M,A),False)
+ class Fake1:
+ def __array__(self):
+ return A
- assert_equal(_datanotshared(M,L),True)
- assert_equal(_datanotshared(M,M2),True)
- assert_equal(_datanotshared(A,M2),True)
+ class Fake2:
+ __array_interface__ = A.__array_interface__
+ F1 = Fake1()
+ F2 = Fake2()
+ AF1 = asarray(F1)
+ AF2 = asarray(F2)
+
+ for item, status in [(M, False), (A, False), (L, True),
+ (M2, False), (F1, False), (F2, False)]:
+ arr = asarray(item)
+ assert_equal(_datacopied(arr, item), status,
+ err_msg=repr(item))
+
+
def test_aligned_mem_float():
"""Check linalg works with non-aligned memory"""
# Allocate 402 bytes of memory (allocated on boundary)
@@ -1207,5 +1221,45 @@
# not properly tested
# cholesky, rsf2csf, lu_solve, solve, eig_banded, eigvals_banded, eigh, diagsvd
+
+class TestOverwrite(object):
+ def test_eig(self):
+ assert_no_overwrite(eig, [(3,3)])
+ assert_no_overwrite(eig, [(3,3), (3,3)])
+ def test_eigh(self):
+ assert_no_overwrite(eigh, [(3,3)])
+ assert_no_overwrite(eigh, [(3,3), (3,3)])
+ def test_eig_banded(self):
+ assert_no_overwrite(eig_banded, [(3,2)])
+ def test_eigvals(self):
+ assert_no_overwrite(eigvals, [(3,3)])
+ def test_eigvalsh(self):
+ assert_no_overwrite(eigvalsh, [(3,3)])
+ def test_eigvals_banded(self):
+ assert_no_overwrite(eigvals_banded, [(3,2)])
+ def test_hessenberg(self):
+ assert_no_overwrite(hessenberg, [(3,3)])
+ def test_lu_factor(self):
+ assert_no_overwrite(lu_factor, [(3,3)])
+ def test_lu_solve(self):
+ x = np.array([[1,2,3], [4,5,6], [7,8,8]])
+ xlu = lu_factor(x)
+ assert_no_overwrite(lambda b: lu_solve(xlu, b), [(3,)])
+ def test_lu(self):
+ assert_no_overwrite(lu, [(3,3)])
+ def test_qr(self):
+ assert_no_overwrite(qr, [(3,3)])
+ def test_rq(self):
+ assert_no_overwrite(rq, [(3,3)])
+ def test_schur(self):
+ assert_no_overwrite(schur, [(3,3)])
+ def test_schur_complex(self):
+ assert_no_overwrite(lambda a: schur(a, 'complex'), [(3,3)],
+ dtypes=[np.float32, np.float64])
+ def test_svd(self):
+ assert_no_overwrite(svd, [(3,3)])
+ def test_svdvals(self):
+ assert_no_overwrite(svdvals, [(3,3)])
+
if __name__ == "__main__":
run_module_suite()
Modified: trunk/scipy/linalg/tests/test_decomp_cholesky.py
===================================================================
--- trunk/scipy/linalg/tests/test_decomp_cholesky.py 2011-01-31 21:52:21 UTC (rev 7115)
+++ trunk/scipy/linalg/tests/test_decomp_cholesky.py 2011-01-31 21:52:58 UTC (rev 7116)
@@ -4,8 +4,10 @@
from numpy import array, transpose, dot, conjugate, zeros_like
from numpy.random import rand
-from scipy.linalg import cholesky, cholesky_banded, cho_solve_banded
+from scipy.linalg import cholesky, cholesky_banded, cho_solve_banded, \
+ cho_factor, cho_solve
+from scipy.linalg._testutils import assert_no_overwrite
def random(size):
return rand(*size)
@@ -138,3 +140,20 @@
b = array([0.0, 0.5j, 3.8j, 3.8])
x = cho_solve_banded((c, True), b)
assert_array_almost_equal(x, [0.0, 0.0, 1.0j, 1.0])
+
+class TestOverwrite(object):
+ def test_cholesky(self):
+ assert_no_overwrite(cholesky, [(3,3)])
+ def test_cho_factor(self):
+ assert_no_overwrite(cho_factor, [(3,3)])
+ def test_cho_solve(self):
+ x = array([[2,-1,0], [-1,2,-1], [0,-1,2]])
+ xcho = cho_factor(x)
+ assert_no_overwrite(lambda b: cho_solve(xcho, b), [(3,)])
+ def test_cholesky_banded(self):
+ assert_no_overwrite(cholesky_banded, [(2,3)])
+ def test_cho_solve_banded(self):
+ x = array([[0, -1, -1], [2, 2, 2]])
+ xcho = cholesky_banded(x)
+ assert_no_overwrite(lambda b: cho_solve_banded((xcho, False), b),
+ [(3,)])
More information about the Scipy-svn
mailing list