"""
@author: j.h.koo@tudelf.nl
"""

import pandas as pd
import numpy as np
from tensorflow import keras

result_path = "D:\\"
input_path = "D:\\"
DNN_model_path = "D:\\"
Scaler_path = "D:\\"


try:
    LV_curve = pd.read_csv(r"/home/koojah/RL/LV_curve.csv")
except:
    LV_curve = pd.read_csv(r"LV_curve.csv")

def LtoS(lvl):
    sto = np.interp(lvl, LV_curve['0'], LV_curve['1'])
    return sto

def StoL(sto):
    lvl = np.interp(sto, LV_curve['1'], LV_curve['0'])
    return lvl


MT = 264  # the maximun discharge via turbines
MSP = int(11680)  # the maximum discharge via spillway gates
LWL = 60.0
FWL = 80.0
LWS = int(LtoS(LWL))
FWS = int(LtoS(FWL))

IWL = 76.5  # initial water level
IWS = int(LtoS(IWL))
TWL = 76.5  # Target water level at the end of the episode
TWL_N = 76.5
TWL_F = 78.5
TWL_L = 76.0
TWS = int(LtoS(TWL))
TWS_F = int(LtoS(TWL_F))
TWS_L = int(LtoS(TWL_L))
TWS_N = int(LtoS(TWL_N))

I_O = 150
O_demand = 52

### ================================================================================
scaler_ = pd.read_excel(Scaler_path + 'Collected_scaler_train.xlsx', index_col=0)
scaler_max = scaler_['Max'][3:16].to_numpy()
scaler_min = scaler_['Min'][3:16].to_numpy()

### ================================================================================
dnnR = keras.models.load_model(DNN_model_path + 'O_DNN.keras')
### ================================================================================

def MPC_main(F, E, PP, random_it, init_RWL = 76.5):

    inf_ = pd.read_csv(input_path + "dc_F{}_E{}_{}_{}.csv".format(F, E, PP, random_it))
    inf_real = pd.read_csv(input_path + "dc_F{}_E{}_{}.csv".format(F, E, 'PT'))
    inf_.drop(inf_.columns[0], axis=1, inplace=True)
    inf_real.drop(inf_real.columns[0], axis=1, inplace=True)

    max_step = len(inf_) - F - 1

    I_SO_k = [I_O for i in range(F)]

    OUT_que = []
    N_C_1 = []
    St_ = []

    for k in range(0, max_step):

        QIN_k_ = inf_.iloc[k].to_list()
        QIN_real_k_ = inf_real.iloc[k].to_list()

        if k == 0:
            SO_p_k = I_SO_k
            I_RWS_k = LtoS(init_RWL)

        St_.append([StoL(I_RWS_k)] + QIN_k_ + SO_p_k)

        state_X = np.array([StoL(I_RWS_k)] + QIN_k_ + SO_p_k)
        state_Xs = (state_X - scaler_min) / (scaler_max - scaler_min)
        O_predictions = dnnR.predict([state_Xs.tolist()])

        OUT_predictions = [round(i) for i in O_predictions[0].tolist()]
        SO_n_k = [SO_p_k[1]] + OUT_predictions

        OUT_que.append(SO_n_k)

        E_RWL_k = []
        E_RWS_k = []
        for h in range(F):
            next_RWS = I_RWS_k + (QIN_k_[h] - SO_n_k[h]) * 3600
            E_RWS_k.append(next_RWS)
            E_RWL_k.append(round(StoL(next_RWS), 3))

        RWL_que_tmp = []
        RWS_que_tmp = []
        for h in range(F):
            next_RWS = I_RWS_k + (QIN_real_k_[h] - SO_n_k[h]) * 3600
            RWS_que_tmp.append(next_RWS)
            RWL_que_tmp.append(round(StoL(next_RWS), 3))

        if [round(i) for i in SO_p_k[1:]] == [round(i) for i in SO_n_k[:-1]]:
            N_C_1.append(0)
        else:
            N_C_1.append(1)

        I_RWS_k = RWS_que_tmp[0]
        SO_p_k = SO_n_k
