import cmath
from collections import defaultdict
import kwant
import tinyarray as ta
import numpy as np

from mpl_toolkits.mplot3d import axes3d
from matplotlib import pyplot, cm


lat = kwant.lattice.honeycomb()
a, b = lat.sublattices


# Make an ancillary system with two symmetries.  With current Kwant 1, this
# system cannot by finalized.
sym = kwant.TranslationalSymmetry(a.vec((1, 0)), a.vec((0, 1)))
anc = kwant.Builder(sym)
anc[a(0, 0)] = anc[b(0, 0)] = None
anc[lat.neighbors()] = None


# Make an equivalent Kwant system without explicit symmetries that can be
# finalized.  The k vector will be a parameter to this system.  We have to be a
# bit careful as different hoppings in the periodic system correspond to the
# same hopping in the pseudo-periodic system.
sys = kwant.Builder()
sys[anc.sites()] = 0

shifts = defaultdict(list)
for a, b in anc.hoppings():
    # a is always in the fundamental domain.
    b_fd = sym.to_fd(b)
    shifts[a, b_fd].append(b.pos - b_fd.pos)

def hopping(a, b, k):
    return sum(cmath.exp(1j * ta.dot(shift, k)) for shift in shifts[a, b])

sys[shifts.keys()] = hopping
sys = sys.finalized()


# Calculate and plot the bands.
momenta = np.linspace(-5, 5, 111)
energies = []
for kx in momenta:
    col = []
    energies.append(col)
    for ky in momenta:
        H = sys.hamiltonian_submatrix(args=[(kx, ky)])
        col.append(np.linalg.eigvalsh(H).real)
energies = np.array(energies)

ax = pyplot.figure().add_subplot(111, projection='3d')
ax.plot_surface(momenta[:, None], momenta[None, :], energies[..., 0],
                rstride=3, cstride=3, cmap=cm.coolwarm)
ax.plot_wireframe(momenta[:, None], momenta[None, :], energies[..., 1],
                  rstride=3, cstride=3)
pyplot.show()
