import os
import sys
sys.path.insert(1, os.path.abspath(os.getcwd()))
from pprint import pprint
from multiprocessing import Pool, cpu_count
import pickle
import pandas as pd
import circuit_simulation.stabilizer_measurement_protocols.stabilizer_measurement_protocols as stab_protocols
from circuit_simulation.stabilizer_measurement_protocols.argument_parsing import compose_parser, group_arguments
from circuit_simulation.gates.gates import *
from circuit_simulation.circuit_simulator import QuantumCircuit
from utilities.files import get_full_path
import itertools as it
import time
import random
from plot_results.analyse_and_group_data import confidence_interval
from circuit_simulation.termcolor.termcolor import cprint
from collections import defaultdict
import numpy as np
from tqdm import tqdm
from copy import copy, deepcopy
import warnings
warnings.filterwarnings('ignore', message='.*Specify dtype option on import or set low_memory=False.*')
SUM_ACCURACY = 7


def print_signature():
    cprint("\nQuantum Circuit Simulator®", color='cyan')
    print("--------------------------")


def create_file_name(filename, **kwargs):
    protocol = kwargs.pop('protocol')
    protocol = protocol if not kwargs['noiseless_swap'] else protocol.strip('_swap')
    filename = "{}{}{}".format(filename, "_" if filename[-1] not in "/_" else "", protocol)

    for key, value in kwargs.items():
        # Do not include if value is None, 0 or np.inf (default cut_off_time) or if key is pulse_duration
        if not value or value == np.inf or key in ['pulse_duration', '_node']:
            continue
        if value is True:
            value = ""
        value = value.capitalize() if type(value) == str else str(value)
        filename += "_" + str(key) + value

    return filename.strip('_')


def _get_cut_off_dataframe(file: str):
    if file is None:
        return
    if file.lower() == 'auto':
        return file
    if not os.path.exists(file):
        raise ValueError('File containing the cut-off times could not be found!')

    return pd.read_csv(file, sep=";", float_precision='round_trip')


def _get_cut_off_time(dataframe, run_dict, circuit_args, **kwargs):
    cut_off_time = run_dict.pop('cut_off_time')

    if cut_off_time != np.inf or dataframe is None:
        return cut_off_time, False

    file_n = create_file_name(kwargs['csv_filename'], dec=circuit_args['decoherence'], prob=circuit_args['probabilistic'],
                              node=run_dict['_node'], decoupling=run_dict['pulse_duration'],
                              noiseless_swap=circuit_args['noiseless_swap'], **run_dict)
    file_n = os.path.join(get_full_path(kwargs['cp_path']), file_n)
    if os.path.exists(file_n + '.csv'):
        data = pd.read_csv(file_n + '.csv', sep=";", float_precision="round_trip")
        if data.loc[0, 'written_to']*1.05 > circuit_args['iterations']:
            print('[INFO] Found cutoff time value from file: {}'.format(data.loc[0, 'dur_99']))
            return data.loc[0, 'dur_99'], False

    # This makes sure the cut-off time in the "auto" mode is calculated with at least 10000 iterations of the circuit:
    circuit_args['iterations'] += 10000 - circuit_args['iterations'] if circuit_args['iterations'] < 10000 else 0
    print('No records for cutoff time found or not enough iterations. First running for:\n{}'.format(file_n))
    return np.inf, True


def _open_existing_results_dataframe(filename, addition=""):
    if filename is None:
        return
    if not os.path.exists(filename + addition):
        return

    existing_file = pd.read_csv(filename + addition, sep=';', float_precision='round_trip')

    return existing_file


def _init_random_seed(timestamp=None, worker=0, iteration=0):
    if timestamp is None:
        timestamp = time.time()
    seed = int("{:.0f}".format(timestamp * 10 ** 7) + str(worker) + str(iteration))
    random.seed(float(seed))
    return seed


def add_column_values(dataframe, columns, values):
    for column, value in zip(columns, values):
        dataframe[column] = None
        dataframe.iloc[0, dataframe.columns.get_loc(column)] = value


