from pathlib import Path

import lmfit
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np
import uncertainties as un
from quantify_core.data.handling import (
    load_dataset,
    set_datadir,
)
from scipy.special import wofz
from uncertainties.umath import sqrt

set_datadir(Path("../data"))


def lorentz(x, offset, gamma, center, amp):
    lorentzfunc = (
        amp / np.pi * (gamma / 2) / ((x - center) ** 2 + (gamma / 2) ** 2) + offset
    )
    return lorentzfunc


def gaussian(x, offset, amp, center, sigma):
    exponent = ((x - center) / sigma) ** 2
    gaussian = amp * np.exp(-exponent / 2)
    return offset + gaussian


def Voigt(x, offset, center, amp, sigma, gamma):
    return offset + amp * np.real(
        wofz((x - center + 1j * gamma) / sigma / np.sqrt(2))
    ) / sigma / np.sqrt(2 * np.pi)


lifetime_limit = 0.013
fit_error = 0.1


def calc_voigt_width(sigma, gamma):
    if sigma.stderr is None:
        sig = un.ufloat(sigma.value, fit_error * sigma.value)
    else:
        sig = un.ufloat(sigma.value, sigma.stderr)
    if gamma.stderr is None:
        gam = 2 * un.ufloat(gamma.value, fit_error * gamma.value)
    else:
        gam = 2 * un.ufloat(gamma.value, gamma.stderr)
    return 0.5346 * gam + sqrt(
        0.2166 * gam**2 + (2 * sig * np.sqrt(2 * np.log(2))) ** 2
    )


dataset_nv_large_scan = load_dataset("20220805-142007-638-353985-line_scan")
dataset_nv = load_dataset("20220805-161039-403-0e202c-line_scan")

start_1, end_1 = 270, 370
start_2, end_2 = 550, 750
start_3, end_3 = 1040, 1140

fit_model_voigt = lmfit.Model(Voigt)
fit_model_voigt.set_param_hint("offset", value=10)
fit_model_voigt.set_param_hint("gamma", value=0.02, min=lifetime_limit / 2)
fit_model_voigt.set_param_hint("sigma", value=0.04, min=0.001)
fit_model_voigt.set_param_hint("center", value=67.55)
fit_model_voigt.set_param_hint("amp", value=16)
fit_result_voigt = fit_model_voigt.fit(
    dataset_nv.y0[start_1:end_1],
    x=dataset_nv.x0[start_1:end_1],
    params=fit_model_voigt.make_params(),
)

fit_model_voigt_2 = lmfit.Model(Voigt)
fit_model_voigt_2.set_param_hint("offset", value=10)
fit_model_voigt_2.set_param_hint("gamma", value=0.02, min=lifetime_limit / 2)
fit_model_voigt_2.set_param_hint("sigma", value=0.04, min=0.001)
fit_model_voigt_2.set_param_hint("center", value=70.25)
fit_model_voigt_2.set_param_hint("amp", value=290)
fit_result_voigt_2 = fit_model_voigt_2.fit(
    dataset_nv.y0[start_2:end_2],
    x=dataset_nv.x0[start_2:end_2],
    params=fit_model_voigt_2.make_params(),
)

fit_model_voigt_3 = lmfit.Model(Voigt)
fit_model_voigt_3.set_param_hint("offset", value=10)
fit_model_voigt_3.set_param_hint("gamma", value=0.02, min=lifetime_limit / 2)
fit_model_voigt_3.set_param_hint("sigma", value=0.04, min=0.001)
fit_model_voigt_3.set_param_hint("center", value=73.7)
fit_model_voigt_3.set_param_hint("amp", value=290)
fit_result_voigt_3 = fit_model_voigt_3.fit(
    dataset_nv.y0[start_3:end_3],
    x=dataset_nv.x0[start_3:end_3],
    params=fit_model_voigt_3.make_params(),
)


fig = plt.figure(figsize=(10, 5))
panel = gridspec.GridSpec(2, 1, height_ratios=[0.8, 1], hspace=0.2, figure=fig)
bottom_grid = gridspec.GridSpecFromSubplotSpec(1, 3, panel[1, 0], wspace=0.08)

ax_main = plt.subplot(panel[0, 0])
ax_1 = plt.subplot(bottom_grid[0, 0])
ax_2 = plt.subplot(bottom_grid[0, 1])
ax_3 = plt.subplot(bottom_grid[0, 2])

plt.setp(ax_2.get_yticklabels(), visible=False)
plt.setp(ax_3.get_yticklabels(), visible=False)

ax_main.text(-0.086, 1.1, "(a)", transform=ax_main.transAxes, fontsize=12, va="top")
ax_main.plot(
    dataset_nv_large_scan.x0, dataset_nv_large_scan.y0, linewidth=2, color="dodgerblue"
)
# ax_main.set_xlim(66,91)
ax_main.set_xlim(45, 95)
ax_main.set_ylabel("Counts (Cts/s)", fontsize=11)

