from pathlib import Path

import matplotlib.gridspec as gridspec
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
import numpy as np
from quantify_core.data.handling import (
    load_dataset,
    set_datadir,
)

set_datadir(Path("../data"))

calc_diamond_modes = lambda x: (2 * x + 1) * (0.637 / (4 * 2.41))
calc_air_modes = lambda x: (2 * x) * (0.637 / (4 * 2.41))

offset_vincent_vega_x = 30  # 46  # um, for plotting
offset_vincent_vega_y = 24  # 18  # um, for plotting
interferometer_vincent_vega = 1e6 * np.loadtxt(
    Path("../data/samples/vincent_vega_interferometer.txt")
)
interfer_vincent_vega_width = 138.145330795  # um
interfer_vincent_vega_height = 165.774396954  # um
x_values_inter_vincent_vega = (
    np.linspace(0, interfer_vincent_vega_height, len(interferometer_vincent_vega[0]))
    - offset_vincent_vega_x
)
y_values_inter_vincent_vega = (
    np.linspace(0, interfer_vincent_vega_width, len(interferometer_vincent_vega))
    - offset_vincent_vega_y
)

finesse_data_vincent_vega_high = load_dataset(
    "20241218-043859-573-85575c-cavity_lateral_finesse_scan"
)

finesse_vincent_vega_high = finesse_data_vincent_vega_high.y1.values.transpose()[::-1]
rsquared_vincent_vega_high = finesse_data_vincent_vega_high.y8.values.transpose()[::-1]
x_values_finesse_vega_high = finesse_data_vincent_vega_high.x0
y_values_finesse_vega_high = finesse_data_vincent_vega_high.x1

for i in range(len(x_values_finesse_vega_high)):
    for j in range(len(y_values_finesse_vega_high)):
        if rsquared_vincent_vega_high[i, j] <= 0.95:
            finesse_vincent_vega_high[i, j] = np.nan

splitting_data_vincent_vega_high = load_dataset(
    "20241218-043859-573-85575c-cavity_lateral_finesse_scan"
)
finesse_vincent_vega_high = splitting_data_vincent_vega_high.y1.values.transpose()[::-1]
splitting_vincent_vega_high = splitting_data_vincent_vega_high.y5.values.transpose()[
    ::-1
]
rsquared_vincent_vega_high = splitting_data_vincent_vega_high.y9.values.transpose()[
    ::-1
]
linewidth_vincent_vega_high = splitting_data_vincent_vega_high.y4.values.transpose()[
    ::-1
]
transmission_vincent_vega_high = splitting_data_vincent_vega_high.y0.values.transpose()[
    ::-1
]
x_values_splitting_vega_high = splitting_data_vincent_vega_high.x0
y_values_splitting_vega_high = splitting_data_vincent_vega_high.x1

for i in range(len(x_values_splitting_vega_high)):
    for j in range(len(y_values_splitting_vega_high)):
        if rsquared_vincent_vega_high[i, j] <= 0.95:
            splitting_vincent_vega_high[i, j] = np.nan
            linewidth_vincent_vega_high[i, j] = np.nan
            transmission_vincent_vega_high[i, j] = np.nan

plt.rcParams.update({"font.size": 14})

fig = plt.figure(figsize=(10, 17.5))
panel = gridspec.GridSpec(3, 2, wspace=0.22, hspace=0.36, figure=fig)
ax_1 = plt.subplot(panel[0, 0])
ax_2 = plt.subplot(panel[0, 1])
ax_3 = plt.subplot(panel[1, 0])
ax_4 = plt.subplot(panel[1, 1])
ax_5 = plt.subplot(panel[2, 0])
ax_6 = plt.subplot(panel[2, 1])


ax_1.text(-0.14, 1.255, "(a)", transform=ax_1.transAxes, va="top", fontsize=18)
ax_1.imshow(mpimg.imread(Path(r"../data/samples/vincent_vega_light_microscope.png")))
ax_1.axis("off")

ax_2.text(-0.14, 1.255, "(b)", transform=ax_2.transAxes, va="top", fontsize=18)
img_2 = ax_2.pcolormesh(
    x_values_inter_vincent_vega,
    y_values_inter_vincent_vega,
    interferometer_vincent_vega,
    vmin=2.0,
    vmax=2.8,
    cmap="inferno",
)
cax = fig.add_axes([0.55, 0.915, 0.35, 0.02])  # [left, bottom, width, height]
cb = fig.colorbar(img_2, cax=cax, orientation="horizontal")
cb.set_label("Height (µm)")
cax.xaxis.set_ticks_position("bottom")

