[Numpy-discussion] help using np.einsum for stacked matrix multiplication

Andrew Nelson andyfaff at gmail.com
Wed Oct 29 16:27:01 EDT 2014

On Wed, Oct 29, 2014 at 10:39 AM, Andrew Nelson <andyfaff at gmail.com> wrote:

> Dear list,
> I have a 4D array, A, that has the shape (NX, NY, 2, 2).  I wish to
> perform matrix multiplication of the 'NY' 2x2 matrices, resulting in the
> matrix B.  B would have shape (NX, 2, 2).  I believe that np.einsum would
> be up to the task, but I'm not quite sure of the subscripts I would need
> achieve this.

Ok, I'll try to explain in more detail of what I'm trying to do (I'm not
skilled in matrix algebra).

Say I have a series of matrices, M, which are all 2x2: M_0, M_1, ..., M_{NY
- 1}. These all need to be multiplied by each other.
i.e. N = M_0 x M_1 x  ... x M_{NY - 1}.
Note that I want to multiply M_0 by M_1, the result of that by M_2, the
result of that by M_3 and so on.
I can hold the (NY) matrices in a single array that has shape (NY, 2, 2).
The first row in that array would be M_0, the last would be M_{NY-1}.  The
output of all that matrix multiplication would be a single 2x2 matrix.  So
I would've thought an operation would do something like this:

 #there are NY-1 matrix multiplications involved here.
M[NY, 2, 2] ----->  N[2, 2]

Now let's make the next level of complication, I have NX of those M[NY, 2,
2] matrices. So I need to do the above matrix multiplication series NX
times.  I could hold all this in an array, P, with shape (NX, NY, 2, 2).
Each of the NX rows are independent.
Currently I am doing this in a nested loop, pseudocode follows:

output = np.zeros((NX, 2, 2))

for i in range(NX):
    temp = np.identity(2)
    for j in range(NY):
        temp = np.dot(temp, P[i, j])

    output[i] = temp

My original question was posted as I would like to remove that doubly
nested loop, with something more elegant, as well as a whole load faster.
Is there an np.einsum that can furnish that?
(Please forgive me if this still isn't clear or precise enough).

Dr. Andrew Nelson

-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://mail.python.org/pipermail/numpy-discussion/attachments/20141030/ba7280ba/attachment.html>

More information about the NumPy-Discussion mailing list