"""
Processes data from the model database MATROOS, which is accesible through:
 - http://matroos.deltares.nl (default)
 - http://matroos.rijkswaterstaat.nl

Specific dataset classes are available as:
 - WaterlevelData
 - VelocitiesData
 - WaveData

There is also a convenience function `load` to load a dataset by name
"""

from __future__ import print_function, division
import netCDF4
import requests
import os
import tempfile
import logging
import datetime
import numpy as np


__all__ = ['HOST', 'PORT', 'SCRIPT',
           'MatroosData', 'WaterlevelData', 'WaveData', 'VelocitiesData',
           'load']


def magnitude(a, b):
    return (a**2+b**2)**.5


def RMS(diff):
    return np.sqrt(np.sum(diff**2)/len(diff))


HOST = 'http://matroos.deltares.nl'
PORT = 80
SCRIPT = 'matroos/scripts/matroos.pl'


class MatroosData(object):

    fields = []
    source = None
    dimensions = ['x', 'y', 't']

    _last_source_url = None

    @classmethod
    def get_server_url(cls):
        return '{}:{}//{}'.format(HOST, PORT, SCRIPT)

    @classmethod
    def _get_datafile(cls, xll=None, yll=None, w=None, h=None, tmin=None, tmax=None, tspan=None, source=None, directory=None, filename=None, reset=False, **kwargs):
        """
        load data from matroos database

        :param xll: x position of lower left corner for the dataset
        :param yll: y position of lower left corner for the dataset
        :param w: width of the dataset
        :param h: height of the dataset
        :param tmin: start time as datetime object (limited to availability of MATROOS data); default=tmax-tspan
        :param tmax: stop time as datetime object (limited to availability of MATROOS data); default=current time
        :param tspan: timespan as timespan object (required when tmin is not specified)
        :param source: data source
        :param cachedir: directory for datafiles; default=system temporary files location
        :param filename: filename to save the data to (is combined with cachedir
        :param reset: boolean that defines if data is loaded from an existing file (fast) or the original source (slow)
        :return: netCDF4.Dataset object
        """

        logger = logging.getLogger(__name__)

        if filename is not None and directory is not None:
            filepath = os.path.join(directory, filename)
        elif filename is not None:
            filepath = filename
        else:
            filepath = tempfile.mkstemp(suffix='.nc')[1]

        if reset:
            current_time = datetime.datetime.now()

            if tmax is None:
                tmax = current_time

            if tmin is None:
                if not isinstance(tspan, datetime.timedelta):
                    raise ValueError('tspan must be a timedelta object if tmin is not specified')
                tmin = current_time - tspan

            if None in (xll, yll, w, h):
                raise ValueError('position required in arguments (xll, yll, w, h)')

            if source is None:
                raise ValueError('source not defined')

            kwargs.update(dict(
                xmin=xll, ymin=yll, xmax=xll+w, ymax=yll+h,
                tmin=tmin, tmax=tmax,
                current_time=current_time,
                source=source, fields=cls.fields, format='nc', coords='UTM31'))

            url = cls.build_matroos_url(**kwargs)
            cls._last_source_url = cls.download(filepath, url)
        else:
            logger.info('loading MATROOS data from local {!r}'.format(filepath))

        if not os.path.isfile(filepath):
            raise IOError('local file {} not found'.format(filepath))

        return netCDF4.Dataset(filepath, 'r')

    @classmethod
    def download(cls, dst, url):
        """
        download data from matroos
        :param dst: destination at which to save the datafile
        :param url: url from which to download
        :return: the used url
        """
        logger = logging.getLogger(__name__)
        logger.info('loading MATROOS data from remote {!r}'.format(url))
        logger.info('storing output at {!r}'.format(dst))
        try:
            cls.download_file(url, dst)
        except Exception as e:
            logger.exception('error while loading data')
            raise

        return url

    @classmethod
    def _get_default_url_data(cls):
        """
        defines default data to use for the matroos request
        :return: dict with settings as used by `build_matroos_url`
        """
        return dict(anal='000000000000', z=0, interpolate='count',
                    cellx="", celly="",
                    stridex="", stridey="", stridetime=1,
                    xn=1121, yn=1261,
                    xmin_abs="", xmax_abs=987586, ymin_abs=4920879, ymax_abs=7135393)

    @classmethod
    def build_matroos_url(cls, tmin, tmax, current_time, fields, **kwargs):
        """
        use various settings to build a url for the matroos database
        :param tmin: start time for data
        :param tmax: end time for data
        :param current_time: current time
        :param fields: data fields to use
        :param kwargs: additional settings
        :return: url
        """
        date_format = '%Y%m%d0000'
        datetime_format = '%Y%m%d%H00'
        data = cls._get_default_url_data()

        fieldstr = ','.join(fields)
        fmt = kwargs.pop('format', 'nc')

        data['fieldoutput'] = fieldstr
        data['outputformat'] = fmt
        data['format'] = fmt
        data['from'] = tmin.strftime(datetime_format)
        data['to'] = tmax.strftime(datetime_format)
        data['now'] = current_time.strftime(date_format)

        data.update(kwargs)

        argstr = '&'.join(['{}={}'.format(k, v) for k, v in data.items()])
        return cls.get_server_url()+'?'+argstr

    @staticmethod
    def download_file(url, filename):
        """download a file from a given url"""
        logging.getLogger(__name__).info('downloading file from {!r}'.format(url))
        r = requests.get(url, stream=True)
        with open(filename, 'wb') as f:
            for chunk in r.iter_content(chunk_size=1024):
                if chunk: # filter out keep-alive new chunks
                    f.write(chunk)
        logging.getLogger(__name__).info('download of file complete')

    def interp_values(self, values, px, py):
        from scipy.interpolate import RectBivariateSpline

        if values.shape != self.x.shape:
            raise ValueError('invalid shape for values; must be {}'.format(self.x.shape))

        if not (np.diff(self.x, axis=0) == 0).all() or not (np.diff(self.y, axis=0) == 0).all():
            raise ValueError('irregular grid')
        x = self.x[0, :]
        y = self.y[:, 0]

        return RectBivariateSpline(x, y, values.T).ev(px, py)

    def __repr__(self):
        if self.source is None:
            source = 'unknown'
        elif isinstance(self.source, netCDF4.Dataset):
            source = self.source.summary
        else:
            source = repr(self.source)
        return '<{} {!r} fields={}, dimensions={}>'.format(
                self.__class__.__name__, source, self.fields, self.dimensions)

    __str__ = __repr__


