from __future__ import annotations

from Quantum_Network_Architecture.schedules import NetworkSchedule
from Quantum_Network_Architecture.networks import Network
# from Quantum_Network_Architecture.export.to_qoala import export_network_schedule
from Quantum_Network_Architecture.tasks import PacketGenerationTask

from typing import List, Dict, TypedDict, Optional

from copy import deepcopy



from qoala.lang.ehi import UnitModule
from qoala.sim.build import build_network_from_config_netschedule
from qoala.lang.parse import QoalaParser
from qoala.lang.program import QoalaProgram
from qoala.runtime.program import BatchResult, ProgramInput, BatchInfo, ProgramBatch
from qoala.runtime.statistics import SchedulerStatistics

from Quantum_Network_Architecture.simulation import (
    QKD_ALICE_PROGRAM_100_BITS,
    QKD_BOB_PROGRAM_100_BITS,
    QKD_ALICE_PROGRAM_SINGLE_ROUND,
    QKD_BOB_PROGRAM_SINGLE_ROUND,
    BQC_SERVER_PROGRAM_100_ROUNDS,
    BQC_CLIENT_PROGRAM_100_ROUNDS,
    BQC_CLIENT_PROGRAM_SINGLE_ROUND,
    BQC_SERVER_PROGRAM_SINGLE_ROUND,
)

import netsquid as ns
import random

from dataclasses import dataclass

@dataclass
class AppResult:
    batch_results: Dict[str, BatchResult]
    statistics: Dict[str, SchedulerStatistics]
    total_duration: float

def load_program(path: str) -> QoalaProgram:
    with open(path) as file:
        text = file.read()
    return QoalaParser(text).parse()

def create_batch(
    program: QoalaProgram,
    unit_module: UnitModule,
    inputs: List[ProgramInput],
    num_iterations: int,
) -> BatchInfo:
    return BatchInfo(
        program=program,
        unit_module=unit_module,
        inputs=inputs,
        num_iterations=num_iterations,
        deadline=0,
    )

def run_network_schedule(
        schedule:NetworkSchedule,
        network:Network,
        seed: int | None = None,
        program_dictionary = None,
        max_time: int | None = None,
        max_instances: Optional[int] = None,
    ):

    programs: TypedDict[str, TypedDict[str, QoalaProgram]] = {
        "QKD": {
            "alice": QKD_ALICE_PROGRAM_SINGLE_ROUND,
            "bob": QKD_BOB_PROGRAM_SINGLE_ROUND,
        },
        "BQC": {
            "client": BQC_CLIENT_PROGRAM_SINGLE_ROUND,
            "server": BQC_SERVER_PROGRAM_SINGLE_ROUND,
        },
        "BQC-100": {
            "client": BQC_CLIENT_PROGRAM_100_ROUNDS,
            "server": BQC_SERVER_PROGRAM_100_ROUNDS,
        }
    } if program_dictionary is None else program_dictionary

    ns.sim_reset()

    ns.set_qstate_formalism(ns.QFormalism.DM)
    new_seed = random.randint(0, 1000)
    ns.set_random_state(seed=new_seed if seed is None else seed)

    qoala_network = build_network_from_config_netschedule(network.network_config, export_network_schedule(schedule, network.timeslot_duration))

    batches: Dict[str, Dict[str, ProgramBatch]] = {}

    for task in schedule.taskset.initialisation_order():
        batches[task.identifier] = {}
        task: PacketGenerationTask
        procnode1 = qoala_network.nodes[task.node1.alias]
        procnode2 = qoala_network.nodes[task.node2.alias]


        # TODO: replace back to match statement

        if task.demand_class == "QKD":

            # For QKD, task.node1 plays the role of Alice, and task.node2 plays the roll of Bob.

            alice_input = ProgramInput({"bob_id": task.node2.id})
            bob_input = ProgramInput({"alice_id": task.node1.id})

            alice_program: QoalaProgram = deepcopy(programs["QKD"]["alice"])
            bob_program: QoalaProgram = deepcopy(programs["QKD"]["bob"])

            # Set the correct name for the destination of the communication sockets for these applications
            alice_program.meta.name = task.node1.alias
            alice_program.meta.epr_sockets[0] = task.node2.alias
            alice_program.meta.csockets[0] = task.node2.alias

            bob_program.meta.name = task.node2.alias
            bob_program.meta.epr_sockets[0] = task.node1.alias
            bob_program.meta.csockets[0] = task.node1.alias


            pass

        elif task.demand_class == "BQC" or task.demand_class == "BQC-100":
            # Node1 is the client (Alice) and node2 is the server (Bob).
            # BQC only has one round, BQC-[N] has N rounds

            alice_input = ProgramInput({"server_id": task.node2.id, "alpha":8, "beta":8, "theta1":0, "theta2":0})
            bob_input = ProgramInput({"client_id": task.node1.id})

            bob_program: QoalaProgram = deepcopy(programs[task.demand_class]["server"])
            alice_program: QoalaProgram = deepcopy(programs[task.demand_class]["client"])

            alice_program.meta.epr_sockets[0] = task.node2.alias
            alice_program.meta.csockets[0] = task.node2.alias

            bob_program.meta.epr_sockets[0] = task.node1.alias
            bob_program.meta.csockets[0] = task.node1.alias



        else:
            raise NotImplementedError

        max_number_of_instances = task.max_number_of_instances_in_session if max_instances is None else max_instances

        unit_module1 = UnitModule.from_full_ehi(procnode1.memmgr.get_ehi())
        batch_info1 = create_batch(alice_program, unit_module1, [alice_input] * max_number_of_instances, max_number_of_instances)
        batches[task.identifier]["alice"] = procnode1.submit_batch(batch_info1)

        unit_module2 = UnitModule.from_full_ehi(procnode2.memmgr.get_ehi())
        batch_info2 = create_batch(bob_program, unit_module2, [bob_input] * max_number_of_instances, max_number_of_instances)
        batches[task.identifier]["bob"] = procnode2.submit_batch(batch_info2)


    for task in schedule.taskset:
        assert (task.session_ids[task.node1.id] == batches[task.identifier]["alice"].batch_id and task.session_ids[task.node2.id] == batches[task.identifier]["bob"].batch_id)


    for node in network.end_nodes:
        procnode = qoala_network.nodes[node.alias]

        tasks_on_node = schedule.taskset.tasks_on_node(node)

        remote_pids = {
            (
                batches[task.identifier]["alice"].batch_id if task.node1 == node else batches[task.identifier]["bob"].batch_id): [
                        instance.pid for instance in \
                            (batches[task.identifier]["bob"].instances if task.node1 == node else batches[task.identifier]["alice"].instances)
                    ] for task in tasks_on_node
        }
        remote_batches = {
            (batches[task.identifier]["alice"].batch_id if task.node1 == node else batches[task.identifier]["bob"].batch_id):(batches[task.identifier]["bob"].batch_id if task.node1 == node else batches[task.identifier]["alice"].batch_id)
             for task in tasks_on_node
        }

        procnode.initialize_processes(remote_pids=remote_pids, remote_batches=remote_batches, linear=True)


    qoala_network.start()
    ns.sim_run(end_time=max_time)

    results = {}
    statistics = {}




    for node in network.end_nodes.node_aliases:
        procnode = qoala_network.nodes[node]
        results[node] = procnode.scheduler.get_batch_results()
        statistics[node] = procnode.scheduler.get_statistics()


    return AppResult(results, statistics, ns.sim_time())




























    pass

