
import os
from math import exp
from typing import Tuple

import matplotlib
import matplotlib.pyplot as plt
import pandas as pd

from data_processing.plotting import produce_plot_set, produce_plot, produce_plot_existing_ax

pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)



matplotlib.rcParams.update({'font.size': 15})
plt.rcParams['figure.autolayout'] = True

def create_plots_bqc_v2(input_directory: str, output_directory: str | None = None ,version: str | None = None, save_plots: bool=False) -> None:


    data_file = os.path.join(input_directory, "raw-data.pickle")
    metadata = os.path.join(input_directory, "metadata.json")


    with open(data_file, 'rb') as F:
        results: pd.DataFrame = pd.read_pickle(F)

    # with open(metadata, 'r') as M:
    #     metadata: dict = json.load(M)

    pd.set_option("display.max_columns", None)
    pd.set_option("display.max_rows", None)

    results.sort_index(inplace=True)

    results["max number of PGAs"] = results["average pga rate"] * (345600 - results["average latency to service"])
    results["average queue time"] = results["average latency to service"] - 3600

    # c = Console(width=400)
    #
    # c.print(results)


    # results = pd.concat(results)

    fig1, ax1, fig1a, ax1a, fig1b, ax1b = produce_plot_set(results, "proportion minimal service")
    fig2, ax2, fig2a, ax2a, fig2b, ax2b = produce_plot_set(results, "average queue time")
    fig3, ax3, fig3a, ax3a, fig3b, ax3b = produce_plot_set(results, "average pga rate")
    fig4, ax4, fig4a, ax4a, fig4b, ax4b = produce_plot_set(results, "average pgt utilisation")
    fig5, ax5, fig5a, ax5a, fig5b, ax5b = produce_plot_set(results, "average est pgas per schedule")
    fig6, ax6, fig6a, ax6a, fig6b, ax6b = produce_plot_set(results, "Average utilisation of link 5--255/255--5", separate_plots=True)




    


    number_of_zero_iterations = results.groupby(level=["requested packet rate", "session renewal rate"]).apply(lambda x: (x["number of sessions"]==0).sum())
    # print(number_of_zero_iterations)


    scaling = 1.1

    ax1a.set_ylabel("Proportion of sessions\nobtaining minimal service")

    # ax3a.set_title(f"{metadata['uid']}\n{metadata['uuid']}")


    fig3a.set_size_inches(scaling*6.4,scaling*4.8)

    ax2.set_ylabel("$\\bar{t}_{queue}$ (s)")
    ax2a.set_ylabel("$\\bar{t}_{queue}$ (s)")
    ax2a.ticklabel_format(style='sci', axis='y', scilimits=(0, 0))


    scheduling_interval = 3600

    expected_queue_time = [scheduling_interval * (1 - 1 / (l * scheduling_interval) + exp(-l * scheduling_interval) / (1 - exp(-l * scheduling_interval))) for l in
                        results.index.get_level_values("session renewal rate").unique()]

    ax2.plot(results.index.get_level_values("session renewal rate").unique(), expected_queue_time, ':')
    ax2a.plot(results.index.get_level_values("session renewal rate").unique(), expected_queue_time, ':')
    ax2b.plot(results.index.get_level_values("session renewal rate").unique(), expected_queue_time, ':')

    def sec2hrs(x):
        return x/3600

    def hrs2sec(x):
        return x * 3600

    # ax2a_secyax = ax2a.secondary_yaxis('right', functions=(sec2hrs, hrs2sec))
    # ax2a_secyax.set_ylabel("$\\bar{t}_{queue}$ (hours)")

    packet_rates = list(results.index.get_level_values('requested packet rate').unique())

    # c = Console(width=80000)
    # c.print(results[results["number of expired sessions"] != 0].groupby(level=["requested packet rate", "session renewal rate"]).agg(["mean", "std"]))

    ax1a.set_ylabel('Proportion of sessions obtaining\nminimal service, $p^{MS}$ (a.u.)')
    ax2a.set_ylabel("Average time spent in\ndemand queue, $\\bar t_{queue}$ (s)")

    fig1a.set_layout_engine('tight')
    ax1a: plt.Axes
    ax1a.legend(fontsize=13)
    fig1a.set_size_inches(6.4, 5)

    ax2a: plt.Axes
    ax2a.legend(fontsize=13)
    fig2a.set_size_inches(6.4, 5)

    x_values = results.index.get_level_values("session renewal rate").unique()

    for i, value in enumerate(results.index.get_level_values('requested packet rate').unique()):
        ave_pga_rate = results.loc[results.index.get_level_values('requested packet rate') == value]

        ave_pga_rate = ave_pga_rate.loc[ave_pga_rate["number of sessions"] != 0]

        ave_util = ave_pga_rate.groupby("session renewal rate")["average pgt utilisation"].agg(['mean', 'std'])


        effective_util_cap = [0.85 - ave_util["mean"][s] for s in results.index.get_level_values("session renewal rate").unique()]

        ax6a[i].plot(x_values, effective_util_cap, ':',
                     color='tab:orange', label="effective $\hat U$")
        ax6a[i].plot(x_values, len(x_values) * [0.85], ':', label='$\hat U$', color="tab:red")

        # sec_x: plt.Axes = ax6a[i].twinx()
        #
        # sec_x.plot(x_values, ave_util['mean'], ':', label='Average PGT Utilisation')


        ax6a[i].legend()

        x_scaling = 2.5
        y_scaling = 1
        fig6a.set_size_inches(x_scaling * 6.4, y_scaling * 4.8)

    # fig1.show()
    # fig1a.show()
    # fig1b.show()
    #
    # fig2.show()
    fig2a.show()
    # fig2b.show()
    # fig3a.show()
    # fig4a.show()
    # fig4b.show()
    # fig5a.show()
    fig6a.show()


    if output_directory is None:
        output_dir = input_directory
    else:
        output_dir = output_directory

    if save_plots:
        fig1.savefig(os.path.join(output_dir, f"prop-min-service_eb{f'_v{version}' if version is not None else ''}.png"), dpi=400)
        fig1a.savefig(
            os.path.join(output_dir, f"prop-min-service_eb_filtered{f'_v{version}' if version is not None else ''}.pdf"), format='pdf', bbox_inches='tight', transparent=True)
        fig1b.savefig(
            os.path.join(output_dir, f"prop-min-service_eb_filtered_expired{f'_v{version}' if version is not None else ''}.png"),
            dpi=400)
        fig2.savefig(os.path.join(output_dir, f"latency_eb{f'_v{version}' if version is not None else ''}.png"), dpi=400)
        fig2a.savefig(os.path.join(output_dir, f"latency_eb_filtered{f'_v{version}' if version is not None else ''}.pdf"), format='pdf', bbox_inches='tight', transparent=True)
        fig2b.savefig(
            os.path.join(output_dir, f"latency_eb_filtered_expired{f'_v{version}' if version is not None else ''}.png"),
            dpi=400)
        fig3a.savefig(os.path.join(output_dir, f"ave-pgt-rate{f'_v{version}' if version is not None else ''}.png"), dpi=400)
        fig4a.savefig(os.path.join(output_dir, f"ave-pgt-util{f'_v{version}' if version is not None else ''}.png"), dpi=400)
        fig5a.savefig(os.path.join(output_dir, f"ave-pgts-per-schedule{f'_v{version}' if version is not None else ''}.png"), dpi=400)
        fig6a.savefig(os.path.join(output_dir, f"ave-util-server-link{f'_v{version}' if version is not None else ''}.png"), dpi=400)


