from pathlib import Path

import lmfit
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np
import scipy.stats as st
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from quantify_core.data.handling import (
    load_dataset,
    set_datadir,
)

set_datadir(Path("../data"))


def lorentz(x, offset, gamma, center, amp):
    lorentzfunc = (
        amp / np.pi * (gamma / 2) / ((x - center) ** 2 + (gamma / 2) ** 2) + offset
    )
    return lorentzfunc


def gaussian(x, offset, amp, center, sigma):
    exponent = ((x - center) / sigma) ** 2
    gaussian = amp * np.exp(-exponent / 2)
    return offset + gaussian


def double_gaussian(x, offset, amp_1, amp_2, center_1, center_2, sigma_1, sigma_2):
    exponent_1 = ((x - center_1) / sigma_1) ** 2
    exponent_2 = ((x - center_2) / sigma_2) ** 2
    gaussian_1 = np.abs(amp_1) * np.exp(-exponent_1 / 2)
    gaussian_2 = np.abs(amp_2) * np.exp(-exponent_2 / 2)
    return offset + gaussian_1 + gaussian_2


# Load and anlyse SnV data

dataset_snv_1 = load_dataset("20230316-205912-136-74ef68-fast_line_scan")

lorentz_centers_1, peaks_fwhm_1 = list(), list()
exclude_scans_1 = [16, 33, 35, 60, 62, 77, 89]

for rep in range(len(dataset_snv_1.x0)):
    if rep not in exclude_scans_1:
        fit_model_lorentz = lmfit.Model(lorentz)
        fit_model_lorentz.set_param_hint("offset", value=200)
        fit_model_lorentz.set_param_hint("gamma", value=0.05)
        fit_model_lorentz.set_param_hint("center", value=123)
        fit_model_lorentz.set_param_hint("amp", value=4000)
        fit_result_lorentz = fit_model_lorentz.fit(
            dataset_snv_1.y0[rep],
            x=dataset_snv_1.x1,
            params=fit_model_lorentz.make_params(),
        )
        peaks_fwhm_1.append(fit_result_lorentz.params["gamma"].value)
        lorentz_centers_1.append(fit_result_lorentz.params["center"].value)

traces_with_snv_1 = list()
for i in range(len(dataset_snv_1.x0)):
    if dataset_snv_1.x0[i] not in exclude_scans_1:
        traces_with_snv_1.append(dataset_snv_1.y0[i])

summed_gaussian_1 = dataset_snv_1.y0.sum(axis=0)

bin_freq_1 = np.arange(-0.75, 0.75, 0.005)
bin_sum_1 = np.zeros(len(bin_freq_1) - 1)
for i in range(len(traces_with_snv_1)):
    cent = lorentz_centers_1[i]
    act_binned, _, _ = st.binned_statistic(
        dataset_snv_1.x1 - cent, traces_with_snv_1[i], statistic="sum", bins=bin_freq_1
    )
    bin_sum_1 += act_binned

fit_model_gauss_1 = lmfit.Model(gaussian)
fit_model_gauss_1.set_param_hint("offset", value=100)
fit_model_gauss_1.set_param_hint("amp", value=10000)
fit_model_gauss_1.set_param_hint("center", value=123)
fit_model_gauss_1.set_param_hint("sigma", value=0.2)
fit_result_gauss_1 = fit_model_gauss_1.fit(
    summed_gaussian_1,
    x=dataset_snv_1.x1,
    params=fit_model_gauss_1.make_params(),
)

fit_model_lorentz_1 = lmfit.Model(lorentz)
fit_model_lorentz_1.set_param_hint("offset", value=100)
fit_model_lorentz_1.set_param_hint("gamma", value=0.1)
fit_model_lorentz_1.set_param_hint("center", value=0)
fit_model_lorentz_1.set_param_hint("amp", value=100000)
fit_result_lorentz_1 = fit_model_lorentz_1.fit(
    bin_sum_1,
    x=bin_freq_1[:-1],
    params=fit_model_lorentz_1.make_params(),
)

norm_gauss_1 = max(
    gaussian(
        x=np.linspace(dataset_snv_1.x1[0].values, dataset_snv_1.x1[-1].values, 1000),
        **fit_result_gauss_1.best_values,
    )
)
norm_lorentz_1 = max(
    lorentz(
        x=np.linspace(bin_freq_1[0], bin_freq_1[-1], 1000),
        **fit_result_lorentz_1.best_values,
    )
)

