#!/usr/bin/env python
# coding: utf-8

# In[44]:


import flopy.utils.binaryfile as bf
import matplotlib.pyplot as plt
import flopy
import numpy as np
import matplotlib.patheffects as pltpe
import math


def plotplume(time, L, ztop, zbot, modelname, model_ws, filename, title):

    ucnobjp = bf.UcnFile(f"{model_ws}\\" + str(filename))
    timesp = ucnobjp.get_times()
    concentrationp = ucnobjp.get_data(totim=timesp[time])
    
    fig, ax = plt.subplots(1, 1, figsize=(8, 4.5), dpi=150, constrained_layout=True)
    
    if title == "DO (mmol)":
        conc_minus = concentrationp * 1000 / 2
        modelmap = ax.imshow(conc_minus[:,0,:], interpolation='nearest',
                  extent=(-35, 250 , 10, 0), aspect="auto", vmax = 0.28)
        
    else:
        conc_minus = concentrationp * 1000

        modelmap = ax.imshow(conc_minus[:,0,:], interpolation='nearest',
                  extent=(-35, 250 , 10, 0), aspect="auto")

    csfont = {'fontname':'Times new roman', 'fontsize' : 30} #16

    ax.set_title(str(title), **csfont)
    ax.set_xlabel('Distance', **csfont)
    ax.set_ylabel('Depth', **csfont)
    ax.tick_params(axis='both', which='major', labelsize=25) #14

    #plt.savefig(os.path.join(path_pht3d +'.png'))

    cb = plt.colorbar(modelmap, shrink=1, ax=ax, pad=0.01)
    cb.ax.tick_params(labelsize=25) #14

    plt.yticks(ticks = [0,5,10], fontname = "Times New Roman") 
    plt.xticks(ticks = [0,50,100,150,200,250], fontname = "Times New Roman")

    for l in cb.ax.yaxis.get_ticklabels():
        l.set_family("Times new roman")
        
    def wells(distance, depth):
        x = (distance, distance)
        y = (depth - 1, depth)
        plt.plot(x, y, "-", color= "white", markersize = 1, alpha= 0.8,
                 path_effects=[pltpe.Stroke(linewidth=3, foreground='black', alpha= 0.9), pltpe.Normal()])
    plt.xlim(-35,200)
    plt.show()



# In[ ]:


def plotallplumesinHD(modelname):

    model_ws = "Output_files_" + modelname

    UCN = []
    savelist = [1, 2, 6, 7, 8, 10, 12, 13, 14, 15, 16,
                19, 21, 34, 35, 36, 37, 38, 39, 40, 41, 42]
    
    for i in savelist:
        UCN.append("PHT3D0" + str(i).zfill(2) + ".UCN")

    plumefiguretitle = ["BEX (mmol)", "NVDOC (mmol)", "Total DIC (mmol)", "CH$_{4}$ (mmol)", 
                        "Ca$^{2+}$ (mmol)", "Fe$^{2+}$ (mmol)","K$^{+}$ (mmol)", "Mg$^{2+}$ (mmol)", 
                        "NO$_{3}$ (mmol)", "Na$^{+}$ (mmol)", "DO (mmol)","SO$_{4}$ (mmol)",
                        "pH", "Calcite (mmol/Lv)", "Siderite (mmol/Lv)", "Fe(OH)$_{3}$ (mmol)", "Pyrite (mmol/Lv)",
                        "CaX2 (mmol)", "MgX2 (mmol)", "KX (mmol)", "NaX (mmol)", "FeX2 (mmol)"
                       ]
    import flopy.utils.binaryfile as bf
    import matplotlib.pyplot as plt
    import flopy


    def plotplume(time, L, ztop, zbot, modelname, model_ws, filename, title):

        ucnobjp = bf.UcnFile(f"{model_ws}\\" + str(filename))
        timesp = ucnobjp.get_times()
        concentrationp = ucnobjp.get_data(totim=timesp[time])

        fig, ax = plt.subplots(1, 1, figsize=(8, 4.5), dpi=150, constrained_layout=True)

        if title == "DO (mmol)":
            conc_minus = concentrationp * 1000 / 2
            modelmap = ax.imshow(conc_minus[:,0,:], interpolation='nearest',
                      extent=(-35, 250 , 10, 0), aspect="auto", vmax = 0.28, cmap = "viridis")
        elif title == "pH":
            conc_minus = concentrationp
            modelmap = ax.imshow(conc_minus[:,0,:], interpolation='nearest',
                      extent=(-35, 250 , 10, 0), aspect="auto", cmap = "viridis")
        else:
            conc_minus = concentrationp * 1000

            modelmap = ax.imshow(conc_minus[:,0,:], interpolation='nearest',
                      extent=(-35, 250 , 10, 0), aspect="auto", cmap = "viridis")

        csfont = {'fontname':'Times new roman', 'fontsize' : 30} #16

        ax.set_title(str(title), **csfont)
        ax.set_xlabel('Distance (m)', **csfont)
        ax.set_ylabel('Depth (m)', **csfont)
        ax.tick_params(axis='both', which='major', labelsize=25) #14

        #plt.savefig(os.path.join(path_pht3d +'.png'))

        cb = plt.colorbar(modelmap, shrink=1, ax=ax, pad=0.01)
        cb.ax.tick_params(labelsize=25) #14

        plt.yticks(ticks = [0,5,10], fontname = "Times New Roman") 
        plt.xticks(ticks = [0,50,100,150,200,250], fontname = "Times New Roman")

        for l in cb.ax.yaxis.get_ticklabels():
            l.set_family("Times new roman")

        def wells(distance, depth):
            x = (distance, distance)
            y = (depth - 1, depth)
            plt.plot(x, y, "-", color= "white", markersize = 1, alpha= 0.8,
                     path_effects=[pltpe.Stroke(linewidth=3, foreground='black', alpha= 0.9), pltpe.Normal()])
        plt.xlim(-35,200)
        plt.show()

    L = 300
    ztop = 10
    zbot = 0
    timeend = -1

    for filename, title in zip(UCN, plumefiguretitle):
        plume_0th_end = plotplume(timeend, L, ztop, zbot, modelname, model_ws, filename, title)


