[Scipy-svn] r3745 - trunk/scipy/sparse
scipy-svn at scipy.org
scipy-svn at scipy.org
Sat Dec 29 15:35:34 EST 2007
Author: wnbell
Date: 2007-12-29 14:35:32 -0600 (Sat, 29 Dec 2007)
New Revision: 3745
Modified:
trunk/scipy/sparse/construct.py
Log:
made spkron(A,B) more efficient for nearly dense B
Modified: trunk/scipy/sparse/construct.py
===================================================================
--- trunk/scipy/sparse/construct.py 2007-12-29 02:53:36 UTC (rev 3744)
+++ trunk/scipy/sparse/construct.py 2007-12-29 20:35:32 UTC (rev 3745)
@@ -12,6 +12,7 @@
from csr import csr_matrix, isspmatrix_csr
from csc import csc_matrix, isspmatrix_csc
+from bsr import bsr_matrix
from coo import coo_matrix
from dok import dok_matrix
from lil import lil_matrix
@@ -110,36 +111,56 @@
"""
#TODO optimize for small dense B and CSR A -> BSR
- A,B = coo_matrix(A),coo_matrix(B)
- output_shape = (A.shape[0]*B.shape[0],A.shape[1]*B.shape[1])
+ B = coo_matrix(B)
- if A.nnz == 0 or B.nnz == 0:
- # kronecker product is the zero matrix
- return coo_matrix( output_shape )
+
+ if 2*B.nnz >= B.shape[0] * B.shape[1]:
+ #B is fairly dense, use BSR
+ A = csr_matrix(A,copy=True)
+
+ output_shape = (A.shape[0]*B.shape[0],A.shape[1]*B.shape[1])
- # expand entries of a into blocks
- row = A.row.repeat(B.nnz)
- col = A.col.repeat(B.nnz)
- data = A.data.repeat(B.nnz)
+ if A.nnz == 0 or B.nnz == 0:
+ # kronecker product is the zero matrix
+ return coo_matrix( output_shape )
+
+ B = B.toarray()
+ data = A.data.repeat(B.size).reshape(-1,B.shape[0],B.shape[1])
+ data = data * B
+
+ return bsr_matrix((data,A.indices,A.indptr),shape=output_shape)
+ else:
+ #use COO
+ A = coo_matrix(A)
+ output_shape = (A.shape[0]*B.shape[0],A.shape[1]*B.shape[1])
- row *= B.shape[0]
- col *= B.shape[1]
+ if A.nnz == 0 or B.nnz == 0:
+ # kronecker product is the zero matrix
+ return coo_matrix( output_shape )
- # increment block indices
- row,col = row.reshape(-1,B.nnz),col.reshape(-1,B.nnz)
- row += B.row
- col += B.col
- row,col = row.reshape(-1),col.reshape(-1)
+ # expand entries of a into blocks
+ row = A.row.repeat(B.nnz)
+ col = A.col.repeat(B.nnz)
+ data = A.data.repeat(B.nnz)
- # compute block entries
- data = data.reshape(-1,B.nnz) * B.data
- data = data.reshape(-1)
+ row *= B.shape[0]
+ col *= B.shape[1]
- return coo_matrix((data,(row,col)), shape=output_shape).asformat(format)
+ # increment block indices
+ row,col = row.reshape(-1,B.nnz),col.reshape(-1,B.nnz)
+ row += B.row
+ col += B.col
+ row,col = row.reshape(-1),col.reshape(-1)
+ # compute block entries
+ data = data.reshape(-1,B.nnz) * B.data
+ data = data.reshape(-1)
+ return coo_matrix((data,(row,col)), shape=output_shape).asformat(format)
+
+
def lil_eye((r,c), k=0, dtype='d'):
"""Generate a lil_matrix of dimensions (r,c) with the k-th
diagonal set to 1.
More information about the Scipy-svn
mailing list