# -*- coding: utf-8 -*-
"""
Created on Sun Jan 30 14:39:34 2022

@author: curran
"""

import pandas as pd
import numpy as np
import math
import matplotlib.pyplot as plt
pd.options.mode.chained_assignment = None  # default='warn'. Needed for reordering
from scipy.stats import f_oneway
#import statsmodels.api as sm
#from scipy.stats import f_oneway


def GetSummaryStats(EMDAT):
    print('Events between 1975 and X: ' + str(len(EMDAT[EMDAT['Year'] >= 1975])))
    print('Extreme Weather events: '+ str(len(EMDAT[EMDAT['Disaster Category'].isin(['Storm','Landslide','Drought','Wildfire','Extreme temperature','Flood'])])))
    print('Storms: '+ str(len(EMDAT[EMDAT['Disaster Category'] == 'Storm'])))
    print('Landslide: '+ str(len(EMDAT[EMDAT['Disaster Category'] == 'Landslide'])))
    print('Drought: '+ str(len(EMDAT[EMDAT['Disaster Category'] == 'Drought'])))
    print('Wildfire: '+ str(len(EMDAT[EMDAT['Disaster Category'] == 'Wildfire'])))
    print('Extreme Temperature: '+ str(len(EMDAT[EMDAT['Disaster Category'] == 'Extreme temperature '])))
    print('Flood: '+ str(len(EMDAT[EMDAT['Disaster Category'] == 'Flood'])))
    print('Riverine Flood: '+ str(len(EMDAT[EMDAT['Flood Category'] == 'Riverine Flood'])))
    print('Coastal Flood: '+ str(len(EMDAT[EMDAT['Flood Category'] == 'Coastal Flood'])))
    print('Geophysical Flood: '+ str(len(EMDAT[EMDAT['Flood Category'] == 'Geophysical Flood'])))
    print('Flash/pluvial Flood: '+ str(len(EMDAT[EMDAT['Flood Category'] == 'Flash/pluvial Flood'])))
    print('Uncategorised Flood: '+ str(len(EMDAT[EMDAT['Flood Category'] == 'Uncategorised Flood'])))
    print('Events with Fatailities reported: ' + str(len(EMDAT[EMDAT['Total Deaths'].notnull()])))
    print('Events with Affected reported: ' + str(len(EMDAT[EMDAT['Total Affected'].notnull()])))
    print('Events with Fatalities and Affected reported: ' + str(len(EMDAT[EMDAT['Total Affected'].notnull() | EMDAT['Total Affected'].notnull()])))
    print('Events with Damage reported: ' + str(len(EMDAT[EMDAT['Total Damages (\'000 US$)'].notnull()])))
    
    return

def GetIndividualStats(EMDAT):
    cols = ['Year', 'Total Deaths','Disaster Category','Country','Disaster Subtype','Flood Category','Total Damages (\'000 US$)','Event Name','Total Affected']
    EMDATSub1 = EMDAT
    EMDATSub2 = EMDAT[EMDAT['Disaster Category'].isin(['Drought','Flood','Landslide','Wildfire','Storm','Extreme temperature '])]
    EMDATSub3 = EMDAT[EMDAT['Disaster Category'].isin(['Flood'])]
    EMDATSub4 = EMDAT[EMDAT['Flood Category'].isin(['Flash/pluvial Flood','Riverine Flood','Uncategorised Flood'])]

    LargestEvents = EMDAT.loc[pd.DataFrame(EMDATSub1['Total Deaths'].dropna()).astype(float).nlargest(5,'Total Deaths',keep='first').index,:][cols]
    LargestExWeEvents = EMDAT.loc[pd.DataFrame(EMDATSub2['Total Deaths'].dropna()).astype(float).nlargest(5,'Total Deaths',keep='first').index,:][cols]
    LargestFloodEvents = EMDAT.loc[pd.DataFrame(EMDATSub3['Total Deaths'].dropna()).astype(float).nlargest(5,'Total Deaths',keep='first').index,:][cols]
    LargestFreshFloodEvents = EMDAT.loc[pd.DataFrame(EMDATSub4['Total Deaths'].dropna()).astype(float).nlargest(5,'Total Deaths',keep='first').index,:][cols]
    
    return(LargestEvents, LargestExWeEvents, LargestFloodEvents, LargestFreshFloodEvents)

def ThreshFilter(EMDAT,Thresholds):
    # Create a string to be evaulated as a filter
    EMDAT_Filt = EMDAT
    for FP in range(0,len(Thresholds)): # create each part of the filter string
        FilterParams = Thresholds[FP]
        NumFilters = int((len(FilterParams)+1)/4)
        FilterString = 'EMDAT_Filt.loc[(EMDAT_Filt[\"' + FilterParams[0] + '\"]' + FilterParams[1] + str(FilterParams[2]) + ')'
        for SubFilter in range(1,NumFilters):
           FilterString = FilterString + FilterParams[(SubFilter*4)-1] + '(EMDAT_Filt[\"' + FilterParams[(SubFilter*4)] + '\"]' + FilterParams[(SubFilter*4)+1] + str(FilterParams[(SubFilter*4)+2]) + ')'
        FilterString = FilterString + ']'        
        EMDAT_Filt = eval(FilterString)

    # Ensure total affected is at least same value as deaths and homeless
    for row in range(0,len(EMDAT)): #len(EMDataRec)
        if EMDAT.iloc[row,:]['Total Affected'] < np.nansum([EMDAT.iloc[row,:]['Total Deaths'],EMDAT.iloc[row,:]['No Injured'],EMDAT.iloc[row,:]['No Affected'],EMDAT.iloc[row,:]['No Homeless']]):
            EMDAT.iloc[row, EMDAT.columns.get_loc('Total Affected')] = max(np.nan_to_num(EMDAT.iloc[row,:]['Total Affected']),np.nan_to_num(EMDAT.iloc[row,:]['Total Deaths']) + np.nan_to_num(EMDAT.iloc[row,:]['No Injured']) + np.nan_to_num(EMDAT.iloc[row,:]['No Affected']) + np.nan_to_num(EMDAT.iloc[row,:]['No Homeless']))
        # Recalculate mortality
        EMDAT.iloc[row,EMDAT.columns.get_loc('Mortality')] = EMDAT.iloc[row,:]['Total Deaths']/EMDAT.iloc[row,:]['Total Affected']  
    
    return EMDAT_Filt

