### Created by Jaho Koo, IHE Delft, TU Delft, K-water

import numpy as np
import pandas as pd
from BNN_Build_MCdropout import MCDropoutBNN
from BNN_Build_MCdropout import get_activation, convert_torch_tensor, get_device
from BNN_Build_MCdropout import train_model_mse_early_stop
from BNN_hyper_TPE import opt_BO, early_stopping_check_1, early_stopping_check_2
from sklearn.model_selection import train_test_split
import optuna
from optuna.samplers import TPESampler



def bnn_main(DC, F=12):


    # ## ================================ hyperparameter optimisation =====================================================
    study = optuna.create_study(direction='minimize', sampler=TPESampler(), storage="sqlite:///BNN_{}_{}.db".format(nm, loss))
    optuna.logging.set_verbosity(optuna.logging.WARNING)
    study.optimize(lambda trial: opt_BO(trial, X_train, y_train, F), n_trials=5000, callbacks=[early_stopping_check_1, early_stopping_check_2])
    best_params = study.best_params

    nodes = best_params['nodes']
    layers = best_params['layers']
    activation_name = best_params['act']
    activation_fn = get_activation(activation_name)
    lr = best_params['lr']
    batch_size = best_params['batch']
    epochs = best_params['epochs']
    dropout_proba = best_params['dropout']
    B = best_params['B']


    # ## ================================ hyperparameter optimisation =====================================================

    DC_ = DC.copy()
    for i in range(B-1):
        col_ = DC.columns.tolist()
        DC = pd.concat([DC, DC_.shift(i+1)], axis=1)
        DC.columns = col_ + [k + '-{}'.format(i + 1) for k in DC_.columns.tolist()]
    for j in range(F):
        DC['{}_+{}'.format('INF', j+1)] = DC['{}'.format('INF')].shift(-j-1)
    DC.dropna(inplace=True)

    DCn_s = (DC - DC.min(axis=0)) / (DC.max(axis=0) - DC.min(axis=0))

    X_train_ = DCn_s[DCn_s.columns[:-1*F]]
    y_train_ = DC[['{}_+{}'.format('INF', i+1) for i in range(F)]]
    X_test_ = DCn_s[DCn_s.columns[:-1*F]]
    # y_test_ = DC[['{}_+{}'.format('INF', i+1) for i in range(F)]]

    X_train, X_val, y_train, y_val = train_test_split(X_train_, y_train_, test_size=0.2, random_state=10)
    X_train = convert_torch_tensor(X_train_.to_numpy())
    X_test = convert_torch_tensor(X_test_.to_numpy())
    y_train = convert_torch_tensor(y_train_.to_numpy())
    X_val = convert_torch_tensor(X_val.to_numpy())
    y_val = convert_torch_tensor(y_val.to_numpy())

    device = get_device()
    nn__ = MCDropoutBNN(X_train.detach().cpu().shape[1], [nodes for i in range(layers)], y_train.detach().cpu().shape[1], dropout_proba / 100, activation_fn).to(device)
    nn__ = train_model_mse_early_stop(nn__, X_train, y_train, X_val, y_val, epochs, lr / 10000, batch_size)
    nn__.eval()

    N__ = 10000
    y_pred_mean, y_pred_std = nn__.predict(X_test, N__)
    y_test_scenarios = nn__.predict_scenarios(X_test, N__)

    return nn__, y_pred_mean, y_pred_std, y_test_scenarios



if __name__ == '__main__':

    input_path = "D:\\"

    DC_ = pd.read_excel(input_path + r"DC_original_UP.xlsx")
    DC_.drop([DC_.columns[0], DC_.columns[1]], axis=1, inplace=True)
    DC_.dropna(inplace=True)
    DC_.replace(" ", np.nan, inplace=True)
    DC_.interpolate(axis=1, inplace=True)

    bnn_model, y_pred_mean, y_pred_std, y_test_scenarios = bnn_main(DC_, F=12)