dataset_snv_2 = load_dataset("20230320-205050-772-1af654-fast_line_scan")
freqiencies_corrected_2 = dataset_snv_2.x1 - 13600
lorentz_centers_2, peaks_fwhm_snv_2 = list(), list()
left_scans_2 = [
    2,
    3,
    10,
    16,
    20,
    30,
    32,
    33,
    43,
    46,
    49,
    53,
    54,
    56,
    58,
    66,
    69,
    70,
    71,
    72,
    73,
    74,
    79,
    80,
    83,
    92,
    93,
    98,
    99,
]
right_scans_2 = [
    0,
    7,
    9,
    12,
    15,
    17,
    19,
    21,
    23,
    24,
    25,
    26,
    27,
    28,
    36,
    38,
    39,
    44,
    45,
    48,
    52,
    55,
    57,
    59,
    61,
    63,
    65,
    67,
    68,
    75,
    76,
    77,
    78,
    81,
    82,
    86,
    87,
    90,
    91,
    96,
]
include_scans_2 = left_scans_2 + right_scans_2
include_scans_2.sort()

for rep in range(len(dataset_snv_2.x0)):
    if rep in include_scans_2:
        fit_model_lorentz = lmfit.Model(lorentz)
        fit_model_lorentz.set_param_hint("offset", value=50)
        fit_model_lorentz.set_param_hint("gamma", value=0.05, min=0, max=0.1)
        fit_model_lorentz.set_param_hint("center", value=123.3, min=123.1, max=123.5)
        fit_model_lorentz.set_param_hint("amp", value=4000)
        fit_result_lorentz = fit_model_lorentz.fit(
            dataset_snv_2.y0[rep],
            x=freqiencies_corrected_2,
            params=fit_model_lorentz.make_params(),
        )
        lorentz_centers_2.append(fit_result_lorentz.params["center"].value)
        peaks_fwhm_snv_2.append(fit_result_lorentz.params["gamma"].value)

traces_with_snv_2 = list()
for i in range(len(dataset_snv_2.x0)):
    if dataset_snv_2.x0[i] in include_scans_2:
        traces_with_snv_2.append(dataset_snv_2.y0[i])

summed_gaussian_2 = dataset_snv_2.y0.sum(axis=0)

bin_freq_2 = np.arange(-0.75, 0.75, 0.005)
bin_sum_2 = np.zeros(len(bin_freq_2) - 1)
for i in range(len(traces_with_snv_2)):
    cent = lorentz_centers_2[i]
    act_binned, _, _ = st.binned_statistic(
        freqiencies_corrected_2 - cent,
        traces_with_snv_2[i],
        statistic="sum",
        bins=bin_freq_2,
    )
    bin_sum_2 += act_binned

fit_model_gauss_2 = lmfit.Model(double_gaussian)
fit_model_gauss_2.set_param_hint("offset", value=5000)
fit_model_gauss_2.set_param_hint("amp_1", value=100000, min=0)
fit_model_gauss_2.set_param_hint("center_1", value=123.2)
fit_model_gauss_2.set_param_hint("sigma_1", value=0.1)
fit_model_gauss_2.set_param_hint("amp_2", value=500000, min=0)
fit_model_gauss_2.set_param_hint("center_2", value=123.35)
fit_model_gauss_2.set_param_hint("sigma_2", value=0.1)
fit_result_gauss_2 = fit_model_gauss_2.fit(
    summed_gaussian_2,
    x=freqiencies_corrected_2,
    params=fit_model_gauss_2.make_params(),
)

fit_model_lorentz_2 = lmfit.Model(lorentz)
fit_model_lorentz_2.set_param_hint("offset", value=100)
fit_model_lorentz_2.set_param_hint("gamma", value=0.1)
fit_model_lorentz_2.set_param_hint("center", value=0)
fit_model_lorentz_2.set_param_hint("amp", value=100000)
fit_result_lorentz_2 = fit_model_lorentz_2.fit(
    bin_sum_2,
    x=bin_freq_2[:-1],
    params=fit_model_lorentz_2.make_params(),
)

