[Scipy-svn] r3745 - trunk/scipy/sparse

scipy-svn at scipy.org scipy-svn at scipy.org
Sat Dec 29 15:35:34 EST 2007


Author: wnbell
Date: 2007-12-29 14:35:32 -0600 (Sat, 29 Dec 2007)
New Revision: 3745

Modified:
   trunk/scipy/sparse/construct.py
Log:
made spkron(A,B) more efficient for nearly dense B


Modified: trunk/scipy/sparse/construct.py
===================================================================
--- trunk/scipy/sparse/construct.py	2007-12-29 02:53:36 UTC (rev 3744)
+++ trunk/scipy/sparse/construct.py	2007-12-29 20:35:32 UTC (rev 3745)
@@ -12,6 +12,7 @@
 
 from csr import csr_matrix, isspmatrix_csr
 from csc import csc_matrix, isspmatrix_csc
+from bsr import bsr_matrix
 from coo import coo_matrix
 from dok import dok_matrix
 from lil import lil_matrix
@@ -110,36 +111,56 @@
 
     """
     #TODO optimize for small dense B and CSR A -> BSR
-    A,B = coo_matrix(A),coo_matrix(B)
-    output_shape = (A.shape[0]*B.shape[0],A.shape[1]*B.shape[1])
+    B = coo_matrix(B)
 
-    if A.nnz == 0 or B.nnz == 0:
-        # kronecker product is the zero matrix
-        return coo_matrix( output_shape )
+    
+    if 2*B.nnz >= B.shape[0] * B.shape[1]:
+        #B is fairly dense, use BSR
+        A = csr_matrix(A,copy=True)
+        
+        output_shape = (A.shape[0]*B.shape[0],A.shape[1]*B.shape[1])
 
-    # expand entries of a into blocks
-    row  = A.row.repeat(B.nnz)
-    col  = A.col.repeat(B.nnz)
-    data = A.data.repeat(B.nnz)
+        if A.nnz == 0 or B.nnz == 0:
+            # kronecker product is the zero matrix
+            return coo_matrix( output_shape )
+        
+        B = B.toarray()
+        data = A.data.repeat(B.size).reshape(-1,B.shape[0],B.shape[1])
+        data = data * B
+        
+        return bsr_matrix((data,A.indices,A.indptr),shape=output_shape)
+    else:
+        #use COO
+        A = coo_matrix(A)
+        output_shape = (A.shape[0]*B.shape[0],A.shape[1]*B.shape[1])
 
-    row *= B.shape[0]
-    col *= B.shape[1]
+        if A.nnz == 0 or B.nnz == 0:
+            # kronecker product is the zero matrix
+            return coo_matrix( output_shape )
 
-    # increment block indices
-    row,col = row.reshape(-1,B.nnz),col.reshape(-1,B.nnz)
-    row += B.row
-    col += B.col
-    row,col = row.reshape(-1),col.reshape(-1)
+        # expand entries of a into blocks
+        row  = A.row.repeat(B.nnz)
+        col  = A.col.repeat(B.nnz)
+        data = A.data.repeat(B.nnz)
 
-    # compute block entries
-    data = data.reshape(-1,B.nnz) * B.data
-    data = data.reshape(-1)
+        row *= B.shape[0]
+        col *= B.shape[1]
 
-    return coo_matrix((data,(row,col)), shape=output_shape).asformat(format)
+        # increment block indices
+        row,col = row.reshape(-1,B.nnz),col.reshape(-1,B.nnz)
+        row += B.row
+        col += B.col
+        row,col = row.reshape(-1),col.reshape(-1)
 
+        # compute block entries
+        data = data.reshape(-1,B.nnz) * B.data
+        data = data.reshape(-1)
 
+        return coo_matrix((data,(row,col)), shape=output_shape).asformat(format)
 
 
+
+
 def lil_eye((r,c), k=0, dtype='d'):
     """Generate a lil_matrix of dimensions (r,c) with the k-th
     diagonal set to 1.




More information about the Scipy-svn mailing list