import json
import os
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import List, Optional
from argparse import ArgumentParser
import dacite
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import math
from collections import defaultdict

@dataclass
class DataPoint:
    prog1_nocrit_makespan: float
    prog2_nocrit_makespan: float
    prog1_nocrit_succ_prob: float
    prog2_nocrit_succ_prob: float 
    prog1_withcrit_makespan: float
    prog2_withcrit_makespan: float
    prog1_withcrit_succ_prob: float
    prog2_withcrit_succ_prob: float 
    prog_size: float
    param_name: str  # Name of param being varied
    param_value: float  # Value of the varied param
   
    prog1_nocrit_makespan_sd: Optional[float] = None
    prog2_nocrit_makespan_sd: Optional[float] = None
    prog1_withcrit_makespan_sd: Optional[float] = None
    prog2_withcrit_makespan_sd: Optional[float] = None
    prog1_nocrit_succ_prob_sd: Optional[float] = None
    prog2_nocrit_succ_prob_sd: Optional[float] = None
    prog1_withcrit_succ_prob_sd: Optional[float] = None
    prog2_withcrit_succ_prob_sd: Optional[float] = None


@dataclass
class DataMeta:
    timestamp: str
    sim_duration: float
    hardware: str
    qia_sga: float
    prog_sizes: List[int]
    num_iterations: List[int]
    num_trials: float
    num_clients: float
    linear: bool
    cc: float
    t1: float
    t2: float
    single_gate_dur: float
    two_gate_dur: float
    all_gate_dur: float
    single_gate_fid: float
    two_gate_fid: float
    all_gate_fid: float
    qnos_instr_proc_time: float
    host_instr_time: float
    host_peer_latency: float
    internal_sched_latency: float
    client_num_qubits: float
    server_num_qubits: float
    use_netschedule: bool
    bin_length: float
    param_name: str  # The parameter being varied
    link_duration: float
    link_fid: float
    seed: float

@dataclass
class Data:
    meta: DataMeta
    data_points: List[DataPoint]

def relative_to_cwd(file: str) -> str:
    return os.path.join(os.path.dirname(__file__), file)

def create_pdf(filename: str):
    output_dir = relative_to_cwd("plots")
    Path(output_dir).mkdir(parents=True, exist_ok=True)
    output_path = os.path.join(output_dir, f"{filename}.pdf")
    plt.savefig(output_path,format="pdf", transparent=True, dpi=1000)
    print(f"plot written to {output_path}")

def create_png(filename: str):
    output_dir = relative_to_cwd("plots")
    Path(output_dir).mkdir(parents=True, exist_ok=True)
    output_path = os.path.join(output_dir, f"{filename}.png")
    plt.savefig(output_path, transparent=True, dpi=1000)
    print(f"plot written to {output_path}")