norm_gauss_2 = max(
    double_gaussian(
        x=np.linspace(
            freqiencies_corrected_2[0].values, freqiencies_corrected_2[-1].values, 1000
        ),
        **fit_result_gauss_2.best_values,
    )
)
norm_lorentz_2 = max(
    lorentz(
        x=np.linspace(bin_freq_2[0], bin_freq_2[-1], 1000),
        **fit_result_lorentz_2.best_values,
    )
)
dataset_snv_3 = load_dataset("20230320-230754-584-58d6c6-fast_line_scan")
freqiencies_corrected_3 = dataset_snv_3.x1 - 13600

lorentz_centers_3, peaks_fwhm_snv_3 = list(), list()
exclude_scans_3 = [
    6,
    7,
    9,
    14,
    17,
    23,
    24,
    25,
    28,
    29,
    36,
    41,
    45,
    46,
    47,
    48,
    49,
    53,
    58,
    59,
    61,
    62,
    63,
    64,
    65,
    66,
    67,
    75,
    76,
    77,
    81,
    82,
    83,
    85,
    86,
    87,
    89,
    90,
    94,
    95,
    97,
    99,
]

for rep in range(len(dataset_snv_3.x0)):
    if rep not in exclude_scans_3:
        fit_model_lorentz = lmfit.Model(lorentz)
        fit_model_lorentz.set_param_hint("offset", value=200)
        fit_model_lorentz.set_param_hint("gamma", value=0.05, min=0)
        fit_model_lorentz.set_param_hint("center", value=123.4)
        fit_model_lorentz.set_param_hint("amp", value=4000, min=0)
        fit_result_lorentz = fit_model_lorentz.fit(
            dataset_snv_3.y0[rep],
            x=freqiencies_corrected_3,
            params=fit_model_lorentz.make_params(),
        )
        peaks_fwhm_snv_3.append(fit_result_lorentz.params["gamma"].value)
        lorentz_centers_3.append(fit_result_lorentz.params["center"].value)

traces_with_snv_3 = list()
for i in range(len(dataset_snv_3.x0)):
    if dataset_snv_3.x0[i] not in exclude_scans_3:
        traces_with_snv_3.append(dataset_snv_3.y0[i])

summed_gaussian_3 = dataset_snv_3.y0.sum(axis=0)

bin_freq_3 = np.arange(-0.75, 0.75, 0.005)
bin_sum_3 = np.zeros(len(bin_freq_3) - 1)
for i in range(len(traces_with_snv_3)):
    cent = lorentz_centers_3[i]
    act_binned, _, _ = st.binned_statistic(
        freqiencies_corrected_3 - cent,
        traces_with_snv_3[i],
        statistic="sum",
        bins=bin_freq_3,
    )
    bin_sum_3 += act_binned

fit_model_gauss_3 = lmfit.Model(double_gaussian)
fit_model_gauss_3.set_param_hint("offset", value=10000)
fit_model_gauss_3.set_param_hint("amp_1", value=100000)
fit_model_gauss_3.set_param_hint("center_1", value=123.3)
fit_model_gauss_3.set_param_hint("sigma_1", value=0.21)
fit_model_gauss_3.set_param_hint("amp_2", value=70000)
fit_model_gauss_3.set_param_hint("center_2", value=123.5)
fit_model_gauss_3.set_param_hint("sigma_2", value=0.1)
fit_result_gauss_3 = fit_model_gauss_3.fit(
    summed_gaussian_3,
    x=freqiencies_corrected_3,
    params=fit_model_gauss_3.make_params(),
)

fit_model_lorentz_3 = lmfit.Model(lorentz)
fit_model_lorentz_3.set_param_hint("offset", value=100)
fit_model_lorentz_3.set_param_hint("gamma", value=0.1, min=0)
fit_model_lorentz_3.set_param_hint("center", value=0)
fit_model_lorentz_3.set_param_hint("amp", value=500, min=0)
fit_result_lorentz_3 = fit_model_lorentz_3.fit(
    bin_sum_3,
    x=bin_freq_3[:-1],
    params=fit_model_lorentz_3.make_params(),
)