def Create_EMDAT_Event(EMDAT_Rec):

    # Re-sort to make joining easier
    EMDAT_Rec.sort_values(by=['Dis No'], inplace=True)
    EMDAT_Rec = EMDAT_Rec.reset_index()
    del EMDAT_Rec['index']
    
    # New blank dataframe
    EMDAT_Evn = pd.DataFrame(index = range(0,len(pd.unique(EMDAT_Rec['Dis No'].str.slice(0, 9)))), columns = list(EMDAT_Rec.columns.values))

    # Group by index number
    AffectedCountries = []
    TotalDeathList = []
    TotalAffList = []
    TotalGDPList = []
    DisCats = []
    FloodCats = []
    IncomeList = []
    count = 0
    #Scroll through records and join as appropriate
    for row in range(0,len(EMDAT_Rec)): #len(EMDataRec)
        # if EMDAT_Rec.iloc[row-1,0][0:9] == '2013-0433':
        #     break
        # # Ensure total affected is at least same value as deaths
        # if EMDAT_Rec.iloc[row,:]['Total Affected'] < EMDAT_Rec.iloc[row,:]['Total Deaths']:
        #     EMDAT_Rec.at[row,'Total Affected'] = max(np.nan_to_num(EMDAT_Rec.iloc[row,:]['Total Affected']), np.nan_to_num(EMDAT_Rec.iloc[row,:]['Total Deaths']), np.nan_to_num(EMDAT_Rec.iloc[row,:]['No Injured']), np.nan_to_num(EMDAT_Rec.iloc[row,:]['No Affected']))
            
        if row == 0: # First entry: direct copy
            EMDAT_Evn.iloc[count,:] = EMDAT_Rec.iloc[row,:]
            EMDAT_Evn.iloc[count,0] = EMDAT_Evn.iloc[count,0][0:9]
            AffectedCountries.append(EMDAT_Rec['Country'][row])
            TotalDeathList.append(EMDAT_Rec['Total Deaths'][row])
            TotalAffList.append(EMDAT_Rec['Total Affected'][row])
            TotalGDPList.append(EMDAT_Rec['GDP ($Bln) During Event'][row]) # Combined GDP
            DisCats.append(EMDAT_Rec['Disaster Category'][row]) # List of disaster categories
            FloodCats.append(EMDAT_Rec['Flood Category'][row]) # List of flood categories
            IncomeList.append(EMDAT_Rec['Flood Category'][row]) # Income
            continue
        else:
            # Check if the events repeat over entries. If so, overwrite
            if EMDAT_Rec.iloc[row,0][0:9] == EMDAT_Rec.iloc[row-1,0][0:9]:
                # List all affected countries
                AffectedCountries.append(EMDAT_Rec['Country'][row])
                # Sum numerical data if available, otherwise nan
                TotalDeathList.append(EMDAT_Rec['Total Deaths'][row])
                TotalDeaths = float('NaN') if all(np.isnan(TotalDeathList)) else sum([0 if math.isnan(x) else x for x in TotalDeathList])
                TotalAffList.append(EMDAT_Rec['Total Affected'][row])
                TotalAff = float('NaN') if all(np.isnan(TotalAffList)) else sum([0 if math.isnan(x) else x for x in TotalAffList])
                TotalGDPList.append(EMDAT_Rec['GDP ($Bln) During Event'][row])
                TotalGDP = float('NaN') if all(np.isnan(TotalGDPList)) else sum([0 if math.isnan(x) else x for x in TotalGDPList])                
                DisCats.append(EMDAT_Rec['Disaster Category'][row]) # List of disaster categories
                FloodCats.append(EMDAT_Rec['Flood Category'][row]) # List of flood categories
                try:
                    FloodCat = FloodCats[np.nanargmax(TotalDeathList)] # Get categorisation with highest death toll
                    DisCat = DisCats[np.nanargmax(TotalDeathList)] # Get categorisation with highest death toll
                except ValueError:
                    FloodCat = FloodCats[0]
                    DisCat = DisCats[0]
                # Overwrite this data into entry
                EMDAT_Evn['Country'][count] = AffectedCountries
                EMDAT_Evn['Total Deaths'][count] = TotalDeaths
                EMDAT_Evn['Total Affected'][count] = TotalAff
                EMDAT_Evn['GDP ($Bln) During Event'][count] = TotalGDP
                EMDAT_Evn['Disaster Category'][count] = DisCat
                EMDAT_Evn['Flood Category'][count] = FloodCat
                #print(str(row) + ' Merged')
            else: # Otherwise, new row of event dataframe, and wipe accumulated data
                count += 1
                EMDAT_Evn.iloc[count,:] = EMDAT_Rec.iloc[row,:]
                EMDAT_Evn.iloc[count,0] = EMDAT_Evn.iloc[count,0][0:9]
                AffectedCountries = []
                TotalDeathList = []
                TotalAffList = []
                TotalGDPList = []
                DisCats = []
                FloodCats = []
                # Store data in case next is repeat
                AffectedCountries.append(EMDAT_Rec['Country'][row])
                TotalDeathList.append(EMDAT_Rec['Total Deaths'][row])
                TotalAffList.append(EMDAT_Rec['Total Affected'][row])
                TotalGDPList.append(EMDAT_Rec['GDP ($Bln) During Event'][row])              
                DisCats.append(EMDAT_Rec['Disaster Category'][row]) # List of disaster categories
                FloodCats.append(EMDAT_Rec['Flood Category'][row]) # List of flood categories

        # Recalculate mortality
        EMDAT_Evn['Mortality'][count] = EMDAT_Evn.loc[count,'Total Deaths']/EMDAT_Evn.loc[count,'Total Affected']
              
    return EMDAT_Evn

def CategoriseFloodEvents(EMDAT,CategoryRules):

    # Add flood category field using flowchart logic (event data)
    EMDAT['Disaster Category'] = ''
    EMDAT['Flood Category'] = ''
    # Change nan values so that they can be read
    EMDAT['Disaster Subtype'] = EMDAT['Disaster Subtype'].fillna('nan')
    # Create the overall string for the rules and execute

    count = 0
    RuleStringCommand =[ [] for _ in range(3) ]
    for RuleSet in CategoryRules:
        rule1 = CategoryRules[count][0]
        RuleStringCommand[count] = 'if (EMDAT[\'' + rule1[0] + '\'][event] ' + rule1[1]+ ' \'' + rule1[2]+ '\') ' + rule1[3] + \
            ' (EMDAT[\'' + rule1[4] + '\'][event] ' + rule1[5] + ' ' + str(rule1[6]) + '): ' + \
            'EMDAT[\'Flood Category\'][event] = \''  + rule1[7] + '\''
        for R in range(1,len(RuleSet)):
            rule = RuleSet[R]
            RuleStringCommand[count] = RuleStringCommand[count] + '\nelif (EMDAT[\'' + rule[0] + '\'][event] ' + rule[1]+ ' \'' + rule[2] + '\') ' + rule[3] + \
                ' (EMDAT[\'' + rule[4] + '\'][event] ' + rule[5] + ' ' + str(rule[6]) + '): ' + \
                'EMDAT[\'Flood Category\'][event] = \''  + rule[7] + '\''
        for event in range(0,len(EMDAT)):
            exec(RuleStringCommand[count])
            if not EMDAT['Flood Category'][event]:
                EMDAT['Disaster Category'][event] = EMDAT['Disaster Type'][event]
            else:
                EMDAT['Disaster Category'][event] = 'Flood'
        count += 1
        
        
    return EMDAT

