"""
Module for handling grid datasets
A grid is defined as a dataset with:
 - x positions (1D)
 - y positions (1D)
 - values (2D)
Cells are square and must be equidistant
"""

from __future__ import print_function, division
import re
import numpy as np
import logging

from .timeseries_selectors import collapse_timeseries, compress_timesteps
from ..plotting.grid import plot_depth, plot_grid
from .. import tools


__all__ = ['BaseGridDataset', 'GridDataset', 'GridTimeStackDataset', 'DifferenceGridDataset', 'join_grids', 'plot_depth']


class BaseGridDataset(object):
    """
    Base class for grid objects
    Each grid object has at least the following attributes:
    x: 1D array of x positions
    y: 1D array of y positions
    Z: 2D array of values
    """

    def __init__(self, x, y, Z):
        self.x = x
        self.y = y
        self.Z = Z

    def get_data(self):
        """convenience method for retrieving the data"""
        return self.x, self.y, self.Z

    def validate(self, vmin, vmax):
        # check for outlier values (50m below mean)
        extreme_mask = (self.Z < vmin) | (self.Z > vmax)
        if np.sum(extreme_mask) > 10:
            logging.getLogger(__name__).warning(
                'masking {num_outliers} outliers at grid definition xll={xll} yll={yll} sid={sid}'.format(
                    num_outliers=np.sum(extreme_mask),
                    xll=self.x.min(),
                    yll=self.y.min(),
                    sid=self.meta.get('sid', '<undefined>')))
        self.Z.mask[extreme_mask] = True

    def __repr__(self):
        name = '{}.{}'.format(self.__module__, self.__class__.__name__)
        xll = np.amin(self.x)
        yll = np.amin(self.y)
        return '<{} x{} y{} {}>'.format(name, xll, yll, self.Z.shape)

    def as_points(self, mask=None):
        """return x, y and Z as tuple of 1D point arrays"""

        # create 2D x and y arrays
        X, Y = np.meshgrid(self.x, self.y)

        Z = self.Z

        # get mask from values
        if mask is None and isinstance(Z, np.ma.MaskedArray):
            mask = Z.mask

        # apply mask
        if mask is not None:
            X = X[~mask]
            Y = Y[~mask]
            Z = Z[~mask]

        return X.flatten(), Y.flatten(), Z.flatten()

    def subset(self, xll, yll, w, h):
        """
        get a subset of the grid
        :param xll: x position of lower left corner
        :param yll: y position of lower left corner
        :param w: width of the dataset
        :param h: height of the dataset
        :return: instance of the same class, but a subset of the data for x, y and Z
        """
        xmask = (self.x >= xll) & (self.x < xll+w)
        ymask = (self.y >= yll) & (self.y < yll+h)
        Xmask, Ymask = np.meshgrid(xmask, ymask)
        submask = (Xmask & Ymask).astype(bool)
        Z = self.Z[submask].reshape(ymask.sum(), xmask.sum())
        return type(self)(self.x[xmask], self.y[ymask], Z)

    def __sub__(self, other, outcls=None):
        """subtract grids"""

        # simple numeric value
        if isinstance(other, (int, float, np.ndarray)):
            return type(self)(self.x, self.y, self.Z - other)

        # if other type than grid
        elif not isinstance(other, BaseGridDataset):
            return super(BaseGridDataset, self).__sub__(other)

        # check equal positions
        elif not (self.x == other.x).all() or not (self.y == other.y).all():
            raise ValueError('cannot match positions')

        # return the result as a grid object
        return (outcls or BaseGridDataset)(self.x, self.y, self.Z - other.Z)

    def __add__(self, other, outcls=None):
        """add grids"""

        # simple numeric value
        if isinstance(other, (int, float, np.ndarray)):
            return type(self)(self.x, self.y, self.Z + other)

        # if other type than grid
        elif not isinstance(other, BaseGridDataset):
            return super(BaseGridDataset, self).__sub__(other)

        # check equal positions
        elif not (self.x == other.x).all() or not (self.y == other.y).all():
            raise ValueError('cannot match positions')

        # return the result as a grid object
        return (outcls or BaseGridDataset)(self.x, self.y, self.Z + other.Z)

    def __mul__(self, other, outcls=None):
        """multiply grids"""

        # simple numeric value
        if isinstance(other, (int, float, np.ndarray)):
            return type(self)(self.x, self.y, self.Z * other)

        # if other type than grid
        elif not isinstance(other, BaseGridDataset):
            return super(BaseGridDataset, self).__mul__(other)

        # check equal positions
        elif not (self.x == other.x).all() or not (self.y == other.y).all():
            raise ValueError('cannot match positions')

        # return the result as a grid object
        return (outcls or type(self))(self.x, self.y, self.Z * other.Z)

    def __div__(self, other, outcls=None):
        """divide grids"""

        # simple numeric value
        if isinstance(other, (int, float, np.ndarray)):
            return type(self)(self.x, self.y, self.Z / other)

        # if other type than grid
        elif not isinstance(other, BaseGridDataset):
            return super(BaseGridDataset, self).__div__(other)

        # check equal positions
        elif not (self.x == other.x).all() or not (self.y == other.y).all():
            raise ValueError('cannot match positions')

        # return the result as a grid object
        return (outcls or type(self))(self.x, self.y, self.Z / other.Z)

    def __truediv__(self, other, **kwargs):
        return self.__div__(other, **kwargs)

    def overlap(self, *others, **kwargs):
        grids = [self]+list(others)
        return self.overlap_grids(grids, **kwargs)

    @classmethod
    def overlap_grids(cls, grids, xmin=None, xmax=None, ymin=None, ymax=None):
        """
        return only the overlap of several grids as a new set of grids
        :param grids: list or tuple of grids
        :param xmin: enforce lower x position
        :param xmax: enforce upper x position
        :param ymin: enforce lower y position
        :param ymax: enforce upper y position
        :return: tuple of adjusted grids
        """

        overlapped_grids = []

        if xmin is None:
            xmin = np.amax([g.x[0] for g in grids])
        if xmax is None:
            xmax = np.amin([g.x[-1] for g in grids])

        if ymin is None:
            ymin = np.amax([g.y[0] for g in grids])
        if ymax is None:
            ymax = np.amin([g.y[-1] for g in grids])

        xstep = grids[0].x[1] - grids[0].x[0]
        ystep = grids[0].y[1] - grids[0].y[0]

        for i, g in enumerate(grids):
            xmask = (g.x >= xmin) & (g.x <= xmax)
            ymask = (g.y >= ymin) & (g.y <= ymax)
            box_mask = np.tile(xmask, (g.y.size, 1)) & np.tile(ymask, (g.x.size, 1)).T

            if not (np.diff(g.x) == xstep).all():
                raise ValueError('x not equally spaced with {} for {!r}'.format(xstep, g))

            if not (np.diff(g.y) == ystep).all():
                raise ValueError('y not equally spaced with {} for {!r}'.format(ystep, g))

            x = g.x[xmask]
            y = g.y[ymask]
            data = g.Z[box_mask].reshape(np.sum(ymask), np.sum(xmask))
            overlapped_grids.append(cls(x, y, data))

        overlapped_masks = np.array([g.Z.mask.astype(bool) for g in overlapped_grids])

        data_mask = np.logical_or.reduce(overlapped_masks)

        for g in overlapped_grids:
            g.Z.mask = data_mask

        return tuple(overlapped_grids)

    def remove_mean(self):
        """subtract the mean from the data"""
        return type(self)(self.x, self.y, self.Z-np.mean(self.Z))

    @property
    def cellarea(self):
        """return the area of a single cell (assuming equal position spacing)"""
        cw, ch = self.cellsize
        return cw*ch

    @property
    def cellsize(self):
        return (self.x[1] - self.x[0]), (self.y[1] - self.y[0])

    @property
    def shape(self):
        return self.Z.shape

    @property
    def position(self):
        """return the position of the grid as xll, yll, w, h"""
        return self.x[0], self.y[0], self.x[-1] - self.x[0], self.y[-1] - self.y[0]

    @classmethod
    def calculate_coverage(cls, data):
        if data.dtype is np.bool:
            nanmask = data
        elif isinstance(data, np.ma.MaskedArray):
            nanmask = data.mask
        else:
            nanmask = np.isnan(data)
        return 1. - float(nanmask.sum()) / float(nanmask.size)

    @property
    def coverage(self):
        return self.calculate_coverage(self.Z)

    def sample(self, px, py, scale=1, mode='linear'):
        if not isinstance(px, np.ndarray):
            if not isinstance(px, list):
                px = [px]
            px = np.array(px)
        if not isinstance(py, np.ndarray):
            if not isinstance(py, list):
                py = [py]
            py = np.array(py)

        vmin = self.Z.min()
        vmax = self.Z.max()
        fill_value = vmin - (vmax - vmin)
        pz = tools.interp.grid_interp(self.x*scale, self.y*scale, self.Z.filled(fill_value), px, py, mode=mode)
        pz[pz < vmin] = np.nan
        return pz

    @property
    def xll(self):
        return self.x.min()

    @property
    def yll(self):
        return self.y.min()

    @property
    def w(self):
        return self.x.max()+self.cellsize[0] - self.x.min()

    @property
    def h(self):
        return self.y.max() + self.cellsize[1] - self.y.min()

    def xlim(self, scale=1):
        return self.x.min()*scale, self.x.max()*scale

    def ylim(self, scale=1):
        return self.y.min()*scale, self.y.max()*scale


