import kwant
import numpy as np
from numpy import sin, cos,tan,sqrt
from matplotlib import pyplot


def site_symbol(site):
    if site.family in lat.sublattices:
        return 'S'
    else:
        return ('P', 3, 0)


def make_system(a, t, W, L1):

    global lat, lat2

    prime=np.array([[sqrt(3)/2,sqrt(3)],[1.5,0.0]])*a
    basis_at=np.array([[0.0,0.0],[-0.5,0.5]])*a

    lat = kwant.lattice.general([prime[:,0],prime[:,1]],[basis_at[:,0],basis_at[:,1]])
    syst = kwant.Builder()

    #### Define the scattering region. ####

    def scatt_shape(pos):
        (x, y) = pos
        return (-L1  < x < L1) and ( -W <y < W)

    syst[lat.shape(scatt_shape,basis_at[:,0])]=0
    syst[lat.neighbors()] = t

    #### Left lead. ####

    def lead_shape(pos):
        (x, y) = pos
        return (-W  < y < W )

    lsym=kwant.TranslationalSymmetry(lat.vec((0,-1)))
    lsym.add_site_family(lat.sublattices[0], other_vectors=[(-2,1)])
    lsym.add_site_family(lat.sublattices[1], other_vectors=[(-2,1)])

    llead = kwant.Builder(lsym)
    llead[lat.shape(lead_shape,basis_at[:,0])] = 0
    llead[lat.neighbors()] = t

    #### Right lead. ####

    ### Define a hexagonal lattice with different - but equivalent! - basis
    basis2=np.array([[0.0,sqrt(3)/2],[-0.5,-1.0]])*a
    lat2 = kwant.lattice.general([prime[:,0],prime[:,1]],[basis2[:,0],basis2[:,1]])

    #### Extra hexagon of lat2 on the right edge of the scattering region:
    ####  -- range(0,1)  - single hexagon
    ##### -- range(-2,3) - full width
    for i in range (0,1):
        syst[lat2.sublattices[0](0+2*i,9-i)]=0.0
        syst[lat2.sublattices[1](0+2*i,8-i)]=0.0
        syst[lat2.sublattices[0](1+2*i,8-i)]=0.0
        syst[lat2.sublattices[1](1+2*i,8-i)]=0.0

    syst[lat2.neighbors()]=t
    syst[lat.neighbors()]=t

    kwant.plot(syst, site_symbol=site_symbol, show=False)

    lsym=kwant.TranslationalSymmetry(lat2.vec((0,1)))
    lsym.add_site_family(lat2.sublattices[0], other_vectors=[(2,-1)])
    lsym.add_site_family(lat2.sublattices[1], other_vectors=[(2,-1)])

    rlead= kwant.Builder(lsym)
    rlead[lat2.shape(lead_shape,basis_at[:,0])] = 0
    rlead[lat2.neighbors()] =t

    syst.attach_lead(llead)
    syst.attach_lead(rlead)

    return syst

syst = make_system(a=1.42,t=3.0,W=10.5,L1=20)
kwant.plot(syst, site_symbol=site_symbol)
syst=syst.finalized()

data1 = []
data2 = []
params = dict(a=1.42, t=-3.0)

energies=[0.2*i+0.01 for i in range(20)]
for energy in energies:
    print('En=',energy)
    smatrix = kwant.smatrix(syst, energy,params=params)
    data1.append(smatrix.transmission(1, 0))
    data2.append(smatrix.num_propagating(0))

pyplot.figure()
pyplot.plot(energies, data1, 'r--', label='T')
pyplot.plot(energies, data2,'r>', label='N')
legend = pyplot.legend(loc='upper left')
pyplot.xlabel("energy")
pyplot.ylabel("conductance [e^2/h]")
pyplot.show()