def CalcAddition(EMDAT,GDPData,IncomeData,QCData):
    EMDAT['GDP ($Bln) During Event'] = ''
    EMDAT['Income'] = ''
    EMDAT['Adjusted Income'] = ''
    EMDAT['Mortality'] = ''
    EMDAT['Relative Damage %'] = ''
    EMDAT['Adjusted Damage'] = ''
    
    # CPI data appears to be missing for 2020 and 2021: assume 100
    EMDAT['CPI'] = EMDAT['CPI'].fillna(100)

    # Ensure total affected is at least same value as deaths
    for row in range(0,len(EMDAT)): #len(EMDataRec)
        if EMDAT.iloc[row,:]['Total Affected'] < EMDAT.iloc[row,:]['Total Deaths']:
            EMDAT.at[row,'Total Affected'] = max(np.nan_to_num(EMDAT.iloc[row,:]['Total Affected']),np.nan_to_num(EMDAT.iloc[row,:]['Total Deaths']),np.nan_to_num(EMDAT.iloc[row,:]['No Affected']),np.nan_to_num(EMDAT.iloc[row,:]['No Injured']))
  
    # Add/update QCed events
    for index,row in QCData.iterrows():
    
        #Check if event already exists
        if row['EM-DAT-ISO'] in list(EMDAT['Dis No']): # Full record found: update from sheet
        # Categorisation
            EMDAT.loc[EMDAT['Dis No'] == row['EM-DAT-ISO'],'Disaster Subsubtype'] = row['Disaster Subsubtype']
            EMDAT.loc[EMDAT['Dis No'] == row['EM-DAT-ISO'],'Disaster Category'] = 'Flood'
            EMDAT.loc[EMDAT['Dis No'] == row['EM-DAT-ISO'],'Flood Category'] = 'Coastal Flood'
            # Values: Only use non-zero QC data
            if row['Total deaths'] != 0:
                EMDAT.loc[EMDAT['Dis No'] == row['EM-DAT-ISO'],'Total Deaths'] = row['Total deaths']
            if row['Total affected population'] != 0:
                EMDAT.loc[EMDAT['Dis No'] == row['EM-DAT-ISO'],'Total Affected'] = row['Total affected population']
            if row['Est. Damage (US$ Million)'] != 0:
                EMDAT.loc[EMDAT['Dis No'] == row['EM-DAT-ISO'],'Total Damages (\'000 US$)'] = row['Est. Damage (US$ Million)']*1000
        elif any(EMDAT['Dis No'].str.contains(row['EM-DAT'], case=False)): # Check if event exists
            #print('Event doesn\'t include country')
            # Copy all values of existing entry (wrong country)
            NewEMDATA_Entry = EMDAT.loc[EMDAT['Dis No'].str.contains(row['EM-DAT'], case=False)].iloc[0,:].to_dict()
            
            # Update with values from QC Table (Only CPI is missing and must be estimated) and append
            CountryCPIs = EMDAT[['CPI','Year']][(EMDAT['ISO'] == row['ISO'])]
            CPI = CountryCPIs.loc[CountryCPIs['Year'].sub(row['Year']).abs().idxmin(),'CPI']

            NewEMDATA_Entry.update({'Dis No':row['EM-DAT-ISO'], 'Year':row['Year'], 'Seq':row['EM-DAT'], 'Disaster Group': 'Natural',\
                               'Disaster Subgroup': 'Meteorological','Disaster Type': 'Flood','Disaster Subtype': 'Coastal flood',\
                               'Disaster Subsubtype': row['Disaster Subsubtype'], 'Country': row['Country'], 'ISO': row['ISO'],\
                               'Region': row['Region_EM'], 'Continent': row['Continent_EM'], 'Start Year': row['Year'],'Start Month': row['Start Month'],\
                               'Start Day': row['Start Date'], 'Total Affected': row['Total affected population'], 'Total Deaths': row['Total deaths'],\
                               'Total Damages (\'000 US$)': row['Est. Damage (US$ Million)']*1000,'CPI': CPI,'Disaster Category': 'Flood','Flood Category':'Coastal Flood'})
            EMDAT.append(NewEMDATA_Entry, ignore_index=True)
        else:
            #print('Completely new event')
            # Update with values from QC Table (Only CPI is missing and must be estimated) and append
            CountryCPIs = EMDAT[['CPI','Year']][(EMDAT['ISO'] == row['ISO'])]
            CPI = CountryCPIs.loc[CountryCPIs['Year'].sub(row['Year']).abs().idxmin(),'CPI']
            # Generate row and append
            NewEMDATA_Entry = {'Dis No':row['EM-DAT-ISO'], 'Year':row['Year'], 'Seq':row['EM-DAT'], 'Disaster Group': 'Natural',\
                               'Disaster Subgroup': 'Meteorological','Disaster Type': 'Flood','Disaster Subtype': 'Coastal flood',\
                               'Disaster Subsubtype': row['Disaster Subsubtype'], 'Country': row['Country'], 'ISO': row['ISO'],\
                               'Region': row['Region_EM'], 'Continent': row['Continent_EM'], 'Start Year': row['Year'],'Start Month': row['Start Month'],\
                               'Start Day': row['Start Date'], 'Total Affected': row['Total affected population'], 'Total Deaths': row['Total deaths'],\
                               'Total Damages (\'000 US$)': row['Est. Damage (US$ Million)']*1000,'CPI': CPI,'Disaster Category': 'Flood','Flood Category':'Coastal Flood'}
            EMDAT = EMDAT.append(NewEMDATA_Entry, ignore_index=True)    
    
    # Add GDP, Income, mortality and damage Data
    for event in range(0,len(EMDAT)):
        Year = EMDAT.loc[event,'Dis No'][0:4]
        CountryCode = EMDAT.loc[event,'Dis No'][10:13]
        # GDP Data
        try: 
            EMDAT.loc[event,'GDP ($Bln) During Event'] = float(GDPData.loc[GDPData['Country Code'] == CountryCode][Year + ' [YR' + Year + ']'])/1000000000
        except (KeyError, ValueError, TypeError):
            EMDAT.loc[event,'GDP ($Bln) During Event'] = np.nan
        # Income Data
        try: 
            EMDAT.loc[event,'Income'] = float(IncomeData.loc[IncomeData['Country Code'] == CountryCode,str(Year)])
            EMDAT.loc[event,'Adjusted Income'] = (EMDAT.loc[event,'Income']/(EMDAT.loc[event,'CPI']))*100
        except (KeyError, TypeError):
            EMDAT.loc[event,'Income'] = np.nan
            EMDAT.loc[event,'Adjusted Income'] = np.nan
        # Mortality Data
        EMDAT.loc[event,'Mortality'] = EMDAT.loc[event,'Total Deaths']/EMDAT.loc[event,'Total Affected']
        # Relative Damage Data
        EMDAT.loc[event,'Relative Damage %'] = ((EMDAT.loc[event,'Total Damages (\'000 US$)']*1000)/(EMDAT.loc[event,'GDP ($Bln) During Event']*1000000000))*100        
        # Adjusted Damage Data
        EMDAT.loc[event,'Adjusted Damage'] = (EMDAT.loc[event,'Total Damages (\'000 US$)']/EMDAT.loc[event,'CPI'])*100         

    return EMDAT

