import json
import datetime
import time
import numpy as np
import logging
from .. import datasets
from .. import settings


__all__ = ['InvalidData',
           'DataBlock', 'GridDataBlock', 'GridStackDataBlock', 'DuoGridDataBlock',
           'DataIterable']


class InvalidData(Exception): pass


class DataBlock(object):
    """
    defines a block of data with a margin around it to prevent problems at data edges
    provides functionality to remove data at the margin border

    :param locations: list of locations to load data from.
    :param name: name of the block
    :param xll: x coord of lower left corner of the data block (no margin)
    :param yll: y coord of lower left corner of the data block (no margin)
    :param w: width of the data block (no margin)
    :param h: height of the data block (no margin)
    :param margin: margin around the data block
    :param opts: optional keyword arguments such as metadata (currently unused)
    """

    def __init__(self, locations, name, xll, yll, w, h, margin, **opts):
        self.locations = locations
        self.name = name
        self.xll = xll
        self.yll = yll
        self.w = w
        self.h = h
        self.margin = margin
        self.opts = opts

    def load(self):
        return vars(self)

    def __repr__(self):
        return '<{0.__class__.__name__} {0.name}>'.format(self)


class GridDataBlock(DataBlock):

    def load(self):
        G = datasets.join_grids([datasets.GridDataset.load(loc) for loc in self.locations])
        return G.subset(self.xll-self.margin,
                        self.yll-self.margin,
                        self.w+2*self.margin,
                        self.h+2*self.margin)

    def mask_gridded(self, x, y, *args):
        """
        apply a mask to remove the margin from gridded data
        :param x: 1D array of x positions
        :param y: 1D array of y positions
        :param args: 2D grids of values
        :return: masked x, y, *args
        """
        xmask = (x >= self.xll) & (x < self.xll + self.w)
        ymask = (y >= self.yll) & (y < self.yll + self.h)
        Ymask, Xmask = np.meshgrid(ymask, xmask)
        submask = (Xmask & Ymask).astype(bool)
        return tuple([x[xmask], y[ymask]] + [a[submask].reshape(ymask.sum(), xmask.sum()) for a in args])

    def mask_points(self, x, y, *args):
        """
        apply a mask to remove the margin from point data
        :param x: 1D array of x positions
        :param y: 1D array of y positions
        :param args: 1D arrays of values
        :return: masked x, y, *args
        """
        xmask = (x >= self.xll) & (x < self.xll + self.w)
        ymask = (y >= self.yll) & (y < self.yll + self.h)
        mask = xmask & ymask
        return tuple([x[mask], y[mask]] + [a[mask] for a in args])


class GridStackDataBlock(GridDataBlock):

    def load(self):
        return datasets.join_grids([datasets.GridTimeStackDataset.load(loc, max_timesteps=20) for loc in self.locations])


class DuoGridDataBlock(GridDataBlock):

    def load(self):
        if len(self.locations) != 1:
            raise NotImplementedError('cannot load more than 1 GridTimeStackDataset for temporal analysis')

        G = datasets.GridTimeStackDataset.load(self.locations[0], max_timesteps=20)
        G.combine_similar_timesteps()
        G.filter_by_coverage(0.3)
        if G.shape[0] < 2:
            raise InvalidData('too few timesteps')
        return G[-2], G[-1]


class DataIterable(object):

    """
    Data iterable for use in batch runs
    :param xll: x coord of lower left corner of area to analyse
    :param yll: y coord of lower left corner of area to analyse
    :param xur: x coord of upper right corner of area to analyse
    :param yur: y coord of upper right corner of area to analyse
    :param blocksize: size (m) of the individual blocks (width and height equal)
    :param margin: margin around the blocks (m)
    :param blocks: block definitions (see DataIterable.search) or filename of the block definitions
    :param blockcls: class of the blocks
    :param preload_cachedir: preload the datafiles to a directory
    :param reset: reset the preloaded datafiles
    :param max_count: break after <max_count> iterations
    """

    def __init__(self,
                 xll=450000, yll=5700000,
                 xur=700000, yur=5950000,
                 blocksize=10000, margin=1000,
                 blocks=None, blockcls=GridDataBlock,
                 preload_cachedir=None, reset=False, max_count=0):
        if blocks is None:
            blocks = self.search(xll=xll, yll=yll, xur=xur, yur=yur, w=blocksize, h=blocksize, margin=margin)
        elif isinstance(blocks, str):
            blocks = self.blocks_from_file(blocks)
        self.blocks = blocks
        self.preload_cachedir = preload_cachedir
        self.reset = reset
        self.blockcls = blockcls
        self.max_count = max_count

    @classmethod
    def blocks_from_file(cls, filename):
        """
        load the block definitions from a file
        """
        with open(filename, 'r') as f:
            header = f.readline()
            data = json.load(f)

        for item in data:
            r = datasets.search_nc(xll=item['xll'], yll=item['yll'], w=item['w'], h=item['h'])
            margin = item['margin']
            name = item['name']
            yield r, margin, name

    @classmethod
    def save_search(cls, setup_file, **kwargs):
        """search block locations and save to file"""
        with open(setup_file, 'w') as f:
            f.write('blocks {!s}\n'.format(datetime.datetime.now()))
            items = []
            for data in cls.search(**kwargs):
                items.append(data)
            json.dump(items, f)

    @classmethod
    def search(cls, xll, yll, xur, yur, w, h, margin, **kw):
        """
        search blocks of a specified size within a specified area
        :param xll: x coord of the lower left corner of the area
        :param yll: y coord of the lower left corner of the area
        :param xur: x coord of the upper right corner of the area
        :param yur: y coord of the upper right corner of the area
        :param w: width of the datablocks
        :param h: height of the datablocks
        :param margin: margin around the datablocks
        :param kw: arguments added to the block definitions for each item
        :return: dict of locations, xll, yll, w, h, margin and name for each block
        """
        logging.getLogger(__name__).info('searching in {}>x>{} {}>y>{} for blocks of {} by {} km'.format(
                xll*.001, xur*.001, yll*.001, yur*.001, w*.001, h*.001))
        for _xll in range(xll, xur, w):
            for _yll in np.arange(yll, yur, h):
                locations = datasets.search_nc(
                    xll=_xll-margin,
                    yll=_yll-margin,
                    w=w+2*margin,
                    h=h+2*margin)
                if not locations:
                    continue
                name = 'xll{xll}yll{yll}w{w}h{h}'.format(xll=_xll, yll=_yll, w=w, h=h)
                data = kw.copy()
                data.update(dict(locations=locations, xll=_xll, yll=_yll, w=w, h=h, margin=margin, name=name))
                yield data

    def __iter__(self):
        """yield DataBlock instance for each iteration"""
        count = 0
        if self.preload_cachedir:
            if isinstance(self.preload_cachedir, str):
                cachedir = self.preload_cachedir
            else:
                cachedir = None
            datasets.preload_nc(cachedir, dodsC2fileServer=True, reset=self.reset)

        for i, data in enumerate(self.blocks):
            if self.max_count and count >= self.max_count:
                break

            yield self.blockcls(**data)
            count += 1