# In[11]:


def plotplumesindifferenttimescmap(time1, time2, time3, L, ztop, zbot, modelname, model_ws, filename, title, plot, plotwells, cmap):
    
    ucnobjp = bf.UcnFile(f"{model_ws}\\" + str(filename))
    timesp = ucnobjp.get_times()
    times = [time1, time2, time3]
    titles = [(str(title) + " " + f"{((time1-1)/12):.0f}" + "-year"), (f"{((time2-1)/12):.0f}" + "-year"), 
              (f"{((time3-1)/12):.0f}" + "-year")]
    conc_time3 = ucnobjp.get_data(totim=timesp[time3])
    conc_time0 = ucnobjp.get_data(totim=timesp[0])
    
    fig = plt.figure(figsize=(18, 3.5), dpi=150, constrained_layout=True)
    #fig = plt.figure(figsize=(16, 3.5), dpi=150, constrained_layout=True)
    
    max_val = np.amax((conc_time0, conc_time3)) * 1000
    min_val = np.amin((conc_time0, conc_time3)) * 1000
    
    csfont = {'fontname':'Times new roman', 'fontsize' : 28} #16

    def plotassignedconcentration(conc_minus, vmax, vmin):

        modelmap = ax.imshow(conc_minus[:,0,:], interpolation='nearest',
                  extent=(-35, 250 , 10, 0), aspect="auto", vmin = vmin, vmax = vmax, cmap = cmap)
        ax.tick_params(axis='both', which='major', labelsize=23) #14
        plt.yticks(ticks = [0,5,10], fontname = "Times New Roman") 
        plt.xticks(ticks = [0,50,100,150,200,250], fontname = "Times New Roman")
        plt.xlim(-35,200)
        ax.set_title(str(y), **csfont)
        
        if x == times[0]:
            ax.set_ylabel('Depth (m)', **csfont)
        if plot == True:
            ax.set_xlabel('Distance (m)', **csfont)
            
        if plotwells == True:
            def wells(distance, depth):
                x = (distance, distance)
                y = (depth, depth+1)
                plt.plot(x, y, "-", color= "white", markersize = 1, alpha= 0.8,
                         path_effects=[pltpe.Stroke(linewidth=3, foreground='black', alpha= 0.9), pltpe.Normal()])

            wells(5, 2)
            wells(55, .6)
            wells(55, 2.2)
            wells(55, 3.8)
            wells(85, 1)
            wells(85, 2.6)
            wells(85, 4.2)
            wells(145, 4.6)
            
            rectangle = plt.Rectangle((-25,0), 50, 2, fill=False, edgecolor='white')
            plt.gca().add_patch(rectangle)
            
        return (modelmap)   
    
    for x,y in zip(times, titles):
        ax = fig.add_subplot(1, 3, times.index(x) + 1)
        concentration = ucnobjp.get_data(totim=timesp[x])

        if title == "pH":
            conc_minus = concentration 
            vmax = max_val/1000
            vmin = min_val/1000
            modelmap = plotassignedconcentration(conc_minus, vmax, vmin)

        elif title == "ORP (mV)":
            R = 8.314 * (10**-3) #Gas constant (kJ/deg/mol)
            T = 281.15 #absolute temperature
            F = 96.42 #Faradays constant (kJ/Volt gram)
            factor = 2.303 * R * T / F * 1000 #Conversion factor to mV
            conc_minus = concentration * factor
            vmax = (max_val/1000*factor)
            vmin = (min_val/1000*factor)
            modelmap = plotassignedconcentration(conc_minus, vmax, vmin)  
            
        elif title == "DO (mmol)":
            conc_minus = concentration * 1000 / 2
            vmax = 0.28
            vmin = (min_val/2)
            modelmap = plotassignedconcentration(conc_minus, vmax, vmin)  
            
        elif title == "Fe(OH)$_{3}$ (mmol)" or title == "Calcite (mmol)" or title == "CaX2 (mmol)" or title == "Pyrite (mmol)":
            conc_base = ucnobjp.get_data(totim=timesp[0])
            base_val = np.amin(conc_base) * 1000
            max_val = max_val
            min_val = min_val
            
            conc_minus = (concentration * 1000) - base_val
            vmax = max_val - base_val
            vmin = min_val - base_val
            modelmap = plotassignedconcentration(conc_minus, vmax, vmin)
        
        elif title == "BEX (mmol)":
            
            filename_Wu = "PHT3D017.UCN"
            ucnobjp_Wu = bf.UcnFile(f"{model_ws}\\" + str(filename_Wu))
            timesp_Wu = ucnobjp_Wu.get_times()
            conc_max_Wu = ucnobjp_Wu.get_data(totim=timesp_Wu[time3])
            conc_min_Wu = ucnobjp_Wu.get_data(totim=timesp_Wu[time1])
            max_val = np.amax(conc_max_Wu) * 1000
            min_val = np.amin(conc_min_Wu) * 1000    
            
            conc_minus = concentration * 1000
            modelmap = plotassignedconcentration(conc_minus, max_val, min_val)
            vmax = max_val
            vmin = min_val
            
        else:
            conc_minus = concentration * 1000            
            modelmap = plotassignedconcentration(conc_minus, max_val, min_val)
            vmax = max_val
            vmin = min_val

    #plt.savefig(os.path.join(path_pht3d +'.png'))
    #modelmap.set_clim(vmin, vmax)
    
    print(vmin, vmax)
    cb = plt.colorbar(modelmap, shrink=1, ax=ax, pad=0.01)
    cb.ax.tick_params(labelsize=23) #14
   
    for l in cb.ax.yaxis.get_ticklabels():
        l.set_family("Times new roman")

    plt.show()