def create_plots_util_bound_comparison(input_directory1: str, input_directory2: str, output_directory: str | None = None, version: str | None = None,
                        save_plots: bool = False) -> None:
    data_file = os.path.join(input_directory1, "raw-data.pickle")
    metadata = os.path.join(input_directory1, "metadata.json")

    with open(data_file, 'rb') as F:
        results: pd.DataFrame = pd.read_pickle(F)

    # with open(metadata, 'r') as M:
    #     metadata: dict = json.load(M)

    pd.set_option("display.max_columns", None)
    pd.set_option("display.max_rows", None)

    results.sort_index(inplace=True)

    # c = Console(width=400)
    #
    # c.print(results)

    # results = pd.concat(results)

    _, _, fig1a, ax1a, _, _ = produce_plot_set(results, "proportion minimal service")
    _, _, fig2a, ax2a, _, _ = produce_plot_set(results, "Average utilisation of link 5--255/255--5")
    fig3, ax3 = produce_plot(results, "average latency to service", filter_non_zero_sessions=True)
    fig4, ax4 = produce_plot(results, "number of sessions", filter_non_zero_sessions=True)

    data_file = os.path.join(input_directory2, "raw-data.pickle")
    with open(data_file, 'rb') as F:
        results1: pd.DataFrame = pd.read_pickle(F)

    ax1a: plt.Axes

    ax1a.lines[0].label = "$\hat U = 0.41$"

    results1_pps = results1.loc[results1.index.get_level_values("requested packet rate") == 0]['proportion minimal service'].groupby(level='session renewal rate').agg(['mean', 'std'])
    results1_au5 = results1.loc[results1.index.get_level_values("requested packet rate") == 0][
        'Average utilisation of link 5--255/255--5'].groupby(level='session renewal rate').agg(['mean', 'std'])
    results1_lts = results1.loc[results1.index.get_level_values("requested packet rate") == 0][
        'average latency to service'].groupby(level='session renewal rate').agg(['mean', 'std'])
    results1_nes = results1.loc[results1.index.get_level_values("requested packet rate") == 0][
        'number of sessions'].groupby(level='session renewal rate').agg(['mean', 'std'])

    ax1a.plot(results1_pps.index, results1_pps['mean'], label='$\hat U = 0.85$',marker='.', markersize=12)
    ax1a.fill_between(results1_pps.index, list(results1_pps['mean'] + results1_pps['std']), list(results1_pps['mean'] - results1_pps['std']), alpha=0.2)

    ax1a.legend()

    ax2a.plot(results1_au5.index, results1_au5['mean'], label='$\hat U = 0.85$', marker='.', markersize=12)
    ax2a.fill_between(results1_au5.index, list(results1_au5['mean'] + results1_au5['std']),
                      list(results1_au5['mean'] - results1_au5['std']), alpha=0.2)

    ax3.plot(results1_lts.index, results1_lts['mean'], label='$\hat U = 0.85$',marker='.', markersize=12)
    ax3.fill_between(results1_lts.index, list(results1_lts['mean'] + results1_lts['std']),
                      list(results1_lts['mean'] - results1_lts['std']), alpha=0.2)

    ax4.plot(results1_nes.index, results1_nes['mean'], label='$\hat U = 0.85$', marker='.', markersize=12)
    ax4.fill_between(results1_nes.index, list(results1_nes['mean'] + results1_nes['std']),
                     list(results1_nes['mean'] - results1_nes['std']), alpha=0.2)





    fig1a.show()
    fig2a.show()
    fig3.show()
    fig4.show()



    # if output_directory is None:
    #     output_dir = input_directory1
    # else:
    #     output_dir = output_directory
    #
    # if save_plots:
    #     fig1a.savefig(
    #         os.path.join(output_dir,
    #                      f"prop-min-service_eb_filtered{f'_v{version}' if version is not None else ''}.png"),
    #         dpi=400)

