from pathlib import Path

import lmfit
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.patches import ConnectionPatch, RegularPolygon
from matplotlib.path import Path as mPath
from quantify_core.data.handling import (
    load_dataset,
    set_datadir,
)

set_datadir(Path("../data"))


def generate_vertices(dataset, vertices_list):
    vertices = []
    for i in range(len(vertices_list)):
        vert = (
            np.where(dataset.x0 >= vertices_list[i][0])[0][0],
            np.where(dataset.x1 >= vertices_list[i][1])[0][0],
        )
        vertices.append(vert)
    return vertices


def generate_masks(dataset, vertices_list):
    nx, ny = len(dataset.x0), len(dataset.x1)
    poly_verts = generate_vertices(dataset, vertices_list)

    x, y = np.meshgrid(np.arange(nx), np.arange(ny))
    x, y = x.flatten(), y.flatten()

    points = np.vstack((x, y)).T

    path = mPath(poly_verts)
    mask_inner = path.contains_points(points)
    mask_inner = mask_inner.reshape((ny, nx))
    mask_outer = np.logical_not(mask_inner)

    return mask_inner, mask_outer


def plot_vertices(ax, vertices_list, color):
    for c in vertices_list:
        hexagon = RegularPolygon(
            (c[0], c[1]), numVertices=5, radius=0.4, alpha=1, edgecolor="white"
        )
        ax.add_patch(hexagon)
    for c in range(len(vertices_list) - 1):
        conn = ConnectionPatch(
            xyA=(vertices_list[c][0], vertices_list[c][1]),
            coordsA="data",
            axesA=ax,
            xyB=(vertices_list[c + 1][0], vertices_list[c + 1][1]),
            coordsB="data",
            axesB=ax,
            color=color,
        )
        ax.add_artist(conn)
    conn = ConnectionPatch(
        xyA=(vertices_list[-1][0], vertices_list[-1][1]),
        coordsA="data",
        axesA=ax,
        xyB=(vertices_list[0][0], vertices_list[0][1]),
        coordsB="data",
        axesB=ax,
        color=color,
    )
    ax.add_artist(conn)


def gaussian(x, offset, amplitude, center, sigma):
    """A gaussian function."""
    exponent = ((x - center) / sigma) ** 2
    return offset + amplitude * np.exp(-exponent / 2)


bare_finesse_vincent_vega = 7860
bare_finesse_pai_mei = 8000

add_loss_bare_vv = ((2 * np.pi) / bare_finesse_vincent_vega - (50e-6 + 280e-6)) * 1e6
add_loss_bare_pm = ((2 * np.pi) / bare_finesse_pai_mei - (50e-6 + 280e-6)) * 1e6

x_fit = np.linspace(4500, 11000, 1000)

finesse_data_pai_mei_high = load_dataset(
    "20250101-231412-630-520c55-cavity_lateral_finesse_scan"
)
finesse_pai_mei_high = finesse_data_pai_mei_high.y1
rsquared_pai_mei_high = finesse_data_pai_mei_high.y8
x_values_finesse_pai_mei_high = finesse_data_pai_mei_high.x0
y_values_finesse_pai_mei_high = finesse_data_pai_mei_high.x1

for i in range(len(x_values_finesse_pai_mei_high)):
    for j in range(len(y_values_finesse_pai_mei_high)):
        if rsquared_pai_mei_high[i, j] <= 0.95:
            finesse_pai_mei_high[i, j] = np.nan

# (x,y) in um, starting from the top left, clockwise
INNER_VERTICES_LIST_PM_H = [(72, 138), (136, 77), (79, 11), (27, 57), (25, 82)]

OUTER_VERTICES_LIST_PM_H = [
    (68, 144.8),
    (80, 144.8),
    (144.8, 84),
    (144.8, 71),
    (80, 0),
    (74, 0),
    (15, 56),
    (15, 90),
]

_, mask_bare_pai_mei = generate_masks(
    finesse_data_pai_mei_high, OUTER_VERTICES_LIST_PM_H
)
mask_diamond_pai_mei, _ = generate_masks(
    finesse_data_pai_mei_high, INNER_VERTICES_LIST_PM_H
)

finesse_data_vincent_vega_high = load_dataset(
    "20241218-043859-573-85575c-cavity_lateral_finesse_scan"
)