class GridDataset(BaseGridDataset):
    """
    class for handling grid datasets
    offers a more extensive interface compared to the simple functions for loading grids

    creating an instance:
    >>>grid = GridDataset(x, y, data)

    loading:
    >>>grid = GridDataset.load(filename, ftype)

    saving:
    >>>GridDataset(x, y, data).save(filename, ftype)

    plotting on given axes:
    >>>GridDataset(x, y, data).plot(ax)

    fast plotting:
    >>>GridDataset(x, y, data).show()

    get difference of two grids as a new grid object:
    >>>diff = GridDataset(x, y, data1) - GridDataset(x, y, data2)
    and add a grid to the difference to get a new normal grid
    >>>diff + GridDataset(x, y, data3)
    """

    # the pattern to which datafiles have to comply
    filename_pattern = re.compile(r'^x([0-9]+)_y([0-9]+)_(\w+)_.*')

    def __init__(self, x, y, Z, **meta):
        if Z.ndim != 2:
            raise ValueError('Z must be 2D array')

        # force masked array
        if isinstance(Z, np.ma.MaskedArray):
            Z = Z.filled(np.nan)
        Z = np.ma.masked_invalid(Z)

        if Z.mask.all():
            logging.getLogger(__name__).warning('defined grid is empty')

        super(GridDataset, self).__init__(x, y, Z)
        self.meta = meta

    def __getattr__(self, item):
        """
        Access to metadata is possible using object attributes
        >>>g = GridDataset(x, y, z, source='some location')
        >>>assert g.source=='some location'
        """
        try:
            return self.meta[item]
        except KeyError:
            return object.__getattribute__(self, item)

    @property
    def mask(self):
        return self.Z.mask

    @classmethod
    def load(cls, file_or_buf, ftype=None, source=None, collapse=None, **kwargs):
        """
        load data from a file
        supported loaders are defined in datasets.loaders
        """

        from .iotools import load_griddata

        x, y, data, meta = load_griddata(file_or_buf, ftype=ftype, source=source, **kwargs)
        if data.ndim == 3:
            t = meta.pop('time', None)
            data = collapse_timeseries(data, method=collapse or 'recent')
        return cls(x, y, data, **meta)

    def save(self, filename, ftype=None, **kwargs):
        """save data to a file"""
        from .iotools import save_griddata
        save_griddata(filename, self.x, self.y, self.Z, ftype=ftype, meta=self.meta, **kwargs)

    def show(self, plottype='depth', colorbar=True, clabel=None, **kwargs):
        """fast plot the dataset and show in a window"""
        from matplotlib import pyplot as plt
        # create figure
        fig = plt.figure()
        ax = fig.gca()

        c = self.plot(ax, plottype=plottype, **kwargs)

        if colorbar:
            # add colorbar
            cb = fig.colorbar(c)
            if clabel is None and plottype == 'depth':
                clabel = 'depth (m)'
            cb.set_label(clabel or '')

        # show figure
        plt.show()

    def plot(self, ax, plottype='depth', **kwargs):
        if plottype == 'depth':
            return plot_depth(ax, self.x, self.y, self.Z, **kwargs)
        elif plottype == 'simple':
            return plot_grid(ax, self, **kwargs)

    def __sub__(self, other):
        """
        subtract grids (incorporating DifferenceGridDataset)
        if other is an instance of DifferenceGridDataset, return an instance of self using BaseGridDataset.__sub__
        if other is an instance of GridDataset, return an instance of DifferenceGridDataset using BaseGridDataset.__sub__
        else, return using BaseGridDataset.__sub__
        see BaseGridDataset.__sub__ and DifferenceGridDataset for more info
        """
        assert isinstance(self, BaseGridDataset), self.__class__
        if isinstance(other, DifferenceGridDataset):
            return BaseGridDataset.__sub__(self, other, outcls=type(self))
        elif isinstance(other, GridDataset):

            return BaseGridDataset.__sub__(self, other, outcls=DifferenceGridDataset)
        else:
            return BaseGridDataset.__sub__(self, other)

    def __add__(self, other):
        """
        add grids (incorporating DifferenceGridDataset)
        if other is an instance of DifferenceGridDataset, return an instance of self using BaseGridDataset.__add__
        else, return using BaseGridDataset.__add__
        see BaseGridDataset.__add__ and DifferenceGridDataset for more info
        """
        if isinstance(other, DifferenceGridDataset):
            return BaseGridDataset.__add__(self, other, outcls=type(self))
        else:
            return BaseGridDataset.__add__(self, other)

    def aggregate(self, blocksize=5):
        Z = tools.interp.aggregate(self.Z, blocksize).mean(axis=-1)
        x = tools.interp.aggregate(self.x, blocksize).mean(axis=-1).astype(int)
        y = tools.interp.aggregate(self.y, blocksize).mean(axis=-1).astype(int)
        return GridDataset(x, y, Z)


