### Created by Jaho Koo, IHE Delft, TU Delft, K-water

from BNN_Build_MCdropout import MCDropoutBNN
from BNN_Build_MCdropout import get_activation, train_model_mse_early_stop
from BNN_Build_MCdropout import convert_torch_tensor, get_device
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
from optuna import Trial
import pandas as pd


def opt_BO(trial: Trial, DC, F, lambda_reg):
    nodes = trial.suggest_categorical('nodes', [256, 512])
    layers = trial.suggest_categorical('layers', [3, 6, 9])
    activation_name = trial.suggest_categorical('act', ['relu', 'tanh'])
    activation_fn = get_activation(activation_name)
    lr = trial.suggest_categorical('lr', [5, 10, 50])
    batch_size = trial.suggest_categorical('batch', [64, 128, 256])
    epochs = trial.suggest_int('epochs', low=5000, high=5000)
    dropout_proba = trial.suggest_int('dropout', low=10, high=50, step=10)
    B = trial.suggest_int('B', low=6, high=F*2, step=6)

    DC_ = DC.copy()
    for i in range(B-1):
        col_ = DC.columns.tolist()
        DC = pd.concat([DC, DC_.shift(i+1)], axis=1)
        DC.columns = col_ + [k + '-{}'.format(i + 1) for k in DC_.columns.tolist()]
    for j in range(F):
        DC['{}_+{}'.format('INF', j+1)] = DC['{}'.format('INF')].shift(-j-1)
    DC.dropna(inplace=True)

    DCn_s = (DC - DC.min(axis=0)) / (DC.max(axis=0) - DC.min(axis=0))

    X_train = DCn_s[DCn_s.columns[:-1*F]]
    y_train = DC[['{}_+{}'.format('INF', i+1) for i in range(F)]]
    # X_test = DCn_s[DCn_s.columns[:-1*F]]
    # y_test = DC[['{}_+{}'.format('INF', i+1) for i in range(F)]]

    X_train_, X_val_, y_train_, y_val_ = train_test_split(X_train, y_train, test_size=0.2, random_state=10)
    X_train_ = convert_torch_tensor(X_train_.to_numpy())
    X_val_ = convert_torch_tensor(X_val_.to_numpy())
    y_train_ = convert_torch_tensor(y_train_.to_numpy())
    y_val__ = convert_torch_tensor(y_val_.to_numpy())

    device = get_device()
    nn_ = MCDropoutBNN(X_train_.shape[1], [nodes for i in range(layers)], y_train_.shape[1], dropout_proba/100, activation_fn).to(device)
    nn_ = train_model_mse_early_stop(nn_, X_train_, y_train_, X_val_, y_val__, epochs, lr/10000, batch_size, lambda_reg=lambda_reg)

    nn_.eval()
    y_pred_mean, std = nn_.predict(X_val_, 100)

    mse = mean_squared_error(y_val_.to_numpy(), y_pred_mean.cpu())
    penalty = mse

    return penalty


def early_stopping_check_1(study, trial, early_stopping_rounds=100):
    current_trial_number = trial.number
    best_trial_number = study.best_trial.number
    should_stop = (current_trial_number - best_trial_number) >= early_stopping_rounds
    if should_stop:
        study.stop()

def early_stopping_check_2(study, trial, early_stopping_threshold=0.001):
    should_stop = study.best_trial.value <= early_stopping_threshold
    if should_stop:
        study.stop()


