from circuit_simulation.circuit_simulator import *
from tqdm import tqdm
PBAR: tqdm = None


def create_quantum_circuit(protocol, pbar, **kwargs):
    """
        Initialises a QuantumCircuit object corresponding to the protocol requested.

        Parameters
        ----------
        protocol : str
            Name of the protocol for which the QuantumCircuit object should be initialised

        For other parameters, please see QuantumCircuit class for more information

    """
    global PBAR
    PBAR = pbar

    if protocol == 'monolithic':
        kwargs.pop('basis_transformation_noise')
        kwargs.pop('no_single_qubit_error')
        qc = QuantumCircuit(9, 2, basis_transformation_noise=True, no_single_qubit_error=False, **kwargs)
        qc.define_node("A", qubits=[1, 3, 5, 7, 0], amount_data_qubits=4)
        qc.define_sub_circuit("A")

    elif protocol in ['plain', 'plain_swap', 'modicum', 'modicum_swap']:
        qc = QuantumCircuit(16, 2, **kwargs)

        qc.define_node("A", qubits=[14, 7, 6])
        qc.define_node("B", qubits=[12, 5, 4])
        qc.define_node("C", qubits=[10, 3, 2])
        qc.define_node("D", qubits=[8, 1, 0])

    elif protocol in ['expedient', 'expedient_swap']:
        qc = QuantumCircuit(20, 2, **kwargs)

        qc.define_node("A", qubits=[18, 11, 10, 9])
        qc.define_node("B", qubits=[16, 8, 7, 6])
        qc.define_node("C", qubits=[14, 5, 4, 3])
        qc.define_node("D", qubits=[12, 2, 1, 0])

    # Common sub circuit defining handled here
    if protocol in ['plain', 'plain_swap', 'modicum', 'modicum_swap', 'expedient', 'expedient_swap']:
        qc.define_sub_circuit("ABCD")
        qc.define_sub_circuit("AB")
        qc.define_sub_circuit("CD", concurrent_sub_circuits="AB")
        qc.define_sub_circuit("AC")
        if 'plain' not in protocol:
            qc.define_sub_circuit("BD", concurrent_sub_circuits="AC")
        qc.define_sub_circuit("A")
        qc.define_sub_circuit("B")
        qc.define_sub_circuit("C")
        qc.define_sub_circuit("D", concurrent_sub_circuits=["A", "B", "C"])

    return qc


def monolithic(qc: QuantumCircuit, *, operation):
    qc.set_qubit_states({0: ket_p})
    qc.stabilizer_measurement(operation, nodes=["A"])

    PBAR.update(90) if PBAR is not None else None


def plain(qc: QuantumCircuit, *, operation):
    qc.start_sub_circuit("AB")
    qc.create_bell_pair(7, 5)
    qc.start_sub_circuit("CD")
    qc.create_bell_pair(3, 1)
    qc.start_sub_circuit("AC")
    success = qc.single_selection(operation, 6, 2)
    if not success:
        qc.start_sub_circuit("A")
        qc.X(7)
        qc.start_sub_circuit("B")
        qc.X("B-e")

    qc.stabilizer_measurement(operation, nodes=["B", "A", "D", "C"])

    PBAR.update(90) if PBAR is not None else None


def plain_swap(qc: QuantumCircuit, *, operation):
    qc.start_sub_circuit("AB")
    qc.create_bell_pair("B-e", "A-e")
    qc.SWAP("A-e", "A-e+1")
    qc.start_sub_circuit("CD")
    qc.create_bell_pair("D-e", "C-e")
    qc.SWAP("C-e", "C-e+1")
    qc.start_sub_circuit("AC")
    success = qc.single_selection(operation, "C-e", "A-e", swap=True)
    if not success:
        qc.start_sub_circuit("A")
        qc.X("A-e+1")
        qc.start_sub_circuit("B")
        qc.X("B-e")

    qc.stabilizer_measurement(operation, nodes=["C", "D", "A", "B"], swap=True)

    PBAR.update(90) if PBAR is not None else None