class GridTimeStackDataset(BaseGridDataset):

    """
    Stack of grid data timesteps

    To create a dataset:
    >>>GS = GridTimeStackDataset(x, y, t, Z, **meta)
    To extract the most recent point at each location
    >>>G = GS.recent()
    Which then contains a map of timesteps at each location:
    >>>G.time

    To show each of the contained grids
    >>>for G in GridTimeStackDataset(x, y, t, Z, **meta):
    >>>    G.show()
    """

    def __init__(self, x, y, t, Z, **meta):
        if Z.ndim != 3:
            raise ValueError('Z must be 3D array')

        if not isinstance(Z, np.ma.MaskedArray):
            Z = np.ma.masked_invalid(Z)

        super(GridTimeStackDataset, self).__init__(x, y, Z)
        self.t = t
        self.meta = meta

        if np.unique(self.t).size != self.t.size:
            self.combine_duplicate_timesteps()

    def get_data(self):
        return self.x, self.y, self.t, self.Z

    def combine_duplicate_timesteps(self, start_at_last=True):
        unique_t, unique_t_ind = np.unique(self.t, return_index=True)
        logging.getLogger(__name__).debug('reducing timesteps from {} to {} (timestep duplicates)'.format(self.t.size, unique_t.size))
        new_Z = np.ma.masked_all((unique_t.size, self.y.size, self.x.size), dtype=self.Z.dtype)
        i_ind = np.tile(np.arange(self.Z.shape[1]), (self.Z.shape[2], 1)).T.flatten()
        j_ind = np.tile(np.arange(self.Z.shape[2]), (self.Z.shape[1], 1)).flatten()

        for i, ti in enumerate(unique_t):
            if start_at_last:
                item = self.Z[::-1, :, :][self.t[::-1] == ti]
            else:
                item = self.Z[self.t == ti]
            t_ind = np.argmin(item.mask, axis=0).flatten()
            new_Z[i, :, :] = item[t_ind, i_ind, j_ind].reshape(item.shape[1:])

        for k, v in self.meta.items():
            if isinstance(v, np.ndarray) and v.size == self.t.size:
                self.meta[k] = v[unique_t_ind]

        self.Z = new_Z
        self.t = unique_t

        return self

    @property
    def mask(self):
        return self.Z.mask.any(axis=0)

    @property
    def total_coverage(self):
        return self.calculate_coverage(self.mask)

    @classmethod
    def calculate_coverage(cls, data):
        if data.ndim == 2:
            return super(GridTimeStackDataset, cls).calculate_coverage(data)
        elif data.ndim == 3:
            coverage = [super(GridTimeStackDataset, cls).calculate_coverage(d) for d in data]
        else:
            raise ValueError
        return np.array(coverage)

    def __getattr__(self, item):
        if item in self.meta:
            return self.meta[item]
        else:
            return object.__getattribute__(self, item)

    def subset(self, xll, yll, w, h):
        xmask = (self.x >= xll) & (self.x < xll+w)
        ymask = (self.y >= yll) & (self.y < yll+h)
        Xmask, Ymask = np.meshgrid(xmask, ymask)
        submask = np.tile((Xmask & Ymask).astype(bool), (self.t.size, 1, 1))
        Z = self.Z[submask].reshape(self.t.size, np.sum(ymask), np.sum(xmask))
        assert Z.shape == (self.t.size, np.sum(ymask), np.sum(xmask))
        return type(self)(self.x[xmask], self.y[ymask], self.t, Z, **self.meta)

    @tools.extend_docstring(collapse_timeseries)
    def collapsed_indices(self, **kwargs):
        time_index = np.tile(np.arange(self.t.size).reshape(-1, 1, 1), (1, self.y.size, self.x.size))
        indices = collapse_timeseries(np.ma.MaskedArray(time_index, mask=self.Z.mask), **kwargs).astype(int)
        return indices

    @tools.extend_docstring(collapsed_indices)
    def collapse_timesteps(self, **kwargs):
        collapsed_time_index = self.collapsed_indices(**kwargs)
        r = np.ma.MaskedArray(self.t[collapsed_time_index.filled(0)], mask=collapsed_time_index.mask)
        return r

    @tools.extend_docstring(compress_timesteps)
    def compress_timesteps(self, dt_max=365*2, overlap_max=0.1, min_coverage=.05):
        return compress_timesteps(self, dt_max=dt_max, overlap_max=overlap_max, min_coverage=min_coverage)

    def __getitem__(self, i):
        if isinstance(i, list):
            i = np.array(i)
        if isinstance(i, (np.ndarray, slice)):
            return self.slice(i)
        elif isinstance(i, int):
            return self.get_grid(i)
        else:
            raise IndexError(repr(i))

    def slice(self, I):
        meta = dict()
        for k, v in self.meta.items():
            if isinstance(v, np.ndarray) and v.shape == self.t.shape:
                meta[k] = v[I]
            else:
                meta[k] = v

        return GridTimeStackDataset(self.x, self.y, self.t[I], self.Z[I, :, :], **meta)

    def __len__(self):
        return self.Z.shape[0]

    def recent(self):
        return GridDataset(self.x,
                           self.y,
                           collapse_timeseries(self.Z, method='recent'),
                           time=self.collapse_timesteps(method='recent'))

    def get_grid(self, i, overlap_only=False):
        if overlap_only:
            Z = self.overlapped
        else:
            Z = self.Z

        try:
            sid = self.meta['sid'][i]
        except (KeyError, IndexError):
            sid = None

        try:
            density = self.meta['density'][i]
        except (KeyError, IndexError):
            density = None

        return GridDataset(self.x, self.y, Z[i, :, :], time=self.t[i], sid=sid, density=density)

    @property
    def overlapped(self):
        mask = np.tile(self.mask, (self.Z.shape[0], 1, 1))
        Z = self.Z.copy()
        Z.mask = mask
        return Z

    def iter_grids(self, area_min_ratio=0, N=None):
        """
        iterate over grids
        :param area_min_ratio: minimum coverage ratio
        :param N: max number of grids
        :return: generator of grids
        """
        for ti, data in enumerate(self.Z):
            if N is not None and ti >= N:
                break

            if area_min_ratio > 0 and self.calculate_coverage(data) < area_min_ratio:
                continue

            yield self.get_grid(ti, overlap_only=False)

    def __iter__(self):
        return self.iter_grids(area_min_ratio=0)

    def get_grids(self, **kwargs):
        return list(self.iter_grids(**kwargs))

    @classmethod
    def from_grids(cls, grids, t=None, **meta):
        if not grids:
            raise ValueError('no grids given')

        x, y = grids[0].x, grids[0].y
        data = np.zeros([len(grids)]+list(grids[0].Z.shape))

        for i, g in enumerate(grids):
            if not np.array_equal(g.x, x) or not np.array_equal(g.y, y):
                raise ValueError('positions not equal for all grids')
            if isinstance(g.Z, np.ma.MaskedArray):
                data[i, :, :] = g.Z.filled(np.nan)
            else:
                data[i, :, :] = g.Z

        data = np.ma.masked_invalid(data)

        if t is None:
            t = np.zeros(len(grids))
            for i, g in enumerate(grids):
                try:
                    t[i] = g.time
                except AttributeError:
                    t[i] = i
        elif len(t) != len(grids):
            raise ValueError('time vector must match number of grids')

        return cls(x, y, t, data, **meta)

    def filter_by_coverage(self, min_coverage):
        N = np.product(self.Z.shape[1:])
        coverage = 1 - self.Z.mask.sum(axis=(1, 2)).astype(float) / N
        indices = np.arange(self.t.size)[coverage >= min_coverage]
        self.Z = self.Z[indices, :, :]
        self.t = self.t[indices]

        assert ((1 - self.Z.mask.sum(axis=(1, 2)).astype(float) / N) >= min_coverage).all()

    def diff(self):
        out = []
        grids = self.get_grids()
        for i, g in enumerate(grids[:-1]):
            out.append(grids[i+1] - g)
        return out

    @classmethod
    def load(cls, f, ftype=None, max_timesteps=None, **kwargs):
        from .iotools import load_gridstackdata

        x, y, t, data, meta = load_gridstackdata(f, ftype=ftype, **kwargs)

        if not isinstance(t, np.ndarray) or t.size != data.shape[0]:
            raise ValueError('invalid time data')

        if max_timesteps is not None:
            sl = slice(-max_timesteps, None, None)
            logging.getLogger(__name__).debug(
                'applying max timesteps {} on {} timesteps'.format(max_timesteps, t.size))

            # check meta data when slicing
            for k, v in meta.items():
                if isinstance(v, np.ndarray) and v.size == t.size:
                    meta[k] = v[sl]

            t = t[sl]
            data = data[sl, :, :]

        # construct object
        return cls(x, y, t, data, **meta)

    def save(self, location, ftype=None, **kw):
        from .iotools import save_gridstackdata
        save_gridstackdata(location, self.x, self.y, self.t, self.Z, ftype=ftype, meta=self.meta, **kw)


