from math import pi
from pathlib import Path

import lmfit
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
from quantify_core.data.handling import (
    load_dataset,
    set_datadir,
)

set_datadir(Path("../data"))


def lorentz(x, offset, gamma, center, amp):
    lorentzfunc = (
        amp / pi * (gamma / 2) / ((x - center) ** 2 + (gamma / 2) ** 2) + offset
    )
    return lorentzfunc


def lorentz_double(x, offset, gamma, center, amp1, amp2, splitting):
    lorentzdouble = (
        offset
        + amp1 / pi * gamma / 2 / ((x - center) ** 2 + gamma**2 / 4)
        + amp2 / pi * gamma / 2 / ((x - center - splitting) ** 2 + gamma**2 / 4)
    )
    return lorentzdouble


def lorentz_triple_double(
    x, offset, gamma, center, amp1, amp2, amp3, amp4, deltat, splitting
):
    lorentztripledouble = (
        offset
        + amp1 / pi * gamma / 2 / ((x - center) ** 2 + gamma**2 / 4)
        + amp2 / pi * gamma / 2 / ((x - center - deltat) ** 2 + gamma**2 / 4)
        + amp2 / pi * gamma / 2 / ((x - center + deltat) ** 2 + gamma**2 / 4)
        + amp3 / pi * gamma / 2 / ((x - center - splitting) ** 2 + gamma**2 / 4)
        + amp4
        / pi
        * gamma
        / 2
        / ((x - center - splitting - deltat) ** 2 + gamma**2 / 4)
        + amp4
        / pi
        * gamma
        / 2
        / ((x - center - splitting + deltat) ** 2 + gamma**2 / 4)
    )
    return lorentztripledouble


dataset_splitting_1 = load_dataset(
    "20241218-100634-306-b252af-cavity_linewidth_sidebands"
)
dataset_splitting_2 = load_dataset(
    "20241217-095515-635-fa1c39-cavity_linewidth_sidebands"
)

time_ns = dataset_splitting_1.x1  # 0, 4, 8, etc
time_us = time_ns / 1000

transmission_1 = dataset_splitting_1.y0.values[0]
transmission_1_td = dataset_splitting_1.y2.values[0]
transmission_2 = dataset_splitting_2.y0.values[0]
transmission_2_td = dataset_splitting_2.y2.values[0]

gamma_1 = 67 / 1000
center_1 = 832666 / 1000
amp1_1 = 88 / 1000
amp2_1 = 54 / 1000
splitting_1 = 350 / 1000

center_1_td = 832163 / 1000
deltat_1 = 141 / 1000
delta12_1 = 3.8
delta34_1 = 3.5

peak_1 = int(center_1 / 4 * 1000)
peak_1_td = int(center_1_td / 4 * 1000)

fit_window = 300
fit_model_1 = lmfit.Model(lorentz_double)
fit_model_td_1 = lmfit.Model(lorentz_triple_double)

fit_model_1.set_param_hint("offset", value=0)
fit_model_1.set_param_hint("gamma", value=gamma_1)
fit_model_1.set_param_hint("amp1", value=amp1_1)
fit_model_1.set_param_hint("amp2", value=amp2_1)
fit_model_1.set_param_hint("splitting", value=splitting_1)
fit_model_1.set_param_hint("center", value=center_1)

params_1 = fit_model_1.make_params()
fit_result_double_1 = fit_model_1.fit(
    transmission_1[peak_1 - fit_window : peak_1 + fit_window],
    x=time_us[peak_1 - fit_window : peak_1 + fit_window],
    params=params_1,
)

fit_model_td_1.set_param_hint("offset", value=0)
fit_model_td_1.set_param_hint(
    "gamma", value=fit_result_double_1.best_values["gamma"], vary=False
)
fit_model_td_1.set_param_hint("amp1", value=0.8 * amp1_1, min=0, max=amp1_1)
fit_model_td_1.set_param_hint("delta12", value=delta12_1, min=1, max=10)
fit_model_td_1.set_param_hint("amp2", expr="amp1/delta12")
fit_model_td_1.set_param_hint("amp3", value=0.8 * amp2_1, min=0, max=amp2_1)
fit_model_td_1.set_param_hint("delta34", value=delta34_1)
fit_model_td_1.set_param_hint("amp4", expr="amp3/delta34")
fit_model_td_1.set_param_hint("deltat", value=deltat_1, min=0, max=1e4)
fit_model_td_1.set_param_hint(
    "splitting", value=fit_result_double_1.best_values["splitting"], vary=False
)
fit_model_td_1.set_param_hint("center", value=center_1_td)

params_1_td = fit_model_td_1.make_params()
fit_result_1_td = fit_model_td_1.fit(
    transmission_1_td[peak_1_td - fit_window : peak_1_td + fit_window],
    x=time_us[peak_1_td - fit_window : peak_1_td + fit_window],
    params=params_1_td,
)

gamma_2 = 40 / 1000
center_2 = 680556 / 1000
amp1_2 = 120 / 1000
amp2_2 = 60 / 1000
splitting_2 = 80 / 1000

center_2_td = 677102 / 1000
deltat_2 = 130 / 1000
delta12_2 = 4
delta34_2 = 4

peak_2 = int(center_2 / 4 * 1000)
peak_2_td = int(center_2_td / 4 * 1000)

fit_window = 300
fit_model_2 = lmfit.Model(lorentz_double)
fit_model_td_2 = lmfit.Model(lorentz_triple_double)