def modicum(qc: QuantumCircuit, *, operation):

    ghz_success = False
    while not ghz_success:
        PBAR.reset() if PBAR is not None else None
        qc.start_sub_circuit("AC")
        qc.create_bell_pair("A-e", "C-e")

        qc.start_sub_circuit("BD")
        qc.create_bell_pair("D-e", "B-e")

        PBAR.update(40) if PBAR is not None else None

        qc.start_sub_circuit("AB")
        qc.create_bell_pair("B-e+1", "A-e+1")
        qc.apply_gate(CNOT_gate, cqubit="A-e", tqubit="A-e+1", reverse=True)
        qc.apply_gate(CNOT_gate, cqubit="B-e", tqubit="B-e+1", reverse=True)
        meas_parity_ab = qc.measure(["A-e+1", "B-e+1"], basis="Z")

        qc.start_sub_circuit("CD")
        qc.create_bell_pair("C-e+1", "D-e+1")
        meas_parity_cd = qc.single_selection(CZ_gate, "C-e+1", "D-e+1", "C-e", "D-e", create_bell_pair=False)

        qc.start_sub_circuit("ABCD", forced_level=True)
        ghz_success = (len(set(meas_parity_ab)) == 1) == meas_parity_cd if not qc.cut_off_time_reached else True
        if not ghz_success:
            continue
        if not meas_parity_cd:
            qc.X("B-e")
            qc.X("D-e")

        PBAR.update(40) if PBAR is not None else None

    qc.stabilizer_measurement(operation, nodes=["C", "A", "B", "D"], swap=True)

    PBAR.update(20) if PBAR is not None else None


def modicum_swap(qc: QuantumCircuit, *, operation):
    ghz_success = False
    while not ghz_success:
        PBAR.reset() if PBAR is not None else None
        qc.start_sub_circuit("AC")
        qc.create_bell_pair("A-e", "C-e")
        qc.SWAP("A-e", "A-e+1", efficient=True)
        qc.SWAP("C-e", "C-e+1", efficient=True)

        qc.start_sub_circuit("BD")
        qc.create_bell_pair("D-e", "B-e")
        qc.SWAP("B-e", "B-e+1", efficient=True)
        qc.SWAP("D-e", "D-e+1", efficient=True)

        PBAR.update(40) if PBAR is not None else None

        qc.start_sub_circuit("AB")
        qc.create_bell_pair("B-e", "A-e")

        # Option I
        qc.H("A-e")
        qc.H("B-e")
        qc.apply_gate(CZ_gate, cqubit="A-e", tqubit="A-e+1")
        qc.apply_gate(CZ_gate, cqubit="B-e", tqubit="B-e+1")
        qc.H("A-e")
        qc.H("B-e")
        # # Option II
        # qc.apply_gate(CNOT_gate, cqubit="A-e+1", tqubit="A-e", electron_is_target=True, reverse=True)
        # qc.apply_gate(CNOT_gate, cqubit="B-e+1", tqubit="B-e", electron_is_target=True, reverse=True)

        meas_parity_ab = qc.measure(["A-e", "B-e"], basis="Z")

        qc.start_sub_circuit("CD")
        qc.create_bell_pair("C-e", "D-e")
        meas_parity_cd = qc.single_selection(CZ_gate, "C-e", "D-e", "C-e+1", "D-e+1", create_bell_pair=False)

        qc.start_sub_circuit("ABCD", forced_level=True)
        ghz_success = (len(set(meas_parity_ab)) == 1) == meas_parity_cd if not qc.cut_off_time_reached else True
        if not ghz_success:
            continue
        if not meas_parity_cd:
            qc.X("B-e+1")
            qc.X("D-e+1")

        PBAR.update(40) if PBAR is not None else None

    qc.stabilizer_measurement(operation, nodes=["C", "A", "B", "D"], swap=True)

    PBAR.update(20) if PBAR is not None else None


