"""
module for processing data retrieved from waterbase (live.waterbase.nl)
"""

from __future__ import print_function, division
import numpy as np
from collections import namedtuple, Counter
import datetime


class WaterbaseDataset(object):

    COLUMN_NAMES = ['locatie', 'datum', 'tijd', 'waarde', 'xlat', 'ylong']
    COLUMN_TYPES = dict(
            locatie=str,
            datum=str,
            tijd=str,
            waarde=float,
            xlat=float,
            ylong=float)
    COLUMN_CONSTS = ['xlat', 'ylong', 'locatie']
    ATTRNAMES = dict(
            locatie='location',
            datum='date',
            tijd='time',
            waarde='value',
            xlat='latitude',
            ylong='longitude')

    def __init__(self, **data):
        for k, v in data.items():
            if k in self.COLUMN_NAMES:
                setattr(self, self.ATTRNAMES.get(k, k), v)

        self.value = np.array(self.value)
        self.t = np.zeros(self.value.shape, dtype=int)
        reftime = datetime.datetime(year=1970, month=1, day=1)
        for i, d in enumerate(self.date):
            dt = datetime.datetime.strptime(d+self.time[i], '%Y-%m-%d%H:%M')
            t_diff = (dt-reftime)
            self.t[i] = t_diff.days*24*60+t_diff.seconds//60
        del self.date
        del self.time

    @classmethod
    def _read_rows(cls, loc):
        with open(loc, 'r') as f:
            for i in range(3):
                f.readline()
            colnames = f.readline().strip().split(';')
            for i, c in enumerate(colnames):
                colnames[i] = c.replace('/', '').split(' ')[0]

            rows = []
            Row = namedtuple('Row', colnames)
            for line in f:
                rows.append(Row(*line.strip().split(';')))

        return colnames, rows

    @classmethod
    def load(cls, loc):
        colnames, rows = cls._read_rows(loc)

        return cls(**cls.rows2data(colnames, rows))

    @classmethod
    def load_multiple(cls, loc, splitter='locatie'):
        colnames, rows = cls._read_rows(loc)
        try:
            split_index = colnames.index(splitter)
        except IndexError:
            raise ValueError('unknown splitter {}'.format(splitter))

        row_datasets = {}
        for row in rows:
            splitval = row[split_index]
            if splitval not in row_datasets:
                row_datasets[splitval] = []
            row_datasets[splitval].append(row)

        for k, v in row_datasets.items():
            data = cls.rows2data(colnames, v)
            yield cls(**data)

    @classmethod
    def rows2data(cls, colnames, rows):
        if not rows:
            raise ValueError('empty')
        data = dict()
        for i, c in enumerate(colnames):
            if cls.COLUMN_NAMES is not None and c not in cls.COLUMN_NAMES:
                continue
            col_type_fn = (lambda x: x) if not cls.COLUMN_TYPES else cls.COLUMN_TYPES[c]
            if c in cls.COLUMN_CONSTS:
                if not all(r[i]==rows[0][i] for r in rows):
                    raise ValueError('{} expected constant'.format(c))
                data[c] = col_type_fn(rows[0][i])
            else:
                data[c] = [col_type_fn(r[i]) for r in rows]
        return data

    def __repr__(self):
        return '<{} at {} lat={} lon={}>'.format(self.__class__.__name__, self.location, self.latitude, self.longitude)

    def time(self, unit='days', reset=False, start=None):
        time_series = self.t
        scale = dict(minutes=1, hours=60, days=24*60)[unit]
        time_series = time_series/(1.*scale)

        if reset:
            if start is not None:
                time_series = time_series - start
            else:
                time_series = time_series - time_series[0]

        return time_series


class WaveHeightDataset(WaterbaseDataset):

    COLUMN_TYPES = WaterbaseDataset.COLUMN_TYPES.copy()
    COLUMN_TYPES['waarde'] = lambda x: int(x)*.01


class WaveDirectionDataset(WaterbaseDataset):

    def cycle_angles(self, diff):
        self.value = (self.value + diff) % 360


class WavePeriodDataset(WaterbaseDataset):
    pass


def check_time(*times, **kwargs):
    step = kwargs.get('step', 60)

    t_start = min(*(t[0] for t in times))
    t_stop = max(*(t[-1] for t in times))
    t = np.arange(t_start, t_stop+step, step)

    overlaps = tuple(np.in1d(t, ti) for ti in times)
    return t, overlaps


class Location(object):

    def __init__(self, name, t, hsig, T, direction):
        self.name = name
        self.t = t
        self.hsig = hsig
        self.T = T
        self.direction = direction

    def waverose(self):
        return WaveRose(self.t, self.hsig, self.direction, self.name)

    def time(self, unit='days', reset=False, start=None):
        time_series = self.t
        scale = dict(minutes=1, hours=60, days=24*60)[unit]
        time_series = time_series/(1.*scale)

        if reset:
            if start is not None:
                time_series = time_series - start
            else:
                time_series = time_series - time_series[0]

        return time_series