def create_plots_mda_v2(*input_data_files, output_directory: str | None = None, version: str | None = None, save_plots: bool = False) -> None:

    results: pd.DataFrame | None = None
    for data_file in input_data_files:

        with open(data_file, 'rb') as F:
            if results is None:
                results: pd.DataFrame = pd.read_pickle(F)
            else:
                results: pd.DataFrame = pd.concat([results, pd.read_pickle(F)])

    # with open(metadata, 'r') as M:
    #     metadata: dict = json.load(M)

    pd.set_option("display.max_columns", None)
    pd.set_option("display.max_rows", None)

    results.sort_index(inplace=True)

    results["max number of PGAs"] = results["average pga rate"] * (2100 - results["average latency to service"])
    # results["average queue time"] = results["average latency to service"] - 300



    # c = Console(width=400)
    #
    # c.print(results)


    # results = pd.concat(results)

    fig1, ax1, fig1a, ax1a, fig1b, ax1b = produce_plot_set(results, "proportion minimal service")
    fig2, ax2, fig2a, ax2a, fig2b, ax2b = produce_plot_set(results, "average queue time")
    fig3, ax3, fig3a, ax3a, fig3b, ax3b = produce_plot_set(results, "average pga rate")
    fig4, ax4, fig4a, ax4a, fig4b, ax4b = produce_plot_set(results, "average pgt utilisation")
    fig5, ax5, fig5a, ax5a, fig5b, ax5b = produce_plot_set(results, "average est pgas per schedule", separate_plots=True)
    fig6, ax6, fig6a, ax6a, fig6b, ax6b = produce_plot_set(results, "Average utilisation of link 5--255/255--5")
    fig7, ax7, fig7a, ax7a, fig7b, ax7b = produce_plot_set(results, "average max resource utilisation", separate_plots=True)
    fig8, ax8, fig8a, ax8a, fig8b, ax8b = produce_plot_set(results, "max number of PGAs")


    fig9, (ax9, ax9a) = plt.subplots(1,2)
    produce_plot_existing_ax(results, ax9, "proportion minimal service")
    produce_plot_existing_ax(results, ax9a, "average latency to service")

    ax9.text(0.5, -0.2, f'(a)', transform=ax9.transAxes,
            size=10, weight='bold', ha='center')
    ax9a.text(0.5, -0.2, f'(b)', transform=ax9a.transAxes,
             size=10, weight='bold', ha='center')

    ax9a: plt.Axes
    ax9a.legend(loc='upper left')

    fig9: plt.Figure

    scaling = 21/2.54/6.4

    fig9.set_size_inches(7,  6)
    fig9.set_layout_engine('tight')

    ax9.set_ylabel('Proportion of sessions obtaining\nminimal service, $p^{MS}$ (a.u.)')
    ax9a.set_ylabel(r"$\bar t_{queue}$ (s)")


    ax5a: plt.Axes
    ax2a: plt.Axes
    ax3a: plt.Axes
    fig3a: plt.Figure


    number_of_zero_iterations = results.groupby(level=["requested packet rate", "session renewal rate"]).apply(lambda x: (x["number of sessions"]==0).sum())
    print(number_of_zero_iterations)

    scheduling_interval = 300

    expected_queue_time = [scheduling_interval * (
                1 - 1 / (l * scheduling_interval) + exp(-l * scheduling_interval) / (1 - exp(-l * scheduling_interval)))
                        for l in
                        results.index.get_level_values("session renewal rate").unique()]

    ax2.plot(results.index.get_level_values("session renewal rate").unique(), expected_queue_time, ':')
    ax2a.plot(results.index.get_level_values("session renewal rate").unique(), expected_queue_time, ':')
    ax2b.plot(results.index.get_level_values("session renewal rate").unique(), expected_queue_time, ':')

    ax1a.set_ylabel('Proportion of sessions obtaining\nminimal service, $p^{MS}$ (a.u.)')
    ax2a.set_ylabel("Average time spent in\ndemand queue, $\\bar t_{queue}$ (s)")

    fig1a.set_layout_engine('tight')
    ax1a: plt.Axes
    ax1a.legend(fontsize=13)
    fig1a.set_size_inches(6.4, 5)

    ax2a: plt.Axes
    ax2a.legend(fontsize=13)
    fig2a.set_size_inches(6.4, 5)


    packet_rates = list(results.index.get_level_values('requested packet rate').unique())

    for i, value in enumerate(results.index.get_level_values('requested packet rate').unique()):
        try:
            ave_pga_rate = results.loc[results.index.get_level_values('requested packet rate') == value]

            ave_pga_rate = ave_pga_rate.loc[ave_pga_rate["number of sessions"] != 0]

            ave_util = ave_pga_rate.groupby("session renewal rate")["average pgt utilisation"].agg(['mean', 'std'])

            x_values = results.index.get_level_values("session renewal rate").unique()


            ave_util = ave_pga_rate.groupby("session renewal rate")["average pgt utilisation"].agg(['mean', 'std'])
            ave_pga_rate = ave_pga_rate.groupby("session renewal rate")["average pga rate"].agg(['mean', 'std'])

            effective_pga_cap = [1500 - scheduling_interval * ave_pga_rate['mean'][s] for s in
                                 results.index.get_level_values("session renewal rate").unique()]

            effective_util_cap = [0.85 - ave_util["mean"][s] for s in results.index.get_level_values("session renewal rate").unique()]

            ax5a[i].plot(results.index.get_level_values("session renewal rate").unique(), effective_pga_cap, ':',
                         color=ax5a[i].get_lines()[0].get_color(), label="effective PGA cap")
            ax5a[i].plot(x_values, [1500]*len(x_values), ':', color='tab:red', label = 'PGA cap')
            ax5a[i].legend()

            ax7a[i].plot(x_values, effective_util_cap, ':',
                         color='tab:orange', label="effective $\hat U$")
            ax7a[i].plot(x_values, len(x_values) * [0.85], ':', label='$\hat U$', color="tab:red")
        except:
            pass

    x_scaling = 2.5
    y_scaling = 1
    fig5a.set_size_inches(x_scaling * 6.4, y_scaling * 4.8)
    fig7a.set_size_inches(x_scaling * 6.4, y_scaling * 4.8)

    ax8a: plt.Axes
    ax8a.legend(ncol=4)

    # fig1.show()
    fig1a.show()
    # # fig1b.show()
    # #
    # # fig2.show()
    fig2a.show()
    # fig2b.show()
    fig3a.show()
    # fig4a.show()
    # fig4b.show()
    # fig5a.show()
    # fig6a.show()
    # fig7a.show()
    fig8a.show()
    # fig9.show()

    print(results.groupby(level=["requested packet rate", "session renewal rate"])['proportion minimal service'].agg(['mean', 'var']))



    if save_plots:

        if output_directory is None:
            output_dir = os.path.split(input_data_files[0])[0]
        else:
            output_dir = output_directory

        fig1.savefig(os.path.join(output_dir, f"prop-min-service_eb{f'_v{version}' if version is not None else ''}.png"), dpi=400)
        fig1a.savefig(
            os.path.join(output_dir, f"prop-min-service_eb_filtered{f'_v{version}' if version is not None else ''}.pdf"), format='pdf', bbox_inches='tight', transparent=True)
        fig1b.savefig(
            os.path.join(output_dir, f"prop-min-service_eb_filtered_expired{f'_v{version}' if version is not None else ''}.png"),
            dpi=400)
        fig2.savefig(os.path.join(output_dir, f"latency_eb{f'_v{version}' if version is not None else ''}.png"), dpi=400)
        fig2a.savefig(os.path.join(output_dir, f"latency_eb_filtered{f'_v{version}' if version is not None else ''}.pdf"), format='pdf', bbox_inches='tight', transparent=True)
        fig2b.savefig(
            os.path.join(output_dir, f"latency_eb_filtered_expired{f'_v{version}' if version is not None else ''}.png"),
            dpi=400)
        fig3a.savefig(os.path.join(output_dir, f"ave-pgt-rate{f'_v{version}' if version is not None else ''}.png"), dpi=400)
        fig4a.savefig(os.path.join(output_dir, f"ave-pgt-util{f'_v{version}' if version is not None else ''}.png"), dpi=400)
        fig5a.savefig(os.path.join(output_dir, f"ave-pgts-per-schedule{f'_v{version}' if version is not None else ''}.png"), dpi=400)
        # fig6a.savefig(os.path.join(output_dir, f"ave-util-server-link{f'_v{version}' if version is not None else ''}.png"), dpi=400)