ax_2.contour(
    x_values_inter_vincent_vega,
    y_values_inter_vincent_vega,
    interferometer_vincent_vega,
    levels=[calc_air_modes(i) for i in [14, 15, 16, 17, 18, 19]],
    colors="white",
)
ax_2.set_xlim(0, 100)
ax_2.set_ylim(0, 100)
ax_2.set_ylabel("Y Position (µm)", labelpad=-1)

cmap = plt.cm.coolwarm
cmap.set_bad(color="black")

ax_3.text(-0.16, 1.245, "(c)", transform=ax_3.transAxes, va="top", fontsize=18)
img_3 = ax_3.pcolormesh(
    x_values_finesse_vega_high,
    y_values_finesse_vega_high,
    transmission_vincent_vega_high,
    vmax=1.6,
    cmap=cmap,
)
ax_3.set_ylabel("Y Position (µm)", labelpad=-1)
ax_3.set_xlim(0, 100)
ax_3.set_ylim(0, 110)
cax = fig.add_axes([0.13, 0.634, 0.35, 0.02])  # [left, bottom, width, height]
cb = fig.colorbar(img_3, cax=cax, orientation="horizontal")
cb.set_label("Cavity Transmission Intensity (V)")
cax.xaxis.set_ticks_position("bottom")
# cax.xaxis.set_ticks([0, 3000, 6000, 9000])

cmap = plt.cm.gist_earth
cmap.set_bad(color="black")

ax_4.text(-0.14, 1.245, "(d)", transform=ax_4.transAxes, va="top", fontsize=18)
img_4 = ax_4.pcolormesh(
    x_values_finesse_vega_high,
    y_values_finesse_vega_high,
    linewidth_vincent_vega_high,
    vmax=6,
    cmap=cmap,
)
ax_4.set_ylabel("Y Position (µm)", labelpad=-1)
ax_4.set_xlim(0, 100)
ax_4.set_ylim(0, 110)
cax = fig.add_axes([0.55, 0.634, 0.35, 0.02])  # [left, bottom, width, height]
cb = fig.colorbar(img_4, cax=cax, orientation="horizontal")
cb.set_label("Cavity Linewidth (GHz)")
cax.xaxis.set_ticks_position("bottom")


cmap = plt.cm.viridis
cmap.set_bad(color="black")

ax_5.text(-0.16, 1.238, "(e)", transform=ax_5.transAxes, va="top", fontsize=18)
img_5 = ax_5.pcolormesh(
    x_values_finesse_vega_high,
    y_values_finesse_vega_high,
    finesse_vincent_vega_high,
    vmax=10000,
    cmap=cmap,
)
ax_5.set_xlabel("X Position (µm)")
ax_5.set_ylabel("Y Position (µm)", labelpad=-1)
ax_5.set_xlim(0, 100)
ax_5.set_ylim(0, 110)
cax = fig.add_axes([0.13, 0.352, 0.35, 0.02])  # [left, bottom, width, height]
cb = fig.colorbar(img_5, cax=cax, orientation="horizontal")
cb.set_label(r"Finesse $\mathcal{F}$")
cax.xaxis.set_ticks_position("bottom")
cax.xaxis.set_ticks([0, 3000, 6000, 9000])

cmap = plt.cm.plasma
cmap.set_bad(color="black")

ax_6.text(-0.14, 1.238, "(f)", transform=ax_6.transAxes, va="top", fontsize=18)
img_6 = ax_6.pcolormesh(
    x_values_splitting_vega_high,
    y_values_splitting_vega_high,
    splitting_vincent_vega_high,
    vmax=14,
    cmap=cmap,
)
ax_6.set_xlabel("X Position (µm)")
ax_6.set_ylabel("Y Position (µm)", labelpad=-1)
ax_6.set_xlim(0, 100)
ax_6.set_ylim(0, 110)

cax = fig.add_axes([0.55, 0.352, 0.35, 0.02])  # [left, bottom, width, height]
cb = fig.colorbar(img_6, cax=cax, orientation="horizontal")
cb.set_label("Polarization Splitting (GHz)")
cax.xaxis.set_ticks_position("bottom")

fig.savefig(Path("../Fig_13_sup.png"), dpi=600, bbox_inches="tight")
