from pathlib import Path

import lmfit
import matplotlib.gridspec as gridspec
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
import numpy as np
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from quantify_core.data.handling import (
    load_dataset,
    set_datadir,
)

set_datadir(Path("../data"))


def gaussian(x, offset, amp, center, sigma):
    exponent = ((x - center) / sigma) ** 2
    gaussian = amp * np.exp(-exponent / 2)
    return offset + gaussian


dataset_map = load_dataset("20231113-160321-694-b91829-cyclops_2d_scan")
dataset_odmr = load_dataset("20230713-092622-621-cd46b1-odmr_green")

fit_model = lmfit.Model(gaussian)
fit_model.set_param_hint("offset", value=63500, min=62500, max=64500)
fit_model.set_param_hint("amp", value=-1000, min=-10000, max=0)
fit_model.set_param_hint("center", value=2.877e9, min=2.86e9, max=2.89e9)
fit_model.set_param_hint("sigma", value=10e6, min=0.5e6, max=20e6)
params = fit_model.make_params()
result = fit_model.fit(
    dataset_odmr.y0.sum(axis=0),
    x=dataset_odmr.x1,
    weights=np.sqrt(dataset_odmr.y0.sum(axis=0)),
)

freq_plot = np.linspace(
    np.min(dataset_odmr.x1).values, np.max(dataset_odmr.x1).values, 1000
)

fig = plt.figure(figsize=(10, 3))
panel = gridspec.GridSpec(1, 3, width_ratios=[0.8, 1, 1], wspace=0.33, figure=fig)

ax_1 = plt.subplot(panel[0, 0])
ax_2 = plt.subplot(panel[0, 1])
ax_3 = plt.subplot(panel[0, 2])

ax_1.text(0, 1.1, "(a)", transform=ax_1.transAxes, fontsize=12, va="top")
ax_1.imshow(mpimg.imread(Path(r"../data/samples/Ringo.png")))
ax_1.axis("off")

ax_2.text(-0.24, 1.1, "(b)", transform=ax_2.transAxes, fontsize=12, va="top")
img = ax_2.pcolormesh(
    dataset_map.x0 - min(dataset_map.x0),
    dataset_map.x1 - min(dataset_map.x1),
    dataset_map.y0,
    vmax=40,
)
ax_2.set_xlim(0, 10)
ax_2.set_ylim(0, 10)
ax_2.set_xlabel("X Position (µm)", fontsize=12)
ax_2.set_ylabel("Y Position (µm)", fontsize=12)
axinset = inset_axes(
    ax_2,
    width="10%",
    height="30%",
    bbox_to_anchor=(-0.13, -0.5, 0.94, 0.92),
    bbox_transform=ax_2.transAxes,
    loc="upper right",
)
cbar = fig.colorbar(img, cax=axinset, orientation="vertical")
cbar.set_label("Counts (kCts/s)", color="white", fontsize=10)
cbar.ax.tick_params(color="white", labelsize=10, labelcolor="white")
cbar.outline.set_edgecolor("white")

dataset_odmr = load_dataset("20230713-092622-621-cd46b1-odmr_green")
fit_model = lmfit.Model(gaussian)
fit_model.set_param_hint("offset", value=63500, min=62500, max=64500)
fit_model.set_param_hint("amp", value=-1000, min=-10000, max=0)
fit_model.set_param_hint("center", value=2.877e9, min=2.86e9, max=2.89e9)
fit_model.set_param_hint("sigma", value=10e6, min=0.5e6, max=20e6)
params = fit_model.make_params()
result = fit_model.fit(
    dataset_odmr.y0.sum(axis=0),
    x=dataset_odmr.x1,
    weights=np.sqrt(dataset_odmr.y0.sum(axis=0)),
)

ax_3.text(-0.29, 1.1, "(c)", transform=ax_3.transAxes, fontsize=12, va="top")
freq_plot = np.linspace(
    np.min(dataset_odmr.x1).values, np.max(dataset_odmr.x1).values, 1000
)
norm = result.params["offset"].value
ax_3.errorbar(
    dataset_odmr.x1 / 1e9,
    dataset_odmr.y0.sum(axis=0) / norm,
    yerr=np.sqrt(dataset_odmr.y0.sum(axis=0)) / norm,
    fmt="o",
    color="yellowgreen",
)
ax_3.plot(
    freq_plot / 1e9,
    gaussian(freq_plot, **result.best_values) / norm,
    color="peru",
    linewidth=2,
    linestyle="dashed",
    label="Gaussian Fit",
    zorder=3,
)
ax_3.set_xlabel("Microwave Frequency (GHz)", fontsize=12)
ax_3.set_ylabel("Contrast", fontsize=12)
ax_3.legend(loc="lower left", fontsize=10)
ax_3.annotate(
    "",
    xy=(2.859, 0.985),
    xycoords="data",
    xytext=(2.8735, 0.985),
    textcoords="data",
    arrowprops=dict(arrowstyle="<|-", color="peru", lw=2),
)
ax_3.annotate(
    "",
    xy=(2.8805, 0.985),
    xycoords="data",
    xytext=(2.895, 0.985),
    textcoords="data",
    arrowprops=dict(arrowstyle="-|>", color="peru", lw=2),
)

width = 2 * np.sqrt(2 * np.log(2)) * result.params["sigma"].value / 1e6
width_err = 2 * np.sqrt(2 * np.log(2)) * result.params["sigma"].stderr / 1e6
ax_3.annotate(
    "{:.1f} MHz\n+/- {:.1f} MHz".format(width, width_err),
    xy=(2.887, 0.976),
    xycoords="data",
    xytext=(2.887, 0.976),
    textcoords="data",
    color="peru",
    fontsize=11,
)

fig.savefig(Path("../figures/fig_21_add.png"), dpi=300, bbox_inches="tight")