def _combine_results_dataframes(dataframe_1, dataframe_2):
    """
        Combines two dataframes with results into one dataframe

        Parameters
        ----------
        dataframe_1 : pd.DataFrame or None
            Dataframe to be combined
        dataframe_2 : pd.DataFrame or None
            Dataframe to be combined
    """
    if dataframe_1 is None and dataframe_2 is None:
        return None
    if dataframe_1 is None:
        return dataframe_2
    if dataframe_2 is None:
        return dataframe_1

    new_df = copy(dataframe_1) if dataframe_1.shape[0] > dataframe_2.shape[0] else copy(dataframe_2)
    other_df = copy(dataframe_2) if dataframe_1.shape[0] > dataframe_2.shape[0] else copy(dataframe_1)

    # First combine the total amount of iterations, such that it can be used later
    written_to_original = new_df.iloc[0, new_df.columns.get_loc("written_to")]
    written_to_new = other_df.iloc[0, other_df.columns.get_loc("written_to")]
    corrected_written_to = written_to_new + written_to_original
    new_df.iloc[0, new_df.columns.get_loc("written_to")] = corrected_written_to

    # Update the average of the other system characteristics
    new_df['total_duration'] = (new_df['total_duration'] + other_df['total_duration'])
    new_df['total_lde_attempts'] = (new_df['total_lde_attempts'] + other_df['total_lde_attempts'])

    new_df['avg_lde_attempts'] = new_df['total_lde_attempts'] / corrected_written_to
    new_df['avg_duration'] = new_df['total_duration'] / corrected_written_to
    if 'dur_99' in new_df:
        new_df['dur_99'] = (new_df['dur_99'].mul(written_to_original) +
                            other_df['dur_99'].mul(written_to_new)) / corrected_written_to

    # Update fidelity
    other_df['ghz_fidelity'] = other_df['ghz_fidelity'].mul(written_to_new)
    new_df['ghz_fidelity'] = new_df['ghz_fidelity'].mul(written_to_original)

    new_df['ghz_fidelity'] = (new_df['ghz_fidelity'] + other_df['ghz_fidelity']) / corrected_written_to
    new_df = new_df[(new_df.T.applymap(lambda x: x != 0 and x is not None and not pd.isna(x))).any()]

    return new_df


def add_decoherence_if_cut_off(qc: QuantumCircuit):
    if qc.cut_off_time < np.inf and not qc.cut_off_time_reached:
        waiting_time = qc.cut_off_time - qc.total_duration
        if waiting_time > 0:
            qc._increase_duration(waiting_time, [], involved_nodes=list(qc.nodes.keys()), check=False)
            qc.end_current_sub_circuit(total=True, duration=waiting_time, sub_circuit="Waiting", apply_decoherence=True)


def _additional_qc_arguments(**kwargs):
    additional_arguments = {
        'noise': True,
        'basis_transformation_noise': False,
        'thread_safe_printing': True,
        'no_single_qubit_error': True
    }
    kwargs.update(additional_arguments)
    return kwargs


def print_circuit_parameters(operational_args, circuit_args, varational_circuit_args):
    print('\n' + 80*'#')
    for args_name, args_values in locals().items():
        print("\n{}:\n-----------------------".format(args_name.capitalize()))
        pprint(args_values)
    print('\n' + 80*'#' + '\n')


