[Numpy-discussion] Scipy dot
Nicolas SCHEFFER
scheffer.nicolas at gmail.com
Wed Nov 7 17:41:37 EST 2012
Hi,
I've written a snippet of code that we could call scipy.dot, a drop-in
replacement for numpy.dot.
It's dead easy, and just answer the need of calling the right blas
function depending on the type of arrays, C or F order (aka slowness
of np.dot(A, A.T))
While this is not the scipy mailing list, I was wondering if this
snippet would relevant and/or useful to others, numpy folks, scipy
folks or could be integrated directly in numpy (so that we keep the
nice A.dot(B) syntax)
This bottleneck of temporary copies has been a problem for lots of
users and it seems everybody has their own snippets.
This code is probably not written as it should, I hope the community
can help improving it! ;)
First FIXME is to make it work for arrays of dimensions other than 2.
Suggestions highly appreciated!
Thanks!
===
Code (also on http://pastebin.com/QrRk0kEf)
def dot(A, B, out=None):
""" A drop in replaement for numpy.dot
Computes A.B optimized using fblas call
note: unlike in numpy the returned array is in F order"""
import scipy.linalg as sp
gemm = sp.get_blas_funcs('gemm', arrays=(A,B))
if out is None:
lda, x, y, ldb = A.shape + B.shape
if x != y:
raise ValueError("matrices are not aligned")
dtype = np.max([x.dtype for x in (A, B)])
out = np.empty((lda, ldb), dtype, order='F')
if A.flags.c_contiguous and B.flags.c_contiguous:
gemm(alpha=1., a=A.T, b=B.T, trans_a=True, trans_b=True,
c=out, overwrite_c=True)
if A.flags.c_contiguous and B.flags.f_contiguous:
gemm(alpha=1., a=A.T, b=B, trans_a=True, c=out, overwrite_c=True)
if A.flags.f_contiguous and B.flags.c_contiguous:
gemm(alpha=1., a=A, b=B.T, trans_b=True, c=out, overwrite_c=True)
if A.flags.f_contiguous and B.flags.f_contiguous:
gemm(alpha=1., a=A, b=B, c=out, overwrite_c=True)
return out
==
Timing (EPD, MKL):
In [15]: A = np.array(np.random.randn(1000, 1000), 'f')
In [16]: %timeit np.dot(A, A)
100 loops, best of 3: 7.19 ms per loop
In [17]: %timeit np.dot(A.T, A.T)
10 loops, best of 3: 27.7 ms per loop
In [18]: %timeit np.dot(A, A.T)
100 loops, best of 3: 18.3 ms per loop
In [19]: %timeit np.dot(A.T, A)
100 loops, best of 3: 18.7 ms per loop
In [20]: %timeit dot(A, A)
100 loops, best of 3: 7.16 ms per loop
In [21]: %timeit dot(A.T, A.T)
100 loops, best of 3: 6.67 ms per loop
In [22]: %timeit dot(A, A.T)
100 loops, best of 3: 6.79 ms per loop
In [23]: %timeit dot(A.T, A)
100 loops, best of 3: 7.02 ms per loop
More information about the NumPy-Discussion
mailing list