def get_vals(data: Data):
    nocrit_makespan_size_dp_map = dict()
    nocrit_succprob_size_dp_map = dict()
    withcrit_makespan_size_dp_map = dict()
    withcrit_succprob_size_dp_map = dict()
    
    nocrit_makespan_sd_size_dp_map = dict()
    nocrit_succprob_sd_size_dp_map = dict()
    withcrit_makespan_sd_size_dp_map = dict()
    withcrit_succprob_sd_size_dp_map = dict()

    x_val_size_dp_map = dict()

    filter_out = []

    meta = data.meta
    datapoints = data.data_points

    sizes = meta.prog_sizes
    for size in sizes:
        dps = [dp for dp in datapoints if dp.prog_size == size]
        
        # Nocrit makespan and success probability
        nocrit_makespan_size_dp_map[size] = [[dp.prog1_nocrit_makespan for dp in dps if dp.param_value not in filter_out], 
                                             [dp.prog2_nocrit_makespan for dp in dps if dp.param_value not in filter_out]]
        nocrit_succprob_size_dp_map[size] = [[dp.prog1_nocrit_succ_prob for dp in dps if dp.param_value not in filter_out], 
                                             [dp.prog2_nocrit_succ_prob for dp in dps if dp.param_value not in filter_out]]

        # Withcrit makespan and success probability
        withcrit_makespan_size_dp_map[size] = [[dp.prog1_withcrit_makespan for dp in dps if dp.param_value not in filter_out], 
                                               [dp.prog2_withcrit_makespan for dp in dps if dp.param_value not in filter_out]]
        withcrit_succprob_size_dp_map[size] = [[dp.prog1_withcrit_succ_prob for dp in dps if dp.param_value not in filter_out], 
                                               [dp.prog2_withcrit_succ_prob for dp in dps if dp.param_value not in filter_out]]

        # Standard deviations for nocrit
        nocrit_makespan_sd_size_dp_map[size] = [[dp.prog1_nocrit_makespan_sd for dp in dps if dp.param_value not in filter_out],
                                                [dp.prog2_nocrit_makespan_sd for dp in dps if dp.param_value not in filter_out]]
        nocrit_succprob_sd_size_dp_map[size] = [[dp.prog1_nocrit_succ_prob_sd for dp in dps if dp.param_value not in filter_out], 
                                                 [dp.prog2_nocrit_succ_prob_sd for dp in dps if dp.param_value not in filter_out]]

        # Standard deviations for withcrit
        withcrit_makespan_sd_size_dp_map[size] = [[dp.prog1_withcrit_makespan_sd for dp in dps if dp.param_value not in filter_out],
                                                  [dp.prog2_withcrit_makespan_sd for dp in dps if dp.param_value not in filter_out]]
        withcrit_succprob_sd_size_dp_map[size] = [[dp.prog1_withcrit_succ_prob_sd for dp in dps if dp.param_value not in filter_out], 
                                                  [dp.prog2_withcrit_succ_prob_sd for dp in dps if dp.param_value not in filter_out]]

        # Param values
        x_val_size_dp_map[size] = [dp.param_value for dp in dps if dp.param_value not in filter_out]

    return (x_val_size_dp_map, 
            nocrit_makespan_size_dp_map, nocrit_succprob_size_dp_map, 
            withcrit_makespan_size_dp_map, withcrit_succprob_size_dp_map, 
            nocrit_makespan_sd_size_dp_map, nocrit_succprob_sd_size_dp_map, 
            withcrit_makespan_sd_size_dp_map, withcrit_succprob_sd_size_dp_map)

def find_worst(path: str, param: str, hardware: str, num_clients: int, savefile: bool = False, timestamp=None):
    # Get all .json files matching the criteria
    files = [
        f for f in os.listdir(path)
        if f.endswith(".json") and param in f and hardware in f and f"numclients{num_clients}" in f
    ]

    # Initialize variables to track the worst files
    worst_makespan_diff = float('-inf')  # Track the largest positive difference (worst case)
    worst_makespan_file = None
    worst_makespan_data = None

    worst_succprob_diff = float('inf')  # Track the smallest (most negative) difference (worst case)
    worst_succprob_file = None
    worst_succprob_data = None

    # Process each file
    for file in files:
        file_path = os.path.join(path, file)
        data = load_data(file_path)

        # Calculate average makespan and success probability differences
        avg_makespan_diff_prog1 = 0
        avg_makespan_diff_prog2 = 0
        avg_succprob_diff_prog1 = 0
        avg_succprob_diff_prog2 = 0

        for dp in data.data_points:
            # Calculate differences for prog1 and prog2
            # Makespan: withcrit should be smaller than nocrit (negative difference is better)
            avg_makespan_diff_prog1 += dp.prog1_withcrit_makespan - dp.prog1_nocrit_makespan
            avg_makespan_diff_prog2 += dp.prog2_withcrit_makespan - dp.prog2_nocrit_makespan

            # Success probability: withcrit should be higher than nocrit (positive difference is better)
            avg_succprob_diff_prog1 += dp.prog1_withcrit_succ_prob - dp.prog1_nocrit_succ_prob
            avg_succprob_diff_prog2 += dp.prog2_withcrit_succ_prob - dp.prog2_nocrit_succ_prob

        # Normalize by the number of data points
        num_data_points = len(data.data_points)
        avg_makespan_diff_prog1 /= num_data_points
        avg_makespan_diff_prog2 /= num_data_points
        avg_succprob_diff_prog1 /= num_data_points
        avg_succprob_diff_prog2 /= num_data_points

        # Calculate total average differences
        total_avg_makespan_diff = (avg_makespan_diff_prog1 + avg_makespan_diff_prog2) / 2
        total_avg_succprob_diff = (avg_succprob_diff_prog1 + avg_succprob_diff_prog2) / 2

        # Update worst makespan difference (largest positive difference is worst)
        if total_avg_makespan_diff > worst_makespan_diff:
            worst_makespan_diff = total_avg_makespan_diff
            worst_makespan_file = file
            worst_makespan_data = data

        # Update worst success probability difference (smallest or most negative difference is worst)
        if total_avg_succprob_diff < worst_succprob_diff:
            worst_succprob_diff = total_avg_succprob_diff
            worst_succprob_file = file
            worst_succprob_data = data

    # Print results
    print(f"Worst average makespan difference: {worst_makespan_diff} in file {worst_makespan_file}")
    print(f"Worst average success probability difference: {worst_succprob_diff} in file {worst_succprob_file}")
    create_plots(timestamp,worst_makespan_data,"makespan",std_dev=False, save=saveFile, plot_p1_p2='p2')
    create_plots(timestamp,worst_succprob_data,"succprob",std_dev=False, save=saveFile, plot_p1_p2='p2')
    create_plots(timestamp,worst_succprob_data,"succsec" ,std_dev=False, save=saveFile, plot_p1_p2='p2')
    
