from pathlib import Path

import lmfit
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np
import scipy.stats as st
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from quantify_core.data.handling import (
    load_dataset,
    set_datadir,
)

set_datadir(Path("../data"))


def lorentz(x, offset, gamma, center, amp):
    lorentzfunc = (
        amp / np.pi * (gamma / 2) / ((x - center) ** 2 + (gamma / 2) ** 2) + offset
    )
    return lorentzfunc


def gaussian(x, offset, amp, center, sigma):
    exponent = ((x - center) / sigma) ** 2
    gaussian = amp * np.exp(-exponent / 2)
    return offset + gaussian


dataset_nv_long = load_dataset("20231113-112012-033-79e95d-ple_laserscan")

integrated_resonant_counts = np.zeros(len(dataset_nv_long.x0))
for i in range(len(dataset_nv_long.x0)):
    for j in range(len(dataset_nv_long.x1)):
        if dataset_nv_long.x1[j] >= dataset_nv_long.attrs["start_resonant_us"]:
            integrated_resonant_counts[i] += dataset_nv_long.y0[i, j]

dataset_nv = load_dataset("20231113-130922-632-cea001-fast_laserscan")
start_f = np.mean(dataset_nv.y2[:, 0].values[65:85])
end_f = dataset_nv.y2.mean(axis=0).values[-1]
start_v = dataset_nv.x1.values[0]
end_v = dataset_nv.x1.values[-1]
voltage_to_freq = lambda v: start_f + (end_f - start_f) / (end_v - start_v) * (
    v - start_v
)
frequencies = voltage_to_freq(dataset_nv.x1)

reps_1 = [6, 7, 8, 9, 10, 12, 15, 16, 20, 22, 23, 24, 26, 27, 28, 29, 30]
centers_1 = [
    98.2,
    98.2,
    97.8,
    97.9,
    97.9,
    98.5,
    98.1,
    98.1,
    98.2,
    98.2,
    98.2,
    98,
    98,
    98,
    98.3,
    98.3,
    98.3,
]
reps_2 = [31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 42, 43, 44, 45, 47, 48, 49, 50]
centers_2 = [
    98.3,
    98.3,
    98.3,
    98.3,
    98.3,
    98.3,
    98.3,
    98.0,
    98.3,
    98.3,
    98.3,
    98.3,
    98.3,
    98.3,
    98.3,
    98.5,
    98.0,
    98.5,
]
reps_3 = [
    51,
    54,
    56,
    58,
    59,
    60,
    61,
    62,
    63,
    67,
    68,
    69,
    72,
    73,
    74,
    75,
    78,
    82,
    84,
    85,
]
centers_3 = [
    98.0,
    98.15,
    98.5,
    98,
    98.4,
    98.4,
    98.4,
    98.0,
    98.7,
    98.4,
    98.3,
    98.3,
    98.7,
    98.3,
    98.5,
    98.5,
    98.5,
    98.2,
    98.5,
    98.2,
]
reps_4 = [86, 87, 89, 91, 92, 93, 94, 95, 96, 97, 98, 99]
centers_4 = [98.2, 98.2, 98.6, 98.6, 98.2, 98.2, 98.2, 98.2, 98.2, 98.2, 98.2, 98.2]
reps = reps_1 + reps_2 + reps_3 + reps_4
centers = centers_1 + centers_2 + centers_3 + centers_4

lorentz_plot = 16
start, end = 100, 320

peaks_fwhm, peaks_fwhm_err, lorentz_centers = list(), list(), list()

for i in range(len(reps)):
    fit_model_lorentz = lmfit.Model(lorentz)
    fit_model_lorentz.set_param_hint("offset", value=1)
    fit_model_lorentz.set_param_hint("gamma", value=0.04)
    fit_model_lorentz.set_param_hint("center", value=centers[i])
    fit_model_lorentz.set_param_hint("amp", value=5)
    fit_result_lorentz = fit_model_lorentz.fit(
        dataset_nv.y0[reps[i]][start:end],
        x=frequencies[start:end],
        params=fit_model_lorentz.make_params(),
    )
    peaks_fwhm.append(fit_result_lorentz.params["gamma"].value)
    peaks_fwhm_err.append(fit_result_lorentz.params["gamma"].stderr)
    lorentz_centers.append(fit_result_lorentz.params["center"].value)


