from __future__ import print_function
from . import datasets
from . import analysis
import sys
import os
from .cmd_utils import setup_parser, parse_args, expose, Directory, File, NewFile, RawTextHelpFormatter, ChoiceList
try:
    from PyQt5 import QtWidgets, QtGui
    from matplotlib.backends.backend_qt5agg import FigureCanvas, NavigationToolbar2QT as NavigationToolbar
except ImportError:
    from PyQt4 import QtGui
    QtWidgets = QtGui
    from matplotlib.backends.backend_qt4agg import FigureCanvas, NavigationToolbar2QT as NavigationToolbar
from matplotlib.figure import Figure
from matplotlib import pyplot as plt


PARSER = setup_parser(prog='sandwaves', formatter_class=RawTextHelpFormatter)
_FIGWIDGETS = []


def _search(xll, yll, width=10000, height=10000, collapse='recent'):
    if collapse == 'all':
        GS = datasets.search_stack(xll, yll, width, height)
        print('data stack loaded')
        return GS
    else:
        G = datasets.search(xll, yll, width, height, collapse=collapse)
        print('data loaded')
        return G


def show_figure(fig):
    global _FIGWIDGETS
    w = QtWidgets.QWidget()
    l = QtWidgets.QVBoxLayout(w)
    l.addWidget(NavigationToolbar(fig.canvas, w))
    l.addWidget(fig.canvas)
    w.show()
    _FIGWIDGETS.append(w)


def plt_show():
    fig = plt.gcf()
    show_figure(fig)
plt.show = plt_show


def _filename_suffix(fname, suffix):
    base, ext = os.path.splitext(fname)
    return base+suffix+ext


def _process_grid(G, show=True, datafile=None, imagefile=None):
    print('processing', G, show, datafile, imagefile)
    if isinstance(G, datasets.GridTimeStackDataset):
        for g in G:
            print(g)
            if datafile is None:
                itemdatafile = None
            else:
                try:
                    itemdatafile = _filename_suffix(datafile, '_sid{}'.format(g.sid))
                except AttributeError:
                    itemdatafile = _filename_suffix(datafile, '_t{}'.format(g.time))
            if imagefile is None:
                itemimagefile = None
            else:
                try:
                    itemimagefile = _filename_suffix(imagefile, '_sid{}'.format(g.sid))
                except AttributeError:
                    itemimagefile = _filename_suffix(imagefile, '_t{}'.format(g.time))
            _process_grid(g, show=True, datafile=itemdatafile, imagefile=itemimagefile)
        return

    if show or imagefile:
        fig = Figure(figsize=(8, 6))
        c = G.plot(fig.gca())
        fig.colorbar(c).set_label('depth (m)')

        canvas = FigureCanvas(fig)

        if imagefile:
            fig.savefig(imagefile)
            print('image saved to '+imagefile)

        if show:
            show_figure(fig)

    if datafile:
        G.save(datafile)
        print('data saved to '+datafile)


def _analyse(G, minheight, lmin, lmax):
    if isinstance(G, datasets.GridTimeStackDataset):
        from collections import OrderedDict
        import numpy as np
        d = OrderedDict()

        for g in G:
            out = analysis.shape2D(g, minheight=minheight, lmin=lmin, lmax=lmax)._asdict()
            N = out['x'].size
            out['t'] = np.zeros(N)+g.time
            out['sid'] = np.zeros(N) + g.sid
            for k, v in out.items():
                d[k] = np.concatenate([d.get(k, []), v])
        print('data stack analysed')
        return d
    else:
        d = analysis.shape2D(G, minheight=minheight, lmin=lmin, lmax=lmax)._asdict()
        print('data analysed')
        return d