def plot_avg(path: str, param: str, hardware: str, scenario: str, num_clients: int = 2, savefile: bool = False, timestamp=None):
    # Get all .json files for the correct parameter and hardware
    files = [f for f in os.listdir(relative_to_cwd(path)) if f[-5:] == ".json" and param in f and hardware in f and scenario in f and f"numclients{num_clients}" in f]

    # Load all of the data objects
    datas = [load_data(path + "/" + f) for f in files]
    avg_data_points = average_data_points(datas)
    datas[0].data_points = avg_data_points
    datas[0].meta.prog_sizes = [3,5]

    # Group DataPoints by prog_size
    from collections import defaultdict
    grouped_data = defaultdict(list)
    for dp in avg_data_points:
        grouped_data[dp.prog_size].append(dp)

    # Calculate average percentage changes for each prog_size
    for prog_size, data_points in grouped_data.items():
        total_prog1_makespan_decrease = 0.0
        total_prog2_makespan_decrease = 0.0
        total_prog1_succprob_increase = 0.0
        total_prog2_succprob_increase = 0.0
        num_data_points = len(data_points)

        # Calculate percentage changes for each DataPoint in the group
        for dp in data_points:
            # Percentage decrease in makespan (withcrit vs. nocrit)
            prog1_makespan_decrease = ((dp.prog1_nocrit_makespan - dp.prog1_withcrit_makespan) / dp.prog1_nocrit_makespan) * 100
            prog2_makespan_decrease = ((dp.prog2_nocrit_makespan - dp.prog2_withcrit_makespan) / dp.prog2_nocrit_makespan) * 100

            # Percentage increase in success probability (withcrit vs. nocrit)
            prog1_succprob_increase = ((dp.prog1_withcrit_succ_prob - dp.prog1_nocrit_succ_prob) / dp.prog1_nocrit_succ_prob) * 100
            prog2_succprob_increase = ((dp.prog2_withcrit_succ_prob - dp.prog2_nocrit_succ_prob) / dp.prog2_nocrit_succ_prob) * 100

            # Add to totals
            total_prog1_makespan_decrease += prog1_makespan_decrease
            total_prog2_makespan_decrease += prog2_makespan_decrease
            total_prog1_succprob_increase += prog1_succprob_increase
            total_prog2_succprob_increase += prog2_succprob_increase

        # Calculate averages for this prog_size
        avg_prog1_makespan_decrease = total_prog1_makespan_decrease / num_data_points
        avg_prog2_makespan_decrease = total_prog2_makespan_decrease / num_data_points
        avg_prog1_succprob_increase = total_prog1_succprob_increase / num_data_points
        avg_prog2_succprob_increase = total_prog2_succprob_increase / num_data_points

        # Print results for this prog_size
        print(f"Program Size: {prog_size}")
        print(f"  prog1: Average Makespan Decrease = {avg_prog1_makespan_decrease:.2f}%")
        print(f"         Average Success Probability Increase = {avg_prog1_succprob_increase:.2f}%")
        print(f"  prog2: Average Makespan Decrease = {avg_prog2_makespan_decrease:.2f}%")
        print(f"         Average Success Probability Increase = {avg_prog2_succprob_increase:.2f}%")
        print()

    # Create plots
    create_plots(timestamp, datas[0], "makespan", std_dev=True, save=savefile, plot_p1_p2='p2')
    create_plots(timestamp, datas[0], "succprob", std_dev=True, save=savefile, plot_p1_p2='p2')
    create_plots(timestamp, datas[0], "succsec", std_dev=True, save=savefile,  plot_p1_p2='p2')