fit_model_2.set_param_hint("offset", value=0)
fit_model_2.set_param_hint("gamma", value=gamma_2)
fit_model_2.set_param_hint("amp1", value=amp1_2)
fit_model_2.set_param_hint("amp2", value=amp2_2)
fit_model_2.set_param_hint("splitting", value=splitting_2)
fit_model_2.set_param_hint("center", value=center_2)

params_2 = fit_model_2.make_params()
fit_result_double_2 = fit_model_2.fit(
    transmission_2[peak_2 - fit_window : peak_2 + fit_window],
    x=time_us[peak_2 - fit_window : peak_2 + fit_window],
    params=params_2,
)

fit_model_td_2.set_param_hint("offset", value=0)
fit_model_td_2.set_param_hint(
    "gamma", value=fit_result_double_2.best_values["gamma"], vary=False
)
fit_model_td_2.set_param_hint("amp1", value=0.8 * amp1_2, min=0, max=amp1_2)
fit_model_td_2.set_param_hint("delta12", value=delta12_2, min=1, max=10)
fit_model_td_2.set_param_hint("amp2", expr="amp1/delta12")
fit_model_td_2.set_param_hint("amp3", value=0.8 * amp2_2, min=0, max=amp2_2)
fit_model_td_2.set_param_hint("delta34", value=delta34_2)
fit_model_td_2.set_param_hint("amp4", expr="amp3/delta34")
fit_model_td_2.set_param_hint("deltat", value=deltat_2, min=0, max=1e4)
fit_model_td_2.set_param_hint(
    "splitting", value=fit_result_double_2.best_values["splitting"], vary=False
)
fit_model_td_2.set_param_hint("center", value=center_2_td)

params_2_td = fit_model_td_2.make_params()
fit_result_2_td = fit_model_td_2.fit(
    transmission_2_td[peak_2_td - fit_window : peak_2_td + fit_window],
    x=time_us[peak_2_td - fit_window : peak_2_td + fit_window],
    params=params_2_td,
)

print(lmfit.fit_report(fit_result_double_1))
print(lmfit.fit_report(fit_result_1_td))
print(lmfit.fit_report(fit_result_double_2))
print(lmfit.fit_report(fit_result_2_td))

fig = plt.figure(figsize=(6, 3))
panel = gridspec.GridSpec(
    2, 2, width_ratios=[1, 1], wspace=0.12, hspace=0.1, figure=fig
)
panel = gridspec.GridSpec(
    1, 2, width_ratios=[1, 1], wspace=0.12, hspace=0.1, figure=fig
)
ax_1 = plt.subplot(panel[0, 0])
ax_2 = plt.subplot(panel[0, 1])

center_1_fit = fit_result_double_1.best_values["center"]
center_1_td_fit = fit_result_1_td.best_values["center"]
center_2_fit = fit_result_double_2.best_values["center"]
center_2_td_fit = fit_result_2_td.best_values["center"]

splitting_1_fit = fit_result_double_1.best_values["splitting"]
splitting_2_fit = fit_result_double_2.best_values["splitting"]
deltat_1_fit = fit_result_1_td.best_values["deltat"]
deltat_2_fit = fit_result_2_td.best_values["deltat"]

ax_1.text(-0.26, 1.1, "(a)", transform=ax_1.transAxes, fontsize=12, va="top")
ax_1.axvline(
    x=0.75 - splitting_1_fit / 2, ls=":", color="skyblue", label="Splitting", alpha=0.5
)
ax_1.axvline(
    x=0.75 - splitting_1_fit / 2 + splitting_1_fit, ls=":", color="skyblue", alpha=0.5
)
ax_1.set_ylim(0, 1)
ax_1.set_xlim(0, 1.5)
ax_1.scatter(time_us - center_1_fit + 0.75 - splitting_1_fit / 2, transmission_1, s=2)
ax_1.plot(
    time_us - center_1_fit + 0.75 - splitting_1_fit / 2,
    lorentz_double(time_us, **fit_result_double_1.best_values),
    color="skyblue",
    linestyle="solid",
    linewidth=1,
    label="Fit",
)
ax_2.text(-0.1, 1.1, "(b)", transform=ax_2.transAxes, fontsize=12, va="top")
ax_2.plot(
    time_us - center_1_td_fit + 0.75 - splitting_1_fit / 2,
    lorentz_triple_double(time_us, **fit_result_1_td.best_values),
    color="skyblue",
    linewidth=1,
    linestyle="solid",
    label="Fit",
)

ax_2.axvline(x=10, ls=":", color="skyblue", label="Splitting", alpha=0.5)
ax_2.axvline(
    x=0.75 - splitting_1_fit / 2 - deltat_1_fit,
    ls="--",
    color="skyblue",
    label="Sidebands",
    alpha=0.5,
)
ax_2.axvline(
    x=0.75 - splitting_1_fit / 2 + deltat_1_fit, ls="--", color="skyblue", alpha=0.5
)
ax_2.set_ylim(0, 1)
ax_2.set_xlim(0, 1.5)
ax_2.scatter(
    time_us - center_1_td_fit + 0.75 - splitting_1_fit / 2, transmission_1_td, s=2
)

ax_2.legend()
ax_2.set_yticklabels([])

fig.text(0.5, -0.02, "Scanning Time (µs)", ha="center")
fig.text(0.05, 0.25, "Cavity Transmission (V)", ha="center", rotation=90)

fig.savefig(Path("../Fig_10_sup.png"), dpi=600, bbox_inches="tight")
