import numpy as np
from matplotlib.colors import ListedColormap, LinearSegmentedColormap
from matplotlib import cm
from .cmaps import viridis_ext



def _add_gridlines(ax, color='k', lw=0.5, alpha=1, **kw):
    for v in ax.get_xticks():
        ax.axvline(v, color=color, lw=lw, alpha=alpha, **kw)
    for v in ax.get_yticks():
        ax.axhline(v, color=color, lw=lw, alpha=alpha, **kw)


def plot_grid(ax, G, gridlines=True, rasterized=True, scale=1, cmap=viridis_ext, griddict=None, **kwargs):
    c = ax.pcolormesh(G.x*scale, G.y*scale, G.Z, cmap=cmap)
    c.set_rasterized(rasterized)

    ax.set_xlim(G.x.min()*scale, G.x.max()*scale)
    ax.set_ylim(G.y.min()*scale, G.y.max()*scale)
    if gridlines:
        _add_gridlines(ax, **(griddict or {}))

    return c


def plot_depth(ax, x, y, z, scale=.001, tick_interval=None, rasterized=True,
               lw=.5, la=.5, lc=(.2, .2, .2), cmap=viridis_ext,
               reverse_depth=False, gridlines=True, title=None, **kwargs):
    if reverse_depth:
        z = np.absolute(z)
        if isinstance(cmap, str):
            if cmap.endswith('_r'):
                cmap = cmap[:-2]
            else:
                cmap += '_r'
        else:
            if isinstance(cmap, ListedColormap):
                cmap = ListedColormap(cmap.colors[::-1])
            elif isinstance(cmap, LinearSegmentedColormap):
                cmap = LinearSegmentedColormap(cmap.name+'_r', cm.revcmap(cmap._segmentdata))

    # plot data
    dx = np.diff(x[-2:])[0]
    dy = np.diff(y[-2:])[0]
    x_ext = np.append(x, x[-1]+dx)
    y_ext = np.append(y, y[-1]+dy)

    c = ax.pcolormesh(x_ext*scale, y_ext*scale, np.ma.masked_invalid(z), cmap=cmap, **kwargs)
    c.set_rasterized(rasterized)

    # set axis properties
    if scale == .001:
        unit = 'km'
    elif scale == 1:
        unit = 'm'
    else:
        unit = '{:.1f} m'.format(1. / scale)

    ax.set_xlabel('easting ({unit})'.format(unit=unit))
    ax.set_ylabel('northing ({unit})'.format(unit=unit))
    ax.set_aspect('equal')

    if tick_interval is None:
        span = max(x.max() - x.min(), y.max() - y.min())
        tick_interval = np.around(span/5000)*1000

    xmin_r = np.ceil(x_ext[0] / tick_interval) * tick_interval
    xmax_r = np.ceil((x_ext[-1]+dx) / tick_interval) * tick_interval

    ymin_r = np.ceil(y_ext[0] / tick_interval) * tick_interval
    ymax_r = np.ceil((y_ext[-1]+dy) / tick_interval) * tick_interval

    xticks = np.arange(xmin_r*scale, xmax_r*scale, tick_interval*scale)
    yticks = np.arange(ymin_r*scale, ymax_r*scale, tick_interval*scale)

    ax.set_xticks(xticks)
    ax.set_yticks(yticks)

    ax.set_xlim(x_ext[0]*scale, x_ext[-1]*scale)
    ax.set_ylim(y_ext[0]*scale, y_ext[-1]*scale)

    # add major grid lines
    if gridlines:
        _add_gridlines(ax, color=lc, lw=lw, alpha=la, ls='-')

    if title is not None:
        ax.set_title(str(title))

    return c