def expedient(qc: QuantumCircuit, *, operation):
    ghz_success = False
    while not ghz_success:
        PBAR.reset() if PBAR is not None else None

        # Step 1-2 from Table D.1 (Thesis Naomi Nickerson)
        qc.start_sub_circuit("AB")
        success_ab = False
        while not success_ab:
            qc.create_bell_pair(11, 8)
            success_ab = qc.double_selection(CZ_gate, 10, 7)
            if not success_ab:
                continue
            success_ab = qc.double_selection(CNOT_gate, 10, 7)

        PBAR.update(20) if PBAR is not None else None

        # Step 1-2 from Table D.1 (Thesis Naomi Nickerson)
        qc.start_sub_circuit("CD")
        success_cd = False
        while not success_cd:
            qc.create_bell_pair(5, 2)
            success_cd = qc.double_selection(CZ_gate, 4, 1)
            if not success_cd:
                continue
            success_cd = qc.double_selection(CNOT_gate, 4, 1)

        PBAR.update(20) if PBAR is not None else None

        # Step 3-5 from Table D.1 (Thesis Naomi Nickerson)
        qc.start_sub_circuit('AC')
        outcome_ac = qc.single_dot(CZ_gate, "A-e+1", "C-e+1")
        qc.start_sub_circuit('BD')
        outcome_bd = qc.single_dot(CZ_gate, "B-e+1", "D-e+1")
        qc.start_sub_circuit("ABCD")
        ghz_success = outcome_bd == outcome_ac
        if not ghz_success:
            continue
        if not outcome_bd:
            qc.X("A-e+1")
            qc.X("B-e+1")

        PBAR.update(20) if PBAR is not None else None

        # Step 6-8 from Table D.1 (Thesis Naomi Nickerson)
        qc.start_sub_circuit("AC", forced_level=True)
        ghz_success_1 = qc.single_dot(CZ_gate, 10, 4)
        qc.start_sub_circuit("BD")
        ghz_success_2 = qc.single_dot(CZ_gate, 7, 1)
        if any([not ghz_success_1, not ghz_success_2]):
            ghz_success = False
            continue

        PBAR.update(20) if PBAR is not None else None

    # Step 9 from Table D.1 (Thesis Naomi Nickerson)
    # ORDER IS ON PURPOSE: EVERYTIME THE TOP QUBIT IS MEASURED, WHICH DECREASES RUNTIME SIGNIFICANTLY
    qc.stabilizer_measurement(operation, nodes=["B", "A", "D", "C"])

    PBAR.update(10) if PBAR is not None else None


def expedient_swap(qc: QuantumCircuit, *, operation, tqubit=None):
    ghz_success = False
    while not ghz_success:
        PBAR.reset() if PBAR is not None else None

        qc.start_sub_circuit("AB")
        success_ab = False
        while not success_ab:
            qc.create_bell_pair("A-e", "B-e")
            qc.SWAP("A-e", "A-e+1", efficient=True)
            qc.SWAP("B-e", "B-e+1", efficient=True)
            success_ab = qc.double_selection(CZ_gate, "A-e", "B-e", swap=True)
            if not success_ab:
                continue
            success_ab = qc.double_selection(CNOT_gate, "A-e", "B-e", swap=True)

        PBAR.update(20) if PBAR is not None else None

        qc.start_sub_circuit("CD")
        success_cd = False
        while not success_cd:
            qc.create_bell_pair("C-e", "D-e")
            qc.SWAP("C-e", "C-e+1", efficient=True)
            qc.SWAP("D-e", "D-e+1", efficient=True)
            success_cd = qc.double_selection(CZ_gate, "C-e", "D-e", swap=True)
            if not success_cd:
                continue
            success_cd = qc.double_selection(CNOT_gate, "C-e", "D-e", swap=True)

        PBAR.update(20) if PBAR is not None else None

        qc.start_sub_circuit('AC')
        outcome_ac = qc.single_dot(CZ_gate, "A-e", "C-e", swap=True)
        qc.start_sub_circuit('BD')
        outcome_bd = qc.single_dot(CZ_gate, "B-e", "D-e", swap=True)
        qc.start_sub_circuit("ABCD", forced_level=True)
        ghz_success = outcome_bd == outcome_ac if not qc.cut_off_time_reached else True
        if not ghz_success:
            continue
        if not outcome_bd:
            qc.X("A-e+1")
            qc.X("B-e+1")

        PBAR.update(20) if PBAR is not None else None

        qc.start_sub_circuit('AC', forced_level=True)
        ghz_success_1 = qc.single_dot(CZ_gate, "A-e", "C-e", swap=True)
        qc.start_sub_circuit("BD")
        ghz_success_2 = qc.single_dot(CZ_gate, "B-e", "D-e", swap=True)
        if any([not ghz_success_1, not ghz_success_2]):
            ghz_success = False
            continue

    PBAR.update(20) if PBAR is not None else None

    # ORDER IS ON PURPOSE: EVERYTIME THE TOP QUBIT IS MEASURED, WHICH DECREASES RUNTIME SIGNIFICANTLY
    qc.stabilizer_measurement(operation, nodes=["B", "A", "D", "C"], swap=True, tqubit=tqubit)

    PBAR.update(10) if PBAR is not None else None