class WaterlevelData(MatroosData):

    fields = ['SEP']

    def __init__(self, x, y, t, wlvl, source=None):
        self.x = x
        self.y = y
        self.t = t
        self.wlvl = wlvl
        self.source = source

    @classmethod
    def load(cls, *args, **kwargs):
        if kwargs.get('source', None) is None:
            if kwargs.get('wind', True):
                kwargs['source'] = 'dcsm_v6_ukmo'
            else:
                kwargs['source'] = 'dcsm_v6_astro'

        datafile = cls._get_datafile(*args, **kwargs)

        try:
            x = datafile.variables['x'][:]
            y = datafile.variables['y'][:]
            t = datafile.variables['time'][:]
            wlvl = datafile.variables['SEP'][:]
        except KeyError as e:
            raise ValueError('invalid datafile at {!r}; check if timesteps are valid at http://matroos.deltares.nl/maps/start'.format(cls._last_source_url))
        return cls(x, y, t, wlvl, source=datafile)


class VelocitiesData(MatroosData):

    fields = ['VELU', 'VELV']

    def __init__(self, x, y, t, u, v, source=None):
        self.x = x
        self.y = y
        self.t = t
        self.u = u
        self.v = v
        self.source = source

    @classmethod
    def load(cls, *args, **kwargs):
        if kwargs.get('source', None) is None:
            if kwargs.get('wind', True):
                kwargs['source'] = 'dcsm_v6_ukmo'
            else:
                kwargs['source'] = 'dcsm_v6_astro'

        datafile = cls._get_datafile(*args, **kwargs)

        try:
            x = datafile.variables['x'][:]
            y = datafile.variables['y'][:]
            t = datafile.variables['time'][:]
            u = np.squeeze(datafile.variables['VELU'][:])
            v = np.squeeze(datafile.variables['VELV'][:])
        except KeyError as e:
            raise ValueError('invalid datafile at {!r}; check if timesteps are valid at http://matroos.deltares.nl/maps/start'.format(cls._last_source_url))
        return cls(x, y, t, u, v, source=datafile)

    def mean_ellipse(self, position_scale=1, velocity_scale=1, apply_position=False):
        mean_u = np.mean(self.u, axis=(1, 2))*velocity_scale
        mean_v = np.mean(self.v, axis=(1, 2))*velocity_scale
        if apply_position:
            return mean_u+np.mean(self.x)*position_scale, mean_v+np.mean(self.y)*position_scale
        else:
            return mean_u, mean_v

    def ellipses(self, position_scale=.001, velocity_scale=1, apply_position=False):
        if apply_position:
            if self.x.ndim == 1:
                X, Y = np.meshgrid(self.x, self.y)
            else:
                X, Y = self.x, self.y
            x = X.reshape((1, X.shape[0], X.shape[1]))
            y = Y.reshape((1, Y.shape[0], Y.shape[1]))
            ell_u, ell_v = self.u*velocity_scale+x*position_scale, self.v*velocity_scale+y*position_scale
        else:
            ell_u, ell_v = self.u*velocity_scale, self.v*velocity_scale
        return ell_u.reshape(ell_u.shape[0], -1), ell_v.reshape(ell_v.shape[0], -1)

    @property
    def velocity_magnitude(self):
        return (self.u**2+self.v**2)**.5

    def decompose(self):
        self.asTide().decompose()

    def asTide(self):
        from . import tide
        return tide.Tide.from_dataset(self)