finesse_vincent_vega_high = finesse_data_vincent_vega_high.y1.values.transpose()[::-1]
rsquared_vincent_vega_high = finesse_data_vincent_vega_high.y8.values.transpose()[::-1]
x_values_finesse_vega_high = finesse_data_vincent_vega_high.x0
y_values_finesse_vega_high = finesse_data_vincent_vega_high.x1

for i in range(len(x_values_finesse_vega_high)):
    for j in range(len(y_values_finesse_vega_high)):
        if rsquared_vincent_vega_high[i, j] <= 0.95:
            finesse_vincent_vega_high[i, j] = np.nan

OUTER_VERTICES_LIST_VV_H = [  # (x,y) in um, starting from the top left, clockwise
    (4, 44),
    (6, 72),
    (52, 109.1),
    (64, 109.1),
    (108, 48),
    (48, 0.8),
    (38, 0.8),
]

INNER_VERTICES_LIST_VV_H = [
    (12, 50),
    (14, 68),
    (56, 100),
    (93, 50),
    (42, 10),
]

_, mask_bare_vincent_vega = generate_masks(
    finesse_data_vincent_vega_high, OUTER_VERTICES_LIST_VV_H
)
mask_diamond_vincent_vega, _ = generate_masks(
    finesse_data_vincent_vega_high, INNER_VERTICES_LIST_VV_H
)


fig = plt.figure(figsize=(8, 6))
panel = gridspec.GridSpec(
    2, 2, width_ratios=[1, 1], wspace=0.55, hspace=0.25, figure=fig
)

ax_1 = plt.subplot(panel[0, 0])
ax_2 = plt.subplot(panel[0, 1])
ax_3 = plt.subplot(panel[1, 0])
ax_4 = plt.subplot(panel[1, 1])

cmap = plt.cm.viridis
cmap.set_bad(color="black")

ax_1.text(-0.3, 1.05, "(a)", transform=ax_1.transAxes, fontsize=12, va="top")
img = ax_1.pcolormesh(
    x_values_finesse_vega_high,
    y_values_finesse_vega_high,
    finesse_vincent_vega_high,
    vmax=10000,
    cmap=cmap,
)
cbar = fig.colorbar(img)
cbar.set_label(r"Finesse $\mathcal{F}$")
ax_1.set_xlabel("X Position (µm)")
ax_1.set_ylabel("Y Position (µm)")
plot_vertices(ax_1, OUTER_VERTICES_LIST_VV_H, "black")
plot_vertices(ax_1, INNER_VERTICES_LIST_VV_H, "white")
ax_1.set_xlim(0, 110)
ax_1.set_ylim(0, 110)

ax_1.plot(50, (110 - 48), "rx", label="Wave pattern Vincent Vega")  # markersize = 0.1

norm_bare_vincent_vega = 1
norm_diamond_vincent_vega = 1

ax_2.text(-0.3, 1.05, "(b)", transform=ax_2.transAxes, fontsize=12, va="top")
counts_VV, bin_edges_VV, _ = ax_2.hist(
    (mask_bare_vincent_vega * finesse_vincent_vega_high).flatten()
    / norm_bare_vincent_vega,
    bins=120,
    range=(1, 12000),
    alpha=0.7,
    label="Bare Cavity\n"
    + r"$\mathcal{L}_{\rm add}$="
    + "{:.0f} ppm".format(round(add_loss_bare_vv, -1)),
    color=plt.colormaps.get_cmap("viridis")(220),
)
bin_centers_VV = (bin_edges_VV[:-1] + bin_edges_VV[1:]) / 2

fit_model = lmfit.Model(gaussian)
fit_model.set_param_hint("offset", value=0, min=0)
fit_model.set_param_hint("amplitude", value=max(counts_VV), min=0)
fit_model.set_param_hint("center", value=7500, min=5e3, max=1e4)
fit_model.set_param_hint("sigma", value=1e2, min=0)

params_VV = fit_model.make_params()

fit_result_VV = fit_model.fit(counts_VV, x=bin_centers_VV, params=params_VV)