def additional_parsing_of_arguments(**args):
    # Pop the argument_file since it is no longer needed from this point
    args.pop("argument_file")

    # THIS IS NOT GENERIC, will error when directories are moved or renamed
    file_dir = os.path.dirname(__file__)
    look_up_table_dir = os.path.join(file_dir, '../gates', 'gate_lookup_tables')

    if args['single_qubit_gate_lookup'] is not None:
        with open(os.path.join(look_up_table_dir, args['single_qubit_gate_lookup']), 'rb') as obj:
            args['single_qubit_gate_lookup'] = pickle.load(obj)

    if args['two_qubit_gate_lookup'] is not None:
        with open(os.path.join(look_up_table_dir, args['two_qubit_gate_lookup']), "rb") as obj2:
            args['two_qubit_gate_lookup'] = pickle.load(obj2)

    gate_duration_file = args.get('gate_duration_file')
    if gate_duration_file is not None and os.path.exists(gate_duration_file):
        set_gate_durations_from_file(gate_duration_file)
    elif gate_duration_file is not None:
        raise ValueError("Cannot find file to set gate durations with. File path: {}"
                         .format(os.path.abspath(gate_duration_file)))

    return args


def _save_results_dataframe(fn, characteristics, succeeded, cut_off):
    succeeded = _add_interval_to_dataframe(succeeded, characteristics)

    if fn:
        # Save pickle the characteristics file
        if os.path.exists(fn + '.pkl') and characteristics:
            characteristics_old = pickle.load(open(fn + '.pkl', 'rb'))
            [characteristics[key].extend(value) for key, value in characteristics_old.items() if key != 'index']
        pickle.dump(characteristics, file=open(fn + '.pkl', 'wb+')) if characteristics else None

        # Save the superoperators to a csv file
        for result, fn_add in zip([succeeded, cut_off], ['.csv', '_failed.csv']):
            fn_new = fn + fn_add
            existing_file = _open_existing_results_dataframe(fn_new)
            result = _combine_results_dataframes(result, existing_file)
            if result is not None:
                result.to_csv(fn_new, sep=';', index=False)


def _add_interval_to_dataframe(dataframe, characteristics):
    if dataframe is not None:
        add_column_values(dataframe, ['dur_99'],
                          [confidence_interval(characteristics['dur'], 0.98)[1]])
    return dataframe


def main_threaded(*, iterations, fn, **kwargs):
    # Run main method asynchronously with each worker getting an equal amount of iterations to run
    results = []
    workers = iterations if 0 < iterations < cpu_count() else cpu_count()
    thread_pool = Pool(workers)
    iterations, remaining_iterations = divmod(iterations, workers)
    kwargs['iter_pw'] = iterations

    for worker in range(1, workers + 1):
        thr_kwargs = deepcopy(kwargs)
        thr_kwargs['calc_id'] = (worker - 1)
        thr_kwargs['iterations'] = iterations + remaining_iterations * int(worker == workers)
        results.append(thread_pool.apply_async(main, kwds=thr_kwargs))
    thread_pool.close()

    # Collect all the results from the workers
    succeeded = None
    cut_off = None
    print_lines_results = []
    tot_characteristics = defaultdict(list)
    for res in results:
        (succeeded_res, cut_off_res), print_lines, characteristics = res.get()
        succeeded = _combine_results_dataframes(succeeded, succeeded_res)
        cut_off = _combine_results_dataframes(cut_off, cut_off_res)
        print_lines_results.extend(print_lines)
        [tot_characteristics[key].extend(value) for key, value in characteristics.items()]

    print(*print_lines_results)

    # Save superoperator dataframe to csv if exists and requested by user
    _save_results_dataframe(fn, tot_characteristics, succeeded, cut_off)


def main_series(fn, **kwargs):
    pbar_2 = tqdm(total=kwargs['iterations']) if kwargs.get('progress_bar') else None
    (succeeded, cut_off), print_lines, characteristics = main(pbar_2=pbar_2, **kwargs)
    print(*print_lines)
    if not kwargs['draw_circuit']:
        print(f"Durations: {characteristics['dur']}.")
        print(f"GHZ fidelities: {characteristics['ghz_fid']}.\n")

    # Save the superoperator to the according csv files (options: normal, cut-off)
    _save_results_dataframe(fn, characteristics, succeeded, cut_off)


