"""
the most important sand wave analysis functions are:
1. analysis.orientation      determines the bedform orientation in degrees
2. analysis.extremes2D       identifies crests and troughs
3. analysis.static2D         calculates bedform shape characteristics
4. analysis.growth2D         calculates sand wave growth
5. analysis.migration        calculates sand wave migration
"""
from __future__ import print_function, division
import logging
import time
from functools import wraps, partial
import numpy as np
import textwrap
from collections import namedtuple

from . import transformations
from . import static
from . import dynamic
from . import postprocessing
from ..datasets.griddata import GridDataset, GridTimeStackDataset


__all__ = ['transformations', 'static', 'dynamic']


DEFAULTS = dict(minheight=.8, lmin=100, lmax=1000)


def expose(obj):
    global __all__
    if obj.__name__ in __all__:
        raise ValueError('{} already exists in `__all__`'.format(obj.__name__))
    __all__.append(obj.__name__)
    return obj


def apply_defaults(kwarg, name):
    if kwarg is not None:
        return kwarg
    else:
        return DEFAULTS[name]


def mask2points(X, Y, *masks):
    out = []
    for m in masks:
        out.append(X[m])
        out.append(Y[m])
    return tuple(out)


def get_window(x, y):
    w = int(75 // (x[1] - x[0]))
    if w % 2 == 0:
        w -= 1
    return w


def apply_docstring(docstr, indent=1):
    def decorator(fn):
        fn.__doc__ = '    '*indent+textwrap.dedent(docstr).replace('\n', '\n'+'    '*indent)
        return fn
    return decorator


def allow_grid_input(fn=None, bind=False, attrname=None):
    if fn is None:
        return partial(allow_grid_input, bind=bind, attrname=attrname)

    if not hasattr(fn, '__call__'):
        raise TypeError('function argument not callable')

    docstr_addage = """
    NB: This function also accepts a {grid_cls_name} object from which x, y and Z are extracted.
    The call of:
    >>>{fn_name}(grid, *args, **kwargs)
    subsequently becomes:
    >>>{fn_name}(grid.x, grid.y, grid.Z, *args, **kwargs)
    """.format(grid_cls_name=GridDataset.__module__+'.'+GridDataset.__name__,
               fn_name=fn.__module__+'.'+fn.__name__)

    @wraps(fn)
    def wrapper(*args, **kwargs):
        if len(args) < 1:
            return fn(*args, **kwargs)
        if isinstance(args[0], GridDataset):
            x, y, Z = args[0].get_data()
            return fn(x, y, Z, *args[1:], **kwargs)
        else:
            return fn(*args, **kwargs)
    wrapper.__doc__ += docstr_addage

    if bind:
        setattr(GridDataset, attrname or fn.__name__, wrapper)

    return wrapper


def allow_coupled_grid_input(fn):
    if not hasattr(fn, '__call__'):
        raise TypeError('function argument not callable')

    docstr_addage = """
    NB: This function also accepts two {grid_cls_name} objects from which x, y and Z are extracted.
    Also, the time difference is extracted (grid2.time - grid1.time)
    The call of:
    >>>{fn_name}(grid1, grid2, *args, **kwargs)
    subsequently becomes:
    >>>{fn_name}(grid1.x, grid1.y, grid1.Z, grid2.Z, grid1.time, grid2.time, *args, **kwargs)
    x and y coordinates should match between the grids
    """.format(grid_cls_name=GridDataset.__module__+'.'+GridDataset.__name__,
               fn_name=fn.__module__+'.'+fn.__name__)

    @wraps(fn)
    def wrapper(*args, **kwargs):
        if len(args) < 2:
            return fn(*args, **kwargs)

        if isinstance(args[0], GridDataset) and isinstance(args[1], GridDataset):
            x1, y1, Z1 = args[0].get_data()
            x2, y2, Z2 = args[1].get_data()
            if not (x1 == x2).all() or not (y1 == y2).all():
                raise ValueError('coordinates do not match')
            T1, T2 = args[0].time, args[1].time
            return fn(x1, y1, Z1, Z2, T1, T2, *args[2:], **kwargs)
        else:
            return fn(*args, **kwargs)
    wrapper.__doc__ += docstr_addage

    return wrapper


def force_gridstack_input(fn=None, bind=False, attrname=None):
    if fn is None:
        return partial(force_gridstack_input, bind=bind, attrname=attrname)

    if not hasattr(fn, '__call__'):
        raise TypeError('function argument not callable')

    docstr_addage = """
    NB: This function also accepts x, y, t, Z which are converted to a {grid_cls_name} object.
    The call of:
    >>>{fn_name}(x, y, t, Z, *args, **kwargs)
    subsequently becomes:
    >>>{fn_name}(gridtimestack, *args, **kwargs)
    """.format(grid_cls_name=GridTimeStackDataset.__module__+'.'+GridTimeStackDataset.__name__,
               fn_name=fn.__module__+'.'+fn.__name__)

    @wraps(fn)
    def wrapper(*args, **kwargs):
        if len(args) < 1:
            return fn(*args, **kwargs)
        gs, args = args[0], args[1:]
        if not isinstance(gs, GridTimeStackDataset):
            x, y, t, Z, args = gs, args[0], args[1], args[2], args[3:]
            gs = GridTimeStackDataset(x, y, t, Z)
        return fn(gs, *args, **kwargs)
    wrapper.__doc__ += docstr_addage

    if bind:
        setattr(GridTimeStackDataset, attrname or fn.__name__, wrapper)

    return wrapper


@expose
@allow_grid_input(bind=True)
@apply_docstring(transformations.Fourier.transform.__doc__)
def fourier(x, y, Z, **kwargs):
    return transformations.Fourier.transform(x, y, Z, **kwargs)


@expose
@allow_grid_input(bind=True)
def orientation(x, y, Z, lmin=None, lmax=None):
    """
    determine the sand wave orientation using a fourier analysis
    :param x: 1D vector of x coordinates
    :param y: 1D vector of y coordinates
    :param Z: 2D vector of depth values
    :param lmin: minimum sand wave length for orientation analysis
    :param lmax: maximum sand wave length for orientation analysis
    :return: sand wave orientation (deg)
    """
    t = time.time()

    d = static.direction(x, y, Z,
                   lmin=apply_defaults(lmin, 'lmin'),
                   lmax=apply_defaults(lmax, 'lmax'))

    logging.getLogger(__name__).info('direction {:.2f} deg determined in {:.2f}s'.format(d, time.time()-t))

    return d


@expose
def extremes1D(x, y, z, window=3, minheight=None):
    """
    determine sand wave crests and troughs for a transect
    :param x: 1D vector of x coordinates
    :param y: 1D vector of y coordinates
    :param z: 1D vector of depth values
    :param minheight: minimum sand wave height
    :param window: smoothing window size (cells) for crest detection
    :return: mask of peak and trough positions
    """
    from .static.extremes import transect_extremes

    t = time.time()
    pm = np.zeros(x.shape, dtype=bool)
    tm = np.zeros(x.shape, dtype=bool)

    pmi, tmi = transect_extremes(z,
                               minheight=apply_defaults(minheight, 'minheight'),
                               window=window)
    pm[pmi] = True
    tm[tmi] = True
    logging.getLogger(__name__).info('peaks and troughs detected in {:.2f}s'.format(time.time()-t))

    return pm, tm


@expose
@allow_grid_input(bind=True, attrname='find_extremes')
def extremes2D(x, y, Z, minheight=None, window=None, d=None, lmin=None, lmax=None, step=1):
    """
    determine sand wave crests and troughs for a 2D dataset
    :param x: 1D vector of x coordinates
    :param y: 1D vector of y coordinates
    :param Z: 2D vector of depth values
    :param minheight: minimum sand wave height
    :param window: smoothing window size (cells) for crest detection
    :param d: sand wave direction (see direction())
    :param lmin: minimum sand wave length for orientation analysis
    :param lmax: maximum sand wave length for orientation analysis
    :param step: transect interval
    :return: x positions (rotated 2D), y positions (rotated 2D), values (rotated 2D), mask of peaks and mask of troughs
    """
    from .static.extremes import extremes
    from .rotate import rotate

    if d is None:
        d = orientation(x, y, Z, lmin=lmin, lmax=lmax)

    t = time.time()
    rotX, rotY, rotZ = rotate(x, y, Z, d)
    logging.getLogger(__name__).debug('data rotated in {:.2f}s'.format(time.time()-t))

    if window is None:
        window = get_window(x, y)

    t = time.time()
    pm, tm = extremes(rotX, rotY, rotZ,
                      minheight=apply_defaults(minheight, 'minheight'),
                      window=window,
                      step=step)
    logging.getLogger(__name__).info('peaks and troughs detected in {:.2f}s'.format(time.time()-t))

    return rotX, rotY, rotZ, pm, tm


@expose
def peak_uncertainty(rotX, rotY, rotZ, pm, tm):
    """
    return error estimates of peak locations from input as obtained from sandwaves.analysis.extremes2D
    :param rotX: rotated X values
    :param rotY: rotated Y values
    :param rotZ: rotated depth values
    :param pm: mask of peak positions
    :param tm: mask of trough positions
    :return: xerr_crest, yerr_crest, xerr_trough, yerr_trough
    """
    return postprocessing.peak_uncertainty(rotX, rotY, rotZ, pm, tm)


@expose
def shape1D(x, y, z, pm=None, tm=None, window=3, **kwargs):
    """
    calculate shape characteristics as points of 1D dataset
    :param x: 1D vector of x coordinates
    :param y: 1D vector of y coordinates
    :param z: 1D vector of depth values
    :param pm: mask of crest locations for 1D coordinate matrices
    :param tm: mask of trough locations for 1D coordinate matrices
    :param window: smoothing window size for extremes detection (see extremes1D)
    :param kwargs: additional keyword arguments (see extremes1D)
    :return: px (N,1),
             py (N,1),
             height (N,1),
             length (N,1),
             asymmetry (N,1),
             depth (N,1)
    """
    from .static.shape import transect_shape
    from .static import transects
    if pm is None and tm is None:
        pm, tm = extremes1D(x, y, z, window=window, **kwargs)

    ind = np.arange(pm.size)
    pi, ti = ind[pm], ind[tm]
    tr_is_valid, pi, ti = transects.check(pi, ti)
    if not tr_is_valid:
        raise transects.InvalidTransectError(pi, ti)

    t = time.time()
    s = transect_shape(x, y, z, pi=pi, ti=ti)
    logging.getLogger(__name__).info('shape characteristics (1D) calculated in {:.2f}s'.format(time.time()-t))
    return s


@expose
@allow_grid_input(bind=True, attrname='calculate_shape')
def shape2D(X, Y, Z, pm=None, tm=None, **kwargs):
    """
    calculate shape characteristics as points of 2D dataset
    :param X: 1d vector or rotated 2D matrix of x coordinates
    :param Y: 1d vector or rotated 2D matrix of y coordinates
    :param Z: 2D matrix or rotated 2D matrix of depth values
    :param pm: mask of crest locations for 2D coordinate matrices
    :param tm: mask of trough locations for 2D coordinate matrices
    :param kwargs: additional keyword arguments (see extremes2D)
    :return: px (N,1),
             py (N,1),
             height (N,1),
             length (N,1),
             asymmetry (N,1),
             depth (N,1)
    """
    from .static.shape import shape

    desc = 'xll={} yll={}'.format(X.min(), Y.min())

    if pm is None and tm is None:
        X, Y, Z, pm, tm = extremes2D(X, Y, Z, **kwargs)
    elif X.ndim != 2 or Y.ndim !=2:
        raise ValueError('X and Y matrices must be rotated and have 2 dimensions when masks are included')

    t = time.time()

    try:
        s = shape(X, Y, Z, pm, tm, desc=desc)
    except Exception:
        logging.getLogger(__name__).error('error at {}'.format(desc))
        raise

    logging.getLogger(__name__).info('shape characteristics (2D) calculated in {:.2f}s'.format(time.time()-t))
    return s


@expose
def static1D(x, y, z, *args, **kwargs):
    """
    static analysis of 1D dataset which returns shape characteristics (see also shape1D)
    :param x: 1D vector of x coordinates
    :param y: 1D vector of y coordinates
    :param z: 1D vector of depth values
    :param args: additional arguments (see shape1D)
    :param kwargs: additional keyword arguments (see shape1D)
    :return: result from shape1D
    """
    t = time.time()
    logger = logging.getLogger(__name__)
    logger.info('static analysis (1D) shape={}'.format(z.shape))
    s = shape1D(x, y, z, *args, **kwargs)
    logger.info('static analysis completed in {}'.format(time.time()-t))
    return s


@expose
@allow_grid_input(bind=True, attrname='analyse_static')
def static2D(x, y, Z, *args, **kwargs):
    """
    static analysis of 2D dataset which returns shape characteristics (see also shape2D)
    :param x: 1D vector of x coordinates
    :param y: 1D vector of y coordinates
    :param Z: 2D matrix of depth values
    :param args: additional arguments (see shape2D)
    :param kwargs: additional keyword arguments (see shape2D)
    :return: result from shape2D
    """
    t = time.time()
    logger = logging.getLogger(__name__)
    logger.info('static analysis (2D) started; shape={}'.format(Z.shape))
    s = shape2D(x, y, Z, *args, **kwargs)
    logger.info('static analysis completed in {}'.format(time.time()-t))
    return s


def _make_2D_time_array(T, shape):
    if isinstance(T, (int, float)):
        dt = np.ones(shape) * T
    if not isinstance(T, np.ndarray) or T.shape != shape:
        raise TypeError('invalid time')
    return T


def iter_sets(GS, indices=None, logger=None, **kw):
    """define iteration of layers  !NB most recent first"""
    it = iter(GS.compress_timesteps(**kw))
    layer1 = next(it).recent()
    for i, layer0 in enumerate(map(lambda x: x.recent(), it)):
        if indices is not None and i not in indices:
            pass
        elif (layer0.mask | layer1.mask).all():
            if logger:
                logger.warning('no overlap between datasets at iteration {}'.format(i))
        else:
            yield layer0, layer1
        layer1 = layer0


@expose
@allow_coupled_grid_input
def growth_duogrid(x, y, Z1, Z2, T1, T2, d=None, debug=0, lmin=None, lmax=None, **kwargs):
    """
    calculate growth rate from a dataset pair
    :param x: 1D vector of x coordinates
    :param y: 1D vector of y coordinates
    :param Z1: 2D matrix of depth values for dataset 1
    :param Z2: 2D matrix of depth values for dataset 2
    :param T1: timestep for dataset 1
    :param T2: timestep for dataset 2
    :param d: sand wave orientation (deg)
    :param debug: bool|int that defines the debug level
    :param lmin: see orientation()
    :param lmax: see orientation()
    :param kwargs: additional keyword arguments (see extremes2D)
    :return: namedtuple as returned from dynamic.growth
    """
    from .rotate import rotate2D

    logger = logging.getLogger(__name__)

    if d is None:
        d = orientation(x, y, Z2, lmin=lmin, lmax=lmax)

    rotX, rotY, rotZ1, pm1, tm1 = extremes2D(x, y, Z1, d=d, **kwargs)
    _,    _,    rotZ2, pm2, tm2 = extremes2D(x, y, Z2, d=d, **kwargs)

    rotT1 = rotate2D(_make_2D_time_array(T1, Z1.shape), d)
    rotT2 = rotate2D(_make_2D_time_array(T2, Z2.shape), d)

    logger.debug('determining sand wave growth for layer set')

    return dynamic.growth(rotX, rotY, rotZ1, pm1, tm1, rotZ2, pm2, tm2, rotT1, rotT2, inspect=int(debug) - 1)


@expose
def steepening(gridstack, min_coverage=.1, **kwargs):
    """

    :param gridstack:
    :param min_coverage:
    :param kwargs:
    :return:
    """
    t = []
    asymmetries = []
    N = []
    for G in gridstack.iter_grids(area_min_ratio=min_coverage):
        x, y, height, length, asymmetry, basedepth = shape2D(G, **kwargs)
        t.append(G.time)
        asymmetries.append(asymmetry.mean())
        N.append(height.size)
    return np.array(t), np.array(asymmetries), np.array(N)


@expose
@force_gridstack_input(bind=True, attrname='analyse_growth')
def growth2D(gridstack, lmin=None, lmax=None, min_coverage=.3, dt_max=365*2, overlap_max=.1, indices=None, debug=0, **kwargs):
    """
    calculate growth from GridTimeStackDataset
    calculates a single orientation, extracts crests and troughs and determines the sand wave height change per point
    :param gridstack: GridTimeStackDataset instance
    :param lmin: minimum wavelength (see orientation)
    :param lmax: maximum wavelength (see orientation)
    :param min_coverage: see GridTimeStackDataset.compress_timesteps
    :param dt_max: see GridTimeStackDataset.compress_timesteps
    :param overlap_max: see GridTimeStackDataset.compress_timesteps
    :param indices: iteration indices to execute (all if None)
    :param debug: bool|int that defines the debug level
    :param kwargs: additional options passed to extremes2D
    :return: iterator with x positions, y positions, first height, second height, first timestep, second timestep
    """

    from .rotate import rotate2D

    t = time.time()
    logger = logging.getLogger(__name__)
    logger.info('dynamic growth analysis (2D) started; shape={}'.format(gridstack.Z.shape))

    # direction of the sand waves
    d = orientation(gridstack.recent(),
                    lmin=lmin,
                    lmax=lmax)

    # loop over layers and yield output per step
    for g1, g2 in iter_sets(gridstack, min_coverage=min_coverage, dt_max=dt_max, overlap_max=overlap_max, indices=indices):
        if debug > 0:
            import matplotlib.pyplot as plt
            fig, axes = plt.subplots(2, 2, figsize=(10, 8))
            g1.plot(axes[0, 0]), g2.plot(axes[0, 1])
            ax = axes[1, 0]
            ax.set_title(g1.time.mean())
            c = GridDataset(g1.x, g1.y, g1.time).plot(ax)
            fig.colorbar(c, ax=ax)
            ax = axes[1, 1]
            ax.set_title(g2.time.mean())
            c = GridDataset(g2.x, g1.y, g2.time).plot(axes[1, 1])
            fig.colorbar(c, ax=ax)

        yield growth_duogrid(g1, g2, d=d, debug=int(debug)-1, **kwargs)

    logger.info('dynamic growth analysis completed in {:.2f}s'.format(time.time()-t))


@expose
@force_gridstack_input
def migration(gridstack, radius=100, points='c', lmin=None, lmax=None,
              min_coverage=.3, dt_max=365*2, overlap_max=.1, indices=None, **kwargs):
    """
    calculate sand wave migration on specified points
    :param gridstack: GridTimeStackDataset instance
    :param radius: search radius for migration vectors
    :param points: migration at crests (c), troughs (t) or both (ct)
    :param lmin: minimum wavelength for orientation
    :param lmax: maximum wavelength for orientation
    :param min_coverage: see GridTimeStackDataset.compress_timesteps
    :param dt_max: see GridTimeStackDataset.compress_timesteps
    :param overlap_max: see GridTimeStackDataset.compress_timesteps
    :param indices: iteration indices to execute (all if None)
    :param kwargs: additional keyword arguments for migration_duogrid
    :return: generator of namedtuples (x, y, mx, my, t1, t2)
    """

    from .rotate import rotate2D

    t = time.time()
    logger = logging.getLogger(__name__)
    logger.info('dynamic migration analysis (2D) started; shape={}'.format(gridstack.Z.shape))

    recent = gridstack.recent()

    d = orientation(recent.x, recent.y, recent.Z, lmin=lmin, lmax=lmax)
    Result = namedtuple('Result', ['x', 'y', 'mx', 'my', 't1', 't2'])

    for g1, g2 in iter_sets(gridstack, min_coverage=min_coverage, dt_max=dt_max, overlap_max=overlap_max, indices=indices):
        rotX, rotY, rotZ, pm, tm = extremes2D(g2, d=d)
        rotT1 = rotate2D(_make_2D_time_array(g1.time, g1.Z.shape), d)
        rotT2 = rotate2D(_make_2D_time_array(g2.time, g2.Z.shape), d)
        if points == 'c':
            m = pm
        elif points == 't':
            m = tm
        elif points == 'ct':
            m = pm | tm
        yield migration_duogrid(g1.x, g1.y, g1.Z, g2.Z,
                                px=rotX[m], py=rotY[m],
                                pt1=rotT1[m], pt2=rotT2[m],
                                radius=radius, resultcls=Result, **kwargs)
    logger.info('dynamic migration analysis completed in {:.2f}s'.format(time.time() - t))


@expose
def migration_duogrid(x, y, Z1, Z2, px, py, pt1, pt2, radius=100, resultcls=tuple, **kwargs):
    """
    calculate sand wave migration on specified points
    :param x:
    :param y:
    :param Z1:
    :param Z2:
    :param px:
    :param py:
    :param pt1:
    :param pt2:
    :param radius:
    :param kwargs:
    :return:
    """

    logger = logging.getLogger(__name__)

    # return migrations
    t = time.time()
    mx, my = dynamic.migration(x, y, Z1, Z2, px, py, radius=radius, **kwargs)
    logger.info('migration points={} radius={}m completed in {:.2f}s'.format(px.size, radius, time.time()-t))
    logger.debug('average migration {:.2f}m'.format(float(np.mean((mx**2+my**2)**.5))))

    return resultcls(px, py, mx, my, pt1, pt2)