reps_with_nv = np.linspace(0, len(reps) - 1, len(reps))
frequencies_with_nv = list()
traces_with_nv = list()
for i in range(len(dataset_nv.x0)):
    if dataset_nv.x0[i] in reps:
        traces_with_nv.append(dataset_nv.y0[i])

summed_gaussian = np.sum(traces_with_nv, axis=0)

bin_freq = np.arange(-1.5, 1.52, 0.02)
bin_sum = np.zeros(len(bin_freq) - 1)
for i in range(len(traces_with_nv)):
    cent = lorentz_centers[i]
    act_binned, _, _ = st.binned_statistic(
        frequencies - cent, traces_with_nv[i], statistic="sum", bins=bin_freq
    )
    bin_sum += act_binned

fit_model_gauss = lmfit.Model(gaussian)
fit_model_gauss.set_param_hint("offset", value=100)
fit_model_gauss.set_param_hint("amp", value=100)
fit_model_gauss.set_param_hint("center", value=98.2)
fit_model_gauss.set_param_hint("sigma", value=0.2)
fit_result_gauss = fit_model_gauss.fit(
    summed_gaussian[start:end],
    x=frequencies[start:end],
    params=fit_model_gauss.make_params(),
)

fit_model_lorentz = lmfit.Model(lorentz)
fit_model_lorentz.set_param_hint("offset", value=100)
fit_model_lorentz.set_param_hint("gamma", value=0.1)
fit_model_lorentz.set_param_hint("center", value=0)
fit_model_lorentz.set_param_hint("amp", value=500)
fit_result_lorentz = fit_model_lorentz.fit(
    bin_sum,
    x=bin_freq[:-1],
    params=fit_model_lorentz.make_params(),
)

norm_guass = max(
    gaussian(
        x=np.linspace(frequencies[start].values, frequencies[end].values, 1000),
        **fit_result_gauss.best_values,
    )
)
norm_lorentz = max(
    lorentz(
        x=np.linspace(bin_freq[0], bin_freq[-1], 1000), **fit_result_lorentz.best_values
    )
)

fig = plt.figure(figsize=(12, 6))
panel = gridspec.GridSpec(
    2,
    2,
    height_ratios=[1.5, 2],
    width_ratios=[2, 1],
    hspace=0.0,
    wspace=0.14,
    figure=fig,
)

ax_1 = plt.subplot(panel[0, 0])
ax_2 = plt.subplot(panel[1, 0], sharex=ax_1)
ax_3 = plt.subplot(panel[0, 1])
ax_4 = plt.subplot(panel[1, 1], sharex=ax_3)

plt.setp(ax_1.get_xticklabels(), visible=False)
plt.setp(ax_3.get_xticklabels(), visible=False)

img = ax_2.pcolormesh(
    dataset_nv_long.x0,
    dataset_nv_long.x1,
    dataset_nv_long.y0.transpose(),
    vmax=50,
)
ax_2.set_xlabel(
    "                                                                           Laser Excitation Frequency w.r.t. 470.4 THz (GHz)",
    fontsize=11,
)
ax_2.set_ylabel("Integration Time (µs)", fontsize=11)
axinset = inset_axes(
    ax_2,
    width="10%",
    height="30%",
    bbox_to_anchor=(0.01, 0, 0.8, 0.95),
    bbox_transform=ax_2.transAxes,
    loc="upper left",
)
cbar = fig.colorbar(img, cax=axinset, orientation="vertical")
cbar.set_label("Counts (Cts/s)", color="white", fontsize=10)
cbar.ax.tick_params(color="white", labelsize=10, labelcolor="white")
cbar.outline.set_edgecolor("white")

ax_1.text(-0.105, 1.15, "(a)", transform=ax_1.transAxes, fontsize=12, va="top")
ax_1.plot(dataset_nv_long.x0, integrated_resonant_counts, color="navy")
ax_1.set_ylabel("Integrated Counts (Cts/s)", fontsize=11)