# In[ ]:


def plotplumesindifferenttimes(time1, time2, time3, L, ztop, zbot, modelname, model_ws, filename, title, plot):
    
    ucnobjp = bf.UcnFile(f"{model_ws}\\" + str(filename))
    timesp = ucnobjp.get_times()
    times = [time1, time2, time3]
    titles = [(str(title) + " " + str((time1-1)*30) + "d"), (str((time2-1)*30) + "d"), (str((time3-1)*30) + "d")]
    conc_max = ucnobjp.get_data(totim=timesp[time3])
    conc_min = ucnobjp.get_data(totim=timesp[time1])
    
    fig = plt.figure(figsize=(18, 3.5), dpi=150, constrained_layout=True)
    #fig = plt.figure(figsize=(16, 3.5), dpi=150, constrained_layout=True)
    
    max_val = np.amax(conc_max) * 1000
    min_val = np.amin(conc_min) * 1000
    csfont = {'fontname':'Times new roman', 'fontsize' : 28} #16
    
    for x,y in zip(times, titles):
        ax = fig.add_subplot(1, 3, times.index(x) + 1)
        concentration = ucnobjp.get_data(totim=timesp[x])
    
        if title == "DO (mmol)":
            conc_minus = concentration * 1000 / 2
            modelmap = ax.imshow(conc_minus[:,0,:], interpolation='nearest',
                      extent=(-35, 250 , 10, 0), aspect="auto", vmin = (min_val/2), vmax = 0.28)
            ax.tick_params(axis='both', which='major', labelsize=23) #14
            plt.yticks(ticks = [0,5,10], fontname = "Times New Roman") 
            plt.xticks(ticks = [0,50,100,150,200,250], fontname = "Times New Roman")
            ax.set_title(str(y), **csfont)
            
            if x == times[0]:
                ax.set_ylabel('Depth (m)', **csfont)
            if plot == True:
                ax.set_xlabel('(m)', **csfont)
                
        else:
            conc_minus = concentration * 1000
            modelmap = ax.imshow(conc_minus[:,0,:], interpolation='nearest',
                      extent=(-35, 250 , 10, 0), aspect="auto", vmax = max_val, vmin = min_val)
            ax.tick_params(axis='both', which='major', labelsize=23) #14
            plt.yticks(ticks = [0,5,10], fontname = "Times New Roman") 
            plt.xticks(ticks = [0,50,100,150,200,250], fontname = "Times New Roman")
            ax.set_title(str(y), **csfont)
            
            if x == times[0]:
                ax.set_ylabel('Depth (m)', **csfont)
            if plot == True:
                ax.set_xlabel('(m)', **csfont)

    #plt.savefig(os.path.join(path_pht3d +'.png'))

    cb = plt.colorbar(modelmap, shrink=1, ax=ax, pad=0.01)
    cb.ax.tick_params(labelsize=23) #14
    for l in cb.ax.yaxis.get_ticklabels():
        l.set_family("Times new roman")
    plt.xlim(-35,200)
    plt.show()


