from __future__ import annotations

import os

from Quantum_Network_Architecture.tasks import PacketGenerationTask,Taskset
from Quantum_Network_Architecture.networks.nodes import NodeList

from typing import Optional

# from program_scheduling.program_scheduling.network_schedule import NetworkSchedule as NodeNetworkSchedule

# from qoala.lang.ehi import EhiNetworkSchedule

import yaml

from rich.table import Table

import pandas as pd
from pandas import DataFrame

from tqdm import tqdm


class NetworkSchedule:


    def __init__(self, tasks: Taskset, status, schedule_length, comp_time = -1, nodes: NodeList | None = None):
        self.status = status

        self.computation_time = comp_time

        self.keys = [T.identifier for T in tasks]

        self.taskset: Taskset = tasks
        self.tasks = tasks.lexicographical_order()
        self.end_times = {T.identifier: T.end_times for T in tasks}
        self.start_times = {T.identifier: T.start_times for T in tasks}
        self.number_scheduled = {T.identifier: len(T.start_times) for T in tasks}


        used_nodes = set(n for t in self.tasks for n in [t.node1, t.node2])

        if nodes is None or set(nodes).issubset(used_nodes):
            self._nodes = NodeList(used_nodes)
        else:

            self._nodes = nodes

        self._nodes.sort()

        self.task_dict = {T.identifier:T for T in tasks}


        self._length_of_schedule: int = schedule_length

        self.metric = -1
        self.delivered_rates = []
        self.jitter = -1


    @classmethod
    def empty(cls):
        return cls(Taskset(),'',-1)

    @classmethod
    def from_minizinc(cls, minizinc_output_dict: dict = None, minizinc_termination_data: dict = None, status: str = '', tasks: {str:PacketGenerationTask} = None, nodes = None):

        try:

            computation_time = float(minizinc_termination_data["initTime"]) + float(
                minizinc_termination_data["solveTime"])

        except (KeyError, TypeError):
            computation_time = -1


        schedule_length = minizinc_output_dict["schedule length"]

        start_times = minizinc_output_dict["Start times"]
        end_times = minizinc_output_dict["End times"]
        # number_scheduled = minizinc_output_dict["Number Scheduled"]


        for task in tasks.keys():
            temp: PacketGenerationTask = tasks[task]

            temp.start_times = start_times[task]
            temp.end_times = end_times[task]

            temp.deadlines = [i*temp.period for i in range(1,schedule_length//temp.period+1)]



        return cls(Taskset(tasks.values()), status=status, schedule_length=schedule_length, comp_time=computation_time, nodes=nodes)



    # def export_to_node_network_schedule(self, nodes: [str], min_time: int | None = None, max_time: int | None = None) -> NodeNetworkSchedule:
    #
    #     node_schedule = self.get_schedule_by_nodes(nodes_to_show=nodes)
    #
    #
    #
    #     if min_time is None:
    #         min_time = 0
    #
    #     if max_time is None:
    #         max_time = self.length_of_schedule
    #
    #     node_network_schedules = []
    #
    #     for node in nodes:
    #         identifiers = []
    #         start_times = []
    #         for start, _, identifier in node_schedule[node]:
    #             if start < min_time:
    #                 continue
    #             if start > max_time:
    #                 break
    #
    #             identifiers.append(identifier)
    #             start_times.append(start)
    #
    #         node_network_schedules.append(NodeNetworkSchedule(dataset_id=-1, n_sessions=len(identifiers), sessions=identifiers, start_times=start_times))
    #
    #
    #
    #
    #     pass

    def export_lengths_of_qc_blocks(self, length_of_time_slot: float = 1.0):
        """
        Exports the length of QC blocks / PGAs in format required by Qoala EhiNetworkSchedule
        :param length_of_time_slot:
        :return:
        """
        t:PacketGenerationTask

        return {t.session_tuple:t.execution_time*length_of_time_slot for t in self.taskset}

    def export_node_schedule_to_list(self):
        _output = [[] for _ in range(self.length_of_schedule)]

        for task in self.taskset:
            task: PacketGenerationTask
            for t in task.start_times:
                _output[t].append(task.session_tuple)

        _output = [x if len(x) != 0 else [(-1,-1,-1,-1)] for x in _output]

        if max([len(i) for i in _output]) == 1:
            _output = [x for y in _output for x in y] #If every timebin has only one session then return as list rather than as list of lists



        return _output

    @property
    def length_of_schedule(self):
        return self._length_of_schedule

    @property
    def nodes(self):
        return self._nodes


    def __lt__(self, other):
        return self.metric < other.metric

    def __gt__(self, other):
        return self.metric > other.metric

    def __eq__(self, other):
        return self.metric == other.metric

    def same_times(self, other) -> bool:
        return self.start_times == other.start_times and self.end_times == self.end_times

    def get_schedule_by_tasks(self, tasks = None):

        if tasks is None:
            tasks = [d.identifier for d in self.tasks]

        packet_generation_attempts_by_demand: {str: [(int, int)]} = dict()

        for D in tasks:
            packet_generation_attempts_by_demand[D] = []


            for i in range(self.number_scheduled[D]):
                packet_generation_attempts_by_demand[D].append((self.start_times[D][i], self.end_times[D][i]))

        return packet_generation_attempts_by_demand

    def get_schedule_by_nodes(self, nodes_to_show = None, alias: bool = True, save: bool = False, save_folder: str = "node_schedule_export_default"):
        """
        Returns the network schedule on a node-by-node basis

        :param alias:
        :param nodes_to_show: Which nodes the schedule should be extracted for
        :param save: (opt: False) Whether to save the extracted node schedules as a CSV file. This is done for each node
                                    individually.
        :param save_folder: Where to save the node schedule(s) if save = True
        :return: node schedules in the format {node: [(start_time, end_time, task_id)]}
        """

        #TODO: Add export option to Hana NS?

        if nodes_to_show is None:
            nodes_to_show = self._nodes.node_aliases() if alias else self._nodes.node_ids

        else:  # allows for a mix of ids and aliases to be specified, e.g. [0, 'bob']
            nodes_to_show = [a if isinstance(a, str) else self._nodes.get_alias_from_id(a) for a in nodes_to_show]



        packet_generation_attempts_by_nodes: {str: [(int, int, str)]} = dict() # {node: [(start_time, end_time, task id)}

        for node in nodes_to_show:
            packet_generation_attempts_by_nodes[node] = []

        for task in self.tasks:
            for i in range(self.number_scheduled[task.identifier]):
                if task.node1.alias in nodes_to_show:
                    packet_generation_attempts_by_nodes[task.node1.alias].append((self.start_times[task.identifier][i], self.end_times[task.identifier][i], task.identifier))
                if task.node2.alias in nodes_to_show:
                    packet_generation_attempts_by_nodes[task.node2.alias].append(
                        (self.start_times[task.identifier][i], self.end_times[task.identifier][i], task.identifier))

        for node in nodes_to_show:
            packet_generation_attempts_by_nodes[node].sort()

        if save:
            if not os.path.isdir(save_folder):
                os.makedirs(save_folder)

            for node in packet_generation_attempts_by_nodes.keys():
                with open(os.path.join(save_folder, f"{node}.csv"), 'w') as F:
                    F.write(
                        'session,start_time,end_time\n' + '\n'.join([f'{PGA[2]},{PGA[0]},{PGA[1]}' for PGA in packet_generation_attempts_by_nodes[node]])
                    )


        return packet_generation_attempts_by_nodes if alias else {self._nodes.get_id_from_alias(x):y for x,y in packet_generation_attempts_by_nodes.items()}

    def node_utilisation(self, nodes = None):

        if nodes is None:
            nodes = self._nodes

        node_schedule = self.get_schedule_by_nodes(nodes)

        return {n:get_utilisation(self.length_of_schedule, node_schedule[n]) for n in nodes}

    def demand_utilisation(self, tasks = None):
        if tasks is None:
            tasks = [d.identifier for d in self.tasks]

        demand_schedule = self.get_schedule_by_tasks(tasks)

        return {d:get_utilisation(self.length_of_schedule, [(t[0], t[1] + self.task_dict[d].minimum_separation) for t in demand_schedule[d]]) for d in tasks}

    def node_schedule(self, nodes = None, min_time: int = 0, max_time: int = None, alias: Optional[bool] = False):
        """
        Creates a formatted text representation of the schedule grouped by nodes. The number of time slots per line and
        the width of the time slots is controlled by setting the relevant parameters in display_config/displayConfig.yaml
        :param nodes: IDs of the nodes to produce the schedule for
        :param min_time:  Earliest time to display
        :param max_time:  Latest time to display
        :param alias: Print the alias of the nodes (True) or just the node_id (False, default)
        :return: text representation of the network schedule grouped by nodes, with rich module markup.
        """
        if nodes is None:
            nodes = self._nodes.node_aliases() if alias else self._nodes.node_ids

        else:  # allows for a mix of ids and aliases to be specified, e.g. [0, 'bob']
            nodes = [a if isinstance(a, str) else self._nodes.get_alias_from_id(a) for a in nodes]

        #TODO: check for duplicate nodes. 

        if max_time is None:
            max_time = self.length_of_schedule




        node_schedule = self.get_schedule_by_nodes(nodes)

        line_starts = {node:(node+' (' if alias else '') + str(self._nodes.get_id_from_alias(node))+(')' if alias else '') for node in nodes}

        out_list = []

        name_offset = max([len(n) for n in line_starts.values()])

        for node in nodes:
            out_list.append(_schedule_line_printer(f"{line_starts[node]: >{name_offset}} ", node_schedule[node],
                                                   self.length_of_schedule, min_time=min_time, max_time=max_time))

        out_list.append(get_timeline(self.length_of_schedule, name_offset, min_time=min_time, max_time =min(max_time, self._length_of_schedule)))

        out_list = [out_list[j][k] for k in range(len(out_list[0])) for j in range(len(out_list))]

        if max_time is not None:
            if max_time < self.length_of_schedule:
                out_list.append(f"\n ... schedule continues for a further {self.length_of_schedule - max_time} time steps")

        return '\n'.join(out_list)

    def average_demand_satisfaction_profile(self, t: int | None = None) -> pd.Series:

        if t is None:
            t = self._length_of_schedule

        global_index = pd.Index(range(t))

        df = DataFrame(index=global_index, columns=self.taskset.ids)

        prog = tqdm(total=len(self.tasks), desc="Task")

        for task in self.taskset.lexicographical_order():
            task: PacketGenerationTask

            profile = task.instantaneous_rate_satisfaction_profile(t)

            index = [t for t, _ in profile]
            data = [x for _, x in profile]

            df[task.identifier] = pd.Series(data=data, index=index)
            prog.update()

        prog.close()
        return df.mean(axis=1)


    def concurrent_users_profile(self, t: int | None = None) -> pd.DataFrame:
        if t is None:
            t = self._length_of_schedule

        df = pd.DataFrame(columns=["number of users"], index=range(t))

        task: PacketGenerationTask

        data = [len([task for task in self.taskset.lexicographical_order() if task.initialisation_time <= s < task.expiry_time]) for s in range(t)]

        df["number of users"] = data

        # for node in self.all_end_nodes:
        #
        #     df[node] = [len([task for task in self.taskset.lexicographical_order() if task.initialisation_time <= s < task.expiry_time and node in task.all_end_nodes]) for s in range(t)]

        return df




    def demand_schedule(self, demands = None, min_time: int = 0, max_time= None):

        if demands is None:
            demands = self.taskset.ids


        if max_time is None:
            max_time = self.length_of_schedule

        demand_schedule = self.get_schedule_by_tasks(demands)

        out_list = []

        for demand in demands:
                out_list.append(
                    _schedule_line_printer(f"{demand: >{max([len(n) for n in demands])}} ", demand_schedule[demand],
                                           self.length_of_schedule, node_schedule=False, min_time=min_time,
                                           max_time=max_time, include_release_deadlines=True,
                                           release_times=self.task_dict[demand].release_times,
                                           deadlines=self.task_dict[demand].deadlines,
                                           execution_time=self.task_dict[demand].execution_time,
                                           expiry_time=self.task_dict[demand].expiry_time))

        label_length = max([len(n) for n in demands]) if demands else 1

        out_list.append(get_timeline(self.length_of_schedule,label_length, min_time=min_time, max_time=min(max_time, self._length_of_schedule)))

        out_list = [out_list[j][k] for k in range(len(out_list[0])) for j in range(len(out_list))]

        activity_indicators = ['-', '|', f'\u2195', '\u2191','\u2193', '\u21E1']

        out_list = [row for row in out_list if any(i in row for i in activity_indicators)]

        if max_time is not None:
            if max_time < self.length_of_schedule:
                out_list.append(f"\n ... schedule continues for a further {self.length_of_schedule - max_time} time steps\n")


        out_list.append(f"[release_deadline]\u2195[/release_deadline] = Release time and deadline | "
                       + f"[deadline]\u2191[/deadline] = Completion Deadline | "
                       + f"[release_time]\u2193[/release_time] = Release time | "
                       + f"[start_deadline]\u21E1[/start_deadline] = Start time deadline | "
                       + f"[expired]/[/expired] = demand expired\n")


        return '\n'.join(out_list)

def schedule_metric(NS: NetworkSchedule):
    return sum(NS.delivered_rates)

def get_utilisation(makespan: int, occupied_times: [tuple]):
    return (sum(x[1] for x in occupied_times) - sum(x[0] for x in occupied_times))/makespan


def get_timeline(length_of_schedule,
                 label_width,
                 min_time: int =0,
                 max_time: int =None):

    if max_time is None:
        max_time = length_of_schedule

    output = []

    width_of_slot = yaml.safe_load(open("Quantum_Network_Architecture/display_config/displayConfig.yaml", 'r')).get("size of time slot")
    line_width = yaml.safe_load(open("Quantum_Network_Architecture/display_config/displayConfig.yaml", 'r')).get("line width")

    time_stamps = ''.join(['|' + f"{i: ^{width_of_slot - 1}}" for i in range(0, max_time)])

    time_stamp_lead = ' ' * (label_width + 1)

    time_stamps = time_stamps[min_time*width_of_slot:max_time*width_of_slot]

    for i in range(len(time_stamps) // (line_width * width_of_slot)):
        output.append(time_stamp_lead + time_stamps[i * line_width * width_of_slot:(i + 1) * line_width * width_of_slot] + f'|\n')

    if max_time % line_width != 0:
        output.append(time_stamp_lead + time_stamps[len(time_stamps) // (
                line_width * width_of_slot) * line_width * width_of_slot:] + f'|')

    return output

def _schedule_line_printer(
        label: str, active_times: [tuple],
        length_of_network_schedule,
        node_schedule: bool = True,
        min_time: int = 0,
        max_time: int = None,
        include_release_deadlines: bool = False,
        release_times: [int] = None,
        deadlines: [int] = None,
        execution_time: int = None,
        expiry_time: int = None,
        start_time: int = 0
        ) -> [str]:
    width_of_slot = yaml.safe_load(open("Quantum_Network_Architecture/display_config/displayConfig.yaml", 'r')).get("size of time slot")
    line_width = yaml.safe_load(open("Quantum_Network_Architecture/display_config/displayConfig.yaml", 'r')).get("line width")
    time_stringed_to = start_time
    temp = []

    # expiry_time = int(expiry_time)

    if max_time is None:
        max_time = start_time + length_of_network_schedule

    max_time = min(start_time + length_of_network_schedule, max_time)

    min_time = max(start_time, min_time)

    # colour_label = type(colour) is str

    # if colour_label:
    #     colour = {'': colour}

    for AT in active_times:
        if len(AT) == 3:
            at_label = AT[2]
        else:
            at_label = label.replace(' ','').replace('-',',')


        temp += [' '* width_of_slot] * (AT[0] - time_stringed_to)

        temp_pga = '[' + f"{at_label:-^{(AT[1] - AT[0]) * width_of_slot - 2}}" + ']'

        """
        Each PGA is represented by [----X----]. Next line slices it up into time slots, 
        adds formatting and then adds it to the temp list to be split into lines and then returned.   
        """

        temp += [f"[x{at_label.replace('-','')}]" +
                 (temp_pga[i*width_of_slot:(i+1)*width_of_slot] if not all(b in temp_pga[i*width_of_slot:(i+1)*width_of_slot] for b in ["[","]"]) else f"\\"+temp_pga[i*width_of_slot:(i+1)*width_of_slot])   # If a task has execution time 1, then escape the block from being interpreted as rich markup
                 + f"[/x{at_label.replace('-','')}]" for i in range(len(temp_pga) // width_of_slot)]

        time_stringed_to = AT[1]
    if expiry_time is None:
        temp += [' ' * width_of_slot] * (max_time - time_stringed_to)
    else:
        temp += [' ' * width_of_slot] * (expiry_time - time_stringed_to)
        temp += ['[expired]' + '/'*width_of_slot + '[/expired]'] * (max_time - expiry_time+1)
    output = []

    if include_release_deadlines:
        for t in range(min_time, min(max_time, (expiry_time if expiry_time is not None else 0))):

            if temp[t] == ' '*width_of_slot:
                if t in deadlines and t in release_times:
                    temp[t] = (f'[release_deadline]\u2195[/release_deadline]' + temp[t])[:-1] #Adds an updown arrow for a coincident release time and deadline.
                elif t in deadlines:
                    temp[t] = (f"[deadline]\u2191[/deadline]" + temp[t])[:-1] #Adds an up arrow for a deadline
                elif t in release_times:
                    temp[t] = (f"[release_time]\u2193[/release_time]" + temp[t])[:-1] #Adds a down arrow for a release time
                elif execution_time is not None:
                    if t+execution_time in deadlines:
                        temp[t] = (f"[start_deadline]\u21E1[/start_deadline]" + temp[t])[:-1] #adds an up arrow for start time deadlines

            elif temp[t] == '[expired]' + ' '*width_of_slot + '[/expired]':
                if t in deadlines and t in release_times:
                    temp[t] = f'[release_deadline]\u2195[/release_deadline][expired]' + '/'*(width_of_slot-1) + '[/expired]' #Adds an updown arrow for a coincident release time and deadline.
                elif t in deadlines:
                    temp[t] = f"[deadline]\u2191[/deadline][expired]" + '/'*(width_of_slot-1) + '[/expired]' #Adds an up arrow for a deadline
                elif t in release_times:
                    temp[t] = f"[release_time]\u2193[/release_time][expired]" + '/'*(width_of_slot-1) + '[/expired]' #Adds a down arrow for a release time
                elif execution_time is not None:
                    if t+execution_time in deadlines:
                        temp[t] = f"[start_deadline]\u21E1[/start_deadline][expired]" + '/'*(width_of_slot-1) + '[/expired]' #adds an up arrow for start time deadlines

            else:
                if t in deadlines and t in release_times:
                    temp[t] = f"[release_deadline]\u2195[/release_deadline]" + ''.join(temp[t].split('-', 1))
                elif t in deadlines:
                    temp[t] = f"[deadline]\u2191[/deadline]" + ''.join(temp[t].split('-', 1))
                elif t in release_times:
                    temp[t] = f"[release_time]\u2193[/release_time]" + ''.join(temp[t].split('-', 1))
                elif execution_time is not None:
                    if t + execution_time in deadlines:
                        temp[t] = f"[start_deadline]\u21E1[/start_deadline]" + ''.join(temp[t].split('-', 1))


    temp = temp[min_time:max_time]

    #Slices into lines

    for i in range(len(temp) // line_width):
        output.append((label if node_schedule else f'[x{label.replace(" ","").replace("-","")}_label]' + label.replace('-',',') + f'[/x{label.replace(" ","").replace("-","")}_label]') + ''.join(temp[i * line_width:(i+1)*line_width]))

    if len(temp) % line_width != 0:
        output.append(label + ''.join(temp[len(temp) // line_width * line_width:]))

    if include_release_deadlines:

        test_sequence = (max_time in deadlines, max_time in release_times, execution_time is not None)

        # Checks to see if there should be a final arrow

        # TODO: Make this cleaner - match statement still not great, maybe some clever f-string???

        if max_time in deadlines and max_time in release_times:
            output[-1] += f"[release_deadline]\u2195[/release_deadline]"

        elif max_time in deadlines and max_time not in release_times:
            output[-1] += f"[deadline]\u2191[/deadline]"

        elif max_time not in deadlines and max_time in release_times:
            output[-1] += f"[release_time]\u2193[/release_time]"

        elif execution_time is not None:
            if max_time + execution_time in deadlines:
                output[-1] += f"[start_deadline]\u21E1[/start_deadline]"
        else:
            pass

    return output

#Pretty-prints a table of tasks
def loaded_tasks_printer(tasks:[PacketGenerationTask]) -> Table:

    task_table =Table(title="Loaded Tasks", expand=True, min_width=200)

    task_table.add_column("Id", justify="right", min_width=6)
    task_table.add_column("Node1", justify='center')
    task_table.add_column("Node2", justify='center')
    task_table.add_column("Execution Time", justify='center')
    task_table.add_column("minsep", justify='center')
    task_table.add_column("maxsep", justify='center')
    task_table.add_column("rate", justify='center')
    task_table.add_column("period", justify='center')
    task_table.add_column("rsf", justify='center')

    for t in tasks:
        t: PacketGenerationTask = t
        task_table.add_row(
            t.identifier,
            t.node1,
            t.node2,
            str(t.execution_time),
            str(t.minsep),
            str(t.maxsep),
            f'{t.rate: .2}',
            str(t.period),
            f'{t.rate_scaling: .2f}'
        )

    # c = Console()
    #
    # c.print(task_table)

    return task_table

if __name__ == '__main__':
    pass