[Numpy-discussion] speed up np.diag

Citi, Luca lciti at essex.ac.uk
Fri Jul 10 10:55:58 EDT 2009


Hello,
I happened to have a look at the code for np.diag
and found it more complicated that necessary.
I think it can be rewritten more cleanly and
efficiently.
Appended you can find both versions.
The speed improvement is significant:

In [145]: x = S.rand(1000,1300)
In [146]: assert all(diag(x,-13) == diag2(x,-13))
In [147]: %timeit diag(x,-13)
10000 loops, best of 3: 37.4 µs per loop
In [148]: %timeit diag2(x,-13)
100000 loops, best of 3: 14.1 µs per loop
In [149]: x = x.T
In [150]: assert all(diag(x,-13) == diag2(x,-13))
In [151]: %timeit diag(x,-13)
10000 loops, best of 3: 81.1 µs per loop
In [152]: %timeit diag2(x,-13)
100000 loops, best of 3: 14.8 µs per loop

It takes slightly more than one third for
the C-contiguous input and slightly more than
one sixth for the F-contiguous input.

Best,
Luca



## ORIGINAL
def diag(v, k=0):
    v = asarray(v)
    s = v.shape
    if len(s)==1:
        n = s[0]+abs(k)
        res = zeros((n,n), v.dtype)
        if (k>=0):
            i = arange(0,n-k)
            fi = i+k+i*n
        else:
            i = arange(0,n+k)
            fi = i+(i-k)*n
        res.flat[fi] = v
        return res
    elif len(s)==2:
        N1,N2 = s
        if k >= 0:
            M = min(N1,N2-k)
            i = arange(0,M)
            fi = i+k+i*N2
        else:
            M = min(N1+k,N2)
            i = arange(0,M)
            fi = i + (i-k)*N2
        return v.flat[fi]
    else:
        raise ValueError, "Input must be 1- or 2-d."


## SUGGESTED
def diag(v, k=0):
    v = asarray(v)
    s = v.shape
    if len(s) == 1:
        n = s[0]+abs(k)
        res = zeros((n,n), v.dtype)
        i = k if k >= 0 else (-k) * s[1]
        res.flat[i::n+1] = v
        return res
    elif len(s) == 2:
        if v.flags.f_contiguous:
            v, k, s = v.T, -k, s[::-1]
        i = k if k >= 0 else (-k) * s[1]
        return v.flat[i::s[1]+1]
    else:
        raise ValueError, "Input must be 1- or 2-d."





More information about the NumPy-Discussion mailing list