# In[ ]:


def plotallplumes(model_ws, ncomp):
    
    spec_name = ["Benz", "NVOC", "Toludeg", "Shortnal", "Longnal",
                 "C4", "C(-4)", "Ca", "Cl", "Fe(2)",
                 "Fe(3)", "K", "Mg", "N(5)", "Na", 
                 "O(0)", "Wu", "S(-2)", "S(6)", "Ntwo", 
                 "pH", "pe", "Benznapl", "NVOCnapl", "Tolunapldeg", 
                 "Shortnapl", "Longnapl", "Lowsolubnapl", "Coutgas", "Methoutgas", 
                 "Cardiox", "Ntwogas", "Wunapl", "Calcite", "Siderite",           
                 "FeOH3a", "CaX2", "MgX2", "KX", "NaX", 
                 "FeX2"
                ] 
    
    species = ncomp
    ucnobjp = bf.UcnFile(f"{model_ws}\\PHT3D001.UCN")
    timesp = ucnobjp.get_times()
    ucnobj_items = []
    UCN = []
    conc_species = [] 

    i = 1

    while i <= species:
        ucnobj_items.append("ucnobjp" + str(i))
        UCN.append("PHT3D0" + str(i).zfill(2) + ".UCN")
        conc_species.append("species" + str(i))
        ucnobj_items[i-1] = bf.UcnFile(f"{model_ws}\\" + UCN[i-1])
        conc_species[i-1] = (ucnobj_items[i-1].get_data(totim = timesp[-1]))
        i = i+1
    
    fig, ax = plt.subplots(1, 1, figsize=(20, 12), constrained_layout=True)
    ax.axis('off')
    
    j = 0
    while j < species:
        ax = fig.add_subplot(9, 5, j + 1)
        im = ax.imshow(conc_species[j][:,0,:], interpolation='nearest',
                  extent=(-35, 250 , 0, 10), aspect="auto")#, cmap = "jet")#, vmin=0, vmax=0.006)
        ax.set_title(spec_name[j])
        plt.colorbar(im, format='%.0e')
        plt.yticks(ticks = [0,5,10]) 
        plt.xticks(ticks = [0,50,100,150,200,250])

        j = j + 1
    plt.xlim(-35,200)
    plt.show()


# In[ ]:


def plotplumesensor(filename, title):
    if filename == "pH":
        ucnobjp = bf.UcnFile(f"{model_ws}\\PHT3D021.UCN")
        timesp = ucnobjp.get_times()
        #print(timesp)
        concentrationp = ucnobjp.get_data(totim=timesp[-1])
        
    elif filename == "ORP":
        R = 8.314 * (10**-3) #Gas constant (kJ/deg/mol)
        T = 281.15 #absolute temperature
        F = 96.42 #Faradays constant (kJ/Volt gram)
        factor = 2.303 * R * T / F * 1000 #Conversion factor to mV

        ucnobjp = bf.UcnFile(f"{model_ws}\\PHT3D022.UCN")
        timesp = ucnobjp.get_times()
        concentrationp = ucnobjp.get_data(totim=timesp[-1]) * factor
        
    elif filename == "EC":
        ucnobj_Ca2 = bf.UcnFile(f"{model_ws}\\PHT3D008.UCN")
        ucnobj_Fe2 = bf.UcnFile(f"{model_ws}\\PHT3D010.UCN")
        ucnobj_K1 = bf.UcnFile(f"{model_ws}\\PHT3D012.UCN")
        ucnobj_Mg2 = bf.UcnFile(f"{model_ws}\\PHT3D013.UCN")
        ucnobj_Na1 = bf.UcnFile(f"{model_ws}\\PHT3D015.UCN")
        timesp = ucnobj_Ca2.get_times()

        conc_Ca2 = ucnobj_Ca2.get_data(totim=timesp[-1]) * 2
        conc_Fe2 = ucnobj_Fe2.get_data(totim=timesp[-1]) * 2
        conc_Mg2 = ucnobj_Mg2.get_data(totim=timesp[-1]) * 2
        conc_Na1 = ucnobj_Na1.get_data(totim=timesp[-1])
        conc_K1 = ucnobj_K1.get_data(totim=timesp[-1])

        concentrationp = (conc_Ca2 + conc_Fe2 + conc_Mg2 + conc_Na1 + conc_K1) * 100 * 1000

    elif filename == "EC (PHREEQC)":
        Cl = conc_Cl 
        Na = conc_Na 
        SO4 =conc_SO4 
        Mg = conc_Mg 
        Ca = conc_Ca 
        H = conc_H 
        HCO3 = conc_HCO3
        NO3 =conc_NO3 
        K = conc_K 

        concentrationp = EC_total(Cl, Na, SO4, Mg, Ca, H, HCO3, NO3, K)
        
    # Make the plot for species 2

    fig, ax = plt.subplots(1, 1, figsize=(8, 6), dpi=150, constrained_layout=True)
    modelmap = flopy.plot.PlotCrossSection(
    model=mt_sim,
    ax=ax,
    line={"row": 0},
    )

    ax.imshow(concentrationp[:,0,:], interpolation='nearest',
              extent=(-35, 250, 0, (ztop-zbot)), aspect="auto")

    csfont = {'fontname':'Times new roman', 'fontsize' : 32} #16

    ax.set_title(str(title), **csfont)
    ax.set_xlabel('Distance', **csfont)
    ax.set_ylabel('Depth', **csfont)
    ax.tick_params(axis='both', which='major', labelsize=27) #14

    #plt.savefig(os.path.join(path_pht3d +'.png'))

    pa = modelmap.plot_array(concentrationp)#, cmap = "magma")#, vmin=0, vmax=0.004)
    cb = plt.colorbar(pa, shrink=1, ax=ax, pad=0)
    cb.ax.tick_params(labelsize=27) #14

    plt.yticks(ticks = [0,5,10], fontname = "Times New Roman") 
    plt.xticks(ticks = [0,50,100,150,200], fontname = "Times New Roman")

    for l in cb.ax.yaxis.get_ticklabels():
        l.set_family("Times new roman")
        
    def wells(distance, depth):
        x = (distance, distance)
        y = (depth - 1, depth)
        plt.plot(x, y, "-", color= "white", markersize = 1, alpha= 0.8,
                 path_effects=[pltpe.Stroke(linewidth=3, foreground='black', alpha= 0.9), pltpe.Normal()])
        
"""        
    wells(40, 8)
    wells(90, 9.4)
    wells(90, 7.8)
    wells(90, 6.2)
    wells(120, 9)
    wells(120, 7.4)
    wells(120, 5.8)
    wells(180, 5.4)

"""
#plt.show()