def average_data_points(data_list: List[Data]) -> List[DataPoint]:
    # This dictionary will store the combined data points, keyed by (param_value, prog_size)
    combined_data = defaultdict(lambda: {
        'prog1_nocrit_makespan': [],
        'prog2_nocrit_makespan': [],
        'prog1_withcrit_makespan': [],
        'prog2_withcrit_makespan': [],
        'prog1_nocrit_succ_prob': [],
        'prog2_nocrit_succ_prob': [],
        'prog1_withcrit_succ_prob': [],
        'prog2_withcrit_succ_prob': [],
        'count': 0
    })
    
    # Iterate through all data points and aggregate
    for data in data_list:
        for data_point in data.data_points:
            key = (data_point.param_value, data_point.prog_size)
            
            # Append values to the lists for each key
            combined_data[key]['prog1_nocrit_makespan'].append(data_point.prog1_nocrit_makespan)
            combined_data[key]['prog2_nocrit_makespan'].append(data_point.prog2_nocrit_makespan)
            combined_data[key]['prog1_withcrit_makespan'].append(data_point.prog1_withcrit_makespan)
            combined_data[key]['prog2_withcrit_makespan'].append(data_point.prog2_withcrit_makespan)
            combined_data[key]['prog1_nocrit_succ_prob'].append(data_point.prog1_nocrit_succ_prob)
            combined_data[key]['prog2_nocrit_succ_prob'].append(data_point.prog2_nocrit_succ_prob)
            combined_data[key]['prog1_withcrit_succ_prob'].append(data_point.prog1_withcrit_succ_prob)
            combined_data[key]['prog2_withcrit_succ_prob'].append(data_point.prog2_withcrit_succ_prob)
            combined_data[key]['count'] += 1
    # Now, calculate the averages and standard deviations
    averaged_data_points = []
    
    for key, values in combined_data.items():
        count = values['count']
        if count > 0:
            # Calculate means
            prog1_nocrit_makespan_mean = sum(values['prog1_nocrit_makespan']) / count
            prog2_nocrit_makespan_mean = sum(values['prog2_nocrit_makespan']) / count
            prog1_withcrit_makespan_mean = sum(values['prog1_withcrit_makespan']) / count
            prog2_withcrit_makespan_mean = sum(values['prog2_withcrit_makespan']) / count
            prog1_nocrit_succ_prob_mean = sum(values['prog1_nocrit_succ_prob']) / count
            prog2_nocrit_succ_prob_mean = sum(values['prog2_nocrit_succ_prob']) / count
            prog1_withcrit_succ_prob_mean = sum(values['prog1_withcrit_succ_prob']) / count
            prog2_withcrit_succ_prob_mean = sum(values['prog2_withcrit_succ_prob']) / count
            
            # Calculate standard deviations
            def calculate_sd(data_points, mean):
                if count <= 1:
                    return 0.0  # Standard deviation is undefined for a single data point
                squared_diffs = [(x - mean) ** 2 for x in data_points]
                return math.sqrt(sum(squared_diffs) / count)
            
            prog1_nocrit_makespan_sd = calculate_sd(values['prog1_nocrit_makespan'], prog1_nocrit_makespan_mean)
            prog2_nocrit_makespan_sd = calculate_sd(values['prog2_nocrit_makespan'], prog2_nocrit_makespan_mean)
            prog1_withcrit_makespan_sd = calculate_sd(values['prog1_withcrit_makespan'], prog1_withcrit_makespan_mean)
            prog2_withcrit_makespan_sd = calculate_sd(values['prog2_withcrit_makespan'], prog2_withcrit_makespan_mean)
            prog1_nocrit_succ_prob_sd = calculate_sd(values['prog1_nocrit_succ_prob'], prog1_nocrit_succ_prob_mean)
            prog2_nocrit_succ_prob_sd = calculate_sd(values['prog2_nocrit_succ_prob'], prog2_nocrit_succ_prob_mean)
            prog1_withcrit_succ_prob_sd = calculate_sd(values['prog1_withcrit_succ_prob'], prog1_withcrit_succ_prob_mean)
            prog2_withcrit_succ_prob_sd = calculate_sd(values['prog2_withcrit_succ_prob'], prog2_withcrit_succ_prob_mean)
            
            # Create a new DataPoint with averages and standard deviations
            averaged_data_point = DataPoint(
                prog1_nocrit_makespan=prog1_nocrit_makespan_mean,
                prog2_nocrit_makespan=prog2_nocrit_makespan_mean,
                prog1_withcrit_makespan=prog1_withcrit_makespan_mean,
                prog2_withcrit_makespan=prog2_withcrit_makespan_mean,
                prog1_nocrit_succ_prob=prog1_nocrit_succ_prob_mean,
                prog2_nocrit_succ_prob=prog2_nocrit_succ_prob_mean,
                prog1_withcrit_succ_prob=prog1_withcrit_succ_prob_mean,
                prog2_withcrit_succ_prob=prog2_withcrit_succ_prob_mean,
                prog1_nocrit_makespan_sd=prog1_nocrit_makespan_sd,
                prog2_nocrit_makespan_sd=prog2_nocrit_makespan_sd,
                prog1_withcrit_makespan_sd=prog1_withcrit_makespan_sd,
                prog2_withcrit_makespan_sd=prog2_withcrit_makespan_sd,
                prog1_nocrit_succ_prob_sd=prog1_nocrit_succ_prob_sd,
                prog2_nocrit_succ_prob_sd=prog2_nocrit_succ_prob_sd,
                prog1_withcrit_succ_prob_sd=prog1_withcrit_succ_prob_sd,
                prog2_withcrit_succ_prob_sd=prog2_withcrit_succ_prob_sd,
                prog_size=key[1],  # prog_size from key
                param_name=data_list[0].meta.param_name,  # assuming same param_name for all Data objects
                param_value=key[0]  # param_value from key
            )
            averaged_data_points.append(averaged_data_point)
    
    return averaged_data_points