class Rose(object):

    def __init__(self, t, value, direction, label):
        self.t = t
        self.value = value
        self.direction = direction
        self.label = label

    def bin(self, vmax=7, vstep=.25):
        val_bins = np.arange(0, vmax, vstep)
        dir_bins = np.arange(0, 360, 10)

        bin_indices_val = np.digitize(self.value, val_bins[1:-1])
        bin_indices_dir = np.digitize(self.direction, dir_bins[1:])

        bins = np.zeros((val_bins.size, dir_bins.size), dtype=int)

        for i in range(self.value.size):
            j, k = bin_indices_val[i], bin_indices_dir[i]
            bins[j, k] += 1

        return val_bins, dir_bins, bins

    def plot(self, ax, ylabel=None, cax=None, lw=0, cmap='Spectral_r', vmin=0, vmax=7, colorscale=1, vstep=.25):
        from matplotlib.projections.polar import PolarAxes
        from matplotlib import cm, patches
        cmap = cm.get_cmap(cmap)

        if not isinstance(ax, PolarAxes):
            raise ValueError('axes not a polar projection')

        value, direction, counts = self.bin(vmax, vstep=vstep)
        total = np.sum(counts)
        bases = np.zeros(counts.shape[1])

        if lw > 0:
            # plot surrounding line
            ax.bar((-direction-5)*np.pi/180+.5*np.pi,
                   np.sum(counts, axis=0)/total,
                   width=10*np.pi/180,
                   color='none',
                   lw=lw)

        # plot bins
        for i, row in enumerate(counts):
            label = '{:.1f}-{:.1f}m'.format(value[i], value[i]+.5)
            color = cmap(((value[i]+.05)/(np.amax(value)+.05))**colorscale)
            ax.bar((-direction-5)*np.pi/180+.5*np.pi,
                   row/total,
                   width=10*np.pi/180,
                   bottom=bases,
                   color=color,
                   lw=0,
                   label=label)
            bases += row/total

            if cax:
                p = patches.Rectangle((0, value[i]), 1, .5, clip_on=True, lw=0, fc=color)
                cax.add_patch(p)

        if cax:
            cax.set_xlim(0, 1)
            cax.set_ylim(0, np.amax(value))
            cticks = list(value[::4])+[vmax]
            cax.set_yticks(cticks)
            cax.set_yticklabels(map(str, cticks))
            cax.yaxis.tick_right()
            cax.yaxis.set_label_position("right")
            cax.set_xticks([])
            if ylabel:
                cax.set_ylabel(ylabel)

        rticks = np.arange(.03, .15, .03)
        ax.set_yticks(rticks)
        ax.set_yticklabels(['{:.0f}%'.format(t*100) for t in rticks])
        ax.set_xticks(np.arange(0, 2*np.pi, .5*np.pi))
        ax.set_xticklabels(['E', 'N', 'W', 'S'])
        ax.set_rlim(0, .12)
        ax.set_axis_bgcolor('none')

        return ax


class WaveRose(Rose):

    def __init__(self, t, hsig, wdir, location):
        super(WaveRose, self).__init__(t, hsig, wdir, location)

    @classmethod
    def from_data(cls, t_hsig, hsig, t_wdir, wdir, location):
        t, (hsig_mask, wdir_mask) = check_time(t_hsig, t_wdir)

        value = np.ma.masked_equal(np.zeros(t.shape), 0)
        direction = np.ma.masked_equal(np.zeros(t.shape), 0)

        value[hsig_mask] = hsig
        direction[wdir_mask] = wdir

        return cls(t, value, direction, location)


def k_residual(omega, k, dep):
    return omega**2 - 9.81 * k * np.tanh(k * dep)


def k_iter(dep, T):
    omega = np.pi * 2 / T
    K = np.zeros(omega.shape)

    for ki, O in enumerate(omega):
        res_norm = 1e-10
        res = res_norm + 1
        k = 2 * np.pi / 80
        i = 0
        while np.abs(res) > res_norm:
            i += 1
            res = k_residual(O, k, dep)
            d_res_dk = -9.81 * np.tanh(k * dep) - dep * k * 9.81 * (1 - np.tanh(k * dep)**2)
            dk = -res / d_res_dk
            k += dk
        K[ki] = k
    return K


def orbital_bed_velocity(hsig, T, dep):
    return np.pi * hsig / (T * np.sinh(k_iter(dep, T) * dep))


def orbital_diameter(hsig, T, dep):
    return hsig / np.sinh(k_iter(dep, T) * dep)


def tau_waves(hsig, T, depth, d50):
    uorb = orbital_bed_velocity(hsig, T, depth)
    return 1000 * uorb**2 * np.exp(-5.977 + 5.213*(2.5 * d50 * 2 / orbital_diameter(hsig, T, depth))**.194)


def shields(hsig, T, depth, d50, rhos):
    return tau_waves(hsig, T, depth, d50)/((rhos-1000)*9.81*d50)