def create_plots_mda_extended(input_directory1, input_directory2, output_directory: str | None = None, version: str | None = None,
                        save_plots: bool = False) -> None:
    data_file = os.path.join(input_directory1, "raw-data.pickle")
    data_file2 = os.path.join(input_directory2, "raw-data.pickle")

    with open(data_file, 'rb') as F:
        results: pd.DataFrame = pd.read_pickle(F)

    with open(data_file2, 'rb') as F:
        results: pd.DataFrame = pd.concat([results, pd.read_pickle(F)])

    # with open(metadata, 'r') as M:
    #     metadata: dict = json.load(M)

    pd.set_option("display.max_columns", None)
    pd.set_option("display.max_rows", None)

    results.sort_index(inplace=True)

    # c = Console(width=400)
    #
    # c.print(results)

    # results = pd.concat(results)

    fig1, ax1, fig1a, ax1a, fig1b, ax1b = produce_plot_set(results, "proportion minimal service")
    fig2, ax2, fig2a, ax2a, fig2b, ax2b = produce_plot_set(results, "average latency to service")
    fig3, ax3, fig3a, ax3a, fig3b, ax3b = produce_plot_set(results, "average pga rate")
    fig4, ax4, fig4a, ax4a, fig4b, ax4b = produce_plot_set(results, "average pgt utilisation")
    fig5, ax5, fig5a, ax5a, fig5b, ax5b = produce_plot_set(results, "average est pgas per schedule",
                                                           separate_plots=True)
    fig6, ax6, fig6a, ax6a, fig6b, ax6b = produce_plot_set(results, "Average utilisation of link 5--255/255--5")
    fig7, ax7, fig7a, ax7a, fig7b, ax7b = produce_plot_set(results, "average max resource utilisation",
                                                           separate_plots=True)

    ax5a: plt.Axes
    ax2a: plt.Axes
    ax3a: plt.Axes
    fig3a: plt.Figure

    number_of_zero_iterations = results.groupby(level=["requested packet rate", "session renewal rate"]).apply(
        lambda x: (x["number of sessions"] == 0).sum())
    print(number_of_zero_iterations)

    scheduling_interval = 300

    ax1a.set_ylabel("Proportion of sessions\nobtaining minimal service")

    packet_rates = list(results.index.get_level_values('requested packet rate').unique())

    for i, value in enumerate(results.index.get_level_values('requested packet rate').unique()):
        try:
            ave_pga_rate = results.loc[results.index.get_level_values('requested packet rate') == value]

            ave_pga_rate = ave_pga_rate.loc[ave_pga_rate["number of sessions"] != 0]

            ave_util = ave_pga_rate.groupby("session renewal rate")["average pgt utilisation"].agg(['mean', 'std'])

            x_values = results.index.get_level_values("session renewal rate").unique()

            ave_util = ave_pga_rate.groupby("session renewal rate")["average pgt utilisation"].agg(['mean', 'std'])
            ave_pga_rate = ave_pga_rate.groupby("session renewal rate")["average pga rate"].agg(['mean', 'std'])

            effective_pga_cap = [1500 - scheduling_interval * ave_pga_rate['mean'][s] for s in
                                 results.index.get_level_values("session renewal rate").unique()]

            effective_util_cap = [0.85 - ave_util["mean"][s] for s in
                                  results.index.get_level_values("session renewal rate").unique()]

            ax5a[i].plot(results.index.get_level_values("session renewal rate").unique(), effective_pga_cap, ':',
                         color=ax5a[i].get_lines()[0].get_color(), label="effective PGA cap")
            ax5a[i].plot(x_values, [1500] * len(x_values), ':', color='tab:red', label='PGA cap')
            ax5a[i].legend()

            ax7a[i].plot(x_values, effective_util_cap, ':',
                         color='tab:orange', label="effective $\hat U$")
            ax7a[i].plot(x_values, len(x_values) * [0.85], ':', label='$\hat U$', color="tab:red")
        except:
            pass

    x_scaling = 2.5
    y_scaling = 1
    fig5a.set_size_inches(x_scaling * 6.4, y_scaling * 4.8)
    fig7a.set_size_inches(x_scaling * 6.4, y_scaling * 4.8)

    # fig1.show()
    fig1a.show()
    # fig1b.show()
    #
    # fig2.show()
    fig2a.show()
    # fig2b.show()
    fig3a.show()
    fig4a.show()
    fig4b.show()
    fig5a.show()
    fig6a.show()
    fig7a.show()

    print(results.groupby(level=["requested packet rate", "session renewal rate"])['proportion minimal service'].agg(
        ['mean', 'var']))

    if output_directory is None:
        output_dir = input_directory1
    else:
        output_dir = output_directory

    if save_plots:
        fig1.savefig(
            os.path.join(output_dir, f"prop-min-service_eb{f'_v{version}' if version is not None else ''}.png"),
            dpi=400)
        fig1a.savefig(
            os.path.join(output_dir,
                         f"prop-min-service_eb_filtered{f'_v{version}' if version is not None else ''}.png"),
            dpi=400)
        fig1b.savefig(
            os.path.join(output_dir,
                         f"prop-min-service_eb_filtered_expired{f'_v{version}' if version is not None else ''}.png"),
            dpi=400)
        fig2.savefig(os.path.join(output_dir, f"latency_eb{f'_v{version}' if version is not None else ''}.png"),
                     dpi=400)
        fig2a.savefig(
            os.path.join(output_dir, f"latency_eb_filtered{f'_v{version}' if version is not None else ''}.png"),
            dpi=400)
        fig2b.savefig(
            os.path.join(output_dir, f"latency_eb_filtered_expired{f'_v{version}' if version is not None else ''}.png"),
            dpi=400)
        fig3a.savefig(os.path.join(output_dir, f"ave-pgt-rate{f'_v{version}' if version is not None else ''}.png"),
                      dpi=400)
        fig4a.savefig(os.path.join(output_dir, f"ave-pgt-util{f'_v{version}' if version is not None else ''}.png"),
                      dpi=400)
        fig5a.savefig(
            os.path.join(output_dir, f"ave-pgts-per-schedule{f'_v{version}' if version is not None else ''}.png"),
            dpi=400)
        # fig6a.savefig(os.path.join(output_dir, f"ave-util-server-link{f'_v{version}' if version is not None else ''}.png"), dpi=400)


