from __future__ import annotations

import copy
from math import inf, lcm, ceil

from rich.table import Table

import pandas as pd
import numpy as np

from Quantum_Network_Architecture.exceptions import InitialisationError
from Quantum_Network_Architecture.networks import Network, HomogeneousStarNetwork
from Quantum_Network_Architecture.networks.nodes import NodeList, Node

from statistics import variance

from typing import List, Dict, Union, Callable, Optional

from datetime import timedelta

from dataclasses import dataclass


def edf_priority(task: PacketGenerationTask, t: float):
    # assert t <= task.next_deadline
    return task.next_deadline - t


class PacketGenerationTask:

    def __init__(self, name,
                 end_nodes: NodeList,
                 execution_time: int,
                 rate: float,
                 number_to_schedule: int | None = None,
                 minsep: int | None = None,
                 maxsep: int | None = None,
                 demand_class: str = "N/A",
                 expiry_time: int | None = None,
                 relative_expiry: bool = False,
                 creation_time: float = 0.,
                 links: List[str] | None = None,
                 session_ids: Dict[int, int] | None = None,  # Node_id -> session_ids
                 max_number_of_instances_in_session: int | None = None,
                 priority_function: Callable[[PacketGenerationTask, float], float] = edf_priority,
                 accepted_qos_option: str = 'min'
                 ):

        self._dropped_packets: int = 0
        self._identifier = name
        self._all_end_nodes: NodeList = end_nodes
        self.execution_time = execution_time
        self.number_to_schedule = number_to_schedule if number_to_schedule is not None else -1
        self._accepted_qos_option = accepted_qos_option

        if links is None:
            self._links: [str] = [f"{self._all_end_nodes[0].id}--{self._all_end_nodes[1].id}"]
        else:
            self._links = links

        self._rate = rate

        self.rate_scaling: float = 1.

        if minsep is not None:
            self.minsep = minsep
        else:
            self.minsep = 0

        if maxsep is not None:
            self.maxsep = maxsep
        else:
            self.maxsep = inf

        # self.start_times = []
        # self.end_times = []

        self.release_times = [0]
        self.deadlines = [self.period]

        self._next_release = 0
        self._next_deadline = self.period

        self._demand_class = demand_class

        self._relative_expiry_time: bool = False

        self._initialisation_time: int | None = None

        self._creation_time: float = creation_time  # time the task was created in seconds.

        self._expiry_time: None | int = expiry_time

        self._relative_expiry_time: bool = relative_expiry

        self._did_schedule_in_period: [bool] = []

        self._session_ids = session_ids

        self._max_number_of_instances_in_session = max_number_of_instances_in_session if max_number_of_instances_in_session is not None else 1

        self._priority_function: Callable[[PacketGenerationTask], float] = priority_function

        self._pgas = PGAList()

        self._is_terminated: bool = False

    @property
    def is_terminated(self):
        return self._is_terminated

    @property
    def utilisation(self):
        return self.execution_time / self.period

    @property
    def start_times(self):
        return self._pgas.start_times

    @property
    def end_times(self):
        return self._pgas.end_times

    @property
    def accepted_qos_option(self):
        return self._accepted_qos_option

    @classmethod
    def from_yaml(cls, name: str, details: dict, creation_time: int = 0):

        if "expiry time" in details.keys():

            if (type(details["expiry time"]) is dict,
                type(details["expiry time"]) is int or type(details["expiry time"]) is np.int64) == (True, False):

                _expiry_time: int | None = details["expiry time"]["value"] if 'value' in details[
                    "expiry time"].keys() else None
                _relative_expiry_time: bool = details["expiry time"]["relative"] if 'relative' in details[
                    "expiry time"].keys() else False

            elif (type(details["expiry time"]) is dict,
                  type(details["expiry time"]) is int or type(details["expiry time"]) is np.int64) == (False, True):

                _expiry_time: int | None = details["expiry time"] if "expiry time" in details.keys() else None
                _relative_expiry_time: bool = False

            else:
                _expiry_time: int | None = None
                _relative_expiry_time: bool = False
        else:
            _expiry_time = None
            _relative_expiry_time = False

        return cls(name=name,
                   end_nodes=NodeList([Node(node_id=details["node1"]), Node(node_id=details["node2"])]),
                   execution_time=details["execution time"],
                   rate=eval(details["rate"]) if type(details["rate"]) is str else details["rate"],
                   number_to_schedule=details["number to schedule"],
                   minsep=details["minimum separation"],
                   maxsep=details["maximum separation"],
                   demand_class=details["demand class"] if "demand class" in details.keys() else "N/A",
                   expiry_time=_expiry_time,
                   relative_expiry=_relative_expiry_time,
                   creation_time=creation_time
                   )

    # TODO: Reorganise these properties into a better order for readability

    @property
    def max_pga_rate(self):
        return 1 / (self.execution_time + self.minsep)

    @property
    def proportion_of_max_rate_requested(self):
        return self.requested_rate / self.max_pga_rate

    def is_active_at(self, t: int) -> bool:

        if not self.start_times:
            return False

        if t > self.end_times[-1]:
            return False

        if self.start_times[-1] <= t < self.end_times[-1]:
            return True

        left, right = 0, len(self.start_times) - 1
        result = -1

        while left <= right:
            mid = (left + right) // 2

            if self.start_times[mid] <= t:
                result = mid
                left = mid + 1
            else:
                right = mid - 1

        return self.start_times[result] <= t < self.end_times[result]

    @property
    def demand_source(self):
        return self.identifier.rsplit('-', 1)[0]

    @property
    def demand_instance(self):
        return int(self.identifier.rsplit('-', 1)[-1])

    @property
    def max_number_of_instances_in_session(self):
        return self._max_number_of_instances_in_session

    @property
    def session_tuple(self):

        ascending_nodes = self.all_end_nodes.node_ids
        ascending_nodes.sort()

        return tuple(
            x for n in ascending_nodes for x in [n, self.session_ids[n]])  # (node_id1, pid1, node_id2, pid2, ... )

    @property
    def session_ids(self):
        if self._session_ids is not None:
            return self._session_ids
        else:
            raise InitialisationError(self.identifier)

    @property
    def links(self):
        return self._links

    @property
    def number_scheduled(self):
        return len(self._pgas)


    @property
    def creation_time(self) -> float:
        return self._creation_time


    @property
    def scheduling_delay(self):
        return self.release_times[0] - self._initialisation_time

    @property
    def is_expired(self) -> bool:
        if self._expiry_time is None:
            return False
        else:
            return self.next_release >= self._expiry_time if not self._is_terminated else True

    @property
    def is_initialised(self):
        return self._initialisation_time is not None

    @property
    def node1(self) -> Node:
        return self._all_end_nodes[0]

    @property
    def node2(self) -> Node:
        return self._all_end_nodes[1]

    @property
    def all_end_nodes(self) -> NodeList:
        return self._all_end_nodes

    @property
    def number_of_nodes(self):
        return len(self._all_end_nodes)

    @property
    def demand_class(self):
        return self._demand_class

    @property
    def identifier(self) -> str:
        return self._identifier

    @property
    def rate(self):
        return self._rate * self.rate_scaling

    @property
    def period(self):
        return round(1 / self.rate)

    @property
    def requested_rate(self):
        return self._rate

    @property
    def details(self):
        return f"ID: {self.identifier} | node1: {self.node1.alias} | node2: {self.node2.alias} | E: {self.execution_time}" \
               f"| minsep: {self.minsep} | maxsep: {self.maxsep} | R: {self.rate}"

    @property
    def next_release(self) -> int:
        # return self.release_times[-1]
        return self._next_release

    @next_release.setter
    def next_release(self, value) -> None:
        # self.release_times[-1] = value
        # self.deadlines[-1] = value + self.period

        self._next_release = value
        self._next_deadline = value + self.period
        pass

    @property
    def next_start_deadline(self):
        return self._next_deadline - self.execution_time

    @property
    def next_deadline(self):
        return self._next_deadline

    def next_complete(self, t: Optional[int] = None):
        if t is None:
            return self.end_times[-1] if len(self.end_times) != 0 else -1
        else:
            _filtered_end_times = [T for T in self.end_times if T > t]
            _filtered_end_times.sort()
            return _filtered_end_times[0]

    @property
    def instance(self):
        return len(
            self._pgas) + self._dropped_packets  # Set to start counting from instance 1. Instance count is #scheduled + #dropped

    @property
    def dropped_packets(self):
        return self._dropped_packets

    @property
    def realised_rate(self):
        return self.rate * self.proportion_delivered

    @property
    def proportion_delivered(self) -> float:
        if self.instance != 1:
            return 1 - self._dropped_packets / (self.instance - 1)
        else:
            return 0.

    @property
    def expiry_time(self):
        return self._expiry_time

    @property
    def initialisation_time(self):
        return self._initialisation_time

    @property
    def lifetime(self) -> float:
        if self._expiry_time is None:
            return inf

        if self._initialisation_time is None:
            return self._expiry_time
        else:
            return self._expiry_time - self._initialisation_time

    @property
    def inter_arrival_times(self):
        return [self.start_times[i + 1] - self.start_times[i] for i in range(len(self.start_times) - 1)]

    @property
    def jitter(self):
        if len(self.inter_arrival_times) >= 2:
            return variance(self.inter_arrival_times)
        else:
            return None

    def priority(self, t: float):
        return self._priority_function(self, t)

    @property
    def priority_function(self):
        return self._priority_function

    @priority_function.setter
    def priority_function(self, fcn: Callable):
        self._priority_function = fcn

    def __lt__(self, other):
        return self.next_start_deadline < other.next_start_deadline

    def __str__(self):
        return self.identifier

    def __eq__(self, other):
        return self._identifier == other.identifier

    def __lt_running__(self, other):
        return self.next_complete() < other.next_complete()

    def clear_schedule(self, t: int = 0):

        self._pgas.remove_after(t)
        self._update_release_deadlines()

        assert False not in [x - y == self.execution_time for x, y in zip(self.end_times,
                                                                          self.start_times)]  # makes sure all end times correct amount after corresponding start time

    def rate_to_time(self, t: int | None = None):
        """
        Returns the average rate over the period from initialisation to the specified time
        :param t: Time up to which to sample. If None, then takes the latest end time.
        :return: Estimate of the average rate.
        """
        if t is None:
            t = max(self.end_times)

        packets_by_time = [p for p in self.start_times if p <= t]

        return len(packets_by_time) / (t - self.initialisation_time)

    def instantaneous_rate(self, t: int | None = None) -> float:
        """
        returns an estimation of the current rate based on the last three periods scheduled.
        :param t: Sample time
        :return: estimate of the rate
        """

        if t is None:
            t = max(self.release_times)

        last_period_to_consider = (t - self.initialisation_time) // self.period
        last_period_to_consider_start = last_period_to_consider * self.period + self.initialisation_time

        if any(s in range(last_period_to_consider_start, t + 1) for s in self.start_times):
            last_period_to_consider += 1

        if last_period_to_consider <= 0:
            return 0

        else:
            return sum(self.did_schedule_in_period[max(0, last_period_to_consider - 3):last_period_to_consider]) / (
                    (last_period_to_consider if last_period_to_consider < 3 else 3) * self.period)

    @property
    def did_schedule_in_period(self):
        """
        Returns a list of bools where entry i is whether a packet was scheduled in period i after initialisation.
        :return:
        """
        return self._did_schedule_in_period
        # return [any(
        #     x in range(self.initialisation_time + i * self.period, self.initialisation_time + (i + 1) * self.period) for
        #     x in self.start_times) for i in range((self._expiry_time - self.initialisation_time) // self.period + 1)]

    def instantaneous_rate_satisfaction_profile(self, t: int | None = None) -> [float]:
        return [(t, x / self.rate) for t, x in self.instantaneous_rate_profile(t)]

    def instantaneous_rate_profile(self, t: int | None = None) -> [float]:

        if t is None:
            t = max(self.end_times)

        return [(t, r) for t, r in zip(list(range(self.initialisation_time, min(t, self.expiry_time))),
                                       [self.instantaneous_rate(s) for s in
                                        range(self.initialisation_time, min(t, self.expiry_time))]) if
                self.initialisation_time <= t <= self.expiry_time]

    def drop_packet(self):

        self._dropped_packets += 1
        self._update_release_deadlines(dropped=True)
        self._did_schedule_in_period.append(False)

    def add_to_schedule(self, time: int) -> PacketGenerationAttempt:

        _new_pga = PacketGenerationAttempt(start_time=time, execution_time=self.execution_time,
                                           release_time=self.next_release, deadline=self.next_deadline)

        self._pgas.append(_new_pga)

        # self.start_times.append(time)
        # self.end_times.append(time + self.execution_time)

        self._update_release_deadlines()
        self._did_schedule_in_period.append(True)

        return _new_pga

    def find_proper_link_names(self, network: Network):

        self._links = [network.links.partial_name_dictionary[l] for l in self._links]

        pass

    def _update_release_deadlines(self, dropped: bool = False):
        if self.is_initialised:
            _next_period_start = self.release_times[0] + self.instance * self.period
        else:
            raise InitialisationError(self.identifier)

        if not dropped:
            # self.release_times.append(max(_next_period_start, time + self.execution_time + self.minsep))
            # self.deadlines.append(
            #     min(_next_period_start + self.period, (time + self.execution_time + self.maxsep + self.execution_time)))
            self._next_release = max(_next_period_start, self._pgas[-1].end_time + self.minsep)
            self._next_deadline = min(_next_period_start + self.period,
                                      self._pgas[-1].end_time + self.execution_time + self.maxsep + self.execution_time)
        else:
            # self.release_times.append(_next_period_start)
            # self.deadlines.append(_next_period_start + self.period)
            self._next_release = _next_period_start
            self._next_deadline = _next_period_start + self.period

    def initialise(self, time: int):
        if not self.is_initialised:  # stops tasks being initialised multiple times
            self._initialisation_time = time
            self.release_times[0] = time
            self.deadlines[0] = time + self.period
            self.next_release = time
            if self._relative_expiry_time and self._expiry_time is not None:
                self._expiry_time += self._initialisation_time

            n: Node
            self._session_ids = {n.id: n.next_session_id for n in self._all_end_nodes}

    def deinitialise(self):
        self._initialisation_time = None

    def terminate(self):
        # self._expiry_time = -1  # This will immediately expire a task.
        self._is_terminated = True

    def __hash__(self):
        return hash(self.identifier + str(self.session_ids))


