# -*- coding: utf-8 -*-
### Created by Jaho Koo, IHE Delft, TU Delft, K-water

import pandas as pd
import numpy as np
import tensorflow as tf
from pyomo.environ import value

from SMPC_formulation_CVaR import MPC_formula
from Input_scenario import call_DC_, call_GP_, call_MH_, Scenarios_DC_1S_, Scenarios_GP_1S_, Scenarios_MH_1S_, Prepare_data_iter
from MPC_solver import c_solver

result_path = "D:\\"
input_path = "D:\\"

try:
    LV_curve = pd.read_csv(r"/home/koojah/RL/LV_curve.csv")
except:
    LV_curve = pd.read_csv(r"LV_curve.csv")


def LtoS(lvl):
    sto = np.interp(lvl, LV_curve['0'], LV_curve['1'])
    return sto

def StoL(sto):
    lvl = np.interp(sto, LV_curve['1'], LV_curve['0'])
    return lvl


# ================================================================================

MT = 264  # the maximun discharge via turbines
MSP = int(11680)  # the maximum discharge via spillway gates
LWL = 60.0
NHWL = 76.5
FWL = 80.0
LWS = int(LtoS(LWL))
FWS = int(LtoS(FWL))
NHWS = int(LtoS(NHWL))

IWL = 76.5  # initial water level
IWS = int(LtoS(IWL))
TWL = 76.5  # Target water level at the end of the episode
TWL_N = 76.5
TWL_F = 77.0
TWL_H = 78.4
TWL_L = 76.0
TWL_U = 75.5
TWS = int(LtoS(TWL))
TWS_F = int(LtoS(TWL_F))
TWS_H = int(LtoS(TWL_H))
TWS_L = int(LtoS(TWL_L))
TWS_N = int(LtoS(TWL_N))
TWS_U = int(LtoS(TWL_U))

O_demand = 52

# ================================================================================

def Sol_to_Weight(solution):
    w_ = [solution[0] * 1000, solution[1] * 1000, solution[2] * 1000, solution[3] * 1000, solution[4] * 1000, solution[5] * 1000]
    return w_


