import numpy as np
import logging


__all__ = ['InvalidTransectError', 'check']


class TransectError(Exception): pass


class InvalidTransectError(TransectError):

    def __init__(self, pi, ti):
        self.pi, self.ti = pi, ti
        super(InvalidTransectError, self).__init__('invalid transect\n  p={}\n  t={}'.format(pi, ti))


class EmptyTransectError(TransectError):

    def __init__(self, msg='transect has no peaks or troughs'):
        super(EmptyTransectError, self).__init__(msg)


def check(pi, ti, correct=True, raise_error=False, return_peak_mask=False):
    try:
        # check peaks and troughs to be alternating, starting with a trough
        if pi.shape[0] == 0 or ti.shape[0] <= 1:
            raise EmptyTransectError()

        try:
            m = (pi > ti[0]) & (pi < ti[-1])
            if not correct:
                assert m.all()
            # pi = pi[(pi > ti[0])]
            # pi = pi[(pi < ti[-1])]
            pi = pi[m]
            assert pi.size == ti.size-1
            assert (pi > ti[:-1]).all()
            assert (pi < ti[1:]).all()
        except AssertionError:
            raise InvalidTransectError(pi, ti)

        success = True

    except TransectError:
        if raise_error:
            raise
        success = False
        m = np.zeros(pi.size, dtype=bool)
        pi, ti = [], []

    out = [pi, ti]
    if not raise_error:
        out.insert(0, success)

    if return_peak_mask:
        out.append(m)
    return tuple(out)


def correct(pi, ti):
    try:
        grouped = np.digitize(pi, ti)
        m = ~np.in1d(grouped, [0, ti.size])
        pi = pi[m]
        grouped = grouped[m] - 1
        uq, idx = np.unique(grouped, return_index=True)

        pi_out = pi[idx]
        ti_out = ti[np.pad(uq + 1, (1, 0), mode='constant')]
    except IndexError:
        logging.getLogger(__name__).error('p {}'.format(pi))
        logging.getLogger(__name__).error('t {}'.format(ti))
        raise

    return pi_out, ti_out




