from pathlib import Path

import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np
import scipy.constants

# import plot_basics
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from quantify_core.data.handling import (
    load_dataset_from_path,
    set_datadir,
)

set_datadir(Path("../data"))

# load data

dat_dis_PM = load_dataset_from_path(
    Path(
        "../data/20250102/20250102-132953-655-bb9bb8-cavity_length_sweep/20250102-133946-008-dca5d8-CavityLengthSweepAnalysis/dataset_processed.hdf5"
    )
)
dat_dis_VV = load_dataset_from_path(
    Path(
        "../data/20241211/20241211-100604-836-fb3283-cavity_length_sweep/20250203-171825-873-b38007-CavityLengthSweepAnalysis/dataset_processed.hdf5"
    )
)

fig = plt.figure(figsize=(10, 4))
panel = gridspec.GridSpec(
    2, 2, width_ratios=[1, 1], wspace=0.55, hspace=0.1, figure=fig
)
panel = gridspec.GridSpec(
    1, 2, width_ratios=[1, 1], wspace=0.55, hspace=0.1, figure=fig
)
ax = plt.subplot(panel[0, 0])
ax_3 = plt.subplot(panel[0, 1])

ax.text(-0.24, 1.0, "(a)", transform=ax.transAxes, fontsize=12, va="top")

img = ax.pcolormesh(
    dat_dis_PM.air_length,
    dat_dis_PM.frequencies,
    np.flip(dat_dis_PM.intensity, axis=1),
    shading="auto",
    vmin=0,
    vmax=15000,
)

axinset = inset_axes(
    ax,
    width="30%",
    height="8%",
    bbox_to_anchor=(0.58, 0.2, 0.8, 0.45),
    bbox_transform=ax.transAxes,
    loc=3,
)
cbar = fig.colorbar(img, cax=axinset, orientation="horizontal", pad=-210)

cbar.set_label("Intensity (Cts/s)", color="white", fontsize=14)

cbar.ax.tick_params(color="white", labelsize=10, labelcolor="white")
cbar.outline.set_edgecolor("white")
ax.plot(
    dat_dis_PM.air_length,
    dat_dis_PM.fit_frequency[0],
    linestyle="solid",
    color="red",
    label=r"Fit",
)

for i in range(1, len(dat_dis_PM.fit_frequency)):
    ax.plot(
        dat_dis_PM.air_length,
        dat_dis_PM.fit_frequency[i],
        linestyle="solid",
        color="red",
    )

ax.set_ylim([np.min(dat_dis_PM.frequencies), np.max(dat_dis_PM.frequencies)])
ax.set_xlabel("Air Gap (µm)", fontsize=14)
ax.set_ylabel("Frequency (THz)", fontsize=14)
ax.axhline(470.4, color="white", linestyle="dashed")
ax.tick_params(axis="both", which="major", labelsize=14)
ax.tick_params(axis="both", which="minor", labelsize=14)
ax_2 = ax.twinx()
wavelength = scipy.constants.c / (dat_dis_PM.frequencies) * 1e-3
ax_2.set_ylim(np.max(wavelength), np.min(wavelength))
ax_2.set_ylabel("Wavelength (nm)", fontsize=14)
ax_2.tick_params(axis="both", which="major", labelsize=14)

lines, labels = ax.get_legend_handles_labels()
lines2, labels2 = ax_2.get_legend_handles_labels()

ax.legend(lines + lines2, labels + labels2, loc="upper right", fontsize=14)
ax.annotate(
    "",
    xy=(6.48, 483),
    xytext=(6.42, 478),
    arrowprops=dict(color="white", headwidth=8, headlength=8, width=0.8),
)


ax_3.text(-0.24, 1.0, "(b)", transform=ax_3.transAxes, fontsize=12, va="top")
img = ax_3.pcolormesh(
    dat_dis_VV.air_length,
    dat_dis_VV.frequencies,
    np.flip(dat_dis_VV.intensity, axis=1),
    shading="auto",
    vmin=0,
    vmax=15000,
)

axinset = inset_axes(
    ax_3,
    width="30%",
    height="8%",
    bbox_to_anchor=(0.58, 0.2, 0.8, 0.45),
    bbox_transform=ax_3.transAxes,
    loc=3,
)
cbar = fig.colorbar(img, cax=axinset, orientation="horizontal", pad=-210)

cbar.set_label("Intensity (Cts/s)", color="white", fontsize=14)

cbar.ax.tick_params(color="white", labelsize=10, labelcolor="white")
cbar.outline.set_edgecolor("white")
ax_3.plot(
    dat_dis_VV.air_length,
    dat_dis_VV.fit_frequency[0],
    linestyle="solid",
    color="red",
    label=r"Fit",
)

for i in range(1, len(dat_dis_VV.fit_frequency)):
    ax_3.plot(
        dat_dis_VV.air_length,
        dat_dis_VV.fit_frequency[i],
        linestyle="solid",
        color="red",
    )

ax_3.set_ylim([np.min(dat_dis_VV.frequencies), np.max(dat_dis_VV.frequencies)])
ax_3.set_xlabel("Air Gap (µm)", fontsize=14)
ax_3.set_ylabel("Frequency (THz)", fontsize=14)
ax_3.tick_params(axis="both", which="major", labelsize=14)
ax_3.tick_params(axis="both", which="minor", labelsize=14)
ax_3.axhline(470.4, color="white", linestyle="dashed")
ax_4 = ax_3.twinx()
wavelength = scipy.constants.c / (dat_dis_VV.frequencies) * 1e-3
ax_4.set_ylim(np.max(wavelength), np.min(wavelength))
ax_4.set_ylabel("Wavelength (nm)", fontsize=14)
ax_4.tick_params(axis="both", which="major", labelsize=14)

lines3, labels3 = ax_3.get_legend_handles_labels()
lines4, labels4 = ax_4.get_legend_handles_labels()

ax_3.legend(lines3 + lines4, labels3 + labels4, loc="upper right", fontsize=14)
ax_3.annotate(
    "",
    xy=(6.48, 483),
    xytext=(6.42, 478),
    arrowprops=dict(color="white", headwidth=8, headlength=8, width=0.8),
)

fig.savefig(Path("../Fig_11_sup.png"), dpi=600, transparent=True, bbox_inches="tight")
