import kwant
import matplotlib.pyplot as plt
import numpy as np
from scipy.constants import constants
import numpy.linalg as nla

# Description:
# In this script, we will compute the retarded Green's function of a 1D system with 
# random on-site potential at a few sites within the scattering region using two different, but equivalent, methods:
# (I) Using direct inversion of the Hamiltonian with leads self energies (computed from the Kwant software): 
# 				G^r(E) = [E+i\eta - H - \Sigma^r_{leads}]^{-1}, where \eta = very small energy broadening
# (II) Using scattering wave function presented in "Numerical simulations of time resolved quantum electronics"(https://arxiv.org/pdf/1307.6419.pdf)
#  by Kwant authors. Specifically, we will use Eq. 25 from that reference in energy domain:
# 				G^r(E) = -i*\Psi(E)*\Psi^\dag(E), where \Psi(E) is obtained from the Kwant software (see below)

# Conclusion:
# The two methods agree ONLY IF on-site potential from the scattering region is removed,
# that is, when the scattering region is a homogeneous region, same as lead. In the presence of some on site potential within the scattering region,
# the two methods result into two different retarded Green's function  and they do not match at all!



# System parameters
L = 5
W = 1 
a = 1.0  
t  = 1.0
energy_broadening = 1e-5  # <--- needed to compute retarded Green's function (see below)

# Build system
sys = kwant.Builder()
lat = kwant.lattice.square(a)

# Define the scattering region/device Hamiltonian
for i in range(L):
    for j in range(W):
        # On-site Hamiltonian
        sys[lat(i, j)] = 2 * t
        # On-site Potential
        if i>1 or i<L-1:
            sys[lat(i, j)] += 10*np.random.rand(1)     # <------notice a random on-site potential added within a few sites of the scattering region (it need not be random and can be fixed value also)
        # Hopping in y-direction
        if j > 0:
            sys[lat(i, j), lat(i, j - 1)] = -t
        # Hopping in x-direction
        if i > 0:
            sys[lat(i, j), lat(i - 1, j)] = -t

# Left lead
sym_left_lead = kwant.TranslationalSymmetry((-a, 0))
left_lead = kwant.Builder(sym_left_lead)
for j in range(W):
    left_lead[lat(0, j)] = 2 * t
    if j > 0:
        left_lead[lat(0, j), lat(0, j - 1)] = -t
    left_lead[lat(1, j), lat(0, j)] = -t
sys.attach_lead(left_lead)
left_lead = left_lead.finalized()

# Right lead
sym_right_lead = kwant.TranslationalSymmetry((a, 0))
right_lead = kwant.Builder(sym_right_lead)
for j in range(W):
    right_lead[lat(0, j)] = 2* t
    if j > 0:
        right_lead[lat(0, j), lat(0, j - 1)] = -t
    right_lead[lat(1, j), lat(0, j)] = -t
sys.attach_lead(right_lead)
right_lead = right_lead.finalized()
sys = sys.finalized()

# Device Hamiltonian
device_H = sys.hamiltonian_submatrix()  
device_leads_self_energy = np.zeros(device_H.shape, dtype=np.complex128)
I = np.eye(device_H.shape[0], device_H.shape[1], dtype=np.complex128)

# Plot system
kwant.plot(sys)

# Energy span
e_size = 100
energy_span = np.linspace(0.001, 4*t, e_size)
debug = True
for i, e in enumerate(energy_span):
    # Self energies
    left_lead_rse = left_lead.selfenergy(e)
    right_lead_rse = right_lead.selfenergy(e)
    device_leads_self_energy[0, 0] = left_lead_rse
    device_leads_self_energy[device_H.shape[0]-1, device_H.shape[0]-1] = right_lead_rse
    # Green's function
    # gf = kwant.greens_function(sys, e, out_leads=[1], in_leads=[0])  # only computes Green's function matrix elements at the interface of the lead and the system, not useful for comparison!
    device_retarded_green_gf = nla.inv((e+1j*energy_broadening)*I-device_H-device_leads_self_energy)  # compute the full Green's function matrix for comparison with wave function approach
    # Wave function
    wf = kwant.solvers.default.wave_function(sys, e)
    device_retarded_green_wf = -1j*np.matmul(np.reshape(wf(0), [-1, 1]), np.reshape(np.conj(wf(0)), [1, -1]))  # <------- this is based on Eq. 25 in energy domain of https://arxiv.org/pdf/1307.6419.pdf by Kwant authors 
    # Error between WF and GF
    diff_abs_real = np.abs(np.abs(np.real(device_retarded_green_gf)) - np.abs(np.real(device_retarded_green_wf)))
    diff_abs_imag = np.abs(np.abs(np.imag(device_retarded_green_gf)) - np.abs(np.imag(device_retarded_green_wf)))
    rel_err_real = diff_abs_real/np.abs(np.real(device_retarded_green_gf))
    np.fill_diagonal(rel_err_real, 0) # ignore diagonal relative error real part
    rel_err_imag = diff_abs_imag/np.abs(np.imag(device_retarded_green_gf))
    abs_err_real = diff_abs_real
    np.fill_diagonal(abs_err_real, 0) # ignore diagonal absolute error real part
    abs_err_imag = diff_abs_imag
    err = np.array([np.max(rel_err_real), np.max(rel_err_imag), np.max(abs_err_real), np.max(abs_err_imag)])
    if debug:
        if np.max(err)>1e1*energy_broadening and e>0:
            print('wf and rgf  rel err=', '(', np.max(rel_err_real), ',', np.max(rel_err_imag),')', 'at energy', e)
            print('wf and rgf  abs err=', '(', np.max(abs_err_real), ',', np.max(abs_err_imag),')', 'at energy', e)
    
    print('e', e)
    
    