norm_gauss_3 = max(
    double_gaussian(
        x=np.linspace(
            freqiencies_corrected_3[0].values, freqiencies_corrected_3[-1].values, 1000
        ),
        **fit_result_gauss_3.best_values,
    )
)
norm_lorentz_3 = max(
    lorentz(
        x=np.linspace(bin_freq_3[0], bin_freq_3[-1], 1000),
        **fit_result_lorentz_3.best_values,
    )
)

fig = plt.figure(figsize=(12, 6))
panel = gridspec.GridSpec(
    2, 3, height_ratios=[1.5, 2], hspace=0.0, wspace=0.18, figure=fig
)

ax_1 = plt.subplot(panel[0, 0])
ax_2 = plt.subplot(panel[1, 0], sharex=ax_1)
ax_3 = plt.subplot(panel[0, 1])
ax_4 = plt.subplot(panel[1, 1], sharex=ax_3)
ax_5 = plt.subplot(panel[0, 2])
ax_6 = plt.subplot(panel[1, 2], sharex=ax_5)

plt.setp(ax_1.get_xticklabels(), visible=False)
plt.setp(ax_3.get_xticklabels(), visible=False)
plt.setp(ax_5.get_xticklabels(), visible=False)

ax_1.text(-0.2, 1.1, "(a)", transform=ax_1.transAxes, fontsize=12, va="top")
ax_1.scatter(
    dataset_snv_1.x1,
    summed_gaussian_1 / norm_gauss_1,
    s=15,
    color="dodgerblue",
    label="Summed Counts",
)
ax_1.plot(
    np.linspace(dataset_snv_1.x1[0].values, dataset_snv_1.x1[-1].values, 1000),
    gaussian(
        x=np.linspace(dataset_snv_1.x1[0].values, dataset_snv_1.x1[-1].values, 1000),
        **fit_result_gauss_1.best_values,
    )
    / norm_gauss_1,
    color="navy",
    linestyle="dashed",
    label="(Double) Gaussian Fit",
)
ax_1.scatter(
    bin_freq_1[:-1] + fit_result_gauss_1.params["center"].value,
    bin_sum_1 / norm_lorentz_1,
    s=15,
    color="yellowgreen",
    label="Centered Summed Counts",
)
ax_1.plot(
    np.linspace(bin_freq_1[0], bin_freq_1[-2], 1000)
    + fit_result_gauss_1.params["center"].value,
    lorentz(
        x=np.linspace(bin_freq_1[0], bin_freq_1[-2], 1000),
        **fit_result_lorentz_1.best_values,
    )
    / norm_lorentz_1,
    color="forestgreen",
    linestyle="dashed",
    label="Lorentzian Fit",
)
ax_1.set_xlim(122.7, 123.3)
ax_1.annotate(
    "",
    xy=(123.09, 0.5),
    xycoords="data",
    xytext=(123.17, 0.5),
    textcoords="data",
    arrowprops=dict(arrowstyle="-|>", color="navy", lw=2),
)
ax_1.annotate(
    "",
    xy=(122.88, 0.5),
    xycoords="data",
    xytext=(122.96, 0.5),
    textcoords="data",
    arrowprops=dict(arrowstyle="<|-", color="navy", lw=2),
)
ax_1.annotate(
    "{:.1f} MHz".format(
        1000 * 2 * np.sqrt(2 * np.log(2)) * fit_result_gauss_1.params["sigma"].value
    ),
    xy=(123.12, 0.58),
    xycoords="data",
    xytext=(123.12, 0.58),
    color="navy",
    textcoords="data",
    fontsize=11,
)
ax_1.annotate(
    "",
    xy=(122.99, 0.5),
    xycoords="data",
    xytext=(123.055, 0.5),
    textcoords="data",
    arrowprops=dict(arrowstyle="<|-|>", color="forestgreen", lw=1),
)
ax_1.annotate(
    "{:.1f} MHz\n+/- {:.1f} MHz".format(
        1000 * fit_result_lorentz_1.params["gamma"].value,
        1000 * fit_result_lorentz_1.params["gamma"].stderr,
    ),
    xy=(122.75, 0.58),
    xycoords="data",
    xytext=(122.75, 0.58),
    textcoords="data",
    color="forestgreen",
    fontsize=11,
)
ax_1.set_ylabel("Normalized Counts", fontsize=12)
ax_1.set_ylim(-0.05, 1.18)