def GraphOptions(EMDAT,EMDataPath):

    # Choose a graph type
    GraphNum = input('Choose a chart type:\n\
    1. Bar Chart\n\
    2. Stacked Bar Chart\n\
    3. Yearly trend plot\n\
    4. Annual Exceedence Frequency curves \n\
    5. Probability Exceedence Frequency curves (per event)\n\
    6. Scatter plot (all events)\n\
    7. Box whisker plot\n\
    8. Box whisker plot comparison\n\
    Number: ')
    
    # Variables for bar charts
    if int(GraphNum) in [1,2]:
        # X-axis
        XAxisBars = input('Define the X-axis variable:\n\
        1. Year\n\
        2. Disaster Category (Extreme weather)\n\
        3. Flood Category (data should be filtered for floods only)\n\
        4. World Region\n\
        5. Continent\n\
        6. Disaster Category\n\
        Number: ')
        if int(XAxisBars) == 1:
            XValColumnString = 'Year'
            XValues = list(pd.unique(EMDAT[XValColumnString])) 
        elif int(XAxisBars) == 2:
            XValColumnString = 'Disaster Category'
            XValues = ['Flood','Storm','Landslide','Extreme temperature','Wildfire','Drought']
        elif int(XAxisBars) == 3:
            XValColumnString = 'Flood Category'
            XValues = list(pd.unique(EMDAT[XValColumnString])) 
        elif int(XAxisBars) == 4:
            XValColumnString = 'Region'
            XValues = list(pd.unique(EMDAT[XValColumnString])) 
        elif int(XAxisBars) == 5:
            XValColumnString = 'Continent'
            XValues = list(pd.unique(EMDAT[XValColumnString]))
        elif int(XAxisBars) == 6:
            XValColumnString = 'Disaster Category'
            XValues = list(pd.unique(EMDAT[XValColumnString]))          
        
        # Y-axis
        YAxisValues = input('Define the Y-axis variable:\n\
        1. Number of events\n\
        2. Total Damage\n\
        3. Total Affected\n\
        4. Total Deaths\n\
        5. Adjusted Damage\n\
        Variable: ')
        if int(YAxisValues) == 1:
            YValColumnString = 'Number of events'
        elif int(YAxisValues) == 2:
            YValColumnString = 'Total Damages (\'000 US$)'
        elif int(YAxisValues) == 3:
            YValColumnString = 'Total Affected'
        elif int(YAxisValues) == 4:
            YValColumnString = 'Total Deaths'
        elif int(YAxisValues) == 5:
            YValColumnString = 'Adjusted Damage'
    
    # Extra variable for categorisation of stacks
    if int(GraphNum) == 2:
        XAxisCategories = input('Define the X-axis variable to categorise stacks:\n\
        1. Disaster Category\n\
        2. World Region\n\
        3. Flood Category (data should be filtered for floods only)\n\
        4. Continent\n\
        Variable: ')
        if int(XAxisCategories) == 1:
            XCatColumnString = 'Disaster Category'
            XCatValues = list(pd.unique(EMDAT[XCatColumnString]))
        elif int(XAxisCategories) == 2:
            XCatColumnString = 'Region'
            XCatValues = list(pd.unique(EMDAT[XCatColumnString])) 
        elif int(XAxisCategories) == 3:
            XCatColumnString = 'Flood Category'
            XCatValues = list(pd.unique(EMDAT[XCatColumnString])) 
        elif int(XAxisCategories) == 4:
            XCatColumnString = 'Continent'
            XCatValues = list(pd.unique(EMDAT[XCatColumnString])) 


    # Generate basic bar chart
    if int(GraphNum) == 1:    
        # Generate data for each stack, for the Y-variable
        YperX = {}
        count = 0
        XCatValues = XValues
        CoverStack = (len(XCatValues)+1)*[len(XValues) *[0]]
        if int(YAxisValues) == 1: # Count the number of events
            YperX[0] = [len(EMDAT[(EMDAT[XValColumnString]==XVal)]) for XVal in XValues]
        elif int(YAxisValues) in [2,3,4]: # Sum the value in each event (assume nan is 0)
            YperX[0] = [sum(EMDAT[YValColumnString][(EMDAT[XValColumnString]==XVal)].fillna(0)) for XVal in XValues]

    # Generate stacked bar chart data
    if int(GraphNum) == 2:    
        # Generate data for each stack, for the Y-variable
        YperX = {}
        count = 0
        CoverStack = (len(XCatValues)+1)*[len(XValues) *[0]]
        for XCatValue in XCatValues:
            count +=1
            if int(YAxisValues) == 1: # Count the number of events
                YperX[XCatValue] = [len(EMDAT[(EMDAT[XCatColumnString]==XCatValue) & (EMDAT[XValColumnString]==XVal)]) for XVal in XValues]
                CoverStack[count] = np.add(CoverStack[count-1], YperX[XCatValue]).tolist() 
            elif int(YAxisValues) in [2,3,4,5]: # Sum the value in each event (assume nan is 0)
                YperX[XCatValue] = [sum(EMDAT[YValColumnString][(EMDAT[XCatColumnString]==XCatValue) & (EMDAT[XValColumnString]==XVal)].fillna(0)) for XVal in XValues]
                CoverStack[count] = np.add(CoverStack[count-1], YperX[XCatValue]).tolist() 
            
    # Plot (stacked) bar chart
    if int(GraphNum) in [1,2]: 
        colors = {'Original data':'skyblue', 'Updated data':'b'} 
        labels = list(colors.keys()) 
        fig, ax = plt.subplots()
        ax.set_xlabel(XValColumnString)
        ax.set_ylabel(YValColumnString)
        tit = fig.suptitle(YValColumnString + ' per ' + XValColumnString + ' in dataset', fontsize=16)
        #ax.set_xticks(list(Years[0::5]))
        if int(GraphNum) == 1:
            plt.bar(XValues, YperX[0],width = 0.65)
        if int(GraphNum) == 2: # Add stacks
            count = 0
            for XCatValue in XCatValues: 
                plt.bar(XValues, YperX[XCatValue],width = 0.65, bottom=CoverStack[count], label = XCatValue)
                count +=1
        plt.legend()
        lgd = ax.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
        if len(XValues) > 10:
            fig.autofmt_xdate(rotation=45)
        plt.show()

        # To do: Fix legend etc

    # Variables for scatter/trend plot
    if int(GraphNum) == 3:
        # X-axis
        print('X-Axis will be years of provided dataset')
        XValColumnString = 'Year'
        XValues = list(pd.unique(EMDAT[XValColumnString]))  
        
        # Y-axis
        YAxisValues = input('Define the Y-axis variable:\n\
        1. Number of events\n\
        2. Damage\n\
        3. Affected\n\
        4. Deaths\n\
        5. Adjusted Damage\n\
        6. Mortality\n\
        Variable: ')
        
        # Get X variables (Nr per year is needed for averages anyway)
        NrXprYr = [len(EMDAT[(EMDAT[XValColumnString]==yr)]) for yr in XValues]
        if int(YAxisValues) == 1:
            YValColumnString ='Events'
            TotXprYr = NrXprYr
            TrendTotXPrYr = np.polyfit(XValues,np.log(NrXprYr), 1)
            AvgOrTot = 1
        elif int(YAxisValues) == 6: # Mortality
            YValColumnString ='Mortality'
            AvgOrTot = input('Choose plotting option:\n\
            1. Total per year\n\
            2. Average per event per year\n\
            Option: ')
            TotXprYr = [sum(EMDAT['Total Deaths'][(EMDAT['Year']==yr)].fillna(0))/sum(EMDAT['Total Affected'][(EMDAT['Year']==yr)].fillna(0)) for yr in XValues]
            TrendTotXPrYr = np.polyfit(XValues,np.log(TotXprYr), 1)
            AvXprYr = [i/j for i,j in zip(TotXprYr, NrXprYr)]
            TrendAvXPrYr = np.polyfit(XValues,np.log([x or 1 for x in AvXprYr]), 1) # can't accept 0 values
            #AvgOrTot = 1
        else:
            if int(YAxisValues) == 2:
                YValColumnString = 'Total Damages (\'000 US$)'
            elif int(YAxisValues) == 3:
                YValColumnString = 'Total Affected'
            elif int(YAxisValues) == 4:
                YValColumnString = 'Total Deaths'
            elif int(YAxisValues) == 5:
                YValColumnString = 'Adjusted Damage'
            AvgOrTot = input('Choose plotting option:\n\
            1. Total per year\n\
            2. Average per event per year\n\
            3. Both\n\
            Option: ')
                
            TotXprYr = [sum(EMDAT[YValColumnString][(EMDAT['Year']==yr)].fillna(0)) for yr in XValues]
            TrendTotXPrYr = np.polyfit(XValues,np.log([x or 1 for x in TotXprYr]), 1) # can't accept 0 values
            AvXprYr = [i/j for i,j in zip(TotXprYr, NrXprYr)]
            TrendAvXPrYr = np.polyfit(XValues,np.log([x or 1 for x in AvXprYr]), 1) # can't accept 0 values

        #  Generate plot
        fig, ax = plt.subplots()
        ax.set_xlabel('Year')
        ax.set_ylabel(YValColumnString)
        tit = fig.suptitle(YValColumnString + ' per ' + XValColumnString + ' in dataset', fontsize=16)
        if int(AvgOrTot) in [1,3]:
            # Total affected per year
            ax.scatter(XValues,TotXprYr, s=10, c='b', marker="s", label='Overall ' + YValColumnString)
            plt.plot(XValues,np.exp(TrendTotXPrYr[1])*np.exp(TrendTotXPrYr[0]*np.asarray(XValues)),"b-", label='Overall ' + YValColumnString + ' trend')
        if int(AvgOrTot) in [2,3]:
            # Average affected per event per year - if applicable              
            ax.scatter(XValues,AvXprYr, s=10, c='r', marker="s", label='Average ' + YValColumnString + ' per event')
            plt.plot(XValues,np.exp(TrendAvXPrYr[1])*np.exp(TrendAvXPrYr[0]*np.asarray(XValues)),"r-", label='Average ' + YValColumnString + ' trend')
        lgd = ax.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
        ax.set_yscale('log')
        plt.show()
        plt.draw()        

    # Variables for exceedence of impact per year graph
    if int(GraphNum) == 4:
        # X-axis
        print('Y-Axis will be yearly exceedence probabilty, based on provided data')
        print('N.B. If Nan values are not filtered out, they will be assumed = 0')
        YValColumnString = 'Year'
        #YValues = list(pd.unique(EMDAT[YValColumnString]))  
        
        # Y-axis
        XAxisValues = input('Define the X-axis variable:\n\
        1. Total Damage\n\
        2. Total Affected\n\
        3. Total Deaths\n\
        4. Adjusted Damage\n\
        5. Mortality\n\
        Variable: ')
        # if int(XAxisValues) == 1:
        #     XValColumnString = 'Number of events'
        if int(XAxisValues) == 1:
            XValColumnString = 'Total Damages (\'000 US$)'
            NMarkers = [1,2,5,10,20,50,100,200,500,1000,2000,5000,10000,20000,50000,100000,200000,500000,1000000,2000000,5000000,10000000,20000000,50000000,10000000] 
        elif int(XAxisValues) == 2:
            XValColumnString = 'Total Affected'
            NMarkers = [1,2,5,10,20,50,100,200,500,1000,2000,5000,10000,20000,50000,100000,200000,500000,1000000,2000000,5000000,10000000,20000000,50000000,10000000]
        elif int(XAxisValues) == 3:
            XValColumnString = 'Total Deaths'
            NMarkers = [1,2,5,10,20,50,100,200,500,1000,2000,5000,10000,20000,50000,100000,200000,500000,1000000,2000000,5000000,10000000,20000000,50000000,10000000]
        elif int(XAxisValues) == 4:
            XValColumnString = 'Adjusted Damage'
            NMarkers = [1,2,5,10,20,50,100,200,500,1000,2000,5000,10000,20000,50000,100000,200000,500000,1000000,2000000,5000000,10000000,20000000,50000000,100000000]
        elif int(XAxisValues) == 5:
            XValColumnString = 'Mortality'
            NMarkers = [10e-6,20e-6,50e-6,10e-5,20e-5,50e-5,10e-4,20e-4,50e-4,10e-3,20e-3,50e-3,10e-2,20e-2,50e-2,10e-1]
         #YValues = list(pd.unique(EMData[YValColumnString]))
        
        # Sub-Category
        XAxisCategories = input('Define the X-axis variable to categorise distributions:\n\
        1. Disaster Type\n\
        2. World Region\n\
        3. Flood Category (data should be filtered for floods only)\n\
        4. Continent\n\
        Variable: ')
        if int(XAxisCategories) == 1:
            XCatColumnString = 'Disaster Category'   
        elif int(XAxisCategories) == 2:
            XCatColumnString = 'Region' 
        elif int(XAxisCategories) == 3:
            XCatColumnString = 'Flood Category' 
        elif int(XAxisCategories) == 4:
            XCatColumnString = 'Continent' 
        XCatValues = list(pd.unique(EMDAT[XCatColumnString]))      
        
        # Visualisation variables
        #DisasterTypesDisplayed = ['Flood','Storm','Earthquake','Drought','Wildfire']
        
        XCatFreq = pd.DataFrame(index=NMarkers, columns=pd.unique(EMDAT[XCatColumnString]).tolist())
        NrYears = float(max(EMDAT['Year']) - min(EMDAT['Year']))

        # Calculate disaster frequency (Disaster Type)
        #for dis in pd.unique(EMDataRec['Disaster Type']).tolist():
        for XCat in XCatValues:
            for NMark in NMarkers:
                XCatFreq.at[NMark,XCat] = sum(EMDAT[XValColumnString][EMDAT[XCatColumnString]==XCat].fillna(0) > NMark)/NrYears
               
        # Plot figure
        lines = ["-","--","-.",":","-","--","-.",":","-","--","-.",":","-","--","-.",":","-","--","-.",":","-","--","-.",":"]
        markers = ['.','o','v','^','<','>','4','s','p','*','h','H','+','x','X','D','.','o','v','^','<','>','4','s','p','*','h','H','+','x','X','D']
        fig, ax = plt.subplots()
      
        ax.set_xlabel(XValColumnString + ' (N)')
        ax.set_ylabel('Annual exceedence frequency of event with impact N')
        
        count = 0
        for XCat in XCatValues:
            plt.plot(NMarkers,XCatFreq[XCat].tolist(), linestyle=lines[count], marker=markers[count], label=XCat)
            count +=1        
        ax.set_xscale('log')
        ax.set_yscale('log')
        #lgd = ax.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
        lgd = ax.legend(loc='upper right')
        tit = fig.suptitle('Annual Frequency - '+ str(XCatColumnString), fontsize=16)  
        plt.show()

    # Variables for exceedence of impact per event graph
    if int(GraphNum) == 5:

        # X-axis
        print('Y-Axis will be event exceedence probabilty, based on provided data')
        print('N.B. If Nan values are not filtered out, they will be assumed = 0')

        # Y-axis
        XAxisValues = input('Define the X-axis variable:\n\
        1. Total Damage\n\
        2. Total Affected\n\
        3. Total Deaths\n\
        4. Adjusted Damage\n\
        5. Mortality\n\
        6. Relative Damage\n\
        Variable: ')
        if int(XAxisValues) == 1:
            XValColumnString = 'Total Damages (\'000 US$)'
        elif int(XAxisValues) == 2:
            XValColumnString = 'Total Affected'
        elif int(XAxisValues) == 3:
            XValColumnString = 'Total Deaths'
        elif int(XAxisValues) == 4:
            XValColumnString = 'Adjusted Damage'
        elif int(XAxisValues) == 5:
            XValColumnString = 'Mortality'
        elif int(XAxisValues) == 6:
            XValColumnString = 'Relative Damage %'
            
        # Sub-Category
        XAxisCategories = input('Define the X-axis variable to categorise distributions:\n\
        1. Disaster Type\n\
        2. World Region\n\
        3. Flood Category (data should be filtered for floods only)\n\
        4. Continent\n\
        Variable: ')
        if int(XAxisCategories) == 1:
            XCatColumnString = 'Disaster Category'   
        elif int(XAxisCategories) == 2:
            XCatColumnString = 'Region' 
        elif int(XAxisCategories) == 3:
            XCatColumnString = 'Flood Category' 
        elif int(XAxisCategories) == 4:
            XCatColumnString = 'Continent' 
        XCatValues = list(pd.unique(EMDAT[XCatColumnString]))  


        # Get XValue sorted data with associated exceedence probability
        XCatXValSorted = []
        XCatXValtProbs = []
        for XCat in XCatValues:
            XSort = sorted(EMDAT[XValColumnString][(EMDAT[XValColumnString].notna()) & (EMDAT[XCatColumnString]== XCat)].tolist())
            # Now assign probability
            NewProb = []
            for evn in XSort:
                NewProb.append(1-float(XSort.index(evn))/max(float((len(XSort)-1)),1))
            XCatXValSorted.append(XSort)
            XCatXValtProbs.append(NewProb)             
                    
        # Plot graph
        lines = ["-","--","-.",":","-","--","-.",":","-","--","-.",":","-","--","-.",":","-","--","-.",":","-","--","-.",":"]
        markers = ['.','o','v','^','<','>','4','s','p','*','h','H','+','x','X','D','.','o','v','^','<','>','4','s','p','*','h','H','+','x','X','D']
        fig, ax = plt.subplots()
        count = 0
        for XCat in XCatValues:
            plt.plot(XCatXValSorted[count],XCatXValtProbs[count], linestyle=lines[count], marker=markers[count], label=XCat)
            count +=1
        ax.set_xscale('log')
        lgd = ax.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) 
        ax.set_xlabel(XValColumnString + ' (N)')
        ax.set_ylabel('Exceedence probability of N for a given event')
        tit = fig.suptitle('Impact exceedance probabilities per event - ' + XCatColumnString, fontsize=16)
        plt.show()

    # Variables for scatter plot of all events
    if int(GraphNum) == 6:
        
        # X-axis
        XAxisValues = input('Define the X-axis variable:\n\
        1. Total Damage\n\
        2. Adjusted Damage\n\
        3. Total Affected\n\
        4. Total Deaths\n\
        5. Mortality\n\
        6. Average GDP of affected countries at time of event\n\
        7. Average Income of affected countries at time of event\n\
        8. Relative Damage (damage/GDP)\n\
        Variable: ')
        if int(XAxisValues) == 1:
            XValColumnString = 'Total Damages (\'000 US$)'
        elif int(XAxisValues) == 2:
            XValColumnString = 'Adjusted Damage'
        elif int(XAxisValues) == 3:
            XValColumnString = 'Total Affected'
        elif int(XAxisValues) == 4:
            XValColumnString = 'Total Deaths'
        elif int(XAxisValues) == 5:
            XValColumnString = 'Mortality'
        elif int(XAxisValues) == 6:
            XValColumnString = 'GDP ($Bln) During Event'
        elif int(XAxisValues) == 7:
            XValColumnString = 'Income'
        elif int(XAxisValues) == 8:
            XValColumnString = 'Relative Damage %' 
            
        # Y-axis
        YAxisValues = input('Define the Y-axis variable:\n\
        1. Total Damage\n\
        2. Adjusted Damage\n\
        3. Total Affected\n\
        4. Total Deaths\n\
        5. Mortality\n\
        6. Average GDP of affected countries at time of event\n\
        7. Average Income of affected countries at time of event\n\
        8. Relative Damage (damage/GDP)\n\
        Variable: ')
        if int(YAxisValues) == 1:
            YValColumnString = 'Total Damages (\'000 US$)'
        elif int(YAxisValues) == 2:
            YValColumnString = 'Adjusted Damage'
        elif int(YAxisValues) == 3:
            YValColumnString = 'Total Affected'
        elif int(YAxisValues) == 4:
            YValColumnString = 'Total Deaths'
        elif int(YAxisValues) == 5:
            YValColumnString = 'Mortality'
        elif int(YAxisValues) == 6:
            YValColumnString = 'GDP ($Bln) During Event'
        elif int(YAxisValues) == 7:
            YValColumnString = 'Income'
        elif int(YAxisValues) == 8:
            YValColumnString = 'Relative Damage %'            

        # Sub-Category
        XAxisCategories = input('Define the scatter plot catergorisation:\n\
        1. Disaster Type\n\
        2. World Region\n\
        3. Flood Category (data should be filtered for floods only)\n\
        4. Continent\n\
        Variable: ')
        if int(XAxisCategories) == 1:
            XCatColumnString = 'Disaster Category'   
        elif int(XAxisCategories) == 2:
            XCatColumnString = 'Region' 
        elif int(XAxisCategories) == 3:
            XCatColumnString = 'Flood Category' 
        elif int(XAxisCategories) == 4:
            XCatColumnString = 'Continent' 
        XCatValues = list(pd.unique(EMDAT[XCatColumnString]))      
                                    
        # Plot figure
        fig, ax = plt.subplots() 
        markers = ['.','o','v','^','<','>','4','s','p','*','h','H','+','x','X','D']           
        count = 0
        ScatX = []
        ScatY = []
        for XCat in XCatValues:
             ScatX.append(EMDAT[XValColumnString][EMDAT[XCatColumnString] == XCatValues[count]])
             ScatY.append(EMDAT[YValColumnString][EMDAT[XCatColumnString] == XCatValues[count]])
             ax.scatter(ScatX[count],ScatY[count],marker=markers[count], label=XCat)
             count +=1        
        ax.set_xscale('log')
        ax.set_yscale('log')
        ax.set_xlabel(XValColumnString)
        ax.set_ylabel(YValColumnString)
        lgd = ax.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
        tit = fig.suptitle(YValColumnString + ' vs ' + XValColumnString + ': All events in dataset', fontsize=16)
        plt.xlim(right = np.nanmax(EMDAT[XValColumnString])*1.1)
        plt.show()

    # Variables for box whisker plot
    if int(GraphNum) == 7:
        
        # X-axis
        XAxisValues = input('Define the X-axis variable for boxes:\n\
        1. Total Damage\n\
        2. Adjusted Damage\n\
        3. Total Affected\n\
        4. Total Deaths\n\
        5. Mortality\n\
        6. Average GDP of affected countries at time of event\n\
        7. Average Income of affected countries at time of event\n\
        8. Relative Damage (damage/GDP)\n\
        9. Disaster Type\n\
        10. World Region\n\
        11. Flood Category (data should be filtered for floods only)\n\
        12. Continent\n\
        13. Average adjusted Income of affected countries at time of event\n\
        14. YearGroup\n\
        Variable: ')
        if int(XAxisValues) == 1:
            XValColumnString = 'Total Damages (\'000 US$)'
        elif int(XAxisValues) == 2:
            XValColumnString = 'Adjusted Damage'
        elif int(XAxisValues) == 3:
            XValColumnString = 'Total Affected'
        elif int(XAxisValues) == 4:
            XValColumnString = 'Total Deaths'
        elif int(XAxisValues) == 5:
            XValColumnString = 'Mortality'
        elif int(XAxisValues) == 6:
            XValColumnString = 'GDP ($Bln) During Event'
        elif int(XAxisValues) == 7:
            XValColumnString = 'Income'
        elif int(XAxisValues) == 8:
            XValColumnString = 'Relative Damage %' 
        elif int(XAxisValues) == 9:
            XCatColumnString = 'Disaster Category'   
        elif int(XAxisValues) == 10:
            XCatColumnString = 'Region' 
        elif int(XAxisValues) == 11:
            XCatColumnString = 'Flood Category' 
        elif int(XAxisValues) == 12:
            XCatColumnString = 'Continent'             
        elif int(XAxisValues) == 13:
            XValColumnString = 'Adjusted Income'
        elif int(XAxisValues) == 14:
            XValColumnString = 'Year'
            
        # Discrete or continuous categories
        if int(XAxisValues) in [9,10,11,12]:
            XCatValues = list(pd.unique(EMDAT[XCatColumnString]))
            print('X-axis boxes will be' + str(XCatValues))
        else:
            print('Max ' + XValColumnString + ' value is: ' + str(max(EMDAT[XValColumnString])))
            print('Min ' + XValColumnString + ' is: ' + str(min(EMDAT[XValColumnString])))
            XValSplitPts = [min(EMDAT[XValColumnString])]
            XValSplitPts.extend([int(item) for item in input('Define splitting points between these max/min values as a list (e.g. 100,1000,10000 or ): ').split(',')])
            XValSplitPts.append(max(EMDAT[XValColumnString]))
            XCatValues = []
            for box in range(0,len(XValSplitPts)-1):
                XCatValues.append(str(round(XValSplitPts[box])) + ' - ' + str(round(XValSplitPts[box+1])))
            print('X-axis boxes will be ' + ','.join([str(elem) for elem in XCatValues]))

        # Y-axis
        YAxisValues = input('Define the Y-axis variable:\n\
        1. Total Damage\n\
        2. Adjusted Damage\n\
        3. Total Affected\n\
        4. Total Deaths\n\
        5. Mortality\n\
        6. Average GDP of affected countries at time of event\n\
        7. Average Income of affected countries at time of event\n\
        8. Relative Damage (damage/GDP)\n\
        Variable: ')
        if int(YAxisValues) == 1:
            YValColumnString = 'Total Damages (\'000 US$)'
        elif int(YAxisValues) == 2:
            YValColumnString = 'Adjusted Damage (US$)'
        elif int(YAxisValues) == 3:
            YValColumnString = 'Total Affected'
        elif int(YAxisValues) == 4:
            YValColumnString = 'Total Deaths'
        elif int(YAxisValues) == 5:
            YValColumnString = 'Mortality'
        elif int(YAxisValues) == 6:
            YValColumnString = 'GDP ($Bln) During Event'
        elif int(YAxisValues) == 7:
            YValColumnString = 'Income'
        elif int(YAxisValues) == 8:
            YValColumnString = 'Relative Damage %'  

        # Generate Data
        DataBox = []
        for box in range(0,len(XValSplitPts)-1):
            TempList = list(EMDAT[YValColumnString][(EMDAT[XValColumnString] >= XValSplitPts[box]) & (EMDAT[XValColumnString] <= XValSplitPts[box+1])])
            TempList = -np.log10(TempList)
            DataBox.append([x for x in TempList if math.isnan(x) == False])
        # Plot
        fig, ax = plt.subplots() 
        plt.boxplot(DataBox)
     
        #plt.xticks(list(range(0,len(XValSplitPts)-1)),XCatValues)
        if int(YAxisValues) in [1,2,3,4,6]: # 5 can be removed for mortality
            ax.set_yscale('log')
            LegendAddition = ' (log scale)'
        elif int(YAxisValues) == 5:
            LegendAddition = ' (log scale)'
            # To do
            plt.ylim(6, 0) # Need to fix this as log of something
            labels = [item.get_text() for item in ax.get_yticklabels()]
            labels = ['0','$10^{-1}$','$10^{-}2$','$10^{-3}$','$10^{-4}$','$10^{-5}$','$10^{-6}$']
            ax.set_yticklabels(labels)
            ax.set_xticklabels(XCatValues)
        else:
            LegendAddition = ''
        ax.set_ylabel(YValColumnString + LegendAddition)
        ax.set_xlabel(XValColumnString + ' groups')      
            
        #lgd = ax.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
        tit = fig.suptitle('Boxplots of ' + YValColumnString + ' per ' + XValColumnString, fontsize=16)
        fig.autofmt_xdate(rotation=45)
        plt.grid()
        plt.show()

    # Variables for box whisker plot comparison
    if int(GraphNum) == 8:
        
        # X-axis
        YearDivision = input('Comparison variable assumed to be time. Define division year: ')
        XAxisValues = input('Define the X-axis variable for boxes:\n\
        1. Total Damage\n\
        2. Adjusted Damage\n\
        3. Total Affected\n\
        4. Total Deaths\n\
        5. Mortality\n\
        6. Average GDP of affected countries at time of event\n\
        7. Average Income of affected countries at time of event\n\
        8. Relative Damage (damage/GDP)\n\
        9. Disaster Type\n\
        10. World Region\n\
        11. Flood Category (data should be filtered for floods only)\n\
        12. Continent\n\
        13. Average adjusted Income of affected countries at time of event\n\
        Variable: ')
        if int(XAxisValues) == 1:
            XValColumnString = 'Total Damages (\'000 US$)'
        elif int(XAxisValues) == 2:
            XValColumnString = 'Adjusted Damage'
        elif int(XAxisValues) == 3:
            XValColumnString = 'Total Affected'
        elif int(XAxisValues) == 4:
            XValColumnString = 'Total Deaths'
        elif int(XAxisValues) == 5:
            XValColumnString = 'Mortality'
        elif int(XAxisValues) == 6:
            XValColumnString = 'GDP ($Bln) During Event'
        elif int(XAxisValues) == 7:
            XValColumnString = 'Income'
        elif int(XAxisValues) == 8:
            XValColumnString = 'Relative Damage %' 
        elif int(XAxisValues) == 9:
            XCatColumnString = 'Disaster Category'   
        elif int(XAxisValues) == 10:
            XCatColumnString = 'Region' 
        elif int(XAxisValues) == 11:
            XCatColumnString = 'Flood Category' 
        elif int(XAxisValues) == 12:
            XCatColumnString = 'Continent'             
        elif int(XAxisValues) == 13:
            XValColumnString = 'Adjusted Income'         

        # Discrete or continuous categories
        if int(XAxisValues) in [9,10,11,12]:
            XCatValues = list(pd.unique(EMDAT[XCatColumnString]))
            print('X-axis boxes will be' + XCatValues)
        else:
            print('Max ' + XValColumnString + ' value is: ' + str(max(EMDAT[XValColumnString])))
            print('Min ' + XValColumnString + ' is: ' + str(min(EMDAT[XValColumnString])))
            XValSplitPts = [min(EMDAT[XValColumnString])]
            XValSplitPts.extend([int(item) for item in input('Define splitting points between these max/min values as a list (e.g. 100,1000,10000): ').split(',')])
            XValSplitPts.append(max(EMDAT[XValColumnString]))
            XCatValues = []
            for box in range(0,len(XValSplitPts)-1):
                XCatValues.append(str(round(XValSplitPts[box])) + ' - ' + str(round(XValSplitPts[box+1])))
            print('X-axis boxes will be ' + ','.join([str(elem) for elem in XCatValues]))

        # Y-axis
        YAxisValues = input('Define the Y-axis variable:\n\
        1. Total Damage\n\
        2. Adjusted Damage\n\
        3. Total Affected\n\
        4. Total Deaths\n\
        5. Mortality\n\
        6. Average GDP of affected countries at time of event\n\
        7. Average Income of affected countries at time of event\n\
        8. Relative Damage (damage/GDP)\n\
        Variable: ')
        if int(YAxisValues) == 1:
            YValColumnString = 'Total Damages (\'000 US$)'
        elif int(YAxisValues) == 2:
            YValColumnString = 'Adjusted Damage (US$)'
        elif int(YAxisValues) == 3:
            YValColumnString = 'Total Affected'
        elif int(YAxisValues) == 4:
            YValColumnString = 'Total Deaths'
        elif int(YAxisValues) == 5:
            YValColumnString = 'Mortality'
        elif int(YAxisValues) == 6:
            YValColumnString = 'GDP ($Bln) During Event'
        elif int(YAxisValues) == 7:
            YValColumnString = 'Income'
        elif int(YAxisValues) == 8:
            YValColumnString = 'Relative Damage %'  

        # Generate Data
        DataBox1 = []
        DataBox2 = []
        for box in range(0,len(XValSplitPts)-1):
            TempList1 = list(EMDAT[YValColumnString][(EMDAT[XValColumnString] >= XValSplitPts[box]) & (EMDAT[XValColumnString] <= XValSplitPts[box+1]) & (EMDAT['Year'] <= int(YearDivision))])
            TempList1 = -np.log10(TempList1) # log vals
            TempList2 = list(EMDAT[YValColumnString][(EMDAT[XValColumnString] >= XValSplitPts[box]) & (EMDAT[XValColumnString] <= XValSplitPts[box+1]) & (EMDAT['Year'] > int(YearDivision))])
            TempList2 = -np.log10(TempList2) # log vals
            DataBox1.append([x for x in TempList1 if math.isnan(x) == False])
            DataBox2.append([x for x in TempList2 if math.isnan(x) == False])
            
        # ANOVA Tests
        #result = sm.stats.anova_lm(model, type=2)
        res = f_oneway(DataBox1[0], DataBox1[1], DataBox1[2], DataBox1[3])
        print('ANOVA test: Period1, all groups: ' + str(res))
        res = f_oneway(DataBox2[0], DataBox2[1], DataBox2[2], DataBox2[3])
        print('ANOVA test: Period2, all groups: '+ str(res))
        res = f_oneway(sum(DataBox1, []), sum(DataBox2, []))
        print('ANOVA test: Both periods: ' + str(res))
        res = f_oneway(DataBox1[0], DataBox2[0])
        print('ANOVA test: Both periods, group1: ' + str(res))
        res = f_oneway(DataBox1[1], DataBox2[1])
        print('ANOVA test: Both periods, group2: ' + str(res))
        res = f_oneway(DataBox1[2], DataBox2[2])
        print('ANOVA test: Both periods, group3: ' + str(res))
        res = f_oneway(DataBox1[3], DataBox2[3])
        print('ANOVA test: Both periods, group4: ' + str(res))
        
        # Plot
        fig, ax = plt.subplots()
        pos1 = [x-0.2 for x in list(range(0,len(DataBox1)))]
        pos2 = [x+0.2 for x in list(range(0,len(DataBox1)))]
        bp1 = plt.boxplot(DataBox1, positions= pos1,widths=0.3, patch_artist=True, manage_ticks=False)
        bp2 = plt.boxplot(DataBox2, positions= pos2,widths=0.3, patch_artist=True, manage_ticks=False)       

        for element in ['boxes', 'whiskers', 'fliers', 'medians', 'caps']:
            plt.setp(bp1[element], color = "tomato")
        for element in ['boxes', 'whiskers', 'fliers', 'medians', 'caps']:
            plt.setp(bp2[element], color = "skyblue")
        for patch in bp1['boxes']:
            patch.set(facecolor='white')
        for patch in bp2['boxes']:
            patch.set(facecolor='white')      
        
        plt.xticks(list(range(0,len(XValSplitPts)-1)),XCatValues)
        if int(YAxisValues) in [1,2,3,4,6]: # 5 can be removed for mortality
            ax.set_yscale('log')
            LegendAddition = ' (log scale)'
        elif int(YAxisValues) == 5:
            LegendAddition = ' (log scale)'
            # To do
            plt.ylim(6, 0) # Need to fix this as log of something
            labels = [item.get_text() for item in ax.get_yticklabels()]
            labels = ['0','$10^{-1}$','$10^{-}2$','$10^{-3}$','$10^{-4}$','$10^{-5}$','$10^{-6}$']
            ax.set_yticklabels(labels)
        else:
            LegendAddition = ''
        ax.set_ylabel(YValColumnString + LegendAddition)
        ax.set_xlabel(XValColumnString + ' groups')
        
        lgd = ax.legend([bp1['boxes'][0], bp2['boxes'][0], bp1['fliers'][0]],['1975-' + str(YearDivision),str(int(YearDivision)+1) + '-2022', 'Outliers'], bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
        #lgd.legendHandles[1].set_color("skyblue")
        tit = fig.suptitle('Boxplots of ' + YValColumnString + ' per ' + XValColumnString, fontsize=16)
        fig.autofmt_xdate(rotation=45)
        plt.grid(linestyle='dashed')
        plt.show()

    # Save figure (or not)
    SaveFig_YN = input('Save this figure? (y/n): ')
    if SaveFig_YN in ['y','Y']:
        FigName = input('Input figure name: ')
        fig.savefig(EMDataPath + '\\Pictures\\' + FigName + '.svg', bbox_extra_artists=(tit,lgd), bbox_inches='tight')
        fig.savefig(EMDataPath + '\\Pictures\\' + FigName + '.png', bbox_extra_artists=(tit,lgd), bbox_inches='tight')            
     
    return(fig,ax)