def main(*, iterations, protocol, stabilizer_type, threaded=False, gate_duration_file=None, cutoff_search=False,
         color=False, draw_circuit=True, save_latex_pdf=False, to_console=False, pbar_2=None, seed_number=None,
         calc_id=0, iter_pw=0, **kwargs):
    seeds_used = [*range(calc_id * iter_pw, calc_id * iter_pw + iterations)]
    results_dataframe_failed = None
    results_dataframe_succeed = None
    total_print_lines = []
    characteristics = {'dur': [], 'ghz_fid': []}

    # Progress bar initialisation
    pbar = None
    if pbar_2:
        # Second bar not working properly within PyCharm. Uncomment when using in normal terminal
        pass
        #pbar = tqdm(total=100, position=1, desc='Current circuit simulation')

    # Set the gate durations (when threaded, each thread needs its own modified copy of the gate duration file)
    if threaded:
        set_gate_durations_from_file(gate_duration_file)

    # Get the QuantumCircuit object corresponding to the protocol and the protocol method by its name
    kwargs = _additional_qc_arguments(**kwargs)
    qc = stab_protocols.create_quantum_circuit(protocol, pbar, **kwargs)
    protocol_method = getattr(stab_protocols, protocol)

    # Run iterations of the protocol
    for iter in range(iterations):
        pbar.reset() if pbar else None
        if pbar_2 is not None:
            pbar_2.update(1) if pbar_2 else None
        elif not kwargs['progress_bar']:
            pass
            # print(">>> At iteration {}/{}.".format(iter + 1, iterations), end='\r', flush=True)

        if seed_number is not None:
            random.seed(int(seed_number))
        else:
            random.seed(int(seeds_used[iter]))
            # _init_random_seed(worker=threading.get_ident(), iteration=iter)

        # Run the user requested protocol
        operation = CZ_gate if stabilizer_type == "Z" else CNOT_gate
        protocol_method(qc, operation=operation)
        qc.end_current_sub_circuit(total=True, forced_level=True, apply_decoherence=True)
        add_decoherence_if_cut_off(qc)

        qc.draw_circuit(no_color=not color, color_nodes=True) if draw_circuit else None
        qc.draw_circuit_latex() if save_latex_pdf else None

        results_dataframe = qc._create_results_dataframe(protocol)

        pbar.update(10) if pbar is not None else None

        if not qc.cut_off_time_reached:
            characteristics['dur'] += [qc.total_duration]
            characteristics['ghz_fid'] += [qc.ghz_fidelity]

        # Fuse the superoperator dataframes obtained in each iteration
        if qc.cut_off_time_reached:
            results_dataframe_failed = _combine_results_dataframes(results_dataframe_failed, results_dataframe)
        else:
            results_dataframe_succeed = _combine_results_dataframes(results_dataframe_succeed, results_dataframe)

        total_print_lines.extend(qc.print_lines)
        total_print_lines.append("\nGHZ fidelity: {} ".format(qc.ghz_fidelity)) if draw_circuit else None
        total_print_lines.append("\nTotal circuit duration: {} s".format(qc.total_duration)) if draw_circuit else None
        qc.reset()

    pbar_2.close() if pbar_2 else None
    pbar.close() if pbar is not None else None
    return (results_dataframe_succeed, results_dataframe_failed), total_print_lines, characteristics