img = ax_2.pcolormesh(
    dataset_snv_1.x1,
    dataset_snv_1.x0,
    dataset_snv_1.y0 / 1000,
)
axinset = inset_axes(
    ax_2,
    width="10%",
    height="30%",
    bbox_to_anchor=(0.68, 0, 0.95, 0.95),
    bbox_transform=ax_2.transAxes,
    loc="upper left",
)
cbar = fig.colorbar(img, cax=axinset, orientation="vertical")
cbar.set_label("Counts (kCts/s)", color="white", fontsize=10)
cbar.ax.tick_params(color="white", labelsize=10, labelcolor="white")
cbar.outline.set_edgecolor("white")
ax_2.set_xticks([122.8, 123.0, 123.2])
ax_2.set_ylabel("Scan Repetition", fontsize=12)

ax_3.text(-0.15, 1.1, "(b)", transform=ax_3.transAxes, fontsize=12, va="top")
ax_3.scatter(
    freqiencies_corrected_2,
    summed_gaussian_2 / norm_gauss_2,
    s=15,
    color="dodgerblue",
    label="Summed Counts",
)
ax_3.plot(
    np.linspace(
        freqiencies_corrected_2[0].values, freqiencies_corrected_2[-1].values, 1000
    ),
    double_gaussian(
        x=np.linspace(
            freqiencies_corrected_2[0].values, freqiencies_corrected_2[-1].values, 1000
        ),
        **fit_result_gauss_2.best_values,
    )
    / norm_gauss_2,
    color="navy",
    linestyle="dashed",
    label="(Double) Gaussian Fit",
)
ax_3.scatter(
    bin_freq_2[:-1] + fit_result_gauss_2.params["center_2"].value,
    bin_sum_2 / norm_lorentz_2,
    s=15,
    color="yellowgreen",
    label="Centered Summed Counts",
)
ax_3.plot(
    np.linspace(bin_freq_2[0], bin_freq_2[-2], 1000)
    + fit_result_gauss_2.params["center_2"].value,
    lorentz(
        x=np.linspace(bin_freq_2[0], bin_freq_2[-2], 1000),
        **fit_result_lorentz_2.best_values,
    )
    / norm_lorentz_2,
    color="forestgreen",
    linestyle="dashed",
    label="Lorentzian Fit",
)
ax_3.set_xlim(123, 123.6)
ax_3.annotate(
    "",
    xy=(123.13, 0.5),
    xycoords="data",
    xytext=(123.21, 0.5),
    textcoords="data",
    arrowprops=dict(arrowstyle="<|-", color="navy", lw=2),
)
ax_3.annotate(
    "",
    xy=(123.37, 0.5),
    xycoords="data",
    xytext=(123.45, 0.5),
    textcoords="data",
    arrowprops=dict(arrowstyle="-|>", color="navy", lw=2),
)
ax_3.annotate(
    "176.0 MHz",
    xy=(123.4, 0.58),
    xycoords="data",
    xytext=(123.4, 0.58),
    color="navy",
    textcoords="data",
    fontsize=11,
)
ax_3.annotate(
    "",
    xy=(123.315, 0.5),
    xycoords="data",
    xytext=(123.37, 0.5),
    textcoords="data",
    arrowprops=dict(arrowstyle="<|-|>", color="forestgreen", lw=1),
)
ax_3.annotate(
    "{:.1f} MHz\n+/- {:.1f} MHz".format(
        1000 * fit_result_lorentz_2.params["gamma"].value,
        1000 * fit_result_lorentz_2.params["gamma"].stderr,
    ),
    xy=(123.04, 0.75),
    xycoords="data",
    xytext=(123.04, 0.75),
    textcoords="data",
    color="forestgreen",
    fontsize=11,
)
ax_3.legend(
    loc="upper center",
    bbox_to_anchor=(0.5, -1.6),
    ncol=4,
    fancybox=True,
    shadow=True,
    fontsize=11,
)
ax_3.set_ylim(-0.05, 1.18)

img = ax_4.pcolormesh(
    freqiencies_corrected_2,
    dataset_snv_2.x0,
    dataset_snv_2.y0 / 1000,
)
axinset = inset_axes(
    ax_4,
    width="10%",
    height="30%",
    bbox_to_anchor=(0.68, 0, 0.95, 0.95),
    bbox_transform=ax_4.transAxes,
    loc="upper left",
)
cbar = fig.colorbar(img, cax=axinset, orientation="vertical")
cbar.set_label("Counts (kCts/s)", color="white", fontsize=10)
cbar.ax.tick_params(color="white", labelsize=10, labelcolor="white")
cbar.outline.set_edgecolor("white")
ax_4.set_xticks([123.1, 123.3, 123.5])