def _process_analysis_result(G, d, datafile=None, show=True, imagefile=None, property=None):

    if datafile:
        datasets.PointDataset(d, columns=d.keys()).save(datafile)
        print('data saved to '+datafile)

    if property is not None:
        if property not in d.keys():
            print('property not in {}'.format('|'.join(d.keys())))
            return

        if show or imagefile:
            print('plotting results for {}'.format(property))
            fig = Figure(figsize=(8, 6))
            canvas = FigureCanvas(fig)

            if isinstance(G, datasets.GridTimeStackDataset):
                G = G.recent()

            ax = fig.gca()
            G.plot(ax, cmap='gray')

            c = ax.scatter(d['x']*.001, d['y']*.001, c=d[property], cmap='viridis', lw=0, s=5)
            fig.colorbar(c).set_label(property)

            if imagefile:
                fig.savefig(imagefile)
                print('image saved to '+imagefile)

            if show:
                show_figure(fig)


@expose(xll=int, yll=int, width=int, height=int,
        imagefile=NewFile, datafile=NewFile,
        collapse=ChoiceList('recent', 'first', 'second', 'all'))
def search(xll, yll, width=5000, height=5000, collapse='recent', imagefile=None, datafile=None):
    """
    Search for data at the given location with given size.
    :param xll: x position (UTM31) of lower left corner
    :param yll: y position (UTM31) of lower left corner
    :param width: width of the dataset
    :param height: height of the dataset
    :param collapse: method of reducing timesteps to single dataset
    :param imagefile: save the image of the data to a file (show figure if not specified)
    :param datafile: save the data to a file
    """
    G = _search(xll, yll, width, height, collapse=collapse)
    _process_grid(G,
                  show=not imagefile,
                  datafile=datafile,
                  imagefile=imagefile)


@expose(filename=File, imagefile=NewFile)
def open_file(filename, imagefile=None):
    """
    Open a saved gridfile
    :param filename: filename of the grid data
    :param imagefile: save the image of the data to a file (show figure if not specified)
    """
    G = datasets.GridDataset.load(filename)
    _process_grid(G,
                  show=not imagefile,
                  datafile=None,
                  imagefile=imagefile)

@expose(infile=File, datafile=NewFile, imagefile=NewFile, minheight=float, lmin=float, lmax=float,
        property=ChoiceList('height', 'length', 'asymmetry'))
def analyse_file(infile, datafile=None, imagefile=None, minheight=.5, lmin=100, lmax=1000, property='height'):
    """
    analyse shape data from a file and save the result
    :param infile: filename from which to load the data
    :param datafile: file to which to store the result
    :param imagefile: save the image of the result to a file (show figure if not specified)
    :param minheight: minimum sand wave height detected
    :param lmin: minimum bedform length for directional fitting
    :param lmax: maximum bedform length for directional fitting
    :param property: property to show if defined (height | length | asymmetry)
    """

    G = datasets.GridDataset.load(infile)
    d = _analyse(G, minheight=minheight, lmin=lmin, lmax=lmax)
    _process_analysis_result(G, d,
                             datafile=datafile,
                             show=not imagefile,
                             imagefile=imagefile,
                             property=property)


@expose(xll=int, yll=int, width=int, height=int, datafile=NewFile, imagefile=NewFile,
        minheight=float, lmin=float, lmax=float,
        collapse=ChoiceList('recent', 'first', 'second', 'all'),
        property=ChoiceList('height', 'length', 'asymmetry', 'all'))
def analyse(xll, yll, width=5000, height=5000, collapse='recent',
            datafile=None, imagefile=None, minheight=.5, lmin=100, lmax=1000, property=None):
    """
    Search for data at the given location, extract shape data and save the result
    :param xll: x position (UTM31) of lower left corner
    :param yll: y position (UTM31) of lower left corner
    :param width: width of the dataset
    :param height: height of the dataset
    :param collapse: method of reducing timesteps to single dataset
    :param datafile: file to which to store the result
    :param imagefile: save the image of the result to a file (show figure if not defined)
    :param minheight: minimum sand wave height detected
    :param lmin: minimum bedform length for directional fitting
    :param lmax: maximum bedform length for directional fitting
    :param property: property to show if defined
    """
    G = _search(xll, yll, width, height, collapse=collapse)
    d = _analyse(G, minheight=minheight, lmin=lmin, lmax=lmax)
    _process_analysis_result(G, d,
                             datafile=datafile,
                             show=not imagefile,
                             imagefile=imagefile,
                             property=property)