class WaveData(MatroosData):

    fields = ['wave_dir_th0', 'wave_height_hm0', 'swellwave_height_hm0']

    @classmethod
    def load(cls, *args, **kwargs):
        kwargs.setdefault('source', 'swan_zuno')
        datafile = cls._get_datafile(*args, **kwargs)

        try:
            x = datafile.variables['x'][:]
            y = datafile.variables['y'][:]
            t = datafile.variables['time'][:]
            wave_dir = datafile.variables['wave_dir_th0'][:]
            wave_hsig = datafile.variables['wave_height_hm0'][:]
            wave_hswell = datafile.variables['swellwave_height_hm0'][:]
        except KeyError:
            print(datafile, map(str, datafile.variables.keys()))
            raise

        return cls(x, y, t, wave_dir, wave_hsig, wave_hswell, source=datafile)

    def __init__(self, x, y, t, wave_dir, wave_hsig, wave_hswell, source=None):
        self.x = x
        self.y = y
        self.t = t
        self.wave_dir = wave_dir
        self.wave_hsig = wave_hsig
        self.wave_hswell = wave_hswell
        self.source = source

    @classmethod
    def _get_default_url_data(cls):
        data = super(WaveData, cls)._get_default_url_data()
        # data['stridetime'] = 5
        return data

    @property
    def hsig_99(self):
        return np.ma.MaskedArray(np.percentile(self.wave_hsig, 99, axis=0), mask=self.wave_hsig.mask.any(axis=0))

    @property
    def hswell_99(self):
        return np.ma.MaskedArray(np.percentile(self.wave_hswell, 99, axis=0), mask=self.wave_hswell.mask.any(axis=0))


def load(name, *args, **kwargs):
    dataclss = dict(velocities=VelocitiesData, waves=WaveData, waterlevels=WaterlevelData)
    return getattr(dataclss[name], 'load')(*args, **kwargs)
