from pathlib import Path

import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable
from quantify_core.data.handling import (
    load_dataset,
    set_datadir,
)

set_datadir(Path("../data"))

dataset_map_2e13 = load_dataset("20230710-100406-778-462f83-cyclops_2d_scan")
dataset_map_4e13 = load_dataset("20230707-133432-166-041a07-cyclops_2d_scan")
dataset_map_1e14 = load_dataset("20230707-180646-804-104f3f-cyclops_2d_scan")


fig = plt.figure(figsize=(10, 3))
panel = gridspec.GridSpec(1, 3, width_ratios=[1, 1, 1], wspace=0.3, figure=fig)

ax_1 = plt.subplot(panel[0, 0])
ax_2 = plt.subplot(panel[0, 1])
ax_3 = plt.subplot(panel[0, 2])

img_1 = ax_1.pcolormesh(
    dataset_map_2e13.x0 - min(dataset_map_2e13.x0),
    dataset_map_2e13.x1 - min(dataset_map_2e13.x1),
    dataset_map_2e13.y0,
)
ax_1_divider = make_axes_locatable(ax_1)
cax_1 = ax_1_divider.append_axes("top", size="5%", pad="4%")
cb_1 = fig.colorbar(img_1, cax=cax_1, orientation="horizontal")
cb_1.set_label("Counts (kCts/s)", labelpad=-38)
cax_1.xaxis.set_ticks_position("top")
ax_1.set_xlabel("X Position (µm)")
ax_1.set_ylabel("Y Position (µm)")
ax_1.text(
    1,
    1,
    r"$2 \times 10^{13} \rm \; e^{-}/cm^2$",
    color="black",
    bbox=dict(facecolor="white", edgecolor="white", boxstyle="round,pad=0.2"),
)

img_2 = ax_2.pcolormesh(
    dataset_map_4e13.x0 - min(dataset_map_4e13.x0),
    dataset_map_4e13.x1 - min(dataset_map_4e13.x1),
    dataset_map_4e13.y0,
)
ax_2_divider = make_axes_locatable(ax_2)
cax_2 = ax_2_divider.append_axes("top", size="5%", pad="4%")
cb_2 = fig.colorbar(img_2, cax=cax_2, orientation="horizontal")
cb_2.set_label("Counts (kCts/s)", labelpad=-38)
cax_2.xaxis.set_ticks_position("top")
ax_2.set_xlabel("X Position (µm)")
ax_2.set_ylabel("Y Position (µm)")
ax_2.text(
    1,
    1,
    r"$4 \times 10^{13} \rm \; e^{-}/cm^2$",
    color="black",
    bbox=dict(facecolor="white", edgecolor="white", boxstyle="round,pad=0.2"),
)

img_3 = ax_3.pcolormesh(
    dataset_map_1e14.x0 - min(dataset_map_1e14.x0),
    dataset_map_1e14.x1 - min(dataset_map_1e14.x1),
    dataset_map_1e14.y0,
)
ax_3_divider = make_axes_locatable(ax_3)
cax_3 = ax_3_divider.append_axes("top", size="5%", pad="4%")
cb_3 = fig.colorbar(img_3, cax=cax_3, orientation="horizontal")
cb_3.set_label("Counts (kCts/s)", labelpad=-38)
cax_3.xaxis.set_ticks_position("top")
ax_3.set_xlabel("X Position (µm)")
ax_3.set_ylabel("Y Position (µm)")
ax_3.text(
    1,
    1,
    r"$1 \times 10^{14} \rm \; e^{-}/cm^2$",
    color="black",
    bbox=dict(facecolor="white", edgecolor="white", boxstyle="round,pad=0.2"),
)


fig.savefig(Path("../figures/fig_14_add.png"), dpi=300, bbox_inches="tight")
