from cmath import exp
import kwant, numpy
from matplotlib import pyplot

lat = kwant.lattice.square()


def phase_x(a, b, B):
    ap = a.pos
    bp = b.pos
    phase = -B * (0.5 * (ap[1] + bp[1]) * (bp[0] - ap[0]))
    return -exp(1j * phase)


def gauge_x_to_y(s):
    return numpy.prod(s.pos)


def phase_y(a, b, B):
    ap = a.pos
    bp = b.pos
    phase = -B * (0.5 * (ap[1] + bp[1]) * (bp[0] - ap[0])
                  + gauge_x_to_y(a) - gauge_x_to_y(b))
    return -exp(1j * phase)


def gauge_scatreg(s):
    return gauge_x_to_y(s) if s in gauge_transformed_sites else 0


def phase_scatreg(a, b, B):
    ap = a.pos
    bp = b.pos
    phase = -B * (0.5 * (ap[1] + bp[1]) * (bp[0] - ap[0])
                  + gauge_scatreg(a) - gauge_scatreg(b))
    return -exp(1j * phase)


def make_lead(direction, width, r0, phase):
    lead = kwant.Builder(kwant.TranslationalSymmetry(lat.vec(direction)))
    lead[( lat(r0[0] - i * direction[1], r0[1] + i * direction[0])
           for i in range(width) )] = 4
    lead[lat.neighbors()] = phase
    return lead


def main(E=1):
    global gauge_transformed_sites

    sys = kwant.Builder()
    sys[( lat(x,y) for x in range(40) for y in range(40)
          if y > 19 - abs(x-17) )] = 4
    sys[lat.neighbors()] = phase_scatreg
    sys.attach_lead(make_lead((1, 0), 18, (50, 0), phase_x))
    sys.attach_lead(make_lead((-1, 0), 20, (50, 23), phase_x))
    sys.attach_lead(make_lead((0, 1), 18, (19, 0), phase_y))
    kwant.plot(sys)
    sys = sys.finalized()

    gauge_transformed_sites = set(sys.sites[i] for i in sys.lead_interfaces[2])

    Bs = numpy.arange(-0.21, 0.5, 0.01)
    conds = []
    for B in Bs:
        cond = kwant.smatrix(sys, E, args=[B]).transmission(1, 0)
        conds.append(cond)
        print(B, cond)

    ax = pyplot.figure().add_subplot(1, 1, 1)
    ax.plot(Bs, conds)
    pyplot.show()


if __name__ == '__main__':
    main()
