[Numpy-discussion] Proposed new feature for numpy.einsum: repeated output subscripts as diagonal
Jaime Fernández del Río
jaime.frio at gmail.com
Thu Aug 21 01:24:22 EDT 2014
On Wed, Aug 20, 2014 at 6:26 AM, Pierre-Andre Noel <
noel.pierre.andre at gmail.com> wrote:
> Thanks all for the feedback!
> So there appears to be interest for this feature, and I think that I can
> implement it. However, it may take a while before I do so: I have other
> priorities right now.
> In view of jaimefrio's comment on
> https://github.com/numpy/numpy/issues/4965 as well as Eelco Hoogendoorn's
> reply above, here is how I currently intend to implement the feature.
> 1. Implement a `diag_view` function that uses strides to make a view. The
> function would use subscripts in a way very similar to `einsum`, except
> that no commas are allowed and all indices appearing on one side of `->`
> must also appear on the other side. Like the current `einsum`, indices on
> the right-hand side of `->` cannot be repeated. For example,
> `B=diag_view('iij->ij',A)` returns a 2D view `B` of the 3D array `A` where
> the off-diagonal elements in the first two dimensions of `A` are
> inaccessible in `B`.
> 2. The edits to `einsum` itself should be minimal. For the purpose of the
> following, suppose that the indices have the form `lhs+'->'+rhs`, where
> `lhs` and `rhs` are character strings. To make sure that the current
> behavior of `einsum` is not slowed down nor broken by the new
> functionality, I intend to limit edits to the point where an error would be
> raised due to repeated indices in `rhs`. The following outlines what would
> replace the current error-raising.
> 2.1 Extract from `rhs` the first occurrences of each indices; call
> that `rhs_first_oc`.
> 2.2 If no `out` has been provided to `einsum`, allocate a zeroed out
> `ndarray` of appropriate size, including off-diagonal entries; call that
> `full_out`. If an `out` was provided to `einsum`, set `full_out=out`.
> 2.3 Set `diag_out=diag_view(rhs+'->'+rhs_first_oc,full_out)`.
> 2.4 Call `einsum(lhs+'->'+rhs_first_oc, [...], out=diag_out)`. This
> call is recursive, but the recursion should stop there.
> 2.5 Return `full_out`.
I have looked a little into this, and I think there is an additional
complication: if I understood the structure of the code correctly,
`einsum`'s current entry point is the function `array_einsum` in
`multiarraymodule.c`, which accepts two different input methods: the
subscript one we have been discussing here, and another one that uses
lists of axes after each operand. This second method gets translated into
subscript notation by several functions in that same module:
`einsum_list_to_subscripts` and `einsum_sub_op_from_lists`, and then the C
API einsum function, `PyArray_EinsteinSum` in `einsum.c.src`, which only
understands the subscript notation, gets called.
The simplest place to implement the changes you propose without any major
rearchitecturing is therefore in `PyArray_EinsteinSum`. And while the flow
you propose seems to me be correct, doing that at the C level will probably
look somewhat different, e.g. you would probably let the iterator create an
array with all the axes, and then remove the repeated ones from the
iterator and modify the strides, instead of passing in a strided view with
If you were planning on writing your code in a Python wrapper, you need to
figure out how to keep the alternative syntax code path. Haven't given it
much thought, but it doesn't look easy without rewriting a lot of stuff.
I see either solution as way too much complication for the reward. And
still see writing a function that does the opposite of your `diag_view`,
and expecting the end user to chain a call to it to the call to einsum, as
the simplest way of providing this functionality. Although if you can find
the time and the motivation to do the big change, I am perfectly OK with
it, of course!
( > <) Este es Conejo. Copia a Conejo en tu firma y ayúdale en sus planes
de dominación mundial.
-------------- next part --------------
An HTML attachment was scrubbed...
More information about the NumPy-Discussion