[Numpy-discussion] How exactly ought 'dot' to work?
Jaime Fernández del Río
jaime.frio at gmail.com
Sat Feb 22 17:37:03 EST 2014
On Feb 22, 2014 2:03 PM, "Nathaniel Smith" <njs at pobox.com> wrote:
>
> Hi all,
>
> Currently numpy's 'dot' acts a bit weird for ndim>2 or ndim<1. In
> practice this doesn't usually matter much, because these are very
> rarely used. But, I would like to nail down the behaviour so we can
> say something precise in the matrix multiplication PEP. So here's one
> proposal.
>
> # CURRENT:
>
> dot(0d, any) -> scalar multiplication
> dot(any, 0d) -> scalar multiplication
> dot(1d, 1d) -> inner product
> dot(2d, 1d) -> treat 1d as column matrix, matrix-multiply, then
> discard added axis
> dot(1d, 2d) -> treat 1d as row matrix, matrix-multiply, then discard
added axis
> dot(2d, 2d) -> matrix multiply
> dot(2-or-more d, 2-or-more d) -> a complicated outer product thing:
> Specifically, if the inputs have shapes (r, n, m), (s, m, k), then
> numpy returns an array with shape (r, s, n, k), created like:
> for i in range(r):
> for j in range(s):
> output[i, j, :, :] = np.dot(input1[i, :, :], input2[j, :, :])
>
> # PROPOSED:
>
> General rule: given dot on shape1, shape2, we try to match these
> shapes against two templates like
> (..., n?, m) and (..., m, k?)
> where ... indicates zero or more dimensions, and ? indicates an
> optional axis. ? axes are always matched before ... axes, so for an
> input with ndim>=2, the ? axis is always matched. An unmatched ? axis
> is treated as having size 1.
>
> Next, the ... axes are broadcast against each other in the usual way
> (prepending 1s to make lengths the same, requiring corresponding
> entries to either match or have the value 1). And then the actual
> computations are performed using the usual broadcasting rules.
>
> Finally, we return an output with shape (..., n?, k?). Here "..."
> indicates the result of broadcasting the input ...'s against each
> other. And, n? and k? mean: "either the value taken from the input
> shape, if the corresponding entry was matched -- but if no match was
> made, then we leave this entry out." The idea is that just as a column
> vector on the right is "m x 1", a 1d vector on the right is treated as
> "m x <nothing>". For purposes of actually computing the product,
> <nothing> acts like 1, as mentioned above. But it makes a difference
> in what we return: in each of these cases we copy the input shape into
> the output, so we can get an output with shape (n, <nothing>), or
> (<nothing>, k), or (<nothing>, <nothing>), which work out to be (n,),
> (k,) and (), respectively. This gives a (somewhat) intuitive principle
> for why dot(1d, 1d), dot(1d, 2d), dot(2d, 1d) are handled the way they
> are, and a general template for extending this behaviour to other
> operations like gufunc 'solve'.
>
> Anyway, the end result of this is that the PROPOSED behaviour differs
> from the current behaviour in the following ways:
> - passing 0d arrays to 'dot' becomes an error. (This in particular is
> an important thing to know, because if core Python adds an operator
> for 'dot', then we must decide what it should do for Python scalars,
> which are logically 0d.)
> - ndim>2 arrays are now handled by aligning and broadcasting the extra
> axes, instead of taking an outer product. So dot((r, m, n), (r, n, k))
> returns (r, m, k), not (r, r, m, k).
>
> Comments?
The proposed behavior for ndim > 2 is what matrix_multiply (is it still in
umath_tests?) does. The nice thing of the proposed new behavior is that the
old behavior is easy to reproduce by fooling a little around with the shape
of the first argument, while the opposite is not true.
Jaime
>
> --
> Nathaniel J. Smith
> Postdoctoral researcher - Informatics - University of Edinburgh
> http://vorpus.org
> _______________________________________________
> NumPy-Discussion mailing list
> NumPy-Discussion at scipy.org
> http://mail.scipy.org/mailman/listinfo/numpy-discussion
-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://mail.python.org/pipermail/numpy-discussion/attachments/20140222/f92067a0/attachment.html>
More information about the NumPy-Discussion
mailing list