import logging
import numpy as np
from ...transformations import fourier_filter
from .base import WindowTooSmall
from .xcorr_shift import calculate_shift
# from .fourier_shift import calculate_shift  # alternative using fourier phase shifts


__all__ = ['migration', 'migration_at_point', 'preprocess', 'preprocess_dataset']


def preprocess_dataset(x, y, data, direction=None, dir_offset=45, kmin=None, kmax=None):
    if kmin is None:
        kmin = 0
    if kmax is None:
        kmax = np.inf
    data = fourier_filter(x, y, data, kmin=kmin, kmax=kmax, theta=direction, theta_offset=dir_offset)

    dx, dy = np.gradient(data)
    slope_data = (dx*2+dy**2)**.5

    ddx, ddy = np.gradient(slope_data.filled(0))
    slope_slope_data = (dx*2+dy**2)**.5

    return slope_slope_data.filled(0)


def preprocess(x, y, d1, d2, **kwargs):
    """preprocess datasets"""
    # direction = calculate_direction(x, y, d2, lmin=100, lmax=1000)
    direction = None
    d1 = preprocess_dataset(x, y, d1, direction=direction, **kwargs)
    d2 = preprocess_dataset(x, y, d2, direction=direction, **kwargs)
    return d1, d2


def subdivide(x, y, d1, d2, ipx, ipy, radius):
    """get subset at point"""
    xmask = (x > ipx-radius) & (x < ipx+radius)
    ymask = (y > ipy-radius) & (y < ipy+radius)

    assert xmask.any() and ymask.any(), 'point ({}, {}) not within domain'.format(ipx, ipy)

    Xmask, Ymask = np.meshgrid(xmask, ymask)
    submask = (Xmask & Ymask).astype(bool)
    subshape = (np.sum(ymask), np.sum(xmask))
    return d1[submask].reshape(subshape), d2[submask].reshape(subshape)


def migration(x, y, d1, d2, px, py, radius=100, kmin=None, kmax=None):
    """
    calculate bedform migration distance between two datasets
    :param x: x positions as 1D array
    :param y: y positions as 1D array
    :param d1: depth values as 2D array
    :param d2: depth values as 2D array
    :param px: x positions of points at which to calculate migration rates
    :param py: y positions of points at which to calculate migration rates
    :param radius: search radius for migration
    :param kmin: minimum wavenumber to account for
    :param kmax: maximum wavenumber to account for
    :return: masked array of migration in x and y direction (x, y) at points (px, py)
    """
    cellsize = x[1] - x[0]
    d1, d2 = preprocess(x, y, d1, d2, kmin=kmin, kmax=kmax)

    mx, my = np.zeros(px.shape), np.zeros(px.shape)
    for i in range(px.size):
        ipx, ipy = px[i], py[i]

        try:
            migr = migration_at_point(x, y, d1, d2, ipx, ipy, radius)
        except WindowTooSmall as e:
            logging.getLogger(__name__).warning('window too small {}'.format(e))
            migr = np.array([np.nan, np.nan])

        mx[i], my[i] = migr[0]*cellsize, migr[1]*cellsize
        if (i+1) % 500 == 0:
            logging.getLogger(__name__).debug('evaluated {} points for migration'.format(i+1))

    m = (mx**2+my**2)**.5
    mask = m>.5*radius
    if mask.any():
        mx[mask] = np.nan
        my[mask] = np.nan
        logging.getLogger(__name__).warning(
                '{} migration values hidden for being too large relative to radius (>{}m)'.format(
                        mask.sum(), .5*radius))

    logging.getLogger(__name__).info('migration rates calculated for {} points'.format(px.size))
    return np.ma.masked_invalid(mx), np.ma.masked_invalid(my)


def migration_at_point(x, y, d1, d2, ipx, ipy, radius):
    """get migration at point"""
    sd1, sd2 = subdivide(x, y, d1, d2, ipx, ipy, radius=radius)
    migr, _ = calculate_shift(sd2, sd1)
    return migr