from __future__ import print_function
try:
    from PyQt5 import QtWidgets, QtGui, QtCore
    from matplotlib.backends.backend_qt5agg import FigureCanvas
except ImportError:
    from PyQt4 import QtGui, QtCore
    QtWidgets = QtGui
    from matplotlib.backends.backend_qt4agg import FigureCanvas
from matplotlib.figure import Figure
from ..analysis.rotate import rotate
from ..analysis import orientation
from ..plotting.cmaps import viridis
import numpy as np
from ..datasets import search, GridDataset
import os
import pandas as pd


class InvalidTransects(Exception): pass


def line(orientation='h'):
    f = QtWidgets.QFrame()
    f.setLineWidth(1)
    if orientation =='h':
        f.setFrameShape(QtWidgets.QFrame.HLine)
    elif orientation == 'v':
        f.setFrameShape(QtWidgets.QFrame.VLine)
    else:
        raise KeyError(orientation)
    f.setFrameShadow(QtWidgets.QFrame.Raised)
    return f


class TransectSelector(QtWidgets.QWidget):

    def __init__(self, parent=None):
        super(TransectSelector, self).__init__(parent)
        self.transect_indices = dict()
        self.last_data = dict()
        self.griddata = None
        self.rotated_data = None
        self.build()
        self.plot_grid()
        self.plot_transects()

    def make_header(self, val):
        font = QtGui.QFont()
        font.setPointSize(10)
        label = QtWidgets.QLabel(val)
        label.setFont(font)
        return label

    def build(self):
        self.layout = QtWidgets.QHBoxLayout(self)
        self.build_form(self.layout)
        self.layout.addWidget(line('v'))
        self.build_plotcanvas(self.layout)
        self.draw()

    def build_search_form(self, layout):
        layout.addRow(self.make_header('1A: Search'))
        self.fields = {}
        for name, default in zip(['xll', 'yll', 'w', 'h'], ['', '', 5000, 5000]):
            w = self.fields[name] = QtWidgets.QLineEdit()
            w.setText(str(default))
            layout.addRow(name, w)
        self.search_button = QtWidgets.QPushButton('search')
        self.search_button.clicked.connect(self.search)
        layout.addWidget(self.search_button)

    def build_load_form(self, layout):
        layout.addRow(self.make_header('1B: Load from file'))
        self.datafile_layout = QtWidgets.QHBoxLayout()
        self.datafile_input = QtWidgets.QLineEdit()
        self.datafile_browse = QtWidgets.QPushButton('browse')
        self.datafile_browse.clicked.connect(self.browse_datafile)
        self.datafile_layout.addWidget(self.datafile_input)
        self.datafile_layout.addWidget(self.datafile_browse)
        layout.addRow('datafile', self.datafile_layout)

    def build_orientation_form(self, layout):
        layout.addRow(self.make_header('2: Select transect orientation'))
        for name, default in zip(['lmin', 'lmax'], [100, 1000]):
            w = self.fields[name] = QtWidgets.QLineEdit()
            w.setText(str(default))
            layout.addRow(name, w)
        self.analyse_button = QtWidgets.QPushButton('analyse')
        self.analyse_button.clicked.connect(self.analyse)
        layout.addWidget(self.analyse_button)

        value_layout = QtWidgets.QHBoxLayout()
        self.orientation_slider = QtWidgets.QSlider(QtCore.Qt.Horizontal)
        self.orientation_slider.setTickPosition(QtWidgets.QSlider.TicksBelow)
        self.orientation_slider.setTickInterval(45)
        self.orientation_slider.setMinimum(0)
        self.orientation_slider.setMaximum(180)
        self.orientation_slider.setValue(0)
        self.orientation_slider.sliderReleased.connect(self.orientation_changed)
        value_layout.addWidget(self.orientation_slider)
        self.orientation_label = QtWidgets.QLabel('0 deg')
        value_layout.addWidget(self.orientation_label)
        layout.addRow('orientation', value_layout)

    def build_select_transect_form(self, layout):
        layout.addRow(self.make_header('2: Select transects'))
        self.tr_index_input = QtWidgets.QLineEdit()
        self.tr_index_input.editingFinished.connect(self.indices_changed)
        layout.addRow('indices', self.tr_index_input)

    def build_save_transect_form(self, layout):
        layout.addRow(self.make_header('3: Save transect data'))
        self.filename_layout = QtWidgets.QHBoxLayout()
        self.filename_input = QtWidgets.QLineEdit()
        self.filename_browse = QtWidgets.QPushButton('browse')
        self.filename_browse.clicked.connect(self.browse_filename)
        self.filename_layout.addWidget(self.filename_input)
        self.filename_layout.addWidget(self.filename_browse)
        layout.addRow('filename', self.filename_layout)

        self.save_button = QtWidgets.QPushButton('save')
        self.save_button.clicked.connect(self.save)
        layout.addWidget(self.save_button)

    def build_form(self, layout):
        form = QtWidgets.QFormLayout()

        xl_font = QtGui.QFont()
        xl_font.setPointSize(12)
        title = QtWidgets.QLabel('Transect selector')
        title.setFont(xl_font)
        form.addRow(title)

        form.addRow(line())
        self.build_search_form(form)
        form.addRow(line())
        self.build_load_form(form)
        form.addRow(line())
        self.build_orientation_form(form)
        form.addRow(line())
        self.build_select_transect_form(form)
        form.addRow(line())
        self.build_save_transect_form(form)
        layout.addLayout(form)

    def build_plotcanvas(self, layout):
        self.fig = Figure(figsize=(4, 5.33), facecolor='none')
        canvas = FigureCanvas(self.fig)
        self.grid_ax = self.fig.add_axes([0, .25, 1, .75])
        self.tr_ax = self.fig.add_axes([0, 0, 1, .22])
        self.tr_ax.axis('off')
        canvas.mpl_connect('button_press_event', self.click_add_transect)
        layout.addWidget(self.fig.canvas)

    def browse_filename(self):
        f = QtWidgets.QFileDialog.getSaveFileName()
        if isinstance(f, tuple):
            f = f[0]
        self.filename_input.setText(f)

    def browse_datafile(self):
        f = QtWidgets.QFileDialog.getOpenFileName()
        if isinstance(f, tuple):
            f = f[0]
        self.datafile_input.setText(f)
        self.load()

    @property
    def indices(self):
        indices_text = str(self.tr_index_input.text())
        for istr in map(lambda x: x.strip(), indices_text.split(',')):
            if not istr:
                continue
            try:
                I = int(istr)
            except ValueError:
                raise InvalidTransects()
            else:
                yield I

    def save(self):
        tri = list(self.indices)
        if self.rotated_data is None:
            return
        rotX, rotY, rotZ = self.rotated_data

        fname = str(self.filename_input.text())
        if not os.path.isdir(os.path.dirname(fname)):
            self.show_error('cannot create file in non-existent directory')
            return
        if not tri:
            self.show_error('no transects selected')
            return
        with open(fname, 'w') as f:
            f.write('xll={g.xll} yll={g.yll} w={g.w} h={g.h} orientation={o:.1f}deg\n'.format(
                g=self.griddata, o=self.orientation))
            for I in tri:
                f.write('{}\n'.format(I))
                x, y, z = rotX[:, I], rotY[:, I], rotZ[:, I]
                m = ~(x.mask | y.mask)
                df = pd.DataFrame(dict(x=x[m], y=y[m], z=z[m].filled(np.nan)), columns=['x', 'y', 'z'])
                df.to_csv(f, sep=',', index=False, float_format='%.3f')

    def search(self):
        try:
            data = {k: int(v.text()) for k, v in self.fields.items()}
            self.griddata = search(xll=data['xll'], yll=data['yll'], w=data['w'], h=data['h'])
        except Exception as e:
            self.show_error(str(e))
            self.griddata = None

        self.grid_changed()

    def load(self):
        filename = str(self.datafile_input.text())
        if not filename:
            return
        elif not os.path.isfile(filename):
            self.show_error('file {} not found'.format(filename))
            return

        try:
            self.griddata = GridDataset.load(filename)
        except Exception as e:
            self.show_error(str(e))
            self.griddata = None

        self.grid_changed()

    @property
    def orientation(self):
        return self.orientation_slider.value()
    @orientation.setter
    def orientation(self, val):
        self.orientation_slider.setValue(int(val))

    def analyse(self):
        if self.griddata is None:
            return
        lmin = int(str(self.fields['lmin'].text()))
        lmax = int(str(self.fields['lmax'].text()))
        self.orientation = orientation(self.griddata, lmin=lmin, lmax=lmax)
        self.last_data['lmin'] = lmin
        self.last_data['lmax'] = lmax
        self.orientation_changed()

    def orientation_changed(self):
        self.orientation_label.setText('{} deg'.format(self.orientation))
        self.rotated_data = tuple(map(np.ma.masked_invalid, rotate(
            self.griddata.x, self.griddata.y, self.griddata.Z, self.orientation)))
        self.indices_changed()

    def grid_changed(self):
        self.plot_grid()
        self.analyse()

    def show_error(self, msg):
        QtWidgets.QMessageBox.warning(self, 'Error', msg)

    def indices_changed(self):
        try:
            indices_text = str(self.tr_index_input.text())
            if not indices_text:
                self.last_data.pop('indices', None)
                return self.plot_transects([])
            self.plot_transects(self.indices)
            self.last_data['indices'] = indices_text
        finally:
            self.tr_index_input.blockSignals(False)

    def plot_grid(self):
        self.grid_ax.clear()
        if self.griddata is not None:
            self.grid_ax.pcolormesh(self.griddata.x, self.griddata.y, self.griddata.Z, cmap=viridis)
        self.grid_ax.axis('off')
        self.grid_ax.set_aspect('equal')
        self.transect_indices = {}
        self.clear_transects()
        self.draw()

    def clear_transects(self):
        keys = list(self.transect_indices.keys())
        for I in keys:
            for h in self.transect_indices.pop(I, ()):
                try:
                    h.remove()
                except ValueError:
                    pass

        self.tr_ax.clear()
        self.tr_ax.axis('off')

    def _plot_transect(self, I, rotX, rotY, rotPos, rotZ):
        try:
            x, y, pos, z = rotX[:, I], rotY[:, I], rotPos[:, I], rotZ[:, I]
        except IndexError:
            self.show_error('transect {} out of range (0-{})'.format(I, rotPos.shape[1] - 1))
            return

        if pos.mask.all():
            self.show_error('transect {} is empty'.format(I))
            return

        h1, = self.grid_ax.plot(x, y, color='r', lw=1)
        h2, = self.tr_ax.plot(pos, z, color='k', lw=1)
        t = self.grid_ax.text(x.mean(), y.mean(), str(I))
        self.transect_indices[I] = h1, h2, t

    def plot_transects(self, indices=()):
        self.clear_transects()
        if self.griddata is None or self.rotated_data is None:
            return
        try:
            rotX, rotY, rotZ = self.rotated_data
            rotPos = (rotX ** 2 + rotY ** 2) ** .5
            rotPos -= rotPos.mean()
            for I in indices:
                self._plot_transect(I, rotX, rotY, rotPos, rotZ)
        except InvalidTransects:
            self.show_error('invalid transects input')
            self.tr_index_input.setText('')
        self.draw()

    def draw(self):
        self.fig.canvas.draw()

    def click_add_transect(self, event):
        x, y, ax = event.xdata, event.ydata, event.inaxes
        if ax != self.grid_ax:
            return
        if self.griddata is None:
            return
        if self.rotated_data is None:
            self.orientation_changed()
        rotX, rotY = self.rotated_data[:2]
        dx = np.absolute(rotX - x)
        dy = np.absolute(rotY - y)
        flat_index = np.argmin(dx+dy)
        _, I = np.unravel_index(flat_index, rotX.shape)
        if I in self.transect_indices:
            return
        tri = list(self.indices)+[I]
        self.tr_index_input.setText(', '.join(map(str, tri)))
        self.indices_changed()
