
from scipy.stats import norm

from Quantum_Network_Architecture.utils.utility_functions import quadratic_formula, combinatorial_factor

from math import ceil, floor, log


def find_min_instances_normal_approximation(number_of_instances: int, pga_success_probability: float, failure_probability: float) -> int:
    phi = norm.ppf(failure_probability)
    _k = number_of_instances
    _p = pga_success_probability
    sol1, sol2 = quadratic_formula(_p ** 2, -_p*(2 * _k + phi ** 2 * (1-_p)), _k ** 2)

    N = ceil(sol1) if sol1 > number_of_instances else ceil(sol2)

    return N

def find_min_instances_hoeffding(number_of_instances: int, pga_success_probability: float, failure_probability: float) -> int:
    sol1, sol2 = quadratic_formula(2*pga_success_probability**2, log(failure_probability)-4*pga_success_probability*number_of_instances, 2*number_of_instances**2)

    N = ceil(sol1) if sol1 > number_of_instances else ceil(sol2)

    return N




def naus_approximation(k, m, N, p) -> float:
    """
    Using the approximation from Naus 1982 "Approximations for Distributions of Scan Statistics" as given in Glaz, Naus & Wallenstein's "Scan Statistics" (2001, p. 47) for P(waiting time for k-in-m < N)
    :param k: Number of successes
    :param m: Window Size
    :param N: Number of trials
    :param p: Probability of success
    :return: P[wait time < N]
    """

    def b(k, m, p): return combinatorial_factor(m, k) * p ** k * (1 - p) ** (m - k)

    def f_b(r, s, p): return sum([b(i, s, p) for i in range(r + 1)]) if r >= 0 else 0

    q2 = (f_b(k - 1, m, p)) ** 2 - (k - 1) * b(k, m, p) * f_b(k - 2, m, p) + m * p * b(k, m, p) * f_b(k - 3, m - 1, p)

    a1 = 2 * b(k, m, p) * f_b(k - 1, m, p) * ((k - 1) * f_b(k - 2, m, p) - m * p * f_b(k - 3, m - 1, p))
    a2 = 0.5 * (b(k, m, p)) ** 2 * (
            (k - 1) * (k - 2) * f_b(k - 3, m, p) - 2 * (k - 2) * m * p * f_b(k - 4, m - 1, p) + m * (
                m - 1) * p ** 2 * f_b(k - 5, m - 2, p)
    )
    a3 = sum([b(2 * k - r, m, p) * (f_b(r - 1, m, p)) ** 2 for r in range(1, k)])

    a4 = sum(([b(2 * k - r, m, p) * b(r, m, p) * (
            (r - 1) * f_b(r - 2, m, p) - m * p * f_b(r - 3, m - 1, p)
    ) for r in range(2, k)]))

    q3 = (f_b(k - 1, m, p)) ** 3 - a1 + a2 + a3 - a4

    return 1 - q2 * (q3 / q2) ** ((N / m) - 2)


def calculate_length_of_pga_window(number_of_pairs: int, window: int, entanglement_success_probability: float,
                                   pga_success_probability: float) -> int:
    """
    Calculates the required number of entanglement generation attempts to generate a packet with some probability. If T is the time the packet is generated, then using an approximation for P[T<N] from Naus '82 (see naus_approximation) we use interval bisection to find the required number of time steps.
    :param number_of_pairs: Number of entangled pairs required
    :param window: Window in which these all need to have been generated
    :param entanglement_success_probability: Probability of generating a link in each time step
    :param pga_success_probability: Target probability of successfully generating a packet.
    :return: Length of the PGA
    """

    # print(naus_approximation(number_of_pairs, window, 20,entanglement_success_probability))

    # Interval bisection for determining the length of a PGA...

    lower_bound = window
    upper_bound = 10 * window

    while naus_approximation(number_of_pairs, window, upper_bound,
                             entanglement_success_probability) < pga_success_probability:
        lower_bound = 0 + upper_bound
        upper_bound += 2*lower_bound
        # print(lower_bound, upper_bound)

    while upper_bound - lower_bound > 1:  # Will not get matching so halts when interval of size one.
        midpoint = floor((upper_bound + lower_bound) / 2)

        if naus_approximation(number_of_pairs, window, midpoint,
                              entanglement_success_probability) > pga_success_probability:
            upper_bound = 0 + midpoint
        else:
            lower_bound = 0 + midpoint

    # print(lower_bound, upper_bound)

    return upper_bound  # Return UB so can promise that the prob. is at least as advertised

    pass