@expose(direct=True)
def transects():
    """Extract transects from a dataset"""
    from .tools.transect_selector import TransectSelector
    w = TransectSelector()
    return w


@expose(name=ChoiceList('shape', 'migration', 'growth'), location=Directory,
        xll=int, yll=int, xur=int, yur=int,
        blocksize=int, margin=int, progresslog=NewFile, reset=bool,
        debug=bool, configfile=File)
def batch(name, location, xll, yll, xur, yur, blocksize=5000, margin=0,
          progresslog=None, reset=True, debug=False, configfile=None):
    """
    Analyse several datasets of bathymetry to extract bed features and save the results
    :param name: name of the batchrun (shape|growth|migration NB: migration not validated)
    :param location: output storage location (directory)
    :param xll: x coord of lower left corner of the area to analyse
    :param yll: y coord of lower left corner of the area to analyse
    :param xur: x coord of upper right corner of the area to analyse
    :param yur: y coord of upper right corner of the area to analyse
    :param blocksize: size of the blocks for subdividing the area
    :param margin: size of the margin to add to each block
    :param progresslog: filename of the log file to store progress
    :param reset: restart all locations or only run failed locations
    :param debug: run with max logging
    :param configfile: configuration file (json) with additional configuration options
    """

    from . import batchruns
    import logging
    from .__init__ import file_logging

    if debug:
        loglevel = 'DEBUG'
    else:
        loglevel = 'WARNING'

    config = {}
    if configfile:
        import json
        try:
            with open(configfile, 'r') as f:
                dat = json.load(f)
                if not isinstance(dat, dict):
                    raise TypeError('data not a dictionary')
                config.update(dat)
        except (IOError, ValueError, TypeError) as e:
            raise ValueError('invalid configuration file ({})'.format(e))

    config.update(xll=xll, yll=yll, xur=xur, yur=yur,
                  progresslog=progresslog, continue_previous=not reset,
                  blocksize=blocksize, margin=margin)

    file_logging(level=loglevel, names=['sandwaves'], filename=os.path.join(location, 'batchrun.log'), reset=reset)
    logger = logging.getLogger('sandwaves.batchruns.config')
    logger.warning('running batchrun for {}'.format(name))
    for k in sorted(config.keys()):
        logger.warning('{:<20s} :: {}'.format(k, config[k]))

    print('starting batchrun {}'.format(name))
    print('this may take some time...')
    batchruns.get_by_name(name)(location, **config)
    print('done')


def short_desc(fn, indent=15):
    fmt = '{namelabel} {docstr}\n{indent}For more information:\n{indent}>python -m sandwaves {name} --help'
    return fmt.format(
        indent=' '*indent,
        name=fn.__name__.replace('_', '-'),
        namelabel=fn.__name__.replace('_', '-').ljust(indent-1),
        docstr=fn.__doc__.strip().splitlines()[0].strip())


__doc__ = '''
This package was written for the loading and analysis of sand waves in the Northsea

{search}

{open_file}

{analyse}

{analyse_file}

{batch}

{transects}

gui            Use the graphical interface to enter arguments
               >python -m sandwaves gui
               >python -m sandwaves gui search
'''.format(search=short_desc(search),
           open_file=short_desc(open_file),
           analyse=short_desc(analyse),
           batch=short_desc(batch),
           analyse_file=short_desc(analyse_file),
           transects=short_desc(transects))

PARSER.description = __doc__


if __name__ == '__main__':
    import warnings
    warnings.filterwarnings("ignore")

    app = QtWidgets.QApplication([])
    w = parse_args(sys.argv[1:])

    try:
        w.show()
    except AttributeError:
        if not _FIGWIDGETS:
            sys.exit(0)

    sys.exit(app.exec_())

