from __future__ import print_function
import logging
from collections import namedtuple
import numpy as np
from . import transects


__all__ = ['shape', 'transect_shape']


def shape(rotX, rotY, rotZ, pm, tm, desc=None):
    Result = namedtuple('SandwaveShapeResult', ['x', 'y', 'height', 'length', 'asymmetry', 'basedepth'])

    x = []
    y = []
    height = []
    asymmetry = []
    basedepth = []
    length = []

    indices = np.arange(0, rotZ.shape[0])
    for i, tr_z in enumerate(rotZ.T):
        pi = indices[pm[:, i]]
        ti = indices[tm[:, i]]

        try:
            tr_is_valid, pi, ti = transects.check(pi, ti)
        except transects.InvalidTransectError:
            logging.getLogger(__name__).error('skipped transect due to invalid peaks and troughs in {}'.format(
                    desc or 'undefined'))
            tr_is_valid = False

        if not tr_is_valid:
            continue

        try:
            tr_x = rotX[:, i]
            tr_y = rotY[:, i]
        except IndexError:
            logging.getLogger(__name__).warning('shapeX {}'.format(rotX.shape))
            logging.getLogger(__name__).warning('shapeY {}'.format(rotY.shape))
            logging.getLogger(__name__).warning('shapeZ {}'.format(rotZ.shape))
            logging.getLogger(__name__).error('array(shape={})[:, {}]'.format(rotX.shape, i))
            raise

        # get shape properties for current transect
        tr_x, tr_y, tr_height, tr_length, tr_asymmetry, tr_basedepth = transect_shape(tr_x, tr_y, tr_z, pi, ti)

        # add transect data to lists
        x.append(tr_x)
        y.append(tr_y)
        height.append(tr_height)
        asymmetry.append(tr_asymmetry)
        basedepth.append(tr_basedepth)
        length.append(tr_length)

    # make arrays for filtering
    length    = np.concatenate(length)
    basedepth = np.concatenate(basedepth)

    condition = np.logical_and(length < 2000, basedepth < -10)

    x         = np.concatenate(x)[condition]
    y         = np.concatenate(y)[condition]
    height    = np.concatenate(height)[condition]
    length    = length[condition]
    asymmetry = np.concatenate(asymmetry)[condition]
    basedepth = basedepth[condition]

    return Result(x, y, height, length, asymmetry, basedepth)


def transect_shape(x, y, z, pi, ti):
    Result = namedtuple('SandwaveShapeTransectResult', ['x', 'y', 'height', 'length', 'asymmetry', 'basedepth'])

    # plt.figure()
    # plt.plot(z)
    # plt.scatter(np.arange(z.size)[pi], z[pi], c='r', lw=0)
    # plt.scatter(np.arange(z.size)[ti], z[ti], c='b', lw=0)
    # plt.show()

    peakpos = pi.astype(float)
    troughpos = ti.astype(float)

    peakdepth = z[pi]
    troughdepth = z[ti]

    asymmetry = np.log((peakpos - troughpos[:-1]) / (troughpos[1:] - peakpos))

    ratios = (peakpos - troughpos[:-1]) / (troughpos[1:] - troughpos[:-1])
    basedepth = troughdepth[:-1]*(1-ratios) + troughdepth[1:]*ratios

    height = peakdepth - basedepth

    dist = (x[ti]**2 + y[ti]**2)**.5
    length = dist[1:] - dist[:-1]

    Result = namedtuple('SandwaveShapeTransectResult', ['x', 'y', 'height', 'length', 'asymmetry', 'basedepth'])
    return Result(x[pi],
                  y[pi],
                  height,
                  length,
                  asymmetry,
                  basedepth)