def load_data(path: str) -> Data:
    with open(relative_to_cwd(path), "r") as f:
        all_data = json.load(f)
    
    return dacite.from_dict(Data, all_data)

import matplotlib.pyplot as plt
import numpy as np

def create_plots(timestamp, data: Data, plottype: str, std_dev: bool = False, plot_p1_p2: str = 'both', save=True):
    meta = data.meta
    prog_sizes = meta.prog_sizes
    x_val_map, nocrit_makespan_map, nocrit_succprob_map, withcrit_makespan_map, withcrit_succprob_map, \
    nocrit_makespan_sd_map, nocrit_succprob_sd_map, withcrit_makespan_sd_map, withcrit_succprob_sd_map = get_vals(data)

    label_fontsize = 14

    if plottype == "makespan" or plottype == "":
        for key in x_val_map.keys():
            for i in range(0, 2):
                nocrit_makespan_map[key][i] =    np.array(nocrit_makespan_map[key][i])/1e6
                withcrit_makespan_map[key][i] =    np.array(withcrit_makespan_map[key][i])/1e6
                nocrit_makespan_sd_map[key][i] = np.array(nocrit_makespan_sd_map[key][i])/1e6
                withcrit_makespan_sd_map[key][i] = np.array(withcrit_makespan_sd_map[key][i])/1e6
            
            if plot_p1_p2 in ['both', 'p1']:  # Plot for p1
                for i in range(0, 1):  # Only plot p1 (index 0)
                    if std_dev:
                        plt.errorbar(
                            x_val_map[key], 
                            nocrit_makespan_map[key][i], 
                            yerr=nocrit_makespan_sd_map[key][i],  # Standard deviation for nocrit
                            fmt='-s', label=f"w\o crit  p1, n={key}",
                            capsize=5, linestyle='-', markersize=5
                        )
                        plt.errorbar(
                            x_val_map[key], 
                            withcrit_makespan_map[key][i], 
                            yerr=withcrit_makespan_sd_map[key][i],  # Standard deviation for withcrit
                            fmt='-d', label=f"w p1, n={key}",
                            capsize=5, linestyle='-', markersize=5
                        )
                    else:
                        plt.plot(
                            x_val_map[key], 
                            [val for val in nocrit_makespan_map[key][i]], 
                            label=f"w\o crit  p1, n={key}", 
                            marker="o", linestyle='-', markersize=5
                        )
                        plt.plot(
                            x_val_map[key], 
                            [val for val in withcrit_makespan_map[key][i]], 
                            label=f"w p1, n={key}", 
                            marker="*", linestyle='-', markersize=5
                        )

            if plot_p1_p2 in ['both', 'p2']:  # Plot for p2
                for i in range(1, 2):  # Only plot p2 (index 1)
                    if std_dev:
                        plt.errorbar(
                            x_val_map[key], 
                            nocrit_makespan_map[key][i], 
                            yerr=nocrit_makespan_sd_map[key][i],  # Standard deviation for nocrit
                            fmt='-s', label=f"$C_1$ w\o crit, n={key}",
                            capsize=5, linestyle='-', markersize=5
                        )
                        plt.errorbar(
                            x_val_map[key], 
                            withcrit_makespan_map[key][i], 
                            yerr=withcrit_makespan_sd_map[key][i],  # Standard deviation for withcrit
                            fmt='-d', label=f"$C_1$ w crit, n={key}",
                            capsize=5, linestyle='-', markersize=5
                        )
                    else:
                        plt.plot(
                            x_val_map[key], 
                            [val for val in nocrit_makespan_map[key][i]], 
                            label=f"$C_1$ w\o crit, n={key}", 
                            marker="o", linestyle='-', markersize=5
                        )
                        plt.plot(
                            x_val_map[key], 
                            [val for val in withcrit_makespan_map[key][i]], 
                            label=f"$C_1$ w crit, n={key}", 
                            marker="*", linestyle='-', markersize=5
                        )

        # plt.ylim(0.9e8, 2.5e8)
        plt.legend(loc='upper center', bbox_to_anchor=(0.5, 1.1), ncol=2, fancybox=True, shadow=True, fontsize=10)
        # plt.legend(loc="upper right", fontsize=11)
        plt.ylabel("Avg Makespan (ms)", fontsize=label_fontsize)
        plt.xlabel("Bin length (multiple of expected entanglement generation time)", fontsize=label_fontsize-1)
        
        if save:
            create_pdf(timestamp + "_" + meta.param_name + "_makespan_numclients"+ str(meta.num_clients))
        else:
            plt.show()
        plt.cla()

    if plottype == "succprob" or plottype == "":
        plt.ylim(0.0, 1.01)
        for key in x_val_map.keys():
            if plot_p1_p2 in ['both', 'p1']:  # Plot for p1
                for i in range(0, 1):  # Only plot p1 (index 0)
                    if std_dev:
                        plt.errorbar(
                            x_val_map[key], 
                            nocrit_succprob_map[key][i], 
                            yerr=nocrit_succprob_sd_map[key][i],  # Standard deviation for nocrit
                            fmt='-s', label=f"w\o crit  p1, n={key}",
                            capsize=5, linestyle='-', markersize=5
                        )
                        plt.errorbar(
                            x_val_map[key], 
                            withcrit_succprob_map[key][i], 
                            yerr=withcrit_succprob_sd_map[key][i],  # Standard deviation for withcrit
                            fmt='-d', label=f"w p1, n={key}",
                            capsize=5, linestyle='-', markersize=5
                        )
                    else:
                        plt.plot(
                            x_val_map[key], 
                            [val for val in nocrit_succprob_map[key][i]], 
                            label=f"w\o crit  p1, n={key}", 
                            marker="o", linestyle='-', markersize=5
                        )
                        plt.plot(
                            x_val_map[key], 
                            [val for val in withcrit_succprob_map[key][i]], 
                            label=f"w p1, n={key}", 
                            marker="*", linestyle='-', markersize=5
                        )

            if plot_p1_p2 in ['both', 'p2']:  # Plot for p2
                for i in range(1, 2):  # Only plot p2 (index 1)
                    if std_dev:
                        plt.errorbar(
                            x_val_map[key], 
                            nocrit_succprob_map[key][i], 
                            yerr=nocrit_succprob_sd_map[key][i],  # Standard deviation for nocrit
                            fmt='-s', label=f"$C_1$ w\o crit, n={key}",
                            capsize=5, linestyle='-', markersize=5
                        )
                        plt.errorbar(
                            x_val_map[key], 
                            withcrit_succprob_map[key][i], 
                            yerr=withcrit_succprob_sd_map[key][i],  # Standard deviation for withcrit
                            fmt='-d', label=f"$C_1$ w crit, n={key}",
                            capsize=5, linestyle='-', markersize=5
                        )
                    else:
                        plt.plot(
                            x_val_map[key], 
                            [val for val in nocrit_succprob_map[key][i]], 
                            label=f"$C_1$ w\o crit, n={key}", 
                            marker="o", linestyle='-', markersize=5
                        )
                        plt.plot(
                            x_val_map[key], 
                            [val for val in withcrit_succprob_map[key][i]], 
                            label=f"$C_1$ w crit, n={key}", 
                            marker="*", linestyle='-', markersize=5
                        )

        plt.legend(loc='upper center', bbox_to_anchor=(0.5, 1.1), ncol=2, fancybox=True, shadow=True, fontsize=10)
        # plt.legend(loc="upper right", fontsize=11)
        plt.ylabel("Success Probability", fontsize=label_fontsize)
        plt.xlabel("Bin length (multiple of expected entanglement generation time)", fontsize=label_fontsize-1)

        if save:
            create_pdf(timestamp + "_" + meta.param_name + "_succprob_numclients"+ str(meta.num_clients))
        else:
            plt.show()
        plt.cla()

if __name__ == "__main__":
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    parser = ArgumentParser()
    parser.add_argument("--folder", "-f", type=str, required=True)
    parser.add_argument("--save", "-s", action="store_true", default=False)
    parser.add_argument("--params", type=str, nargs="+", required=True)
    parser.add_argument("--hardware", type=str, nargs="+", required=True)
    parser.add_argument("--scenario", type=str, nargs="+", required=True)
    parser.add_argument("--num_clients", "-c", type=int, default=2)
    
    args = parser.parse_args()
    folder = args.folder
    saveFile = args.save
    params = args.params
    hardware = args.hardware
    scenarios = args.scenario
    num_clients = args.num_clients
    
    for param in params:
        for hw in hardware:
            for scenario in scenarios:
                # find_worst(folder, param, hw, num_clients, saveFile, timestamp)
                plot_avg(folder, param, hw, scenario, num_clients, saveFile, timestamp)

        
