import itertools
import pandas as pd
import numpy as np
from utils import maximize
import geopandas as gpd
from collections import Counter
import matplotlib.pyplot as plt
from collections import defaultdict
from sklearn.svm import OneClassSVM
from sklearn.model_selection import KFold
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split
from utils import trim_columns_for_each_experiment


pd.options.display.max_columns = 500
pd.options.display.width = 0

META_COLS =  ['datetime', 'longitude', 'latitude', 'safety_region', 'geometry']
DATA_COLS = ['svf-25m', 'ws_loc', 'wd_loc', 'ws_loc_6h', 'wd_loc_6h', 'BG2017', 'trees-pct', 'trees-height', 'spi1', 'spi3', 'spi6', 'top10nl_Height', 'dist-top10nl_waterbody_line', 'dist-top10nl_waterbody_surface', 'trees_pixel', 'trees_buff', 'closest_tree', 'avg_dist']

################
# Main program #
################

experiments = [("1.4", "landuse-cobra"), ("1.4", "no-landuse-cobra"), ("1.4", "landuse-no-cobra"), ("1.4", "no-landuse-no-cobra") ]
lmonname = ["January", "February", "March", "April", "May", "June", "July", "August", "September", "October", "November", "December"]

for experiment in experiments:
    features_version = experiment[0]
    name_exp = experiment[1]
    path_in_dmg = r"../data/features/Stormdamage_v1.4.csv"
    dic_suptitle = {"landuse-cobra":"v1.4 (Exp1: incl. land use and Cobra data)".format(features_version),
                    "no-landuse-cobra": "v1.4 (Exp2: no land use, incl. Cobra data)".format(features_version),
                    "landuse-no-cobra": "v1.4 (Exp3: incl. land use, no Cobra data)".format(features_version),
                    "no-landuse-no-cobra":"v1.4 (Exp4: no land use, no Cobra data)".format(features_version)}

    # Reading and pre-processing data
    # -------------------------------------------------------------
    print("\nReading data and sanity checks")
    print("-" * 80)
    df_dmg = pd.read_csv(path_in_dmg, sep=";", header=0, index_col=["rowid"], infer_datetime_format=True, parse_dates=["datetime"])
    print("Original shape: ", df_dmg.shape)
    df_dmg = df_dmg.mask(df_dmg["svf-25m"] < 0)
    print("Clean SVF: ", df_dmg.shape)
    df_dmg = df_dmg.dropna(axis=0) # Drops observations from 2023, since we have no SPI for them
    print("Drops rows with NaN: ", df_dmg.shape)
    print("Years in the train/test dataset: ", np.unique(df_dmg["datetime"].dt.year))

    df_dmg = trim_columns_for_each_experiment(name_exp, df_dmg)

    DATA_COLS_EXP = df_dmg.columns[~df_dmg.columns.isin(META_COLS)]
    print("Experiment: \t\t{0}".format(name_exp))
    print("Training columns: ", DATA_COLS_EXP.tolist())
    print("-" * 80)
    print()

    groups = df_dmg.groupby(df_dmg["datetime"].dt.month)
    dic_mon = defaultdict(object)
    for name, group in groups:
        dic_mon[name] = group

    ltotal = []

    # Modelling phase
    # ---------------------------------------------------------------

    training_sizes = np.arange(0.1, 1, 0.1)
    percent_outliers = [0.2, 0.15, 0.1, 0.05, 0.01, 0.005, 0.001]
    gamma_influence = np.arange(0.1, 5, 0.1)

    for key in sorted(dic_mon.keys()):
        print("\nProcessing: ", lmonname[key-1])
        df = dic_mon[key]

        # Separating data from metadata and scaling
        scaler = MinMaxScaler()
        arr_meta = df[META_COLS]
        arr_data = df[DATA_COLS_EXP].to_numpy()
        arr_scaled = scaler.fit_transform(arr_data)

        dic_board = defaultdict(list)

        for s in training_sizes:
            xtrain, xtest = train_test_split(arr_scaled, train_size=s, random_state=42)
            print("\tTraining with: ", round(s * 100), "%", "\tTrain/test: {0} {1}".format(xtrain.shape, xtest.shape))
            for p in percent_outliers:
                for g in gamma_influence:
                    clf = OneClassSVM(kernel='rbf', gamma=g, nu=p, shrinking=True).fit(xtrain)
                    pred = clf.predict(xtest)
                    dic_count = Counter(pred)
                    total_test_samples = xtest.shape[0]
                    positive_samples = dic_count[1]
                    percent_inliers = np.round(np.divide(100 * positive_samples, total_test_samples), decimals=1)
                    results_for_params = (p, g, percent_inliers)
                    ltotal.append(percent_inliers)
                    dic_board[s].append(results_for_params)

    print("Total average for experiment {0} = {1}".format(name_exp, np.round(np.mean(ltotal), decimals=0)))
    print("Total median for experiment {0} = {1}".format(name_exp, np.round(np.median(ltotal), decimals=0)))

    # Plotting results
    # ----------------------------------------------------------------

    nr = 2
    nc = 5
    rows = range(0, nr)
    cols = range(0, nc)
    pairs = list(itertools.product(rows, cols))
    xlinspace = np.linspace(10, 90, 9)

    fig, ax = plt.subplots(nrows=nr, ncols=nc, subplot_kw={"projection": "3d"})
    plt.subplots_adjust(wspace=0.5, hspace=0.1)
    plt.suptitle("per-month hyperparameter tuning (nu, gamma) - {0}".format(dic_suptitle[name_exp]), y=0.95, size=28)

    k = 0
    for s in training_sizes:
        r, c = pairs[k]
        results_per_training_size = dic_board[s]

        x_vals = [item[0] for item in results_per_training_size]
        y_vals = [item[1] for item in results_per_training_size]
        z_vals = [item[2] for item in results_per_training_size]

        l_packed = list(zip(x_vals, y_vals, z_vals))
        maximum = sorted(l_packed, key=lambda x: x[2], reverse=True)[0]

        plot = ax[r, c].scatter(x_vals, y_vals, z_vals, c=z_vals, vmin=60, vmax=100)
        ax[r, c].set_title("Training data: {0}% \n Max: (ν={1}, ɣ={2}, in={3}%)".format(int(s * 100), maximum[0], maximum[1], int(maximum[2])))
        ax[r, c].set_zlim(0, 100)
        ax[r, c].set_xlabel("pct outliers (ν)", size=14, labelpad=15)
        ax[r, c].set_ylabel("influence (ɣ)", size=14, labelpad=15)
        ax[r, c].set_zlabel("pct inliers", size=14, labelpad=20)
        ax[r, c].tick_params(axis='x', labelsize=12)
        ax[r, c].tick_params(axis='y', labelsize=12)
        ax[r, c].axvline(x=1, ymin=0.5, ymax=1)
        fig.colorbar(plot, ax=ax[r, c], shrink=0.70, pad=0.2, orientation="horizontal")


        k += 1

    fig.delaxes(ax[1, 4])
    # plt.show()
    maximize()



