import json
from pathlib import Path

import lmfit
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np
from quantify_core.data.handling import (
    load_dataset,
    set_datadir,
)

set_datadir(Path("../data"))

fiber_roc = 17.3  # um, measured by white light interferometer


def line(x, offset, slope):
    return slope * x + offset


def w_0(l_cav, lambda_0, roc):
    return np.sqrt((lambda_0 / np.pi) * np.sqrt(l_cav * (roc - l_cav)))


def l_clip(l_cav, lambda_0, n, d_fiber, roc):
    w_0_calc = w_0(l_cav, lambda_0, roc)
    w_m = w_0_calc * np.sqrt(1 + ((l_cav * lambda_0) / (np.pi * n * w_0_calc**2)) ** 2)
    return np.exp(-2 * (d_fiber / (2 * w_m)) ** 2)


def finesse(l_cav, l_fiber, l_mirror, l_add, lambda_0, n, d_fiber, roc):
    return (
        2
        * np.pi
        / (l_fiber + l_mirror + l_add + l_clip(l_cav, lambda_0, n, d_fiber, roc))
    )


def lorentz(x, offset, gamma, center, amp):
    lorentzfunc = (
        amp / np.pi * (gamma / 2) / ((x - center) ** 2 + (gamma / 2) ** 2) + offset
    )
    return lorentzfunc


dataset_finesse = load_dataset("20240717-160933-083-f927c1-cavity_finesse_length")
dataset_scan = load_dataset("20241014-131725-377-1e2358-cavity_finesse")

with open(
    Path(
        "../data/20240717/20240717-160933-083-f927c1-cavity_finesse_length/20240717-160936-673-c1a18a-CavityFinesseLength/quantities_of_interest.json"
    ),
    "r",
) as file:
    qof_finesse = json.load(file)

dataset_length_1 = load_dataset("20240717-163414-332-e4b792-cavity_length")
with open(
    Path(
        "../data/20240717/20240717-163414-332-e4b792-cavity_length/20240717-163414-428-addbd9-CavityLengthAnalysis/quantities_of_interest.json"
    ),
    "r",
) as file:
    qof_length_1 = json.load(file)
dataset_length_2 = load_dataset("20240717-163538-750-1748d7-cavity_length")
with open(
    Path(
        "../data/20240717/20240717-163538-750-1748d7-cavity_length/20240717-163538-838-810c2c-CavityLengthAnalysis/quantities_of_interest.json"
    ),
    "r",
) as file:
    qof_length_2 = json.load(file)
dataset_length_3 = load_dataset("20240717-163751-426-e8b8e0-cavity_length")
with open(
    Path(
        "../data/20240717/20240717-163751-426-e8b8e0-cavity_length/20240717-163751-506-d8b09a-CavityLengthAnalysis/quantities_of_interest.json"
    ),
    "r",
) as file:
    qof_length_3 = json.load(file)


volts = np.array(
    [
        dataset_length_1.attrs["coarse_voltage"],
        dataset_length_2.attrs["coarse_voltage"],
        dataset_length_3.attrs["coarse_voltage"],
    ]
)
lengths = np.array(
    [
        np.mean(qof_length_1["cavity_lengths"]),
        np.mean(qof_length_2["cavity_lengths"]),
        np.mean(qof_length_3["cavity_lengths"]),
    ]
)
lengths_err = np.array(
    [
        np.std(qof_length_1["cavity_lengths"]),
        np.std(qof_length_2["cavity_lengths"]),
        np.std(qof_length_3["cavity_lengths"]),
    ]
)

fit_line = lmfit.Model(line)
fit_line.set_param_hint("offset", value=32e-6)
fit_line.set_param_hint("slope", value=-10e-6)
fit_result_line = fit_line.fit(
    lengths,
    x=volts,
    params=fit_line.make_params(),
)

finesse_values, length = list(), list()
finesse_values_to_fit, length_to_fit = list(), list()
for i in range(len(qof_finesse["finesse"])):
    if np.isnan(qof_finesse["finesse"][i]):
        pass
    else:
        l = 1e6 * line(dataset_finesse.x0.values[i], **fit_result_line.best_values)
        finesse_values.append(qof_finesse["finesse"][i])
        length.append(l)
        if l < fiber_roc:
            finesse_values_to_fit.append(qof_finesse["finesse"][i])
            length_to_fit.append(l)

fit_finesse = lmfit.Model(finesse)
fit_finesse.set_param_hint("l_fiber", value=50e-6, vary=False)
fit_finesse.set_param_hint("l_mirror", value=280e-6, vary=False)
fit_finesse.set_param_hint("l_add", value=400e-6, min=10e-6, max=2000e-6)
fit_finesse.set_param_hint("lambda_0", value=0.637, vary=False)
fit_finesse.set_param_hint("n", value=1, vary=False)
fit_finesse.set_param_hint("d_fiber", value=12, min=10, max=20)
fit_finesse.set_param_hint("roc", value=fiber_roc, vary=False)
fit_result_finesse = fit_finesse.fit(
    np.array(finesse_values_to_fit),
    l_cav=np.array(length_to_fit),
    params=fit_finesse.make_params(),
)

print(lmfit.fit_report(fit_result_finesse))

time_us = dataset_scan.x1 / 1000
start_1, end_1 = 50050, 50550
start_2, end_2 = 85380, 85880
time_left = np.linspace(time_us[start_1].values, time_us[end_1].values, 1000)
time_right = np.linspace(time_us[start_2].values, time_us[end_2].values, 1000)

