import numpy as np
from collections import namedtuple
from ..utils import moving_avg


__all__ = [
    'extremes',
    'transect_extremes']


def extremes(rotX, rotY, rotZ, minheight, window, step=1):
    """
    Loop over the 2nd dimension of the dataset and determine crest and trough positions
    for each transect using `transect_extremes`.
    Return a mask of peaks and troughs for the input datasets
    Point data may be retrieved as such:
    >>>pm, tm = extremes(rotX, rotY, rotZ, minheight=.5, window=3, step=1)
    >>>px, py, pz = rotX[pm], rotY[pm], rotZ[pm]
    >>>tx, ty, tz = rotX[tm], rotY[tm], rotZ[tm]

    :param rotX: 2D array of x positions
    :param rotY: 2D array of y positions
    :param rotZ: 2D array of values
    :param minheight: see transect_extremes
    :param window: see transect_extremes
    :param step: interval for which transects should be evaluated
    :return: peak and trough masks of the same shape as the input variables
    """
    pm = np.zeros(rotZ.shape, dtype=bool)
    tm = np.zeros(rotZ.shape, dtype=bool)

    for i, tr_z in enumerate(rotZ.T):
        if i % step != 0:
            continue

        # calculate crests and troughs along transect
        pmi, tmi = transect_extremes(tr_z, minheight, window)

        # fill result matrix
        pm[pmi,i] = True
        tm[tmi,i] = True

    # return result
    Result = namedtuple('SandwaveExtremesResult', ['peakmask', 'troughmask'])
    return Result(pm, tm)


def transect_extremes(data, minheight, window):
    """
    locate sand wave crests and troughs in several steps:
    1. interpolate gaps in the data (see `interpolate_gaps`)
    2. smooth transect data (see `moving_avg`)
    3. locate extremes in the smoothed data (see `get_extremes`)
    4. locate extremes in original data near the extremes from the smoothed data (see `get_extremes_in_window`)
    5. remove sand waves where crest-trough-difference is smaller than the threshold value (see `filter_by_height`)

    :param data: transect data as depth values (1D array)
    :param minheight: threshold value for difference between crest and adjacent troughs
    :param window: moving average window in number of cells. it is advised to use a window smaller than the
                   minimum sand wave length
    :return: indices of crests in data, indices of troughs in data
    """

    # 1. interpolate gaps
    data = interpolate_gaps(data)

    # 2. smoothing data
    data_avg = moving_avg(data, window=window, same_size=True)

    # 3. detect extremes of averaged data
    extremes, peak_mask = get_extremes(data_avg)
    # prevent detection of all peaks or all troughs
    if sum(peak_mask) in (0, extremes.size):
        return np.array([], dtype=int), np.array([], dtype=int)

    # 4. correct extremes for averaging
    corrected_extremes, corrected_peak_mask = get_extremes_in_window(
        data, extremes, peak_mask, window=window)

    # 5. filter extremes for minimum height
    if minheight > 0:
        filtered_extremes, filtered_peak_mask = filter_by_height(
            data, corrected_extremes, corrected_peak_mask, minheight)
    else:
        filtered_extremes, filtered_peak_mask = corrected_extremes, corrected_peak_mask

    # split peaks and troughs and return
    return filtered_extremes[filtered_peak_mask], filtered_extremes[~filtered_peak_mask]


def interpolate_gaps(data):
    """linear interpolation of gaps in 1D array"""

    assert data.ndim == 1

    # create mask of missing values
    if isinstance(data, np.ma.MaskedArray):
        mask = data.mask
    else:
        mask = np.isnan(data)

    # trim data
    indices = np.arange(data.shape[0])

    # check if gaps are present
    if not mask.all():
        start = mask.argmin()
        end = len(data) -  mask[::-1].argmin()

        # find missing values
        missing_indices = indices[start:end][mask[start:end]]

        # check if values are missing
        if missing_indices.size > 0:
            # fill values by linear interpolation
            data[missing_indices] = np.interp(missing_indices, indices[~mask], data[~mask])

    # return changed data
    return data


