[Scipy-svn] r6274 - in trunk/scipy/linalg: . tests
scipy-svn at scipy.org
scipy-svn at scipy.org
Fri Mar 26 20:17:23 EDT 2010
Author: warren.weckesser
Date: 2010-03-26 19:17:22 -0500 (Fri, 26 Mar 2010)
New Revision: 6274
Modified:
trunk/scipy/linalg/basic.py
trunk/scipy/linalg/tests/test_basic.py
Log:
ENH: Allow linalg.block_diag to accept scalar and 1D arguments (ticket #1128)
Modified: trunk/scipy/linalg/basic.py
===================================================================
--- trunk/scipy/linalg/basic.py 2010-03-26 05:35:35 UTC (rev 6273)
+++ trunk/scipy/linalg/basic.py 2010-03-27 00:17:22 UTC (rev 6274)
@@ -18,7 +18,7 @@
from numpy import asarray, zeros, sum, greater_equal, subtract, arange,\
conjugate, dot, transpose
import numpy
-from numpy import asarray_chkfinite, outer, concatenate, reshape, single
+from numpy import asarray_chkfinite, atleast_2d, outer, concatenate, reshape, single
from numpy import matrix as Matrix
from numpy.linalg import LinAlgError
from scipy.linalg import calc_lwork
@@ -894,7 +894,7 @@
return concatenate(concatenate(o, axis=1), axis=1)
def block_diag(*arrs):
- """Create a diagonal matrix from the provided arrays.
+ """Create a block diagonal matrix from the provided arrays.
Given the inputs `A`, `B` and `C`, the output will have these
arrays arranged on the diagonal::
@@ -908,8 +908,9 @@
Parameters
----------
- A, B, C, ... : 2-D ndarray
- Input arrays.
+ A, B, C, ... : array-like, up to 2D
+ Input arrays. A 1D array or array-like sequence with length n is
+ treated as a 2D array with shape (1,n).
Returns
-------
@@ -929,15 +930,28 @@
>>> B = [[3, 4, 5],
... [6, 7, 8]]
>>> C = [[7]]
- >>> print block_diag(A, B, C)
- [[ 1. 0. 0. 0. 0. 0.]
- [ 0. 1. 0. 0. 0. 0.]
- [ 0. 0. 3. 4. 5. 0.]
- [ 0. 0. 6. 7. 8. 0.]
- [ 0. 0. 0. 0. 0. 7.]]
+ >>> print(block_diag(A, B, C))
+ [[1 0 0 0 0 0]
+ [0 1 0 0 0 0]
+ [0 0 3 4 5 0]
+ [0 0 6 7 8 0]
+ [0 0 0 0 0 7]]
+ >>> block_diag(1.0, [2, 3], [[4, 5], [6, 7]])
+ array([[ 1., 0., 0., 0., 0.],
+ [ 0., 2., 3., 0., 0.],
+ [ 0., 0., 0., 4., 5.],
+ [ 0., 0., 0., 6., 7.]])
"""
- arrs = [asarray(a) for a in arrs]
+ if arrs == ():
+ arrs = ([],)
+ arrs = [atleast_2d(a) for a in arrs]
+
+ bad_args = [k for k in range(len(arrs)) if arrs[k].ndim > 2]
+ if bad_args:
+ raise ValueError("arguments in the following positions have dimension "
+ "greater than 2: %s" % bad_args)
+
shapes = numpy.array([a.shape for a in arrs])
out = zeros(sum(shapes, axis=0), dtype=arrs[0].dtype)
@@ -947,4 +961,3 @@
r += rr
c += cc
return out
-
Modified: trunk/scipy/linalg/tests/test_basic.py
===================================================================
--- trunk/scipy/linalg/tests/test_basic.py 2010-03-26 05:35:35 UTC (rev 6273)
+++ trunk/scipy/linalg/tests/test_basic.py 2010-03-27 00:17:22 UTC (rev 6274)
@@ -463,7 +463,24 @@
x = block_diag([[True]])
assert_equal(x.dtype, bool)
+
+ def test_scalar_and_1d_args(self):
+ a = block_diag(1)
+ assert_equal(a.shape, (1,1))
+ assert_array_equal(a, [[1]])
+
+ a = block_diag([2,3], 4)
+ assert_array_equal(a, [[2, 3, 0], [0, 0, 4]])
+ def test_bad_arg(self):
+ assert_raises(ValueError, block_diag, [[[1]]])
+
+ def test_no_args(self):
+ a = block_diag()
+ assert_equal(a.ndim, 2)
+ assert_equal(a.nbytes, 0)
+
+
class TestPinv(TestCase):
def test_simple(self):
More information about the Scipy-svn
mailing list