[Numpy-discussion] speed up np.diag
Citi, Luca
lciti at essex.ac.uk
Fri Jul 10 10:55:58 EDT 2009
Hello,
I happened to have a look at the code for np.diag
and found it more complicated that necessary.
I think it can be rewritten more cleanly and
efficiently.
Appended you can find both versions.
The speed improvement is significant:
In [145]: x = S.rand(1000,1300)
In [146]: assert all(diag(x,-13) == diag2(x,-13))
In [147]: %timeit diag(x,-13)
10000 loops, best of 3: 37.4 µs per loop
In [148]: %timeit diag2(x,-13)
100000 loops, best of 3: 14.1 µs per loop
In [149]: x = x.T
In [150]: assert all(diag(x,-13) == diag2(x,-13))
In [151]: %timeit diag(x,-13)
10000 loops, best of 3: 81.1 µs per loop
In [152]: %timeit diag2(x,-13)
100000 loops, best of 3: 14.8 µs per loop
It takes slightly more than one third for
the C-contiguous input and slightly more than
one sixth for the F-contiguous input.
Best,
Luca
## ORIGINAL
def diag(v, k=0):
v = asarray(v)
s = v.shape
if len(s)==1:
n = s[0]+abs(k)
res = zeros((n,n), v.dtype)
if (k>=0):
i = arange(0,n-k)
fi = i+k+i*n
else:
i = arange(0,n+k)
fi = i+(i-k)*n
res.flat[fi] = v
return res
elif len(s)==2:
N1,N2 = s
if k >= 0:
M = min(N1,N2-k)
i = arange(0,M)
fi = i+k+i*N2
else:
M = min(N1+k,N2)
i = arange(0,M)
fi = i + (i-k)*N2
return v.flat[fi]
else:
raise ValueError, "Input must be 1- or 2-d."
## SUGGESTED
def diag(v, k=0):
v = asarray(v)
s = v.shape
if len(s) == 1:
n = s[0]+abs(k)
res = zeros((n,n), v.dtype)
i = k if k >= 0 else (-k) * s[1]
res.flat[i::n+1] = v
return res
elif len(s) == 2:
if v.flags.f_contiguous:
v, k, s = v.T, -k, s[::-1]
i = k if k >= 0 else (-k) * s[1]
return v.flat[i::s[1]+1]
else:
raise ValueError, "Input must be 1- or 2-d."
More information about the NumPy-Discussion
mailing list