def get_extremes(data):
    """
    determine locations of all crests and troughs in the supplied data
    this is calculated as the locations where the direction of the slope changes from
    positive to negative or reversed

    :param data: data that is evaluated
    :return: tuple of crest/trough indices and mask of crests in the indices
    """
    ddx = np.sign(np.diff(np.sign(np.diff(data))))

    if isinstance(ddx, np.ma.MaskedArray):
        ddx = ddx.filled(0)

    ddx = np.lib.pad(ddx, (1, 1), 'constant', constant_values=(0, 0))
    indices = np.arange(data.size)

    extremes = ddx[ddx != 0]
    extr_ind = indices[ddx != 0]

    return extr_ind, extremes == -1


def get_extremes_in_window(data, extremes_avg, peak_mask, window):
    """
    after the extreme values have been selected from the averaged data this function corrects the locations of the
    extremes to the best matching extreme in the real data. This prevents topping off of sand waves as a consequence of
    averaging. The new extreme is selected within the averaging domain. Due to possible overlap between extremes in the
    averaging window the new extremes must be within the averaging window as well as between the midpoints that extreme
    and neighbouring points.

    :param data: transect data to be evaluated
    :param extremes_avg: locations of extremes based on moving average
    :param peak_mask: extremes that are peaks (others are troughs)
    :param window: averaging window used for the moving average
    :return: corrected extremes locations and peakmask
    """
    # calculate distance from midpoint of window
    windowradius = (window-1)//2

    # calculate midpoints
    # TODO midpoints as bending points avg between peaks
    midpoints = np.ceil((extremes_avg[1:]+extremes_avg[:-1])/2).astype(int)

    # minimum index within window after previous midpoint
    min_limits = np.maximum(np.concatenate(([0], midpoints)),
                            extremes_avg-windowradius)
    # maximum index within window before next midpoint
    max_limits = np.minimum(np.concatenate((midpoints, [data.size-1])),
                            extremes_avg+windowradius)

    # extremes
    extremes = np.zeros(extremes_avg.shape, dtype=int)
    # duplicate points mask
    # duplicates may occur when the peaks in the moving average are the result of values outside the midpoint range
    # these points are not valid and are removed
    duplicate_mask = np.zeros(extremes.shape, dtype=bool)

    ranges = []

    # iterate over the average extemes and calculate the corrected value
    for i, e in enumerate(extremes_avg):
        # determine the function to calculate the representative index from the range
        if peak_mask[i]:
            fn = np.argmax
        else:
            fn = np.argmin

        # create an array with options based on the limits
        e_range = np.arange(min_limits[i], max_limits[i]+1, dtype=int)
        ranges.append(e_range)

        # determine the best option
        extremes[i] = e_range[fn(data[e_range])]

        # check duplicates
        if i and extremes[i-1] == extremes[i]:
            duplicate_mask[i-1] = True
            duplicate_mask[i] = True

    # return result without duplicates
    return extremes[~duplicate_mask], peak_mask[~duplicate_mask]


def filter_by_height(data, extremes, peak_mask, minheight):
    """
    Identified extremes are filtered based on a minimum height. This height is defined as the difference in elevation
    between subsequent extremes. Note that this is a different definition from the general sand wave height.
    An iterative process removes the set of 2 points with the smallest elevation difference until all points comply
    with the minimum height. This iteration is required because by removing a set of extremes, surrounding differences
    change.

    :param data: transect data
    :param extremes: location of extremes in the data
    :param peak_mask: mask that diffrentiates peaks and troughs
    :param minheight: minimum height for filtering
    :return: corrected extremes and peakmask
    """

    indices = list(range(extremes.shape[0]))
    values = data[extremes]

    for i in range(extremes.shape[0]//2):
        # get differences
        valid_values = values[indices]
        diff = np.absolute(valid_values[1:] - valid_values[:-1])
        if (diff >= minheight).all():
            break

        # select worst peak-trough combination
        bad_index = np.argmin(diff)

        # remove this combination from the indices
        del indices[bad_index:bad_index+2]

    # return corrected extremes
    return extremes[indices], peak_mask[indices]