# Modified list class to all the tasks in the environment sorted by rate descending.
class Taskset(List[PacketGenerationTask]):

    def __init__(self, __iterable=None):
        if __iterable is None:
            __iterable = list()
        super().__init__(__iterable)
        self.sort()

    @property
    def average_pga_rate(self):
        return sum(t.rate for t in self) / len(self) if self else 0

    @property
    def average_utilisation(self):
        return sum(t.utilisation for t in self) / len(self) if self else 0

    def append(self, __object: PacketGenerationTask) -> None:
        super().append(__object)
        self.sort()

    def sort(self: list, *, key: None = (lambda x: x.rate), reverse: bool = True) -> None:
        super().sort(key=key, reverse=reverse)

    def lexicographical_order(self):
        return sorted(self, key=lambda x: x.identifier)

    def chronological_order(self):
        return sorted(self, key=lambda x: x.creation_time)

    def initialisation_order(self):

        _list = copy.deepcopy(self)

        correctly_sorted: bool = False
        while not correctly_sorted:
            correctly_sorted = True
            for i in range(len(_list) - 1):
                t: PacketGenerationTask = _list[i]
                s: PacketGenerationTask = _list[i + 1]
                if not all(
                        t.session_ids[p] < s.session_ids[p] for p in t.session_ids.keys() if p in s.session_ids.keys()):
                    correctly_sorted = False
                    _list.insert(i, _list.pop(i + 1))

        return _list

    def initialise_tasks(self, time) -> Union[Taskset, None]:
        t: PacketGenerationTask
        _output = Taskset()
        for t in self:
            if not t.is_initialised:
                t.initialise(time)
                _output.append(t)

        return _output if _output != Taskset() else None

    def set_priority_functions(self, fcn: Callable[[PacketGenerationTask, float], float]):
        for t in self:
            t: PacketGenerationTask
            t.priority_function = fcn

    def get_full_link_ids(self, network: Network):

        t: PacketGenerationTask

        for t in self:
            if any(l not in network.link_ids for l in t.links):
                t.find_proper_link_names(network)

    def remove_expired_tasks(self):
        t: PacketGenerationTask
        tasks_to_remove = [t for t in self if t.is_expired]
        for t in tasks_to_remove:
            self.remove(t)

        return Taskset(tasks_to_remove)

    @property
    def ids(self):
        return [t.identifier for t in self.lexicographical_order()]

    def __eq__(self, other: Taskset):
        if not isinstance(other, Taskset):
            raise TypeError(f"Comparison not supported between types Taskset and {type(other)}")

        return all(t in other for t in self) and all(s in self for s in other)

    def tasks_on_node(self, node: Node):

        if not isinstance(node, Node):
            raise TypeError

        t: PacketGenerationTask

        return [t for t in self.initialisation_order() if node in t.all_end_nodes]

    def query_name(self, identifier: str):
        task: PacketGenerationTask
        for task in self:
            if task.identifier == identifier:
                return task

        raise KeyError(f"{identifier} not in taskset")

    def as_table(self,
                 init_order: bool = False,
                 show_nodes: bool = True,
                 show_execution_time: bool = True,
                 show_minsep: bool = True,
                 show_maxsep: bool = False,
                 show_Rate: bool = True,
                 show_period: bool = True,
                 show_utilisation: bool = False,
                 show_req_rate: bool = False,
                 show_scaling: bool = False,
                 show_dropped_packets: bool = False,
                 show_no_schedule: bool = False,
                 show_prop_delivered: bool = False,
                 show_jitter: bool = False,
                 show_create_time: bool = False,
                 show_init_time: bool = False,
                 show_expiry_time: bool = True,
                 show_pids: bool = True,
                 network_timeslot_length: int = 10000
                 ) -> Table:
        table = Table(title=f"Taskset (Hyperperiod = {lcm(*[t.period for t in self])})")

        header = ["ID", "N1", "PID 1", "N2", "PID 2", "Execution Time", "minsep", "maxsep", "Rate", "Rate (Hz)", "Period",
                  "Utilisation", "Requested Rate",
                  "RRF",
                  "Dropped Packets", "Number Scheduled", "Proportion Delivered", "Jitter", "Creation Time",
                  "Initialisation Time", "Expiry Time"]

        for col, show in zip(header, [True, show_nodes, show_pids, show_nodes, show_pids, show_execution_time,
                                      show_minsep, show_maxsep, show_Rate, show_Rate, show_period,
                                      show_utilisation, show_req_rate, show_scaling, show_dropped_packets,
                                      show_no_schedule, show_prop_delivered, show_jitter, show_create_time,
                                      show_init_time, show_expiry_time]):
            if show:
                table.add_column(col)

        for t in (self.lexicographical_order() if not init_order else self.initialisation_order()):
            row_data = [
                t.identifier,
                t.node1.id if show_nodes else None,
                t.session_ids[t.node1.id] if show_pids else None,
                t.node2.id if show_nodes else None,
                t.session_ids[t.node2.id] if show_pids else None,
                t.execution_time if show_execution_time else None,
                t.minsep if show_minsep else None,
                t.maxsep if show_maxsep else None,
                t.rate if show_Rate else None,
                t.rate / (network_timeslot_length / 1e9) if show_Rate else None,
                t.period if show_period else None,
                t.utilisation if show_utilisation else None,
                t.requested_rate if show_req_rate else None,
                t.rate_scaling if show_scaling else None,
                t.dropped_packets if show_dropped_packets else None,
                t.number_scheduled if show_no_schedule else None,
                t.proportion_delivered if show_prop_delivered else None,
                t.jitter if show_jitter else None,
                timedelta(seconds=t.creation_time) if show_create_time else None,
                timedelta(seconds=t.initialisation_time * network_timeslot_length / 1e9) if show_init_time else None,
                timedelta(seconds=t.expiry_time * network_timeslot_length / 1e9) if show_expiry_time else None,
            ]

            table.add_row(*[str(v) if any(isinstance(v, t) for t in [int, np.int_, timedelta]) or v is None
                else (f"{v:.3g}" if not any(isinstance(v, t) for t in [str, np.str_])
                      else v) for v in row_data if v is not None])

        return table


@dataclass
class PacketGenerationAttempt:
    start_time: int
    execution_time: int
    release_time: int
    deadline: int

    @property
    def end_time(self):
        return self.start_time + self.execution_time

    @property
    def missed_deadline(self):
        return self.end_time >= self.deadline


class PGAList(List[PacketGenerationAttempt]):

    def remove_after(self, t: int):
        for pga in self:
            if pga.start_time >= t:
                self.remove(pga)

    @property
    def start_times(self):
        return [s.start_time for s in self]

    @property
    def end_times(self):
        return [s.end_time for s in self]

    # @property
    # # def active_times(self):
    # #     return [t for pga in self for t in range(pga.start_time, pga.end_time)]
    #
    # pass