img = ax_4.pcolormesh(
    frequencies,
    dataset_nv.x0,
    dataset_nv.y0,
    vmax=10,
)
ax_4.set_xlim(97.2, 99.2)
ax_4.set_ylabel("Scan Repetition", fontsize=11)
axinset = inset_axes(
    ax_4,
    width="10%",
    height="30%",
    bbox_to_anchor=(0.01, 0, 0.8, 0.95),
    bbox_transform=ax_4.transAxes,
    loc="upper left",
)
cbar = fig.colorbar(img, cax=axinset, orientation="vertical")
cbar.set_label("Counts (Cts/s)", color="white", fontsize=10)
cbar.ax.tick_params(color="white", labelsize=10, labelcolor="white")
cbar.outline.set_edgecolor("white")

ax_3.text(-0.2, 1.15, "(b)", transform=ax_3.transAxes, fontsize=12, va="top")
ax_3.scatter(
    frequencies,
    summed_gaussian / norm_guass,
    s=8,
    color="dodgerblue",
    label="Summed Counts",
)
ax_3.plot(
    np.linspace(frequencies[start].values, frequencies[end].values, 1000),
    gaussian(
        x=np.linspace(frequencies[start].values, frequencies[end].values, 1000),
        **fit_result_gauss.best_values,
    )
    / norm_guass,
    color="navy",
    linestyle="dashed",
    label="Gaussian Fit",
)
ax_3.scatter(
    bin_freq[:-1] + fit_result_gauss.params["center"].value,
    bin_sum / norm_lorentz,
    s=8,
    color="yellowgreen",
    label="Centered Summed Counts",
)
ax_3.plot(
    np.linspace(bin_freq[0], bin_freq[-2], 1000)
    + fit_result_gauss.params["center"].value,
    lorentz(
        x=np.linspace(bin_freq[0], bin_freq[-2], 1000), **fit_result_lorentz.best_values
    )
    / norm_lorentz,
    color="forestgreen",
    linestyle="dashed",
    label="Lorentzian Fit",
)
ax_3.set_xlim(97.2, 99.2)
ax_3.set_ylabel("Normalized Counts", fontsize=11)
ax_3.legend(
    loc="upper left",
    bbox_to_anchor=(-2, -1.56),
    ncol=4,
    fancybox=True,
    shadow=True,
    fontsize=11,
)

ax_3.annotate(
    "",
    xy=(97.97, 0.5),
    xycoords="data",
    xytext=(97.7, 0.5),
    textcoords="data",
    arrowprops=dict(arrowstyle="-|>", color="navy", lw=2),
)
ax_3.annotate(
    "",
    xy=(98.43, 0.5),
    xycoords="data",
    xytext=(98.7, 0.5),
    textcoords="data",
    arrowprops=dict(arrowstyle="-|>", color="navy", lw=2),
)
width_ple = 2 * np.sqrt(2 * np.log(2)) * fit_result_gauss.params["sigma"].value
width_ple_err = 2 * np.sqrt(2 * np.log(2)) * fit_result_gauss.params["sigma"].stderr
ax_3.annotate(
    "{:.1f} MHz".format(1000 * width_ple),
    xy=(98.55, 0.55),
    xycoords="data",
    xytext=(98.55, 0.55),
    textcoords="data",
    color="navy",
    fontsize=11,
)
ax_3.annotate(
    "{:.1f} MHz\n+/- {:.1f} MHz".format(
        1000 * fit_result_lorentz.params["gamma"].value,
        1000 * fit_result_lorentz.params["gamma"].stderr,
    ),
    xy=(97.3, 0.55),
    xycoords="data",
    xytext=(97.3, 0.55),
    textcoords="data",
    color="forestgreen",
    fontsize=11,
)
ax_3.annotate(
    "",
    xy=(98.1, 0.5),
    xycoords="data",
    xytext=(98.28, 0.5),
    textcoords="data",
    arrowprops=dict(arrowstyle="<|-|>", color="forestgreen", lw=1),
)

fig.savefig(Path("../figures/fig_22_add.png"), dpi=300, bbox_inches="tight")
