[Scipy-svn] r3451 - in trunk/scipy/sparse: . tests

scipy-svn at scipy.org scipy-svn at scipy.org
Sun Oct 21 20:21:18 EDT 2007


Author: wnbell
Date: 2007-10-21 19:21:14 -0500 (Sun, 21 Oct 2007)
New Revision: 3451

Modified:
   trunk/scipy/sparse/sparse.py
   trunk/scipy/sparse/tests/test_sparse.py
Log:
fix sparse matrix * dense matrix multiplicationa
now returns dense matrix


Modified: trunk/scipy/sparse/sparse.py
===================================================================
--- trunk/scipy/sparse/sparse.py	2007-10-21 20:55:46 UTC (rev 3450)
+++ trunk/scipy/sparse/sparse.py	2007-10-22 00:21:14 UTC (rev 3451)
@@ -17,7 +17,7 @@
 from numpy import zeros, isscalar, real, imag, asarray, asmatrix, matrix, \
                   ndarray, amax, amin, rank, conj, searchsorted, ndarray,   \
                   less, where, greater, array, transpose, empty, ones, \
-                  arange, shape, intc, clip, prod, unravel_index
+                  arange, shape, intc, clip, prod, unravel_index, hstack
 import numpy
 from scipy.sparse.sparsetools import cscmux, csrmux, \
      cootocsr, csrtocoo, cootocsc, csctocoo, csctocsr, csrtocsc, \
@@ -345,8 +345,7 @@
 
     def dot(self, other):
         """ A generic interface for matrix-matrix or matrix-vector
-        multiplication.  Returns A.transpose().conj() * other or
-        A.transpose() * other.
+        multiplication.  
         """
 
         try:
@@ -355,23 +354,16 @@
             # If it's a list or whatever, treat it like a matrix
             other = asmatrix(other)
 
-        if len(other.shape) == 1:
-            result = self.matvec(other)
-        elif isdense(other) and asarray(other).squeeze().ndim <= 1:
-            # If it's a row or column vector, return a DENSE result
-            result = self.matvec(other)
+        if isdense(other) and asarray(other).squeeze().ndim <= 1:
+            # it's a dense row or column vector
+            return self.matvec(other)
         elif len(other.shape) == 2:
-            # Return a sparse result
-            result = self.matmat(other)
+            # it's a 2d dense array, dense matrix, or sparse matrix
+            return self.matmat(other)
         else:
             raise ValueError, "could not interpret dimensions"
+        
 
-        if isinstance(other, matrix) and isdense(result):
-            return asmatrix(result)
-        else:
-            # if the result is sparse or 'other' is an array:
-            return result
-
     def matmat(self, other):
         csc = self.tocsc()
         return csc.matmat(other)
@@ -663,9 +655,10 @@
             other = self._tothis(other)
             return self._binopt(other,fn,in_shape=(M,N),out_shape=(M,N))
         elif isdense(other):
-            # This is SLOW!  We need a more efficient implementation
-            # of sparse * dense matrix multiplication!
-            return self.matmat(csc_matrix(other))
+            # TODO make sparse * dense matrix multiplication more efficient
+            
+            # matvec each column of other 
+            return hstack( [ self * col.reshape(-1,1) for col in other.T ] )
         else:
             raise TypeError, "need a dense or sparse matrix"
 

Modified: trunk/scipy/sparse/tests/test_sparse.py
===================================================================
--- trunk/scipy/sparse/tests/test_sparse.py	2007-10-21 20:55:46 UTC (rev 3450)
+++ trunk/scipy/sparse/tests/test_sparse.py	2007-10-22 00:21:14 UTC (rev 3451)
@@ -213,40 +213,40 @@
         # Currently M.matvec(asarray(col)) is rank-1, whereas M.matvec(col)
         # is rank-2.  Is this desirable?
 
-    def check_matmat(self):
+    def check_matmat_sparse(self):
         a = matrix([[3,0,0],[0,1,0],[2,0,3.0],[2,3,0]])
         a2 = array([[3,0,0],[0,1,0],[2,0,3.0],[2,3,0]])
         b = matrix([[0,1],[1,0],[0,2]],'d')
         asp = self.spmatrix(a)
         bsp = self.spmatrix(b)
         assert_array_almost_equal((asp*bsp).todense(), a*b)
-        assert_array_almost_equal((asp*b).todense(), a*b)
-        assert_array_almost_equal((a*bsp).todense(), a*b)
-        assert_array_almost_equal((a2*bsp).todense(), a*b)
+        assert_array_almost_equal( asp*b, a*b)
+        assert_array_almost_equal( a*bsp, a*b)
+        assert_array_almost_equal( a2*bsp, a*b)
 
         # Now try performing cross-type multplication:
         csp = bsp.tocsc()
         c = b
         assert_array_almost_equal((asp*csp).todense(), a*c)
         assert_array_almost_equal((asp.matmat(csp)).todense(), a*c)
-        assert_array_almost_equal((asp*c).todense(), a*c)
+        assert_array_almost_equal( asp*c, a*c)
         
-        assert_array_almost_equal((a*csp).todense(), a*c)
-        assert_array_almost_equal((a2*csp).todense(), a*c)
+        assert_array_almost_equal( a*csp, a*c)
+        assert_array_almost_equal( a2*csp, a*c)
         csp = bsp.tocsr()
         assert_array_almost_equal((asp*csp).todense(), a*c)
         assert_array_almost_equal((asp.matmat(csp)).todense(), a*c)
-        assert_array_almost_equal((asp*c).todense(), a*c)
+        assert_array_almost_equal( asp*c, a*c)
 
-        assert_array_almost_equal((a*csp).todense(), a*c)
-        assert_array_almost_equal((a2*csp).todense(), a*c)
+        assert_array_almost_equal( a*csp, a*c)
+        assert_array_almost_equal( a2*csp, a*c)
         csp = bsp.tocoo()
         assert_array_almost_equal((asp*csp).todense(), a*c)
         assert_array_almost_equal((asp.matmat(csp)).todense(), a*c)
-        assert_array_almost_equal((asp*c).todense(), a*c)
+        assert_array_almost_equal( asp*c, a*c)
 
-        assert_array_almost_equal((a*csp).todense(), a*c)
-        assert_array_almost_equal((a2*csp).todense(), a*c)
+        assert_array_almost_equal( a*csp, a*c)
+        assert_array_almost_equal( a2*csp, a*c)
 
         # Test provided by Andy Fraser, 2006-03-26
         L = 30
@@ -262,6 +262,18 @@
         assert_array_almost_equal(B.todense(), A.todense() * A.T.todense())
         assert_array_almost_equal(B.todense(), A.todense() * A.todense().T)
     
+    def check_matmat_dense(self):
+        a = matrix([[3,0,0],[0,1,0],[2,0,3.0],[2,3,0]])
+        asp = self.spmatrix(a)
+        
+        # check both array and matrix types
+        bs = [ array([[1,2],[3,4],[5,6]]), matrix([[1,2],[3,4],[5,6]]) ]
+
+        for b in bs:
+            result = asp*b
+            assert( isinstance(result, type(b)) )
+            assert_equal( result.shape, (4,2) )
+            assert_equal( result, dot(a,b) )
     
     def check_tocoo(self):
         a = self.datsp.tocoo()




More information about the Scipy-svn mailing list