ax_5.text(-0.15, 1.1, "(c)", transform=ax_5.transAxes, fontsize=12, va="top")
ax_5.scatter(
    freqiencies_corrected_3,
    summed_gaussian_3 / norm_gauss_3,
    s=15,
    color="dodgerblue",
    label="Summed Counts",
)
ax_5.plot(
    np.linspace(
        freqiencies_corrected_3[0].values, freqiencies_corrected_3[-1].values, 1000
    ),
    double_gaussian(
        x=np.linspace(
            freqiencies_corrected_3[0].values, freqiencies_corrected_3[-1].values, 1000
        ),
        **fit_result_gauss_3.best_values,
    )
    / norm_gauss_3,
    color="navy",
    linestyle="dashed",
    label="(Double) Gaussian Fit",
)
ax_5.scatter(
    bin_freq_3[:-1] + fit_result_gauss_3.params["center_1"].value,
    bin_sum_3 / norm_lorentz_3,
    s=15,
    color="yellowgreen",
    label="Centered Summed Counts",
)
ax_5.plot(
    np.linspace(bin_freq_3[0], bin_freq_3[-2], 1000)
    + fit_result_gauss_3.params["center_1"].value,
    lorentz(
        x=np.linspace(bin_freq_3[0], bin_freq_3[-2], 1000),
        **fit_result_lorentz_3.best_values,
    )
    / norm_lorentz_3,
    color="forestgreen",
    linestyle="dashed",
    label="Lorentzian Fit",
)
ax_5.annotate(
    "",
    xy=(123.205, 0.5),
    xycoords="data",
    xytext=(123.285, 0.5),
    textcoords="data",
    arrowprops=dict(arrowstyle="<|-", color="navy", lw=2),
)
ax_5.annotate(
    "",
    xy=(123.484, 0.5),
    xycoords="data",
    xytext=(123.564, 0.5),
    textcoords="data",
    arrowprops=dict(arrowstyle="-|>", color="navy", lw=2),
)
ax_5.annotate(
    "211.0 MHz",
    xy=(123.51, 0.58),
    xycoords="data",
    xytext=(123.51, 0.58),
    color="navy",
    textcoords="data",
    fontsize=11,
)
ax_5.annotate(
    "",
    xy=(123.315, 0.5),
    xycoords="data",
    xytext=(123.365, 0.5),
    textcoords="data",
    arrowprops=dict(arrowstyle="<|-|>", color="forestgreen", lw=1),
)
ax_5.annotate(
    "{:.1f} MHz\n+/- {:.1f} MHz".format(
        1000 * fit_result_lorentz_3.params["gamma"].value,
        1000 * fit_result_lorentz_3.params["gamma"].stderr,
    ),
    xy=(123.11, 0.91),
    xycoords="data",
    xytext=(123.11, 0.91),
    textcoords="data",
    color="forestgreen",
    fontsize=11,
)
ax_5.set_xlim(123.1, 123.7)
ax_5.set_ylim(-0.05, 1.18)

img = ax_6.pcolormesh(
    freqiencies_corrected_3,
    dataset_snv_3.x0,
    dataset_snv_3.y0 / 1000,
)
axinset = inset_axes(
    ax_6,
    width="10%",
    height="30%",
    bbox_to_anchor=(0.68, 0, 0.95, 0.95),
    bbox_transform=ax_6.transAxes,
    loc="upper left",
)
cbar = fig.colorbar(img, cax=axinset, orientation="vertical")
cbar.set_label("Counts (kCts/s)", color="white", fontsize=10)
cbar.ax.tick_params(color="white", labelsize=10, labelcolor="white")
cbar.outline.set_edgecolor("white")

fig.text(
    0.5,
    0.03,
    "Laser Excitation Frequency w.r.t. 484.4 THz (GHz)",
    ha="center",
    fontsize=12,
)

fig.savefig(Path("../figures/fig_18_add.png"), dpi=300, bbox_inches="tight")
