[Scipy-svn] r6274 - in trunk/scipy/linalg: . tests

scipy-svn at scipy.org scipy-svn at scipy.org
Fri Mar 26 20:17:23 EDT 2010


Author: warren.weckesser
Date: 2010-03-26 19:17:22 -0500 (Fri, 26 Mar 2010)
New Revision: 6274

Modified:
   trunk/scipy/linalg/basic.py
   trunk/scipy/linalg/tests/test_basic.py
Log:
ENH: Allow linalg.block_diag to accept scalar and 1D arguments (ticket #1128)

Modified: trunk/scipy/linalg/basic.py
===================================================================
--- trunk/scipy/linalg/basic.py	2010-03-26 05:35:35 UTC (rev 6273)
+++ trunk/scipy/linalg/basic.py	2010-03-27 00:17:22 UTC (rev 6274)
@@ -18,7 +18,7 @@
 from numpy import asarray, zeros, sum, greater_equal, subtract, arange,\
      conjugate, dot, transpose
 import numpy
-from numpy import asarray_chkfinite, outer, concatenate, reshape, single
+from numpy import asarray_chkfinite, atleast_2d, outer, concatenate, reshape, single
 from numpy import matrix as Matrix
 from numpy.linalg import LinAlgError
 from scipy.linalg import calc_lwork
@@ -894,7 +894,7 @@
     return concatenate(concatenate(o, axis=1), axis=1)
 
 def block_diag(*arrs):
-    """Create a diagonal matrix from the provided arrays.
+    """Create a block diagonal matrix from the provided arrays.
 
     Given the inputs `A`, `B` and `C`, the output will have these
     arrays arranged on the diagonal::
@@ -908,8 +908,9 @@
 
     Parameters
     ----------
-    A, B, C, ... : 2-D ndarray
-        Input arrays.
+    A, B, C, ... : array-like, up to 2D
+        Input arrays.  A 1D array or array-like sequence with length n is
+        treated as a 2D array with shape (1,n).
 
     Returns
     -------
@@ -929,15 +930,28 @@
     >>> B = [[3, 4, 5],
     ...      [6, 7, 8]]
     >>> C = [[7]]
-    >>> print block_diag(A, B, C)
-    [[ 1.  0.  0.  0.  0.  0.]
-     [ 0.  1.  0.  0.  0.  0.]
-     [ 0.  0.  3.  4.  5.  0.]
-     [ 0.  0.  6.  7.  8.  0.]
-     [ 0.  0.  0.  0.  0.  7.]]
+    >>> print(block_diag(A, B, C))
+    [[1 0 0 0 0 0]
+     [0 1 0 0 0 0]
+     [0 0 3 4 5 0]
+     [0 0 6 7 8 0]
+     [0 0 0 0 0 7]]
+    >>> block_diag(1.0, [2, 3], [[4, 5], [6, 7]])
+    array([[ 1.,  0.,  0.,  0.,  0.],
+           [ 0.,  2.,  3.,  0.,  0.],
+           [ 0.,  0.,  0.,  4.,  5.],
+           [ 0.,  0.,  0.,  6.,  7.]])
 
     """
-    arrs = [asarray(a) for a in arrs]
+    if arrs == ():
+        arrs = ([],)
+    arrs = [atleast_2d(a) for a in arrs]
+
+    bad_args = [k for k in range(len(arrs)) if arrs[k].ndim > 2]
+    if bad_args:
+        raise ValueError("arguments in the following positions have dimension "
+                            "greater than 2: %s" % bad_args) 
+
     shapes = numpy.array([a.shape for a in arrs])
     out = zeros(sum(shapes, axis=0), dtype=arrs[0].dtype)
 
@@ -947,4 +961,3 @@
         r += rr
         c += cc
     return out
-

Modified: trunk/scipy/linalg/tests/test_basic.py
===================================================================
--- trunk/scipy/linalg/tests/test_basic.py	2010-03-26 05:35:35 UTC (rev 6273)
+++ trunk/scipy/linalg/tests/test_basic.py	2010-03-27 00:17:22 UTC (rev 6274)
@@ -463,7 +463,24 @@
 
         x = block_diag([[True]])
         assert_equal(x.dtype, bool)
+        
+    def test_scalar_and_1d_args(self):
+        a = block_diag(1)
+        assert_equal(a.shape, (1,1))
+        assert_array_equal(a, [[1]])
+        
+        a = block_diag([2,3], 4)
+        assert_array_equal(a, [[2, 3, 0], [0, 0, 4]])
 
+    def test_bad_arg(self):
+        assert_raises(ValueError, block_diag, [[[1]]])
+
+    def test_no_args(self):
+        a = block_diag()
+        assert_equal(a.ndim, 2)
+        assert_equal(a.nbytes, 0)
+
+
 class TestPinv(TestCase):
 
     def test_simple(self):




More information about the Scipy-svn mailing list