class DifferenceGridDataset(BaseGridDataset):

    def plot(self, ax, plottype='depth', cmap='seismic', vmin=None, vmax=None, **kwargs):
        if vmin is None:
            if vmax is None:
                vmax = np.amax(np.absolute(self.Z))
            vmin = -vmax

        if plottype == 'depth':
            return plot_depth(ax, self.x, self.y, self.Z, cmap=cmap, vmin=vmin, vmax=vmax, **kwargs)
        else:
            return plot_grid(ax, self, cmap=cmap, vmin=vmin, vmax=vmax, **kwargs)

    def show(self, plottype='depth', colorbar=True, clabel=None, **kwargs):
        """fast plot the dataset and show in a window"""
        from matplotlib import pyplot as plt
        # create figure
        fig = plt.figure()
        ax = fig.gca()

        c = self.plot(ax, plottype=plottype, **kwargs)

        if colorbar:
            # add colorbar
            cb = fig.colorbar(c)
            if clabel is None and plottype == 'depth':
                clabel = 'depth (m)'
            cb.set_label(clabel or '')

        # show figure
        plt.show()

    def __sub__(self, other):
        return BaseGridDataset.__sub__(self, other, outcls=DifferenceGridDataset)

    def __add__(self, other):
        if isinstance(other, GridDataset):
            return BaseGridDataset.__add__(self, other, outcls=GridDataset)
        elif isinstance(other, DifferenceGridDataset):
            return BaseGridDataset.__add__(self, other, outcls=type(self))
        else:
            return BaseGridDataset.__add__(self, other)

    def __mul__(self, other):
        return type(self)(self.x, self.y, self.Z * other)

    def __pow__(self, other):
        return type(self)(self.x, self.y, self.Z ** other)

    def __div__(self, other):
        return type(self)(self.x, self.y, self.Z / other)


