[Numpy-discussion] How exactly ought 'dot' to work?

Nathaniel Smith njs at pobox.com
Sat Feb 22 17:03:03 EST 2014

Hi all,

Currently numpy's 'dot' acts a bit weird for ndim>2 or ndim<1. In
practice this doesn't usually matter much, because these are very
rarely used. But, I would like to nail down the behaviour so we can
say something precise in the matrix multiplication PEP. So here's one


dot(0d, any) -> scalar multiplication
dot(any, 0d) -> scalar multiplication
dot(1d, 1d) -> inner product
dot(2d, 1d) -> treat 1d as column matrix, matrix-multiply, then
discard added axis
dot(1d, 2d) -> treat 1d as row matrix, matrix-multiply, then discard added axis
dot(2d, 2d) -> matrix multiply
dot(2-or-more d, 2-or-more d) -> a complicated outer product thing:
Specifically, if the inputs have shapes (r, n, m), (s, m, k), then
numpy returns an array with shape (r, s, n, k), created like:
    for i in range(r):
        for j in range(s):
            output[i, j, :, :] = np.dot(input1[i, :, :], input2[j, :, :])


General rule: given dot on shape1, shape2, we try to match these
shapes against two templates like
  (..., n?, m) and (..., m, k?)
where ... indicates zero or more dimensions, and ? indicates an
optional axis. ? axes are always matched before ... axes, so for an
input with ndim>=2, the ? axis is always matched. An unmatched ? axis
is treated as having size 1.

Next, the ... axes are broadcast against each other in the usual way
(prepending 1s to make lengths the same, requiring corresponding
entries to either match or have the value 1).  And then the actual
computations are performed using the usual broadcasting rules.

Finally, we return an output with shape (..., n?, k?). Here "..."
indicates the result of broadcasting the input ...'s against each
other. And, n? and k? mean: "either the value taken from the input
shape, if the corresponding entry was matched -- but if no match was
made, then we leave this entry out." The idea is that just as a column
vector on the right is "m x 1", a 1d vector on the right is treated as
"m x <nothing>". For purposes of actually computing the product,
<nothing> acts like 1, as mentioned above. But it makes a difference
in what we return: in each of these cases we copy the input shape into
the output, so we can get an output with shape (n, <nothing>), or
(<nothing>, k), or (<nothing>, <nothing>), which work out to be (n,),
(k,) and (), respectively. This gives a (somewhat) intuitive principle
for why dot(1d, 1d), dot(1d, 2d), dot(2d, 1d) are handled the way they
are, and a general template for extending this behaviour to other
operations like gufunc 'solve'.

Anyway, the end result of this is that the PROPOSED behaviour differs
from the current behaviour in the following ways:
- passing 0d arrays to 'dot' becomes an error. (This in particular is
an important thing to know, because if core Python adds an operator
for 'dot', then we must decide what it should do for Python scalars,
which are logically 0d.)
- ndim>2 arrays are now handled by aligning and broadcasting the extra
axes, instead of taking an outer product. So dot((r, m, n), (r, n, k))
returns (r, m, k), not (r, r, m, k).


Nathaniel J. Smith
Postdoctoral researcher - Informatics - University of Edinburgh

More information about the NumPy-Discussion mailing list