def create_plots_cka_extended(input_directory: str, input_directory2: str, output_directory: str | None = None, version: str | None = None,
                        save_plots: bool = False) -> None:
    data_file = os.path.join(input_directory, "raw-data.pickle")
    data_file1 = os.path.join(input_directory2, "raw-data.pickle")

    with open(data_file, 'rb') as F:
        results: pd.DataFrame = pd.read_pickle(F)

    with open(data_file1, 'rb') as F:
        results: pd.DataFrame = pd.concat([results, pd.read_pickle(F)])

    # with open(metadata, 'r') as M:
    #     metadata: dict = json.load(M)

    pd.set_option("display.max_columns", None)
    pd.set_option("display.max_rows", None)

    results.sort_index(inplace=True)

    # c = Console(width=400)
    #
    # c.print(results)

    # results = pd.concat(results)

    fig1, ax1, fig1a, ax1a, fig1b, ax1b = produce_plot_set(results, "proportion minimal service")
    fig2, ax2, fig2a, ax2a, fig2b, ax2b = produce_plot_set(results, "average latency to service")
    fig3, ax3, fig3a, ax3a, fig3b, ax3b = produce_plot_set(results, "average pga rate")
    fig4, ax4, fig4a, ax4a, fig4b, ax4b = produce_plot_set(results, "average pgt utilisation")
    fig5, ax5, fig5a, ax5a, fig5b, ax5b = produce_plot_set(results, "average est pgas per schedule")
    fig6, ax6, fig6a, ax6a, fig6b, ax6b = produce_plot_set(results, "Average utilisation of link 5--255/255--5",
                                                           separate_plots=True)

    number_of_zero_iterations = results.groupby(level=["requested packet rate", "session renewal rate"]).apply(
        lambda x: (x["number of sessions"] == 0).sum())
    # print(number_of_zero_iterations)

    scaling = 1.1

    ax1a.set_ylabel("Proportion of sessions\nobtaining minimal service")

    # ax3a.set_title(f"{metadata['uid']}\n{metadata['uuid']}")

    fig3a.set_size_inches(scaling * 6.4, scaling * 4.8)

    ax2.set_ylabel("$\\bar{t}_{queue}$ (s)")
    ax2a.set_ylabel("$\\bar{t}_{queue}$ (s)")
    ax2a.ticklabel_format(style='sci', axis='y', scilimits=(0, 0))

    scheduling_interval = 3600

    expected_latency = [scheduling_interval + scheduling_interval * (
                1 - 1 / (l * scheduling_interval) + exp(-l * scheduling_interval) / (1 - exp(-l * scheduling_interval)))
                        for l in
                        results.index.get_level_values("session renewal rate").unique()]

    ax2.plot(results.index.get_level_values("session renewal rate").unique(), expected_latency, ':')
    ax2a.plot(results.index.get_level_values("session renewal rate").unique(), expected_latency, ':')
    ax2b.plot(results.index.get_level_values("session renewal rate").unique(), expected_latency, ':')

    def sec2hrs(x):
        return x / 3600

    def hrs2sec(x):
        return x * 3600

    # ax2a_secyax = ax2a.secondary_yaxis('right', functions=(sec2hrs, hrs2sec))
    # ax2a_secyax.set_ylabel("$\\bar{t}_{queue}$ (hours)")

    packet_rates = list(results.index.get_level_values('requested packet rate').unique())

    # c = Console(width=80000)
    # c.print(results[results["number of expired sessions"] != 0].groupby(level=["requested packet rate", "session renewal rate"]).agg(["mean", "std"]))

    x_values = results.index.get_level_values("session renewal rate").unique()

    for i, value in enumerate(results.index.get_level_values('requested packet rate').unique()):
        ave_pga_rate = results.loc[results.index.get_level_values('requested packet rate') == value]

        ave_pga_rate = ave_pga_rate.loc[ave_pga_rate["number of sessions"] != 0]

        ave_util = ave_pga_rate.groupby("session renewal rate")["average pgt utilisation"].agg(['mean', 'std'])

        effective_util_cap = [0.85 - ave_util["mean"][s] for s in
                              results.index.get_level_values("session renewal rate").unique()]

        ax6a[i].plot(x_values, effective_util_cap, ':',
                     color='tab:orange', label="effective $\hat U$")
        ax6a[i].plot(x_values, len(x_values) * [0.85], ':', label='$\hat U$', color="tab:red")

        # sec_x: plt.Axes = ax6a[i].twinx()
        #
        # sec_x.plot(x_values, ave_util['mean'], ':', label='Average PGT Utilisation')

        ax6a[i].legend()

        x_scaling = 2.5
        y_scaling = 1
        fig6a.set_size_inches(x_scaling * 6.4, y_scaling * 4.8)

    # fig1.show()
    fig1a.show()
    # fig1b.show()
    #
    # fig2.show()
    fig2a.show()
    # fig2b.show()
    fig3a.show()
    fig4a.show()
    # fig4b.show()
    # fig5a.show()
    fig6a.show()

    if output_directory is None:
        output_dir = input_directory
    else:
        output_dir = output_directory

    if save_plots:
        fig1.savefig(
            os.path.join(output_dir, f"prop-min-service_eb{f'_v{version}' if version is not None else ''}.png"),
            dpi=400)
        fig1a.savefig(
            os.path.join(output_dir,
                         f"prop-min-service_eb_filtered{f'_v{version}' if version is not None else ''}.png"),
            dpi=400)
        fig1b.savefig(
            os.path.join(output_dir,
                         f"prop-min-service_eb_filtered_expired{f'_v{version}' if version is not None else ''}.png"),
            dpi=400)
        fig2.savefig(os.path.join(output_dir, f"queue_eb{f'_v{version}' if version is not None else ''}.png"),
                     dpi=400)
        fig2a.savefig(
            os.path.join(output_dir, f"queue_eb_filtered{f'_v{version}' if version is not None else ''}.png"),
            dpi=400)
        fig2b.savefig(
            os.path.join(output_dir, f"queue_eb_filtered_expired{f'_v{version}' if version is not None else ''}.png"),
            dpi=400)
        fig3a.savefig(os.path.join(output_dir, f"ave-pgt-rate{f'_v{version}' if version is not None else ''}.png"),
                      dpi=400)
        fig4a.savefig(os.path.join(output_dir, f"ave-pgt-util{f'_v{version}' if version is not None else ''}.png"),
                      dpi=400)
        fig5a.savefig(
            os.path.join(output_dir, f"ave-pgts-per-schedule{f'_v{version}' if version is not None else ''}.png"),
            dpi=400)
        fig6a.savefig(
            os.path.join(output_dir, f"ave-util-server-link{f'_v{version}' if version is not None else ''}.png"),
            dpi=400)


