"""
Module for reducing xyt datasets to xy datasets by collapsing the time dimension
"""

from __future__ import print_function
import numpy as np
from collections import namedtuple
import logging


__all__ = ['collapse_timeseries']


def collapse_timeseries(A, method='recent'):
    """
    :param A: data array
    :param method: recent|first|second
    :return: a 2D array where the time dimension is collapsed
    """
    index, reverse = _parse_order(method)
    return collapse_array(A, index=index, reverse=reverse)


def _parse_order(method):
    Result = namedtuple('Order', ['index', 'reverse'])

    if method == 'recent':
        return Result(0, True)
    elif method == 'first':
        return Result(0, False)
    elif method == 'second':
        return Result(1, False)
    elif isinstance(method, int):
        if method < 0:
            return Result(-1 - method, True)
        else:
            return Result(method, False)
    else:
        raise ValueError('invalid order {!r}'.format(method))


def collapse_array(A, index=0, reverse=True, covered=True, missing_value=-999):

    if A.ndim < 3:
        return A

    masked = isinstance(A, np.ma.MaskedArray)
    if masked:
        A = A.filled(missing_value)

    fn = select_covered if covered else select_step
    step = 1 if not reverse else -1
    out = fn(A[::step, :, :], index, missing_value=missing_value)

    if masked:
        out = np.ma.MaskedArray(out, mask=out==missing_value)

    return out


def select_covered(A, index, missing_value=-999):
    # matrix in which the amount of values to pass for each position is defined
    # if a cell reaches 0, the cell is filled
    counts = np.ones(A.shape[1:])*(index+1)

    # output matrix
    out = np.ones(A.shape[1:])*missing_value

    # loop over time dimension (1st)
    for t in A:
        mask = (out == missing_value) & (t != missing_value)
        counts[mask] -= 1
        mask = mask & (counts == 0)
        out[mask] = t[mask]

        if not (out == missing_value).any():
            break

    return out


def select_step(A, index, **kwargs):
    return A[index, :, :]


def compress_timesteps(S, dt_max, overlap_max, min_coverage):
    """
    compress timesteps of the data sensibly
    :param S: the stack of datasets
    :param dt_max: maximum time difference between layers for joining
    :param overlap_max: maximum overlap between layers for joining
    :param min_coverage: minimum coverage ratio of each of the compressed output layers
    :return: generator of output stacks (GridTimeStackDataset) with the most recent layer first
    """
    class TimeSet(list):

        def __init__(self, val, idx, mask):
            self.idx = idx
            self.mask = mask.astype(bool)
            super(TimeSet, self).__init__(val)

        def extend(self, other):
            self += other
            self.idx += other.idx
            self.mask |= other.mask

        def overlap(self, other):
            counts = self.mask.astype(int) + other.mask.astype(int)
            n_overlap = np.sum(counts == 2)
            # n_covered = np.sum(counts > 0)
            n_total = counts.size
            return float(n_overlap) / n_total

        def timediff(self, other):
            return max(abs(min(self) - max(other)),
                       abs(max(self) - min(other)))

        def check_join(self, other, dt_max=365, overlap_max=.1):
            if self.timediff(other) > dt_max:
                return False
            if self.overlap(other) > overlap_max:
                return False
            return True

    # sort timeseries
    t = S.t
    sorted_idx = np.argsort(t)[::-1]
    idx = np.arange(t.size)[sorted_idx]

    # make sets
    sets = [TimeSet([t[i]], [i], ~S.Z[i, :, :].mask) for i in idx]

    # join sets when valid
    set_index = 0
    while set_index + 1 < len(sets):
        s0, s1 = sets[set_index], sets[set_index + 1]
        if s0.check_join(s1, dt_max, overlap_max):
            s0.extend(sets.pop(set_index + 1))
        else:
            set_index += 1

    # filter coverage
    for s in sets:
        substack = S[s.idx]
        if substack.recent().coverage >= min_coverage:
            yield substack
