[Numpy-discussion] Bug in einsum?

Jaakko Luttinen jaakko.luttinen at aalto.fi
Wed Mar 13 05:15:36 EDT 2013


I have encountered a very weird behaviour with einsum. I try to compute
something like R*A*R', where * denotes a kind of "matrix
multiplication". However, for particular shapes of R and A, the results
are extremely bad.

I compare two einsum results:
First, I compute in two einsum calls as (R*A)*R'.
Second, I compute the whole result in one einsum call.
However, the results are significantly different for some shapes.

My test:
import numpy as np
for D in range(30):
    A = np.random.randn(100,D,D)
    R = np.random.randn(D,D)
    Y1 = np.einsum('...ik,...kj->...ij', R, A)
    Y1 = np.einsum('...ik,...kj->...ij', Y1, R.T)
    Y2 = np.einsum('...ik,...kl,...lj->...ij', R, A, R.T)
    print("D=%d" % D, np.allclose(Y1,Y2), np.linalg.norm(Y1-Y2))

D=0 True 0.0
D=1 True 0.0
D=2 True 8.40339658678e-15
D=3 True 8.09995399928e-15
D=4 True 3.59428803435e-14
D=5 False 34.755610184
D=6 False 28.3576558351
D=7 False 41.5402690906
D=8 True 2.31709582841e-13
D=9 False 36.0161112799
D=10 True 4.76237746912e-13
D=11 True 4.57944440782e-13
D=12 True 4.90302218301e-13
D=13 True 6.96175851271e-13
D=14 True 1.10067181384e-12
D=15 True 1.29095933163e-12
D=16 True 1.3466837332e-12
D=17 True 1.52265065763e-12
D=18 True 2.05407923852e-12
D=19 True 2.33327630748e-12
D=20 True 2.96849358082e-12
D=21 True 3.31063706175e-12
D=22 True 4.28163620455e-12
D=23 True 3.58951880681e-12
D=24 True 4.69973694769e-12
D=25 True 5.47385264567e-12
D=26 True 5.49643316347e-12
D=27 True 6.75132988402e-12
D=28 True 7.86435437892e-12
D=29 True 7.85453681029e-12

So, for D={5,6,7,9}, allclose returns False and the error norm is HUGE.
It doesn't seem like just some small numerical inaccuracy because the
error norm is so large. I don't know which one is correct (Y1 or Y2) but
at least either one is wrong in my opinion.

I ran the same test several times, and each time same values of D fail.
If I change the shapes somehow, the failing values of D might change
too, but I usually have several failing values.

I'm running the latest version from github (commit bd7104cef4) under
Python 3.2.3. With NumPy 1.6.1 under Python 2.7.3 the test crashes and
Python exits printing "Floating point exception".

This seems so weird to me that I wonder if I'm just doing something stupid..

Thanks a lot for any help!

More information about the NumPy-Discussion mailing list