ax_2.hist(
    (mask_diamond_vincent_vega * finesse_vincent_vega_high).flatten()
    / norm_diamond_vincent_vega,
    bins=120,
    range=(1, 12000),
    alpha=0.7,
    label="Diamond",
    color=plt.colormaps.get_cmap("viridis")(60),
)
ax_2.plot(
    x_fit,
    gaussian(x_fit, **fit_result_VV.best_values),
    linestyle="solid",
    color="grey",
    label="Gaussian Fit",
)
ax_2.legend(fontsize=8)
ax_2.set_xlabel(r"Finesse $\mathcal{F}$")
ax_2.set_ylabel("Occurrences")
ax_2.set_xlim(0, 11000)
# ax_2.set_ylim(0,1200)
print("fit report Vincent Vega:")
print(fit_result_VV.fit_report())
ax_3.text(-0.3, 1.05, "(c)", transform=ax_3.transAxes, fontsize=12, va="top")
img = ax_3.pcolormesh(
    x_values_finesse_pai_mei_high,
    y_values_finesse_pai_mei_high,
    finesse_pai_mei_high,
    vmax=10000,
    cmap=cmap,
)
cbar = fig.colorbar(img)
cbar.set_label(r"Finesse $\mathcal{F}$")
ax_3.set_xlabel("X Position (µm)")
ax_3.set_ylabel("Y Position (µm)")
plot_vertices(ax_3, OUTER_VERTICES_LIST_PM_H, "black")
plot_vertices(ax_3, INNER_VERTICES_LIST_PM_H, "white")
ax_3.set_xlim(0, 145)
ax_3.set_ylim(0, 145)

ax_3.plot(62, 34, "rx", label="Wave pattern Pai Mei")  # markersize = 0.1

norm_bare_bare_mei = 1
norm_diamond_bare_mei = 1

ax_4.text(-0.3, 1.05, "(d)", transform=ax_4.transAxes, fontsize=12, va="top")
counts_PM, bin_edges_PM, _ = ax_4.hist(
    (mask_bare_pai_mei * finesse_pai_mei_high).values.flatten() / norm_bare_bare_mei,
    bins=120,
    range=(1, 12000),
    alpha=0.7,
    label="Bare Cavity\n"
    + r"$\mathcal{L}_{\rm add}$="
    + "{:.0f} ppm".format(round(add_loss_bare_pm, -1)),
    color=plt.colormaps.get_cmap("viridis")(220),
)
bin_centers_PM = (bin_edges_PM[:-1] + bin_edges_PM[1:]) / 2

fit_model = lmfit.Model(gaussian)
fit_model.set_param_hint("offset", value=0, min=0)
fit_model.set_param_hint("amplitude", value=max(counts_PM), min=0)
fit_model.set_param_hint("center", value=7500, min=5e3, max=1e4)
fit_model.set_param_hint("sigma", value=1e2, min=0)

params_PM = fit_model.make_params()

fit_result_PM = fit_model.fit(counts_PM, x=bin_centers_PM, params=params_PM)

print("fit report Pai Mei:")
print(fit_result_PM.fit_report())

ax_4.hist(
    (mask_diamond_pai_mei * finesse_pai_mei_high).values.flatten()
    / norm_diamond_bare_mei,
    bins=120,
    range=(1, 12000),
    alpha=0.7,
    label="Diamond",
    color=plt.colormaps.get_cmap("viridis")(60),
)
ax_4.plot(
    x_fit,
    gaussian(x_fit, **fit_result_PM.best_values),
    linestyle="solid",
    color="grey",
    label="Gaussian Fit",
)
ax_4.legend(fontsize=8)
ax_4.set_xlabel(r"Finesse $\mathcal{F}$")
ax_4.set_ylabel("Occurrences")
ax_4.set_xlim(0, 11000)

bare_finesse_vincent_vega = 7860
bare_finesse_pai_mei = 8000

diamond_mode_color = "navy"
air_mode_color = "deepskyblue"

print(
    "Additional losses bare cavity Vincent Vega: {} ppm".format(
        ((2 * np.pi) / fit_result_VV.params["center"].value - (50e-6 + 280e-6)) * 1e6
    )
)
print(
    "Additional losses bare cavity Pai-Mei: {} ppm".format(
        ((2 * np.pi) / fit_result_PM.params["center"].value - (50e-6 + 280e-6)) * 1e6
    )
)

fig.savefig(Path("../Fig_3.png"), dpi=600, bbox_inches="tight")