def compare_dummy_to_actual_scheduling(input1: str, input2: str):

    data_file = os.path.join(input1, 'raw-data.pickle')

    with open(data_file, 'rb') as F:
        results: pd.DataFrame = pd.read_pickle(F)

    data_file_2 = os.path.join(input2, 'raw-data.pickle')
    with open(data_file_2, 'rb') as F:
        results2: pd.DataFrame = pd.read_pickle(F)




    fig1, ax1, fig1a, ax1a, fig1b, ax1b = produce_plot_set(results, "proportion minimal service", separate_plots=True)

    # fig2, ax2, fig2a, ax2a, fig2b, ax2b = produce_plot_set(results2, "proportion minimal service", separate_plots=True)

    ax1a: Tuple[plt.Axes]
    ax2a: Tuple[plt.Axes]


    for i, value in enumerate(results2.index.get_level_values('requested packet rate').unique()):

        try:
            _results = results2.loc[results2.index.get_level_values('requested packet rate') == value]
            _results = _results.loc[_results["number of expired sessions"] != 0]

            _results = _results.groupby("session renewal rate")["proportion minimal service"].agg(['mean', 'std'])

            results2: pd.DataFrame

            ax1a[i]: plt.Axes

            # _ax[i].errorbar(_results.index, _results['mean'], marker='.', yerr=_results['std'], label=f"$R = {value}$ Hz", capsize=5, markersize=12)
            ax1a[i].plot(_results.index, _results["mean"], marker='.', label=f"$R = {value}$ Hz (Dummy)")
            ax1a[i].fill_between(_results.index, list(_results['mean'] - _results['std']),
                                list(_results['mean'] + _results['std']), alpha=0.2)

            ax1a[i].set_xlabel("Session renewal rate $\lambda$ $(s^{-1})$")
            ax1a[i].set_ylabel(f'proportion minimal service')


            ax1a[i].legend()


        except:
            pass

    x_scaling = 2.5
    y_scaling = 1
    fig1a.set_size_inches(x_scaling * 6.4, y_scaling * 4.8)

    fig1a.show()
    # fig2a.show()

