"""
Interactive tool to draw mask on an image or image-like array.
Adapted from https://gist.github.com/tonysyu/3090704 which in turn
is adapted from matplotlib/examples/event_handling/poly_editor.py
"""
from __future__ import print_function, division
import numpy as np

import matplotlib.pyplot as plt
from matplotlib.patches import Polygon
from matplotlib.mlab import dist_point_to_segment
from matplotlib import path as mpl_path


class MaskCreator(object):
    """An interactive polygon editor.
    Parameters
    ----------
    poly_xy : list of (float, float)
        List of (x, y) coordinates used as vertices of the polygon.
    max_ds : float
        Max pixel distance to count as a vertex hit.
    Key-bindings
    ------------
    't' : toggle vertex markers on and off.  When vertex markers are on,
          you can move them, delete them
    'd' : delete the vertex under point
    'a' : add a vertex at point.  You must be within max_ds of the
          line connecting two existing vertices
    'r' :

    """

    description = "Click and drag a point to move it\n"\
                  "Press 'a' to add a point at the cursor position and 'd' to delete a selected point\n"\
                  "Press 'r' to reset the polygon and 'i' to switch between including (yellow) or excluding (red) all points in the polygon"

    def __init__(self, ax, canvas=None, poly_xy=None, max_ds=10, invert=False):
        self.invert = invert
        self.showverts = True
        self.max_ds = max_ds
        if poly_xy is None:
            poly_xy = default_vertices(ax)
        self.poly = Polygon(poly_xy, animated=True,
                            fc=(.8, .8, .2), ec='none', alpha=0.2)

        ax.add_patch(self.poly)
        ax.set_clip_on(False)
        self.ax = ax

        x, y = zip(*self.poly.xy)
        self.line = plt.Line2D(x, y, color=self.edgecolor, lw=.5, marker='o', mfc=self.edgecolor,
                               alpha=0.8, animated=True)
        self._update_line()
        self.ax.add_line(self.line)

        self.poly.add_callback(self.poly_changed)
        self._ind = None # the active vert

        if canvas is None:
            canvas = self.poly.figure.canvas
        self.canvas = canvas

        self.canvas.mpl_connect('draw_event', self.draw_callback)
        self.canvas.mpl_connect('button_press_event', self.button_press_callback)
        self.canvas.mpl_connect('button_release_event', self.button_release_callback)
        self.canvas.mpl_connect('key_press_event', self.key_press_callback)
        self.canvas.mpl_connect('motion_notify_event', self.motion_notify_callback)

    @property
    def edgecolor(self):
        if self.invert:
            return (1, .4, .4)
        else:
            return (1, 1, .4)

    def get_mask(self, shape):
        """Return image mask given by mask creator"""
        h, w = shape
        y, x = np.mgrid[:h, :w]
        points = np.transpose((x.ravel(), y.ravel()))
        mask = mpl_path.Path(self.verts).contains_points(points)
        M = mask.reshape(h, w)
        if self.invert:
            M = np.invert(M)

    def poly_changed(self, poly):
        """this method is called whenever the polygon object is called"""
        # only copy the artist props to the line (except visibility)
        vis = self.line.get_visible()
        #Artist.update_from(self.line, poly)
        self.line.set_visible(vis)  # don't use the poly visibility state

    def draw_callback(self, event):
        self.background = self.canvas.copy_from_bbox(self.ax.bbox)
        self.redraw()
        # self.ax.draw_artist(self.poly)
        # self.ax.draw_artist(self.line)
        # self.canvas.blit(self.ax.bbox)

    def button_press_callback(self, event):
        """whenever a mouse button is pressed"""
        ignore = not self.showverts or event.inaxes is None or event.button != 1
        if ignore:
            return
        self._ind = self.get_ind_under_cursor(event)

    def button_release_callback(self, event):
        """whenever a mouse button is released"""
        ignore = not self.showverts or event.button != 1
        if ignore:
            return
        self._ind = None

    def key_press_callback(self, event):
        """whenever a key is pressed"""
        if not event.inaxes:
            return
        if event.key=='t':
            self.toggle_vertices_visibility(event)
        elif event.key=='d':
            self.delete_vertex(event)
        elif event.key=='a':
            self.insert_vertex(event)
        elif event.key=='r':
            self.reset(event)
        elif event.key=='enter':
            self.process(event)
        elif event.key=='i':
            self.toggle_invert(event)

        self.canvas.draw()

    def toggle_vertices_visibility(self, event):
        self.showverts = not self.showverts
        self.line.set_visible(self.showverts)
        if not self.showverts:
            self._ind = None

    def delete_vertex(self, event):
        ind = self.get_ind_under_cursor(event)
        if ind is None:
            return
        if ind == 0 or ind == self.last_vert_ind:
            print("Cannot delete root node")
            return
        self.poly.xy = [tup for i,tup in enumerate(self.poly.xy)
                            if i!=ind]
        self._update_line()

    def reset(self, event):
        self.poly.xy = [self.poly.xy[0], self.poly.xy[-1]]
        self._update_line()

    def insert_vertex(self, event):
        xys = self.poly.get_transform().transform(self.poly.xy)
        p = event.x, event.y # cursor coords
        for i in range(len(xys)-1):
            s0 = xys[i]
            s1 = xys[i+1]
            d = dist_point_to_segment(p, s0, s1)
            if d <= self.max_ds:
                self.poly.xy = np.array(
                    list(self.poly.xy[:i+1]) +
                    [(event.xdata, event.ydata)] +
                    list(self.poly.xy[i+1:]))
                self._update_line()
                break

    def process(self, event):
        pass

    def toggle_invert(self, event):
        self.invert = not self.invert
        self.line.set_color(self.edgecolor)
        self.line.set_markerfacecolor(self.edgecolor)

    def motion_notify_callback(self, event):
        'on mouse movement'
        ignore = (not self.showverts or event.inaxes is None or
                  event.button != 1 or self._ind is None)
        if ignore:
            return
        x,y = event.xdata, event.ydata

        if self._ind == 0 or self._ind == self.last_vert_ind:
            self.poly.xy[0] = x,y
            self.poly.xy[self.last_vert_ind] = x,y
        else:
            self.poly.xy[self._ind] = x,y
        self._update_line()

        self.canvas.restore_region(self.background)
        self.redraw()

    def redraw(self):
        self.ax.draw_artist(self.poly)
        self.ax.draw_artist(self.line)
        self.canvas.blit(self.ax.bbox)

    def _update_line(self):
        # save verts because polygon gets deleted when figure is closed
        self.verts = self.poly.xy
        self.last_vert_ind = len(self.poly.xy) - 1
        self.line.set_data(zip(*self.poly.xy))

    def get_ind_under_cursor(self, event):
        'get the index of the vertex under cursor if within max_ds tolerance'
        # display coords
        xy = np.asarray(self.poly.xy)
        xyt = self.poly.get_transform().transform(xy)
        xt, yt = xyt[:, 0], xyt[:, 1]
        d = np.sqrt((xt - event.x)**2 + (yt - event.y)**2)
        indseq = np.nonzero(np.equal(d, np.amin(d)))[0]
        ind = indseq[0]
        if d[ind] >= self.max_ds:
            ind = None
        return ind


