# Import necessary libraries
import numpy as np  # For numerical computations
import matplotlib.pyplot as plt  # For data visualization
from scipy.optimize import curve_fit  # For non-linear curve fitting
from scipy.interpolate import interp1d  # For interpolation operations
from scipy.io import loadmat  # For loading .mat files

# The IDE I am using is eclipse-Pydev. If you are like me, please use the following 3 lines of code
from PythonFiles.createFit1 import createFit1  # Import the createFit1 function
from PythonFiles.disper import disper  # Import the disper function
from PythonFiles.HU_method_Hs import HU_method_Hs  # Import the HU_method_Hs function

# If the above 3 lines do no work, please use the following 3 lines of code and comment out the above 3 lines of code
#from createFit1 import createFit1  # Import the createFit1 function
#from disper import disper  # Import the disper function
#from HU_method_Hs import HU_method_Hs  # Import the HU_method_Hs function

# Load data
data = loadmat('hld_hs_ts_depth.mat')  # Load MATLAB file
Hs = data['hld_Hs'][:, 1:4]  # Significant wave height data
Depth = data['hld_depth'][:, 1:4]  # Water depth data
Tp = data['hld_Ts'][:, 1:4]  # Wave period data

# Define observation points
OBS_1, OBS_2, OBS_3 = 0, 104, 216
X = [OBS_1, OBS_2, OBS_3]  # Observation point distances

# Filter valid wave height data and conditions
Hs_in = Hs[:, 2]
ind = np.where(
    (Hs_in <= 0.05) | (Hs_in >= 0.39) | np.isnan(Hs_in) | (data['hld_Ts'][:, 2] > 6) | (data['hld_Ts'][:, 2] < 2)
)[0]  # Filtering conditions

# Remove rows that do not meet the conditions
Tp = np.delete(Tp, ind, axis=0)
Depth = np.delete(Depth, ind, axis=0)
Hs = np.delete(Hs, ind, axis=0)

# Extract target columns
Tp = Tp[:, 2]  # Wave period
Depth = Depth[:, 2]  # Water depth
Hs = Hs[:, :3]  # The first three columns of wave height

# Calculate the number of segments
L_num = int(OBS_3 / 5)  # number of wavelength to be considered, assume the smallest wavelength is 5 meter to obtain more L_num

# Initialize variables
beta, k, L, Ur = [], [], [], []  # Store beta, wave numbers, wavelengths, and Ursell numbers
KLn = np.zeros((L_num, len(Tp)))  # Initialize the KL array

# Main loop: Calculate KL values
for i in range(len(Tp)):
    # Normalize wave heights and reverse
    Hs_temp = Hs[i, :] / Hs[i, 2]
    Hs_temp = np.flip(Hs_temp)
    
    # Use createFit1 to calculate beta values
    beta_val, _ = createFit1(X, Hs_temp)  # obtain the beta value of equation S4 in the SI by data fitting, originally from Dalrymple, 1984
    beta.append(beta_val)

    # Calculate wave numbers, wavelengths, and Ursell numbers
    k_val = disper(2 * np.pi / Tp[i], Depth[i])  # obtain the wave number
    k.append(k_val)
    L_val = 2 * np.pi / k_val  # obtain the wave length
    L.append(L_val)
    Ur.append(Hs[i, 2] * L_val ** 2 / (Depth[i] ** 3))  # obtain the Ur number

    # Calculate KL values
    for j in range(L_num):
        KLn[j, i] = 1 / (1 + beta_val * L_val * (j + 1))  # use beta value to obtain the relative wave height along the forest

# Reload data and filter again
Hs = data['hld_Hs'][:, 1:4]
Depth = data['hld_depth'][:, 1:4]
Tp = data['hld_Ts'][:, 1:4]
Hs_in = Hs[:, 2]

ind = np.where(
    (Hs_in < 0.40) | np.isnan(Hs_in) | (data['hld_Ts'][:, 2] > 6) | (data['hld_Ts'][:, 2] < 2)
)[0]  # remove all the data out of this range [Hs<0.4], only predict the high wave cases

# Remove rows that do not meet the conditions
Tp = np.delete(Tp, ind, axis=0)
Depth = np.delete(Depth, ind, axis=0)
Hs = np.delete(Hs, ind, axis=0)

# Extract target columns
Tp = Tp[:, 2]
Depth = Depth[:, 2]
Hs = Hs[:, :3]

# Initialize matrices for predictions and observations
MM = np.full((len(Tp), 2), np.nan)  # Prediction matrix, filled with NaN
Ob = np.zeros((len(Tp), 2))  # Observation matrix

# Calculate predictions and observations
for p in range(len(Tp)): 
    L_val = 2 * np.pi / disper(2 * np.pi / Tp[p], Depth[p])  # obtain the wave length
    Ur_big = Hs[p, 2] * L_val ** 2 / Depth[p] ** 3  # obtain the Ur number

    # Check if Ursell number is within range
    if Ur_big > max(Ur):
        MM[p, :] = np.nan  # Fill with NaN if out of range
    else:
        # Use HU_method_Hs to calculate KLm values
        MM[p, 0] = HU_method_Hs(KLn, Ur, Hs[p, 2], OBS_3, Tp[p], Depth[p])  # this function is the main function of calculation
        MM[p, 1] = HU_method_Hs(KLn, Ur, Hs[p, 2], OBS_2, Tp[p], Depth[p])  # modelling results of relative wave height (KLm) at the 2nd most landward station

    # Normalize observation values
    Ob[p, :] = Hs[p, :2] / Hs[p, 2]

# Flatten matrices for plotting
MM_all = MM.flatten()
Ob_all = Ob.flatten()

# Plot predicted vs observed values
plt.plot(MM_all, Ob_all, '+')  # Scatter plot
plt.plot([0, 1], [0, 1])  # Reference line y=x
plt.xlabel("Predicted")  # X-axis label
plt.ylabel("Observed")  # Y-axis label
plt.savefig("H.jpg", dpi=300)  # Save plot as H.jpg
plt.show()

# Calculate correlation coefficient and root mean square error
R2 = 1 - np.nansum((MM_all - Ob_all) ** 2) / np.nansum((np.nanmean(Ob_all) - Ob_all) ** 2)  # R value
RMSE = np.sqrt(np.nanmean((MM_all - Ob_all) ** 2))  # Root mean square error
print("R2:", R2)
print("RMSE:", RMSE)

# Save results to files
np.savetxt("MMH.txt", MM, fmt="%.5f")  # Save predictions
np.savetxt("ObH.txt", Ob, fmt="%.5f")  # Save observations
