# -*- coding: utf-8 -*-
"""
@author: Avelon Gerritsma

"""

import matplotlib.pyplot as plt
from matplotlib import gridspec
import cmocean
import xarray as xr
import numpy as np
import pandas as pd

# open dataset figure 7 and bedlevel and distance information
uds_reg = xr.open_dataset("crs_data_figure7.nc")
bedlevel = xr.open_dataarray("bedlevel_crs.nc")
distance = xr.open_dataarray("distance_crs.nc")

#plot settings
params = {"ytick.color" : "black",
          "xtick.color" : "black",
          "axes.labelcolor" : "black",
          "axes.edgecolor" : "black",
          "font.family" : "sans-serif",
          "font.size": 12,
          "font.sans-serif" : "Arial"}
plt.rcParams.update(params)

#set limit for colorbar and choose variable, in this case salinity
clims=[0,15]
varname= 'mesh2d_sa1'
gridname='mesh2d'

#create figure 7
fig = plt.figure(figsize=(10, 12))
gs = gridspec.GridSpec(ncols=1, nrows=6, figure=fig, hspace=0.25)

for timestep, time in enumerate(uds_reg.time.values):
    data_timestep = uds_reg.sel(time=time)
    ax = plt.Subplot(fig, gs[timestep])
    pc = data_timestep[varname].plot.pcolormesh(ax=ax, x=f"{gridname}_distance", y=f"{gridname}_depth", cmap=cmocean.cm.haline, add_labels=False, add_colorbar=False, zorder=0)
    cts = data_timestep[varname].plot.contour(ax=ax, x=f"{gridname}_distance", y=f"{gridname}_depth", colors='white', add_labels=False, add_colorbar=False, vmin=0, vmax=15, levels=16, zorder=1, linewidths=0.8)
    pc.set_clim(clims)
    ymin= np.nanmin(bedlevel)
    ax.fill_between(distance, bedlevel+0.5, [ymin for x in np.arange(0, len(bedlevel), 1)], color='white', zorder=2)
    ax.plot(distance, bedlevel+0.5, c='k', lw=1.25)
    ax.set_ylabel('z [m]', fontsize=12)
    ax.set_yticks([0,-10, -20, -30])
    ax.set_xlabel('Distance along thalweg [m]', fontsize=12)
    ax.tick_params(labelsize=12)
    ax.text(0.03, 0.05, r'$\bf{{{t}}}$' + ' = {}'.format(pd.to_datetime(time).strftime('%Y-%m-%d %H:%M')), transform=ax.transAxes, fontsize=12, color='black')
    fig.add_subplot(ax)


# open dataset for figure 9
uds_reg = xr.open_dataset("crs_data_figure9.nc")

#set limit for colorbar and choose variable, in this case salinity
clims=[0.1,0.45]
varname= 'mesh2d_sa1'
gridname='mesh2d'

#create figure 9
fig = plt.figure(figsize=(10, 8))
gs = gridspec.GridSpec(ncols=1, nrows=4, figure=fig, hspace=0.2)

for timestep, time in enumerate(uds_reg.time.values):
    data_timestep = uds_reg.sel(time=time)
    ax = plt.Subplot(fig, gs[timestep])
    pc = data_timestep[varname].plot.pcolormesh(ax=ax, x=f"{gridname}_distance", y=f"{gridname}_depth", cmap=cmocean.cm.haline, add_labels=False, add_colorbar=False, zorder=0)
    cts = data_timestep[varname].plot.contour(ax=ax, x=f"{gridname}_distance", y=f"{gridname}_depth", colors='white', add_labels=False, add_colorbar=False, vmin=0.1, vmax=0.45, levels=8, zorder=1, linewidths=0.8)
    pc.set_clim(clims)
    ymin= np.nanmin(bedlevel)
    ax.fill_between(distance, bedlevel+0.5, [ymin for x in np.arange(0, len(bedlevel), 1)], color='white', zorder=2)
    ax.plot(distance, bedlevel+0.5, c='k', lw=1.25)
    ax.set_ylabel('z [m]', fontsize=12)
    ax.set_yticks([0,-10, -20, -30])
    ax.set_xlabel('Distance along thalweg [m]', fontsize=12)
    ax.tick_params(labelsize=12)
    ax.text(0.03, 0.05, r'$\bf{{{t}}}$' + ' = {}'.format(pd.to_datetime(time).strftime('%Y-%m-%d %H:%M')), transform=ax.transAxes, fontsize=12, color='black')
    fig.add_subplot(ax)