def create_quad_plot(input_cka: str, input_mda: str):

    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, layout="tight")

    with open(input_cka, 'rb') as F:
        results_cka: pd.DataFrame = pd.read_pickle(F)

    with open(input_mda, 'rb') as F:
        results_mda: pd.DataFrame = pd.read_pickle(F)


    produce_plot_existing_ax(results_mda, ax1, "proportion minimal service")
    produce_plot_existing_ax(results_mda, ax2, "average latency to service")

    produce_plot_existing_ax(results_cka, ax3, "proportion minimal service")
    produce_plot_existing_ax(results_cka, ax4, "average latency to service")

    ax2.ticklabel_format(style='sci', axis='y', scilimits=(0, 0))
    ax4.ticklabel_format(style='sci', axis='y', scilimits=(0, 0))


    ax2.set_ylabel('$\\bar t_{queue}$ (s)')
    ax4.set_ylabel('$\\bar t_{queue}$ (s)')

    ax1.set_ylabel("Proportion of sessions\nobtaining minimal service")
    ax3.set_ylabel("Proportion of sessions\nobtaining minimal service")

    fig.set_size_inches(10, 7.5)

    ax1.set_title("A")
    ax2.set_title("B")
    ax3.set_title("C")
    ax4.set_title("D")







    fig: plt.Figure

    fig.show()

    fig.savefig('/home/tbeauchamp/quad-plot.png', dpi=150)

    pass

def create_dual_plot(input_cka: str, input_mda: str):

    fig, (ax1, ax2) = plt.subplots(1, 2, layout="tight")

    with open(input_cka, 'rb') as F:
        results_cka: pd.DataFrame = pd.read_pickle(F)

    with open(input_mda, 'rb') as F:
        results_mda: pd.DataFrame = pd.read_pickle(F)


    produce_plot_existing_ax(results_mda, ax1, "proportion minimal service")
    # produce_plot_existing_ax(results_mda, ax2, "average latency to service")

    produce_plot_existing_ax(results_cka, ax2, "proportion minimal service")
    # produce_plot_existing_ax(results_cka, ax4, "average latency to service")

    ax2.ticklabel_format(style='sci', axis='y', scilimits=(0, 0))
    # ax4.ticklabel_format(style='sci', axis='y', scilimits=(0, 0))


    ax2.set_ylabel('$\\bar t_{queue}$ (s)')
    # ax4.set_ylabel('$\\bar t_{queue}$ (s)')

    ax1.set_ylabel("Proportion of sessions\nobtaining minimal service")
    ax2.set_ylabel("Proportion of sessions\nobtaining minimal service")

    fig.set_size_inches(8, 4)

    plt.rcParams['text.usetex'] = True


    ax1.set_title(r"Peer-to-peer MDA")
    ax2.set_title(r"Client-Server CKA")

    # ax3.set_title("C")
    # ax4.set_title("D")







    fig: plt.Figure

    fig.show()

    fig.savefig('/Users/tbeauchamp/PycharmProjects/quantum-network-architecture-simulation-results/dual-plot-titled-low-res.png', dpi=150)
    fig.savefig(
        '/Users/tbeauchamp/PycharmProjects/quantum-network-architecture-simulation-results/dual-plot-titled-high-res-no-bg.png',
        dpi=800, transparent=True)
    pass

def create_rel_queuing_times_plot(input_cka: str, input_mda: str):
    with open(input_cka, 'rb') as F:
        results_cka: pd.DataFrame = pd.read_pickle(F)

    with open(input_mda, 'rb') as F:
        results_mda: pd.DataFrame = pd.read_pickle(F)

    fig, ax = plt.subplots()
    ax: plt.Axes

    results_cka["relative queue time"] = (results_cka["average latency to service"] - 3600)/345600
    results_mda["relative queue time"] = (results_mda["average latency to service"] - 300)/2100

    ax1 = ax.twiny()

    produce_plot_existing_ax(results_cka, ax, "relative queue time")
    produce_plot_existing_ax(results_mda, ax1, "relative queue time")

    for ln in ax1.lines:
        ln: plt.Line2D
        ln.set_linestyle(':')

    ax1.set_xlabel('MDA '+ax1.get_xlabel())
    ax.set_xlabel('CKA ' + ax.get_xlabel())

    lbls = ['MDA ' + ln.get_label() for ln in ax1.lines] + ['CKA ' + ln.get_label() for ln in ax.lines]
    lines = ax1.lines + ax.lines

    ax.legend(lines, lbls, loc='center left', bbox_to_anchor=(1,0.5), fontsize=10)
    ax1.get_legend().set_visible(False)

    fig.set_size_inches(6.4*1.25, 4.8)

    # ax.lines[0].set_label("")
    # ax.

    fig.show()