fit_model_lorentz = lmfit.Model(lorentz)
fit_model_lorentz.set_param_hint("offset", value=0.1, min=0.01)
fit_model_lorentz.set_param_hint("gamma", value=0.01, min=0.001)
fit_model_lorentz.set_param_hint("center", value=804.5, min=800, max=810)
fit_model_lorentz.set_param_hint("amp", value=2, min=0.02)
fit_result_lorentz = fit_model_lorentz.fit(
    dataset_scan.y0[0][start_1:end_1],
    x=time_us[start_1:end_1],
    params=fit_model_lorentz.make_params(),
)

fit_model_lorentz_2 = lmfit.Model(lorentz)
fit_model_lorentz_2.set_param_hint("offset", value=0.1, min=0.01)
fit_model_lorentz_2.set_param_hint("gamma", value=0.01, min=0.001)
fit_model_lorentz_2.set_param_hint("center", value=1370, min=1300, max=1420)
fit_model_lorentz_2.set_param_hint("amp", value=2, min=0.02)
fit_result_lorentz_2 = fit_model_lorentz_2.fit(
    dataset_scan.y0[0][start_2:end_2],
    x=time_us[start_2:end_2],
    params=fit_model_lorentz_2.make_params(),
)

time_offset = 750

fig = plt.figure(figsize=(10, 8))
grid = gridspec.GridSpec(2, 2, width_ratios=[1, 1], wspace=0.25, figure=fig)
left_panels = gridspec.GridSpecFromSubplotSpec(
    2, 2, grid[0, 0], hspace=0.24, wspace=0.3
)

ax_1 = plt.subplot(left_panels[0, :])
ax_2 = plt.subplot(left_panels[1, 0])
ax_3 = plt.subplot(left_panels[1, 1])
ax_4 = plt.subplot(grid[0, 1])

ax_1.text(-0.18, 1.3, "(a)", transform=ax_1.transAxes, fontsize=12, va="top")
ax_1.scatter(time_us - time_offset, dataset_scan.y0[0], s=3)
ax_1.set_xlim(0, 680)
ax_1.plot(
    time_right - time_offset,
    lorentz(time_right, **fit_result_lorentz_2.best_values),
    color="skyblue",
    linestyle="solid",
)
ax_1.plot(
    time_left - time_offset,
    lorentz(time_left, **fit_result_lorentz.best_values),
    color="skyblue",
    linestyle="solid",
    label="Lorentzian Fit",
)
ax_1.annotate(
    "",
    xy=(50, 1),
    xycoords="data",
    xytext=(624, 1),
    textcoords="data",
    arrowprops=dict(arrowstyle="<|-|>", color="skyblue", lw=1),
)
ax_1.legend()

ax_2.set_xlim(54, 54.6)
ax_2.scatter(time_us - time_offset, dataset_scan.y0[0], s=6)
ax_2.plot(
    time_left - time_offset,
    lorentz(time_left, **fit_result_lorentz.best_values),
    color="skyblue",
    linestyle="solid",
    label="Fit",
)
ax_2.annotate(
    "",
    xy=(54.1, 0.8),
    xycoords="data",
    xytext=(54.29, 0.8),
    textcoords="data",
    arrowprops=dict(arrowstyle="<|-", color="skyblue", lw=1),
)
ax_2.annotate(
    "",
    xy=(54.33, 0.8),
    xycoords="data",
    xytext=(54.52, 0.8),
    textcoords="data",
    arrowprops=dict(arrowstyle="-|>", color="skyblue", lw=1),
)

ax_3.set_xlim(619.7, 620.3)
ax_3.scatter(time_us - time_offset, dataset_scan.y0[0], s=6)
ax_3.plot(
    time_right - time_offset,
    lorentz(time_right, **fit_result_lorentz_2.best_values),
    color="skyblue",
    linestyle="solid",
    label="Fit",
)
ax_3.annotate(
    "",
    xy=(619.785, 0.8),
    xycoords="data",
    xytext=(619.975, 0.8),
    textcoords="data",
    arrowprops=dict(arrowstyle="<|-", color="skyblue", lw=1),
)
ax_3.annotate(
    "",
    xy=(620.02, 0.8),
    xycoords="data",
    xytext=(620.21, 0.8),
    textcoords="data",
    arrowprops=dict(arrowstyle="-|>", color="skyblue", lw=1),
)

ax_4.text(-0.22, 1.14, "(b)", transform=ax_4.transAxes, fontsize=12, va="top")
ax_4.errorbar(
    length,
    finesse_values,
    yerr=np.sqrt(finesse_values),
    fmt="o",
    markersize="2",
    color="navy",
    label="Measured",
)
ax_4.plot(
    np.array(length_to_fit),
    finesse(np.array(length_to_fit), **fit_result_finesse.best_values),
    color="skyblue",
    label="Fit",
)
ax_4.legend()
ax_4.set_xlabel(r"Cavity Length $L_{cav}$ (µm)")
ax_4.set_ylabel(r"Finesse $\mathcal{F}$")

f = lambda x: x / (0.637 / 2)
ax_4_2 = ax_4.secondary_xaxis("top", functions=(f, f))
ax_4_2.set_xlabel(r"Fundamental Mode Number")

fig.text(0.29, 0.48, "Scanning Time (µs)", ha="center")
fig.text(0.07, 0.6, "Cavity Transmission (V)", ha="center", rotation=90)

fig.savefig(Path("../Fig_9_sup.png"), dpi=600, bbox_inches="tight")