class FourierMaskCreator(MaskCreator):

    def __init__(self, ax, shape, kx, ky, callback, **kwargs):
        super(FourierMaskCreator, self).__init__(ax, **kwargs)
        self.shape = shape
        self.kx = kx
        self.ky = ky
        self.callback = callback

    def get_mask(self):
        h, w = self.shape

        x = np.linspace(np.amin(self.kx), np.amax(self.kx), w)
        y = np.linspace(np.amin(self.ky), np.amax(self.ky), h)

        X, Y = np.meshgrid(x, y)

        points = np.transpose((X.ravel(), Y.ravel()))
        mask = mpl_path.Path(self.verts).contains_points(points)

        M = mask.reshape(h, w)
        M = np.logical_or(M, np.fliplr(np.flipud(M)))

        if self.invert:
            M = np.invert(M)

        return M

    def process(self, event):
        m = self.get_mask()
        self.callback(m)


def default_vertices(ax):
    """Default to rectangle that has a quarter-width/height border."""
    xlims = ax.get_xlim()
    ylims = ax.get_ylim()
    w = np.diff(xlims)
    h = np.diff(ylims)
    x1, x2 = xlims + w / 4 * np.array([1, -1])
    y1, y2 = ylims + h / 4 * np.array([1, -1])
    return ((x1, y1), (x1, y2), (x2, y2), (x2, y1))


def mask_creator_demo():
    img = np.random.uniform(0, 255, size=(90, 100))
    ax = plt.subplot(111)
    ax.pcolormesh(img, cmap='gray')

    mc = MaskCreator(ax)
    plt.show()

    mask = mc.get_mask(img.shape)
    img[~mask] = np.uint8(np.clip(img[~mask] - 100., 0, 255))
    plt.pcolormesh(img, cmap='gray')
    plt.title('Region outside of mask is darkened')
    plt.show()


if __name__ == '__main__':
    mask_creator_demo()