ax_1.text(-0.26, 1.15, "(b)", transform=ax_1.transAxes, fontsize=12, va="top")
max_1 = max(
    Voigt(
        x=np.linspace(dataset_nv.x0[start_1].values, dataset_nv.x0[end_1].values, 1000),
        **fit_result_voigt.best_values,
    )
)
ax_1.scatter(dataset_nv.x0, dataset_nv.y0 / max_1, s=14, color="dodgerblue")
ax_1.plot(
    np.linspace(dataset_nv.x0[start_1].values, dataset_nv.x0[end_1].values, 1000),
    Voigt(
        x=np.linspace(dataset_nv.x0[start_1].values, dataset_nv.x0[end_1].values, 1000),
        **fit_result_voigt.best_values,
    )
    / max_1,
    color="navy",
    linestyle="dashed",
    label="Voigt Fit",
)
ax_1.set_xlim(67.35, 67.75)
ax_1.set_ylim(0, 1.1)
ax_1.legend(loc="upper left", fontsize=10)
width = calc_voigt_width(
    fit_result_voigt.params["sigma"], fit_result_voigt.params["gamma"]
)
ax_1.annotate(
    "",
    xy=(67.46, 0.5),
    xycoords="data",
    xytext=(67.54, 0.5),
    textcoords="data",
    arrowprops=dict(arrowstyle="<|-", color="navy", lw=2),
)
ax_1.annotate(
    "",
    xy=(67.565, 0.5),
    xycoords="data",
    xytext=(67.655, 0.5),
    textcoords="data",
    arrowprops=dict(arrowstyle="-|>", color="navy", lw=2),
)
ax_1.annotate(
    "{:.1f} MHz\n+/- {:.1f} MHz".format(1000 * width.n, 1000 * width.s),
    xy=(67.59, 0.7),
    xycoords="data",
    xytext=(67.59, 0.6),
    textcoords="data",
    color="navy",
    fontsize=11,
)
ax_1.set_ylabel("Normalized Counts", fontsize=11)

max_2 = max(
    Voigt(
        x=np.linspace(dataset_nv.x0[start_2].values, dataset_nv.x0[end_2].values, 1000),
        **fit_result_voigt_2.best_values,
    )
)
ax_2.scatter(dataset_nv.x0, dataset_nv.y0 / max_2, s=14, color="dodgerblue")
ax_2.plot(
    np.linspace(dataset_nv.x0[start_2].values, dataset_nv.x0[end_2].values, 1000),
    Voigt(
        x=np.linspace(dataset_nv.x0[start_2].values, dataset_nv.x0[end_2].values, 1000),
        **fit_result_voigt_2.best_values,
    )
    / max_2,
    color="navy",
    linestyle="dashed",
    label="Voigt Fit",
)
ax_2.set_xlim(70.025, 70.425)
ax_2.set_ylim(0, 1.1)
ax_2.legend(loc="upper left", fontsize=10)
width_2 = calc_voigt_width(
    fit_result_voigt_2.params["sigma"], fit_result_voigt_2.params["gamma"]
)
ax_2.annotate(
    "",
    xy=(70.11, 0.5),
    xycoords="data",
    xytext=(70.19, 0.5),
    textcoords="data",
    arrowprops=dict(arrowstyle="<|-", color="navy", lw=2),
)
ax_2.annotate(
    "",
    xy=(70.242, 0.5),
    xycoords="data",
    xytext=(70.322, 0.5),
    textcoords="data",
    arrowprops=dict(arrowstyle="-|>", color="navy", lw=2),
)
ax_2.annotate(
    "{:.1f} MHz\n+/- {:.1f} MHz".format(1000 * width_2.n, 1000 * width_2.s),
    xy=(70.255, 0.6),
    xycoords="data",
    xytext=(70.255, 0.6),
    textcoords="data",
    color="navy",
    fontsize=11,
)
ax_2.set_xlabel("Laser Excitation Frequency w.r.t. 470.4 THz (GHz)", fontsize=11)

max_3 = max(
    Voigt(
        x=np.linspace(dataset_nv.x0[start_3].values, dataset_nv.x0[end_3].values, 1000),
        **fit_result_voigt_3.best_values,
    )
)
ax_3.scatter(dataset_nv.x0, dataset_nv.y0 / max_3, s=14, color="dodgerblue")
ax_3.plot(
    np.linspace(dataset_nv.x0[start_3].values, dataset_nv.x0[end_3].values, 1000),
    Voigt(
        x=np.linspace(dataset_nv.x0[start_3].values, dataset_nv.x0[end_3].values, 1000),
        **fit_result_voigt_3.best_values,
    )
    / max_3,
    color="navy",
    linestyle="dashed",
    label="Voigt Fit",
)
ax_3.set_xlim(73.55, 73.95)
ax_3.set_ylim(0, 1.1)
ax_3.legend(loc="upper left", fontsize=10)
width_3 = calc_voigt_width(
    fit_result_voigt_3.params["sigma"], fit_result_voigt_2.params["gamma"]
)
ax_3.annotate(
    "",
    xy=(73.64, 0.5),
    xycoords="data",
    xytext=(73.72, 0.5),
    textcoords="data",
    arrowprops=dict(arrowstyle="<|-", color="navy", lw=2),
)
ax_3.annotate(
    "",
    xy=(73.766, 0.5),
    xycoords="data",
    xytext=(73.846, 0.5),
    textcoords="data",
    arrowprops=dict(arrowstyle="-|>", color="navy", lw=2),
)
ax_3.annotate(
    "{:.1f} MHz\n+/- {:.1f} MHz".format(1000 * width_3.n, 1000 * width_3.s),
    xy=(73.78, 0.6),
    xycoords="data",
    xytext=(73.78, 0.6),
    textcoords="data",
    color="navy",
    fontsize=11,
)

fig.savefig(Path("../figures/fig_20_add.png"), dpi=300, bbox_inches="tight")