def run_for_arguments(operational_args, circuit_args, var_circuit_args, **kwargs):
    filenames = []
    fn = None
    cut_off_dataframe = _get_cut_off_dataframe(operational_args['cut_off_file'])

    node_df = pd.read_csv(get_full_path("circuit_simulation/node/nodes.csv"), sep=";", float_precision="round_trip",
                          index_col="name")
    node_dict = node_df.drop(columns="nickname").to_dict("index")
    parameters_in_df = list(node_dict[list(node_dict.keys())[0]].keys())

    iterations = circuit_args['iterations']
    var_circuit_args['seed_number'] = [None] if iterations > 1 else var_circuit_args['seed_number']

    # Loop over command line arguments
    for run in it.product(*(it.product([key], var_circuit_args[key]) for key in var_circuit_args.keys())):
        count = 0
        run_dict = dict(run)

        # Set run_dict values based on circuit arguments
        run_dict['lde_success'] = run_dict['lde_success'] if circuit_args['probabilistic'] else 1
        run_dict['fixed_lde_attempts'] = run_dict['fixed_lde_attempts'] if run_dict['pulse_duration'] > 0 else 0
        run_dict['pm'] = (run_dict['pg'] if circuit_args['pm_equals_pg'] else run_dict['pm'])
        run_dict['protocol'] = (run_dict['protocol'] + "_swap" if circuit_args['use_swap_gates']
                                else run_dict['protocol'])

        check = dict((k, circuit_args[k]) for k in parameters_in_df if k in circuit_args)
        try:
            node_name = list(node_dict.keys())[list(node_dict.values()).index(check)]
        except:
            node_name = "Unspec" + str(len(node_dict))
            row_to_add = [["Unspecified"] + list(check.values())]
            column_names = ["nickname"] + list(check.keys())
            index_to_add = [node_name]
            node_df = node_df.append(pd.DataFrame(row_to_add, columns=column_names, index=index_to_add))
            node_df.to_csv(get_full_path("circuit_simulation/node/nodes.csv"), sep=";", index_label="name")
            node_dict[node_name] = check
        run_dict['_node'] = node_name
        # run_dict['_node'] = node[circuit_args['T1_lde']]

        # If cutoff time is not found in auto mode, it first does simulations to find this and then reruns with cutoff time
        while (run_dict['cut_off_time'] == np.inf and cut_off_dataframe == 'auto') or count == 0:
            count += 1
            circuit_args['iterations'] = iterations
            run_dict['cut_off_time'], circuit_args['cutoff_search'] = _get_cut_off_time(cut_off_dataframe, run_dict,
                                                                                        circuit_args,
                                                                                        **operational_args)

            if operational_args['csv_filename']:
                # Create parameter specific filename
                fn = create_file_name(operational_args['csv_filename'], dec=circuit_args['decoherence'],
                                      prob=circuit_args['probabilistic'], node=run_dict['_node'],
                                      decoupling=run_dict['pulse_duration'],
                                      noiseless_swap=circuit_args['noiseless_swap'], **run_dict)

                fn = os.path.join(get_full_path(operational_args['cp_path']), fn)
                filenames.append(fn) if not (run_dict['cut_off_time'] == np.inf and cut_off_dataframe == 'auto') else \
                    None

                # Check if parameter settings has not yet been evaluated, else skip
                if not operational_args['force_run'] and fn is not None and os.path.exists(fn + ".csv"):
                    data = pd.read_csv(fn + '.csv', sep=";", float_precision='round_trip')
                    res_iterations = int(circuit_args['iterations'] - data.loc[0, 'written_to'])
                    # iterations within 5% margin
                    if not circuit_args['probabilistic'] or circuit_args['iterations'] * 0.05 >= res_iterations:
                        print("\n[INFO] Skipping circuit for file '{}', since it already exists.".format(fn))
                        continue
                    else:
                        print("\nFile found with too less iterations. Running for {} iterations\n".format(
                            res_iterations))
                        circuit_args['iterations'] = res_iterations

            print("\nRunning {} iteration(s) with values for the variational arguments:"
                  .format(circuit_args['iterations']))
            pprint({**run_dict})

            if operational_args['threaded']:
                main_threaded(fn=fn, **operational_args, **run_dict, **circuit_args)
            else:
                main_series(fn=fn, **operational_args, **run_dict, **circuit_args)

    return filenames


if __name__ == "__main__":
    parser = compose_parser()
    args = vars(parser.parse_args())
    args = additional_parsing_of_arguments(**args)

    grouped_arguments = group_arguments(parser, **args)
    print_signature()
    print_circuit_parameters(*grouped_arguments)

    # Loop over all possible combinations of the user determined parameters
    run_for_arguments(*grouped_arguments, **args)
