from modules import system
import kwant
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr

# Build device
dimensions=dict(
    L_device=200,
    L_lead=20,
    prefactor = 10
)
device=system.Device(dimensions, experiment='longitudinal')
device.make_device()

# Load data from Fermi surface calculation
attrs = xr.load_dataset('./data/fermi_surf.nc').attrs

# Set parameters with trigonal warping
params_1 = dict(
    B=0,
    mu=attrs['energy1'],
    dmu=0.,
    m=attrs['m'],
    m2=0,
    alpha=1,
    chi=None,
    L_device = dimensions['L_device'],
    L_lead = dimensions['L_lead']
)
# Set parameters without trigonal warping
params_0 = params_1.copy()
params_0['alpha'] = 0
params_0['mu'] = attrs['energy0']
# Compute injected wavefunctions
wf0 = kwant.wave_function(device.fsyst, energy=0., params=params_0)
psis0 = wf0(0)
wf1 = kwant.wave_function(device.fsyst, energy=0., params=params_1)
psis1 = wf1(0)

# +
# Define semicircle
r_cut = 0.5 * dimensions['L_device']

def circle(site_to, site_from):
    origin = np.array([-dimensions['L_device'], 0])
    pos_to = site_to.pos - origin
    pos_from = site_from.pos - origin
    r_to = np.linalg.norm(pos_to)
    r_from = np.linalg.norm(pos_from)
    return (
        r_from <= r_cut
        and r_cut < r_to
    )

# Define current operator in semicircle
J = kwant.operator.Current(device.fsyst, where=circle, sum=False)
sites_ids = np.asarray(J.where)

# Compute angle in semicircle for each corresponding hoppinf
@np.vectorize
def return_angle(site_id1, site_id2):
    origin = np.array([-dimensions['L_device'], 0])
    r1 = device.fsyst.sites[site_id1].pos
    r2 = device.fsyst.sites[site_id2].pos
    r_med = 0.5 * (r1 + r2) - origin
    r_diff = r2 - r1
    return np.arctan2(*np.flip([*r_med])), np.arctan2(*np.flip([*r_diff]))

cut_angles, hopp_angles = return_angle(sites_ids[:,0], sites_ids[:,1])
# -

# Compute current density passing through semicircle
current0 = sum(J(psi, params=params_0) for psi in psis0)
current1 = sum(J(psi, params=params_1) for psi in psis1)

# +
# Define injection angle
inj_angle = 2 * np.arcsin(1/50)
angle_interval = np.pi / len(sites_ids)
points = int(inj_angle / angle_interval)

# Order by angle
sort_idx = np.argsort(cut_angles)
cut_angles = cut_angles[sort_idx]
hopp_angles = hopp_angles[sort_idx]
current0 = current0[sort_idx]
current1 = current1[sort_idx]

current = np.stack(
    [
       current0,
       current1,
    ]
)
# -

# Create dataset
ds = xr.Dataset(
    data_vars=dict(
        current=(['alpha', 'theta'], current)
    ),
    coords=dict(
        theta=cut_angles,
        alpha=[0, 1]
    )
)

# +
# Take rolling average to average out C-C bonds

ds_roll = ds.rolling(theta=300, center=True).mean()

# +
# Inspect distribution
fig = plt.figure()
ax = fig.add_subplot(111, polar=True)

ds_roll.current.plot(hue='alpha', ax=ax)

ax.set_thetamin(-90)
ax.set_thetamax(90)
ax.set_theta_zero_location('N')
# ax.set_yticks([0, 1])
ax.set_xlabel('')
ax.set_ylabel('')

plt.show()

# +
fig = plt.figure()
ax = fig.add_subplot(111)

(ds_roll.current / ds_roll.current.max()).plot(hue='alpha', ax=ax)
ax.plot(ds_roll.theta, np.cos(ds_roll.theta), ls='--', c='k')

ax.set_yticks([0, 1])
ax.set_xlabel('')
ax.set_ylabel('')

plt.show()
# -

# Store data
ds.to_netcdf('./data/angle_distribution.nc')

# Store as CSV
df = ds.to_dataframe()
df = df.to_csv('./data/angle_distribution.csv')
