from __future__ import division, print_function

import logging
import time
from collections import namedtuple
import datetime
import numpy as np

from ..static import transects


AVG_YEAR_LENGTH = 365.2425  # https://en.wikipedia.org/wiki/Year


def growth(rotX, rotY, rotZ1, pm1, tm1, rotZ2, pm2, tm2, rotT1, rotT2, inspect=0, inspect_indices=()):
    """
    Calculate growth from two datasets
    >>>d = sandwaves.analysis.direction(x, y, Z)
    >>>rotX, rotY, rotZ1, pm1, tm1 = sandwaves.analysis.extremes2D(x, y, Z1, d=d)
    >>>_, _, rotZ2, pm2, tm2 = sandwaves.analysis.extremes2D(x, y, Z2, d=d)
    >>>x, y, dh, t1, t2 = growth(rotX, rotY, rotZ1, pm1, tm1, rotZ2, pm2, tm2, rotT1, rotT2)

    :param rotX: rotated X positions
    :param rotY: rotated Y positions
    :param rotZ1: rotated bedlevels dataset 1
    :param pm1: mask of crest positions on rotated dataset 1
    :param tm1: mask of trough positions on rotated dataset 1
    :param rotZ2: rotated bedlevels dataset 2
    :param pm2: mask of crest positions on rotated dataset 2
    :param tm2: mask of trough positions on rotated dataset 2
    :param timespan: datetime.timedelta object that defines the time between the datasets or
                     numpy.ndarray of shape equal to data (unit=days)
    :return: x positions , y positions, heights first timestep, heights second timestep, first timestep, second timestep
    """

    # # convert days value to array
    # dt_days = _timespan_as_days_array(timespan, rotX.shape)

    # prepare output
    Result = namedtuple('SandwaveGrowthResult', ['x', 'y', 'h1', 'h2', 't1', 't2'])
    xpos = []
    ypos = []
    h1s = []
    h2s = []
    t1s = []
    t2s = []

    # mask missing extremes in either dataset
    nanmask = np.logical_or(rotZ1.mask, rotZ2.mask)
    pm1[nanmask] = False
    tm1[nanmask] = False
    pm2[nanmask] = False
    tm2[nanmask] = False

    indices = np.arange(rotZ1.shape[0])

    if not inspect_indices:
        inspect_indices = (rotZ1.shape[1] // 2,)

    for i in range(rotZ1.shape[1]):
        if inspect > 0 and i in inspect_indices:
            tr_inspect = int(inspect) - 1
            plot_selected_transect(rotX, rotY, rotZ2, i)
        else:
            tr_inspect = 0

        I, h1, h2 = transect_growth(
            rotZ1[:, i],
            rotZ2[:, i],
            indices[pm1[:, i]],
            indices[tm1[:, i]],
            indices[pm2[:, i]],
            indices[tm2[:, i]],
            inspect=tr_inspect)

        xpos.append(rotX[I, i])
        ypos.append(rotY[I, i])
        h1s.append(h1)
        h2s.append(h2)
        t1s.append(rotT1[I, i])
        t2s.append(rotT2[I, i])

    return Result(np.concatenate(xpos),
                  np.concatenate(ypos),
                  np.concatenate(h1s),
                  np.concatenate(h2s),
                  np.concatenate(t1s),
                  np.concatenate(t2s))


def transect_growth(z1, z2, pi1, ti1, pi2, ti2, ref='first', inspect=0):
    """
    calculate increase in value for a transect
    :param z1: first value
    :param z2: second value
    :param pi1: indices of crests first timestep
    :param ti1: indices of troughs first timestep
    :param pi2: indices of crests second timestep
    :param ti2: indices of troughs second timestep
    :param ref: use locations of first|second crests
    :param inspect: integer indicating number of levels to inspect
    :return: indices of locations, change in values
    """

    try:
        pi1, pi2, ti1, ti2 = _check_transects(pi1, pi2, ti1, ti2)
    except transects.TransectError as e:
        return [], [], []

    try:
        _pi1, _pi2, _ti1, _ti2 = link(pi1, pi2, ti1, ti2, inspect=int(inspect) - 1)
    except transects.EmptyTransectError:
        return [], [], []
    except transects.InvalidTransectError:
        logging.getLogger(__name__).exception('invalid transect')
        return [], [], []

    if ref == 'first':
        I = _pi1
    elif ref == 'second':
        I = _pi2
    else:
        raise KeyError(ref)

    h1 = z1[_pi1] - 0.5 * (z1[_ti1][1:] + z1[_ti1][:-1])
    h2 = z2[_pi2] - 0.5 * (z2[_ti2][1:] + z2[_ti2][:-1])

    if inspect > 0:
        plot_tr_growth(z1, z2, _pi1, _ti1, _pi2, _ti2, I, h2 - h1)

    return I, h1, h2


def link(pi1, pi2, ti1, ti2, inspect=0, attempts=0):
    """
    link crests and troughs from two timesteps
    returns linked crests and troughs so that:
      - the number of crests is equal for both timesteps
      - the number of troughs is one larger than the number of crests
      - crests and troughs alternate

    :param pi1: integer positions of crests first timestep
    :param pi2: integer positions of crests second timestep
    :param ti1: integer positions of troughs first timestep
    :param ti2: integer positions of troughs second timestep
    :param inspect: show the linking step as a matplotlib figure if inspect > 0
    :param attempts: number of previous attempts for valid output
    """

    # link crests two ways
    lpi1a, lpi2a = _link(pi1, pi2, ti2)
    lpi2b, lpi1b = _link(pi2, pi1, ti1)

    # link troughs two ways
    lti1a, lti2a = _link(ti1, ti2, pi2)
    lti2b, lti1b = _link(ti2, ti1, pi1)

    # exclude points that are not linked in both
    # first to second timestep amd second to first timestep
    mp = np.in1d(lpi1a, lpi1b) & np.in1d(lpi2a, lpi2b)
    mt = np.in1d(lti1a, lti1b) & np.in1d(lti2a, lti2b)

    # 1. exclude troughs belonging to missing crests
    # 2. exclude crests belonging to missing troughs
    N = min(mt.size, mp.size)  # number of overlapping points
    mp[:N] = mt[:N] = mp[:N] & mt[:N]

    if not mp.any() or not mt.any():
        return tuple(map(lambda x: np.array(x, dtype=int), [[], [], [], []]))

    # mask missing points
    lpi1 = lpi1a[mp]
    lpi2 = lpi2a[mp]
    lti1 = lti1a[mt]
    lti2 = lti2a[mt]

    lpi1, lti1 = transects.correct(lpi1, lti1)
    lpi2, lti2 = transects.correct(lpi2, lti2)

    # require equal lengths for t=1 and t=2
    # otherwise, re-evaluate up to 4 times
    if (lpi1.size != lpi2.size) or (lti1.size != lti2.size):
        if attempts > 4:
            return tuple(map(lambda x: np.array(x, dtype=int), [[], [], [], []]))
        return link(lpi1, lpi2, lti1, lti2, attempts=attempts + 1)

    if inspect > 0:
        plot_link(pi1, ti1, pi2, ti2, lpi1, lti1, lpi2, lti2)

    return lpi1, lpi2, lti1, lti2


def _link(a, b, splits):
    """link values using nearest point in between the splits"""
    if b.size == splits.size - 1:
        splits = splits[1:-1]

    if not all([a.size, b.size, splits.size]):
        return [], []

    if b.size - 1 != splits.size:
        raise ValueError('length of array not 1 less than length of splits')

    # find points between splits
    links = np.digitize(a, splits)

    # count the link occurences
    unique, counts = np.unique(links, return_counts=True)

    # create a mask of valid links
    mask = np.in1d(links, unique[counts == 1])
    for i in unique[counts > 1]:
        indices = np.arange(links.size)[links == i]  # indices of current duplicate set
        dists = np.absolute(a[indices] - b[i])
        mask[indices[np.argmin(dists)]] = True

    la, lb = a[mask], b[links[mask]]
    return la, lb


def _timespan_as_days_array(timespan, shape):
    # convert input to days
    if isinstance(timespan, (int, float)):
        dt_days = timespan
    elif isinstance(timespan, datetime.timedelta):
        dt_days = timespan.days
    elif isinstance(timespan, np.ndarray):
        dt_days = timespan
    else:
        raise TypeError('invalid timespan, expected int, float or datetime.timedelta or ndarray')

    # create array from numeric
    if isinstance(dt_days, (int, float)):
        dt_days = np.ones(shape) * dt_days

    # validate
    if not isinstance(dt_days, np.ndarray):
        raise TypeError('cannot create timespan array from input')
    if dt_days.shape != shape:
        raise ValueError(
            'invalid timespan, cannot create array of shape {0} from {1.__class__} {1.shape}'.format(shape, dt_days))

    return dt_days


def _check_transects(pi1, pi2, ti1, ti2):
    pi1, ti1 = transects.check(pi1, ti1, raise_error=True)
    pi2, ti2 = transects.check(pi2, ti2, raise_error=True)
    return pi1, pi2, ti1, ti2


def plot_selected_transect(rotX, rotY, rotZ, i):
    from matplotlib import pyplot as plt
    plt.figure()
    ax = plt.gca()
    ax.pcolor(np.ma.masked_invalid(rotX),
              np.ma.masked_invalid(rotY),
              rotZ, cmap='viridis')
    ax.plot(rotX[:, i], rotY[:, i], 'r')


def plot_link(pi1, ti1, pi2, ti2, lpi1, lti1, lpi2, lti2):
    from matplotlib import pyplot as plt
    kw = dict(ms=10)
    plt.figure(figsize=(16, 6))
    ax = plt.subplot(211)
    ax.plot(pi1, np.ones(pi1.size), 'r.', **kw)
    ax.plot(ti1, np.ones(ti1.size) - .1, 'b.', **kw)
    ax.plot(pi2, np.zeros(pi2.size), 'r.', **kw)
    ax.plot(ti2, np.zeros(ti2.size) - .1, 'b.', **kw)
    ax.set_yticks([])
    ax.set_ylim(-1, 2)
    xlim = ax.get_xlim()
    ax.text(xlim[0] + 3, 0, 't=2')
    ax.text(xlim[0] + 3, 1, 't=1')
    ax.set_title('original')

    ax = plt.subplot(212)
    ax.plot(lpi1, np.ones(lpi1.size), 'r.', **kw)
    ax.plot(lti1, np.ones(lti1.size) - .1, 'b.', **kw)
    ax.plot(lpi2, np.zeros(lpi2.size), 'r.', **kw)
    ax.plot(lti2, np.zeros(lti2.size) - .1, 'b.', **kw)
    ax.set_yticks([])
    ax.set_ylim(-1, 2)
    ax.set_xlim(*xlim)
    ax.text(xlim[0] + 3, 0, 't=2')
    ax.text(xlim[0] + 3, 1, 't=1')
    ax.set_title('linked')


def plot_tr_growth(z1, z2, pi1, ti1, pi2, ti2, I, pdh):
    indices = np.arange(z1.size)
    fig = plt.figure(figsize=(16, 6))
    ax = fig.add_subplot(211)
    ax.plot(indices, z1, 'k-')
    ax.plot(indices[pi1], z1[pi1], 'r.')
    ax.plot(indices[ti1], z1[ti1], 'r.')

    ax.plot(indices, z2, 'g-')
    ax.plot(indices[pi2], z2[pi2], 'b.')
    ax.plot(indices[ti2], z2[ti2], 'b.')
    ax.axis('off')

    y0, y1 = ax.get_yticks().min(), ax.get_yticks().max()
    for t in np.arange(y0, y1, .5):
        ax.axhline(t, color='k', lw=.2, alpha=.5)
    xlim = ax.get_xlim()

    ax = fig.add_subplot(212)
    ax.bar(indices[I], pdh)
    ax.set_xlim(*xlim)

    for t in [-1, -.5, 0, .5, 1]:
        if t == 0:
            lw = 1
        else:
            lw = .2
        ax.axhline(t, color='k', lw=lw, alpha=.5)