class Gridset(object):

    def __init__(self, *grids):
        self.grids = []
        for g in grids:
            self.add_grid(g)

    def add_grid(self, g):
        if not isinstance(g, BaseGridDataset):
            raise TypeError('object not a valid grid type')
        self.grids.append(g)

    def calculate_cellsize(self):
        for i, g in enumerate(self.grids):
            dx = np.diff(g.x)
            dy = np.diff(g.y)

            if i == 0:
                x_cellsize = dx[0]
                y_cellsize = dy[0]

            if (dx != x_cellsize).any():
                raise ValueError('non-consistent x cellsize')

            if (dy != y_cellsize).any():
                raise ValueError('non-consistent y cellsize')

        return int(x_cellsize), int(y_cellsize)

    def join_meta(self, metas):
        meta = metas.pop(0)
        for m in metas:
            for k, v in m.items():
                if k in meta:
                    if isinstance(v, np.ndarray):
                        if (v == meta[k]).all():
                            continue
                    elif v == meta[k]:
                        continue
                meta.pop(k, None)
        return meta

    def _indexed_join(self, series):
        all_data = np.concatenate(series)
        uq, ind, inv = np.unique(all_data, return_index=True, return_inverse=True)
        indices = []
        inv_indices = []
        cursor = 0
        for s in series:
            inv_indices.append(inv[cursor:cursor + len(s)])
            I = ind[(cursor <= ind) & (ind < cursor + len(s))]
            indices.append(I - cursor)
            cursor += len(s)
        return uq, np.array(indices), inv_indices

    def join(self):
        if not self.grids:
            raise ValueError('no grids found')

        if len(self.grids) == 1:
            return self.grids[0]

        gridtypes = [type(g) for g in self.grids]
        if set(gridtypes) == {GridDataset}:
            return self.join_xy(
                [g.x for g in self.grids],
                [g.y for g in self.grids],
                [g.Z for g in self.grids],
                **self.join_meta([g.meta for g in self.grids]))
        elif set(gridtypes) == {GridTimeStackDataset}:
            # get survey ids
            sids = [g.meta.get('sid', None) for g in self.grids]

            for _item in sids:
                if _item is None:
                    sids = None
                    break

            # get densities
            densities = [g.meta.get('density', None) for g in self.grids]
            for _item in densities:
                if _item is None:
                    densities = None
                    break

            return self.join_xyt(
                [g.x for g in self.grids],
                [g.y for g in self.grids],
                [g.t for g in self.grids],
                [g.Z for g in self.grids],
                sids=sids, densities=densities)

    def join_xy(self, xs, ys, Zs, **meta):
        dx, dy = self.calculate_cellsize()

        x = self.join_equidistant(xs, stepsize=dx, name='x')
        y = self.join_equidistant(ys, stepsize=dy, name='y')

        shape = (y.size, x.size)

        Z = np.ma.masked_array(np.zeros(shape), mask=np.ones(shape))

        for g in self.grids:
            i_start = int((g.y.min() - y.min()) // dy)
            j_start = int((g.x.min() - x.min()) // dx)

            # fill with block
            old_data = Z[i_start:i_start+len(g.y), j_start:j_start+len(g.x)]
            m = old_data.mask & ~g.Z.mask  # new values from current layer
            old_data[m] = g.Z[m]
            Z[i_start:i_start+len(g.y), j_start:j_start+len(g.x)] = old_data

        return GridDataset(x, y, Z, **meta)

    # def join_xyt(self, xs, ys, ts, Zs, sids=None, densities=None):
    #     dx, dy = self.calculate_cellsize()
    #
    #     x = self.join_equidistant(xs, stepsize=dx, name='x')
    #     y = self.join_equidistant(ys, stepsize=dy, name='y')
    #
    #     meta = dict()
    #     if sids is not None:
    #         unique_series = sids
    #         meta['sid'] = self.join_series(sids, key=unique_series)
    #         unique_name = 'sid'
    #     else:
    #         unique_series = ts
    #         unique_name = 't'
    #
    #     t = self.join_series(ts, key=unique_series)
    #     unique_joined = self.join_series(unique_series, key=unique_series)
    #
    #     if densities is not None:
    #         meta['density'] = self.join_series(densities, key=unique_series)
    #
    #     shape = (t.size, y.size, x.size)
    #
    #     # TODO: unique_joined sorts and therefore does not match grid order
    #     # TODO: prevent sorting and check all!!!
    #     # output should be sorted, so dont use mask?
    #     Z = np.ma.masked_array(np.zeros(shape), mask=np.ones(shape))
    #
    #     for i, g in enumerate(self.grids):
    #         ti = np.in1d(unique_joined, getattr(g, unique_name))
    #         i_start = ((g.y.min() - y.min()) // dy).astype(int)
    #         j_start = ((g.x.min() - x.min()) // dx).astype(int)
    #
    #         gZ = g.Z[:, :, :]
    #
    #         # fill with block
    #         old_data = Z[ti, i_start:i_start+len(g.y), j_start:j_start+len(g.x)]
    #         # old_data[old_data.mask] = g.Z[old_data.mask]  # old version: raised warning MaskedArrayFutureWarning
    #         if old_data.shape != gZ.shape:
    #             logging.getLogger(__name__).error('shapes differ for joining data {} != {}'.format(old_data.shape, gZ.shape))
    #             import pdb; pdb.set_trace()
    #         old_data[old_data.mask & ~gZ.mask] = gZ[old_data.mask & ~gZ.mask]
    #         Z[ti, i_start:i_start+len(g.y), j_start:j_start+len(g.x)] = old_data
    #
    #     return GridTimeStackDataset(x, y, t, Z, **meta)

    def join_xyt(self, xs, ys, ts, Zs, sids=None, densities=None):
        dx, dy = self.calculate_cellsize()

        x = self.join_equidistant(xs, stepsize=dx, name='x')
        y = self.join_equidistant(ys, stepsize=dy, name='y')

        meta = dict()
        if sids is not None:
            unique_series = sids
            meta['sid'], indices, inverse_indices = self._indexed_join(sids)
            t = np.concatenate([_t[i] for i, _t in zip(indices, ts)])
        else:
            t, indices, inverse_indices = self._indexed_join(ts)

        if densities is not None:
            meta['density'] = np.concatenate([d[i] for i, d in zip(indices, densities)])

        shape = (t.size, y.size, x.size)

        Z = np.ma.masked_array(np.zeros(shape), mask=np.ones(shape))

        for i, (I, Iinv, g) in enumerate(zip(indices, inverse_indices, self.grids)):
            i_start = ((g.y.min() - y.min()) // dy).astype(int)
            j_start = ((g.x.min() - x.min()) // dx).astype(int)

            gZ = g.Z[:, :, :]

            # fill with block
            old_data = Z[Iinv, i_start:i_start + len(g.y), j_start:j_start + len(g.x)]
            # old_data[old_data.mask] = g.Z[old_data.mask]  # old version: raised warning MaskedArrayFutureWarning
            old_data[old_data.mask & ~gZ.mask] = gZ[old_data.mask & ~gZ.mask]
            Z[Iinv, i_start:i_start + len(g.y), j_start:j_start + len(g.x)] = old_data

        return GridTimeStackDataset(x, y, t, Z, **meta)

    def join_equidistant(self, arrays, stepsize=None, name=None):
        if not arrays:
            raise ValueError('no arrays given')
        if stepsize is None:
            stepsize = arrays[0][1] - arrays[0][0]

        vmin = min(a.min() for a in arrays)
        vmax = max(a.max() for a in arrays)

        for a in arrays:
            if not (np.diff(a) == stepsize).all():
                raise ValueError('stepsize of {} not equidistant'.format(name or 'array'))

        return np.arange(vmin, vmax+stepsize, stepsize)

    def join_series(self, series, key=None, name=None):
        if not series:
            raise ValueError('no series given')

        if key is None:
            key = series

        if len(key) != len(series):
            raise ValueError('key must have the same length as the series')

        sorted_key, indices = np.unique(np.concatenate(key), return_index=True)

        return np.concatenate(series)[indices]


def join_grids(gs):
    return Gridset(*gs).join()