if __name__ == '__main__':
    # create_plots_bqc_v2('/home/tbeauchamp/quantum-network-arx-evaluation/many-user-cs-bqc/20240408-1130', save_plots=True, version='2')
    # create_plots_util_bound_comparison('/home/tbeauchamp/quantum-network-arx-evaluation/many-user-cs-cka/20240416+17/', '/home/tbeauchamp/quantum-network-arx-evaluation/many-user-cs-bqc/20240408-1130')
    # create_plots_mda_v2('/home/tbeauchamp/quantum-network-arx-evaluation/many-user-p2p-qkd/20240327-0920')
    # create_plots_mda_v2('/home/tbeauchamp/quantum-network-arx-evaluation/many-user-p2p-qkd/20240419-1753/raw-data.pickle', save_plots=False, version='4')
    # create_plots_mda_v2('/home/tbeauchamp/quantum-network-architecture-simulation/quantum-network-arx-evaluation/many-user-p2p-qkd/20240424-1615')
    # compare_dummy_to_actual_scheduling('/home/tbeauchamp/quantum-network-arx-evaluation/many-user-p2p-qkd/20240419-1753', '/home/tbeauchamp/quantum-network-arx-evaluation/many-user-p2p-qkd/20240424-1755')
    # compare_dummy_to_actual_scheduling(
    #     '/home/tbeauchamp/quantum-network-arx-evaluation/many-user-p2p-qkd/20240419-1753',
    #     '/home/tbeauchamp/quantum-network-arx-evaluation/many-user-p2p-mda/20240429-1008'
    # )
    # create_plots_mda_v2('/home/tbeauchamp/quantum-network-arx-evaluation/many-user-p2p-mda/20240425-1350')
    # create_plots_mda_v2('/home/tbeauchamp/quantum-network-arx-evaluation/many-user-p2p-mda/20240425-1651')
    # create_plots_mda_v2('/home/tbeauchamp/quantum-network-arx-evaluation/many-user-p2p-mda/20240430-1618')  # MDA with max duration 3000
    # create_plots_mda_v2('/home/tbeauchamp/quantum-network-arx-evaluation/many-user-p2p-mda/20240501-1155')  # MDA with max duration 2400
    # create_plots_mda_v2('/home/tbeauchamp/quantum-network-arx-evaluation/many-user-p2p-mda/20240501-1156')  # MDA with max duration 1800

    # create_plots_mda_v2('/home/tbeauchamp/quantum-network-arx-evaluation/many-user-p2p-mda/20240501-1354')  # MDA with SI 600s and PGA cap 2000
    # create_plots_mda_v2('/home/tbeauchamp/quantum-network-arx-evaluation/many-user-p2p-mda/20240501-1355')  # MDA with SI 150s and PGA cap 1000

    # create_plots_mda_v2('/home/tbeauchamp/quantum-network-arx-evaluation/many-user-p2p-mda/20240501-1524')  # MDA with succ prob of 0.0001
    # create_plots_mda_v2('/home/tbeauchamp/quantum-network-arx-evaluation/many-user-p2p-mda/20240501-1525')  # MDA with succ prob of 0.0002
    # create_plots_mda_v2('/home/tbeauchamp/quantum-network-arx-evaluation/many-user-p2p-mda/20240501-1526')  # MDA with succ prob of 0.000025

    # create_plots_mda_v2('/home/tbeauchamp/quantum-network-arx-evaluation/many-user-p2p-mda/20240501-1645')  # MDA with ppacket 0.3
    # create_plots_mda_v2('/home/tbeauchamp/quantum-network-arx-evaluation/many-user-p2p-mda/20240501-1646')  # MDA with ppacket 0.1



    # create_plots_mda_v2('/home/tbeauchamp/quantum-network-arx-evaluation/many-user-p2p-mda/20240502-1022/') # Increased PGA Cap to 2000 w/ SI of 300

    # create_plots_mda_v2('/home/tbeauchamp/quantum-network-arx-evaluation/many-user-p2p-mda/20240502-1104/') # Extended lambda range
    # create_plots_mda_extended('/home/tbeauchamp/quantum-network-arx-evaluation/many-user-p2p-mda/20240429-1008/', '/home/tbeauchamp/quantum-network-arx-evaluation/many-user-p2p-mda/20240502-1104/')

    # create_plots_cka_extended('/home/tbeauchamp/quantum-network-arx-evaluation/many-user-cs-bqc/20240408-1130', '/home/tbeauchamp/quantum-network-arx-evaluation/many-user-cs-cka/20240502-1105')

    # create_plots_mda_v2('/Users/tbeauchamp/PycharmProjects/quantum-network-architecture-simulation-results/many-user-p2p-mda/20240506-0945/raw-data.pickle', '/Users/tbeauchamp/PycharmProjects/quantum-network-architecture-simulation-results/many-user-p2p-qkd/20240419-1753/raw-data.pickle')

    # create_quad_plot(input_mda='/home/tbeauchamp/quantum-network-arx-evaluation/many-user-p2p-qkd/20240419-1753/raw-data.pickle', input_cka='/home/tbeauchamp/quantum-network-arx-evaluation/many-user-cs-bqc/20240408-1130/raw-data.pickle')
    # create_dual_plot(
    #     input_mda='/home/tbeauchamp/quantum-network-arx-evaluation/many-user-p2p-qkd/20240419-1753/raw-data.pickle',
    #     input_cka='/home/tbeauchamp/quantum-network-arx-evaluation/many-user-cs-bqc/20240408-1130/raw-data.pickle')
    # create_dual_plot(
    #     input_cka='/Users/tbeauchamp/PycharmProjects/quantum-network-architecture-simulation-results/many-user-cs-bqc/20240408-1130/raw-data.pickle',
    #     input_mda='/Users/tbeauchamp/PycharmProjects/quantum-network-architecture-simulation-results/many-user-p2p-qkd/20240419-1753/raw-data.pickle'
    #                  )

    # create_rel_queuing_times_plot(
    #     input_cka='/Users/tbeauchamp/PycharmProjects/quantum-network-architecture-simulation-results/many-user-cs-bqc/20240408-1130/raw-data.pickle',
    #     input_mda='/Users/tbeauchamp/PycharmProjects/quantum-network-architecture-simulation-results/many-user-p2p-qkd/20240419-1753/raw-data.pickle'
    #                  )

    # create_plots_mda_v2('/Users/thomasbeauchamp/PycharmProjects/quantum-network-architecture-simulation-results/many-user-p2p-mda/20240626-1116/raw-data.pickle')

    # create_plots_mda_v2('/Users/thomasbeauchamp/PycharmProjects/quantum-network-architecture-simulation-results/many-user-p2p-mda/20240709-1100/raw-data.pickle', save_plots=True) # FINAL DATA FOR MDA
    create_plots_bqc_v2('/home/tbeauchamp/quantum-network-arx-evaluation/many-user-cs-cka/20240711-1000/raw-data.pickle', save_plots=True) # FINAL DATA FOR CKA