def MPC_main(F, E, Weight, init_RWL = 76.5, init_OUT = MT, n_cluster = 5, wn = 10, type_ = 'close', delta_SO_1 = 100, DCW_H = 4.5, beta1=0.99, beta2=0.99):

    inf_real = pd.read_excel(input_path + "DC_training_original.xlsx", index_col=None)
    inf_real.drop(inf_real.columns[0], axis=1, inplace=True)
    inf_real__ = inf_real.copy()
    inf_real__ = inf_real__[inf_real__['E']>=20][['E', 'INF']]
    inf_real__.reset_index(drop=True, inplace=True)
    inf_real_ = inf_real__[inf_real__['E']==E]['INF']

    DC_, nm, DC_min_, DC_max_ = call_DC_()
    GP_, nm, GP_min_, GP_max_ = call_GP_()
    MH_, nm, MH_min_, MH_max_ = call_MH_()

    DC_B = 11
    GP_B = 9
    MH_B = 11

    DC_X_, DC_y_ = Prepare_data_iter(DC_, DC_B, F, E, DC_min_, DC_max_, t_nm='INF')
    GP_X_, GP_y_ = Prepare_data_iter(GP_, GP_B, F, E, GP_min_, GP_max_, t_nm='gp_wl')
    MH_X_, MH_y_ = Prepare_data_iter(MH_, MH_B, F, E, MH_min_, MH_max_, t_nm='mh_wl')

    ss_k = max(DC_X_.index[0], GP_X_.index[0], MH_X_.index[0])
    ee_k = min(DC_X_.index[-1], GP_X_.index[-1], MH_X_.index[-1])

    DC_inf_nn = tf.keras.models.load_model(input_path + "Deep_BNN_DC_inf_AR_opt_F12.keras")
    GP_nn = tf.keras.models.load_model(input_path + "Deep_BNN_gp_wl_opt_F12.keras")
    MH_nn = tf.keras.models.load_model(input_path + "Deep_BNN_mh_wl_opt_F12.keras")

    dw_real = pd.read_excel(input_path + "wl_down_events.xlsx")
    dw_real.drop(dw_real.columns[0], axis=1, inplace=True)
    dw_real_ = dw_real[dw_real['E'] == E][dw_real.columns[3:]].copy()

    I_SO_k = [init_OUT for i in range(F)]

    OUT_que = []
    GWL_que = []
    inf_mean_ls = []

    ## ====================== start MPC iterations ================================================================

    model = MPC_formula()
    solver = c_solver(model)

    GN_nn = tf.keras.models.load_model(input_path + "ANN_wl_opt_F1.keras")
    GN_scaler = pd.read_excel(input_path + "DCn_scaler.xlsx")
    GN_scaler = GN_scaler.iloc[1:-1]

    for k in range(ss_k, ee_k-1):

        QIN_k_df_ = Scenarios_DC_1S_(DC_inf_nn = DC_inf_nn, X_=DC_X_, E=E, F=F, k=k, n_clusters=n_cluster, T_scenarios=wn, type_ = type_)
        QIN_k_ = QIN_k_df_[QIN_k_df_.columns[:-4]].to_numpy()
        QIN_k_p_ = QIN_k_df_[QIN_k_df_.columns[-1]].to_numpy()

        inf_mean_ = QIN_k_.mean(axis=0)
        inf_mean_ls.append(inf_mean_)

        gpf_k_df_ = Scenarios_GP_1S_(GP_nn = GP_nn, X_=GP_X_, E=E, F=F, k=k, n_clusters=n_cluster, T_scenarios=wn, type_ = type_)
        gpf_k_ = gpf_k_df_[gpf_k_df_.columns[:-4]].to_numpy()
        gpf_k_p_ = gpf_k_df_[gpf_k_df_.columns[-1]].to_numpy()

        mhf_k_df_ = Scenarios_MH_1S_(MH_nn = MH_nn, X_=MH_X_, E=E, F=F, k=k, n_clusters=n_cluster, T_scenarios=wn, type_ = type_)
        mhf_k_ = mhf_k_df_[mhf_k_df_.columns[:-4]].to_numpy()
        mhf_k_p_ = mhf_k_df_[mhf_k_df_.columns[-1]].to_numpy()
        QIN_real_k_ = inf_real_.loc[k:k + F-1].to_list()
        Md_k_ = [O_demand for i in range(F)]

        if k == ss_k:
            I_GWL = dw_real_.loc[ss_k]['gn_wl']
        else:
            I_GWL = GWL_que[-1]

        dw_real_.loc[k, ('gn_wl')] = I_GWL

        if k-ss_k < 6:
            p_down_ = np.array([dw_real_[['gp_wl', 'mh_wl', 'gn_wl']].loc[ss_k].tolist() for i in range(6 - (k-ss_k))] + dw_real_[['gp_wl', 'mh_wl', 'gn_wl']].loc[:k + 1].to_numpy().tolist())
            OUT_tmp_ = [[init_OUT for i in range(F)] for j in range(6 - (k-ss_k))] + OUT_que[:k-ss_k]
        else:
            p_down_ = dw_real_[['gp_wl', 'mh_wl', 'gn_wl']].loc[k - 6:k + 1].to_numpy()
            OUT_tmp_ = OUT_que[k-ss_k - 6:]

        if k-ss_k == 0:
            max_INF = MT  # MT/0.63 = 419
            SO_p_k = I_SO_k
            I_RWS_k = LtoS(init_RWL)
        else:
            max_INF = MT

        etc_1 = 0.807587321097601 * p_down_[-1][2]

        State_k = {None: {'I_RWS': {None: I_RWS_k},
                          # 'w': {None: [i for i in range(len(weights))]},
                          'W1': {None: 1},
                          'W2': {None: 1},
                          'W3': {None: 1},
                          'W4': {None: 1},
                          'W5': {None: 1},
                          'W6': {None: 1},
                          'TWS_N': {None: TWS_N},
                          'TWS_L': {None: TWS_L},
                          'TWS_F': {None: TWS_F},
                          'TWS_H': {None: TWS_H},
                          'TWS_U': {None: TWS_U},
                          'FWS': {None: FWS},
                          'LWS': {None: LWS},
                          'NHWS': {None: NHWS},
                          'DCW_H': {None: DCW_H},
                          'I_G': {None: I_GWL},
                          't': {None: [i for i in range(F)]},
                          'w': {None: [i for i in range(len(QIN_k_p_))]},
                          'wp': {None: [i for i in range(len(gpf_k_p_))]},
                          'wm': {None: [i for i in range(len(mhf_k_p_))]},
                          'QIN': {(i, j): QIN_k_[j][i] for j in range(len(QIN_k_p_)) for i in range(F)},
                          'gpf': {(i, j): gpf_k_[j][i] for j in range(len(gpf_k_p_)) for i in range(F)},
                          'mhf': {(i, j): mhf_k_[j][i] for j in range(len(mhf_k_p_)) for i in range(F)},
                          'P_QIN': {j: QIN_k_p_[j] for j in range(len(QIN_k_p_))},
                          'P_gpf': {j: gpf_k_p_[j] for j in range(len(gpf_k_p_))},
                          'P_mhf': {j: mhf_k_p_[j] for j in range(len(mhf_k_p_))},
                          # 'etc_': {i: etc_[i] for i in range(F)},
                          'etc_': {None: etc_1},
                          'N_QIN': {None: sum(QIN_k_p_)},
                          'N_gpf': {None: sum(gpf_k_p_)},
                          'N_mhf': {None: sum(mhf_k_p_)},
                          'SO_P': {i: SO_p_k[i] for i in range(F)},
                          'SP_P': {i: max(SO_p_k[i] - MT, 0) for i in range(F)},
                          'max_INF': {None: max_INF},
                          'delta_SO_1': {None: delta_SO_1},
                          'beta1': {None: beta1},
                          'beta2': {None: beta2},
                          'F': {None: F},
                          'Md': {i: Md_k_[i] for i in range(F)}}}

        weights = Sol_to_Weight(Weight)
        State_k[None]['W1'][None] = weights[0]
        State_k[None]['W2'][None] = weights[1]
        State_k[None]['W3'][None] = weights[2]
        State_k[None]['W4'][None] = weights[3]
        State_k[None]['W5'][None] = weights[4]
        State_k[None]['W6'][None] = weights[5]

        inst, results = solver.call_solver(State_k)

        SO_n_k = [value(inst.SO[i]) for i in inst.t]


        RWL_que_tmp = []
        RWS_que_tmp = []
        for h in range(F):
            if h == 0:
                next_RWS = I_RWS_k + (QIN_real_k_[h] - round(SO_n_k[h])) * 3600
            else:
                next_RWS = next_RWS + (inf_mean_[h] - round(SO_n_k[h])) * 3600  # for training and testing(evaluating), apply real inflow
            RWS_que_tmp.append(next_RWS)
            RWL_que_tmp.append(round(StoL(next_RWS), 3))

        n_down_ = dw_real_[['gp_wl', 'mh_wl']].loc[k+1].to_list()
        X_real = [OUT_tmp_[-1][0], p_down_[-1][0], p_down_[-1][1], p_down_[-1][2], p_down_[-2][2], SO_n_k[0], n_down_[0],
                  n_down_[1]]
        X_real = [(X_real[i] - GN_scaler['min'].iloc[i]) / (GN_scaler['max'].iloc[i] - GN_scaler['min'].iloc[i]) for i in
                  range(len(X_real))]
        GWL_1 = GN_nn.predict([X_real])[0][0]

        ## =============================================================================================

        I_RWS_k = RWS_que_tmp[0]
        SO_p_k = [round(i) for i in SO_n_k]

        ## ============================================== to save detailed results ==================================

