# Import necessary libraries
import numpy as np  # For numerical computations
from scipy.interpolate import interp1d  # For interpolation operations
import matplotlib.pyplot as plt  # For data visualization
from scipy.optimize import curve_fit  # For non-linear curve fitting

# The IDE I am using is eclipse-Pydev. If you are like me, please use the following 4 lines of code
from PythonFiles.createFit1 import createFit1  # Import the createFit1 function
from PythonFiles.createFit2 import createFit2  # Import the createFit2 function
from PythonFiles.disper import disper  # Import the disper function
from PythonFiles.HU_method_Hs import HU_method_Hs  # Import the HU_method_Hs function

# If the above 4 lines do no work, please use the following 4 lines of code and comment out the above 4 lines of code
#from reateFit1 import createFit1  # Import the createFit1 function
#from createFit2 import createFit2  # Import the createFit2 function
#from disper import disper  # Import the disper function
#from HU_method_Hs import HU_method_Hs  # Import the HU_method_Hs function

# Load input data
Hs = np.loadtxt("Hs.txt")  # Significant wave height data
Depth = np.loadtxt("WL.txt")  # Water depth data
Tp = np.loadtxt("Tp.txt")  # Wave period data

# Define observation points
OBS_1 = 0   # need to input the distances between each observation stations 
OBS_2 = 140
OBS_3 = OBS_2 + 130.1
OBS_4 = OBS_3 + 45.8

X = [OBS_1, OBS_2, OBS_3, OBS_4]  # there are 5 stations in total, but the most seaward one is on bare flat, so only use the rest 4 stations.

# Extract the specific column of wave height data
Hs_in = Hs[:, 3]

# Filter rows where wave height is between 0.1 and 0.2
ind = np.where((Hs_in < 0.10) | (Hs_in > 0.20))[0]  # remove all the data out of this range [0.1<Hs<0.2]
Tp = np.delete(Tp, ind, axis=0)  # Remove rows that do not meet the condition
Depth = np.delete(Depth, ind, axis=0)
Hs = np.delete(Hs, ind, axis=0)

# Extract specific columns as calculation parameters
Tp = Tp[:, 3]
Depth = Depth[:, 3]
Hs = Hs[:, :4]  # the significant wave height at the mangrove seaward edge is used to build the predictor

# Calculate the number of segments between observation points
L_num = int(np.floor(OBS_4 / 5))   # number of wavelength to be considered, assuming the smallest wavelength is 5 meter

# Initialize storage variables
beta = []  # Store beta values
k = []  # Store wave numbers
L = []  # Store wavelengths
Ur = []  # Store Ursell numbers
KLn = np.zeros((L_num, len(Tp)))  # Store KL values

# Main loop
for i in range(len(Tp)):
    # Normalize wave height and reverse the order
    Hs_temp = Hs[i, :] / Hs[i, 3]
    Hs_temp = Hs_temp[::-1]
    
    # Use createFit1 to calculate beta values
    beta_val, gof = createFit1(X, Hs_temp)  # obtain the beta value of equation S4 in the SI by data fitting, originally from Dalrymple, 1984
    beta.append(beta_val)

    # Calculate wave number, wavelength, and Ursell number
    k_val = disper(2 * np.pi / Tp[i], Depth[i])  # obtain the wave number
    k.append(k_val)
    L_val = 2 * np.pi / k_val  # obtain the wave length
    L.append(L_val)
    Ur_val = Hs[i, 0] * L_val ** 2 / (Depth[i] ** 3)  # obtain the Ur number
    Ur.append(Ur_val)

    # Calculate KL values
    for j in range(L_num):
        KLn[j, i] = 1 / (1 + beta_val * L_val * (j + 1))  # use beta value to obtain the relative wave height along the forest

# Reload data and filter again
Hs = np.loadtxt("Hs.txt")
Depth = np.loadtxt("WL.txt")
Tp = np.loadtxt("Tp.txt")

Hs_in = Hs[:, 3]
ind = np.where(Hs_in <= 0.20)[0]  # remove all the data out of this range [Hs<0.2], only predict the high wave cases
Tp = np.delete(Tp, ind, axis=0)
Depth = np.delete(Depth, ind, axis=0)
Hs = np.delete(Hs, ind, axis=0)

Tp = Tp[:, 3]
Depth = Depth[:, 3]
Hs = Hs[:, :4]

# Initialize matrices for predictions and observations
MM = np.empty((len(Tp), 3))  # To store predictions
MM[:] = np.nan  # Initialize to NaN
Ob = np.empty((len(Tp), 3))  # To store observations

# Second main loop: Calculate predictions and observations
for p in range(len(Tp)):
    # Calculate wavelength and Ursell number
    L_val = 2 * np.pi / disper(2 * np.pi / Tp[p], Depth[p])
    Ur_big = Hs[p, 3] * L_val ** 2 / Depth[p] ** 3

    # Check if Ursell number is out of range
    if Ur_big > max(Ur):
        MM[p, :] = np.nan  # Fill with NaN if out of range
    else:
        # Use HU_method_Hs to calculate KLm values
        MM[p, 0] = HU_method_Hs(KLn, Ur, Hs[p, 3], OBS_4, Tp[p], Depth[p])  # this function is the main function of calculation, obtain the relative wave height (KLm) at OBS_4
        MM[p, 1] = HU_method_Hs(KLn, Ur, Hs[p, 3], OBS_3, Tp[p], Depth[p])  # obtain the relative wave height (KLm) at OBS_3
        MM[p, 2] = HU_method_Hs(KLn, Ur, Hs[p, 3], OBS_2, Tp[p], Depth[p])  # obtain the relative wave height (KLm) at OBS_2

    # Normalize observation values
    Ob[p, :] = Hs[p, :3] / Hs[p, 3]

# Flatten matrices for plotting
MM_all = MM.flatten()
Ob_all = Ob.flatten()

# Plot predicted vs observed values
plt.plot(MM_all, Ob_all, '+')  # Scatter plot
plt.plot([0, 1], [0, 1])  # Reference line y=x
plt.xlabel("Predicted")  # X-axis label
plt.ylabel("Observed")  # Y-axis label
plt.title("Prediction vs Observation")  # Plot title
plt.savefig("B.jpg", dpi=300)  # Save the plot as B.jpg
plt.show()  # Display the plot

# Calculate R and RMSE
R2 = 1 - np.nansum((MM_all - Ob_all) ** 2) / np.nansum((np.nanmean(Ob_all) - Ob_all) ** 2)  # R value
RMSE = np.sqrt(np.nanmean((MM_all - Ob_all) ** 2))  # Root mean square error

# Print results
print("R2:", R2)
print("RMSE:", RMSE)

# Save predictions and observations
np.savetxt("MMB.txt", MM)  # Save predictions
np.savetxt("ObB.txt", Ob)  # Save observations
