# %%
import pandas as pd
import os
from scipy.optimize import curve_fit
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import r2_score
import logging

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')

# Option to disable logging for simplified output
# Uncomment the next line to disable all logging
# logging.disable(logging.CRITICAL)

# Specify the product to analyse
product = 'Acetate'  # Replace with the desired product column name

# Load the data
cwd = os.getcwd()
file_path = 'DoE_Product_Concentrations_HAc.xlsx'
df = pd.read_excel(file_path, sheet_name='Sheet1')

# Validate product column
if product not in df.columns:
    raise ValueError(f"Product '{product}' not found in the dataframe columns.")

# Define a simple linear function
def linear_func(x, a, b):
    # Linear function for curve fitting.
    return a * x + b

# Define permutation test function
def permutation_test(data, product, n_permutations=1000, random_seed=42):
    """
    Perform a permutation test to evaluate the significance of the slope 
    from a linear fit of `y_data ~ x_data`.
    
    Parameters
    data : pandas.DataFrame
        DataFrame containing columns ['Time'] and the given `product`.
    product : str
        Column name for the product (dependent variable).
    n_permutations : int, optional
        Number of permutations to perform (default=1000).
    random_seed : int, optional
        Seed for reproducibility (default=42).
    
    Returns
    observed_slope : float or None
        The slope extracted from the best-fit linear model. If initial
        fitting fails, returns None.
    permuted_slopes : np.ndarray
        Array of slope values from each permuted dataset.
    p_value : float or None
        The two-sided p-value. If initial fitting fails, returns None.
    """

    # Set the random seed for reproducible permutations
    np.random.seed(random_seed)

    try:
        # Extract the independent variable (Time) and dependent variable (product)
        x_data = data['Time']
        y_data = data[product]
        # Fit the linear model (linear_func) to the unpermuted data
        # popt will contain the optimized parameters: [slope, intercept]
        popt, _ = curve_fit(linear_func, x_data, y_data)
        observed_slope = popt[0]
    except RuntimeError as e:
        logging.error(f"Curve fitting failed for observed data: {e}")
        return None

    # Prepare a list to store slope estimates from each permutation
    permuted_slopes = []

    # Perform the permutation test
    for _ in range(n_permutations):
        # Shuffle the x_data (independent variable) to simulate the null hypothesis
        # Note: Some studies shuffle y_data instead, depending on the null hypothesis.
        shuffled_x = np.random.permutation(x_data)
        try:
            # Fit the linear model to the permuted data
            popt_perm, _ = curve_fit(linear_func, shuffled_x, y_data)
             # Extract the slope from the permuted fit
            permuted_slopes.append(popt_perm[0])
        except RuntimeError as e:
            logging.warning(f"Permutation {i} failed: {e}")
            continue

    # Convert the list of permuted slopes to a NumPy array
    permuted_slopes = np.array(permuted_slopes)

    # Calculate the p-value: proportion of permuted slopes whose absolute value is at least as large as the observed slope
    p_value = np.sum(np.abs(permuted_slopes) >= np.abs(observed_slope)) / len(permuted_slopes)

    # Return the observed slope, the distribution of permuted slopes, and the p-value
    return observed_slope, permuted_slopes, p_value

# Define experimental conditions and reactors
conditions = df[['pH', 'CO2_rate', 'H2_rate','Added_HAc']].drop_duplicates()
reactors = df['Reactor'].unique()

# Prepare for results
results = []
perm_test_results = {}

# Set up the plot grid for regression results
fig, axs = plt.subplots(len(conditions), len(reactors), figsize=(20, 20), sharex=True, sharey=True)
fig.tight_layout(pad=5.0)

# Iterate over each combination of conditions and reactors
# 'cond' is a Series with keys: 'pH', 'CO2_rate', 'H2_rate'
for i, (index, cond) in enumerate(conditions.iterrows()):
    for j, reactor in enumerate(reactors):
        # Filter the DataFrame 'df' for rows matching the current condition + reactor
        reactor_df = df[
            (df['pH'] == cond['pH']) &
            (df['CO2_rate'] == cond['CO2_rate']) &
            (df['H2_rate'] == cond['H2_rate']) &
            (df['Added_HAc'] == cond['Added_HAc']) &
            (df['Reactor'] == reactor)
        ]

        if reactor_df.empty or reactor_df['Time'].sum() == 0:
            print(f"Skipping reactor {reactor} for condition {cond.to_dict()} due to insufficient or invalid data.")
            continue

        # Attempt to fit the linear model and perform the permutation test
        try:
            # Extract x (Time) and y (the product we are measuring) for curve fitting
            x_data = reactor_df['Time']
            y_data = reactor_df[product]
            # Use curve_fit to find the best-fit parameters (slope 'a' and intercept 'b')
            popt, pcov = curve_fit(linear_func, x_data, y_data)
            a, b = popt
            # Compute the R² (coefficient of determination) to quantify goodness-of-fit
            r_squared = r2_score(y_data, linear_func(x_data, *popt))
            # Perform the permutation test using the defined function
            result = permutation_test(reactor_df, product)
            if result:
                # Unpack the observed slope, array of permuted slopes, and p-value
                observed_slope, permuted_slopes, p_value = result
                # Store a dictionary of results for future analysis or tabulation
                results.append({
                    'pH': cond['pH'],
                    'CO2_rate': cond['CO2_rate'],
                    'H2_rate': cond['H2_rate'],
                    'Added_HAc': cond['Added_HAc'],
                    'Reactor': reactor,
                    'R_squared': r_squared,
                    'Coefficient_Beta1': a,
                    'Intercept_Beta0': b,
                    'Observed_Slope': observed_slope,
                    'P_Value': p_value
                })
                # Store the permuted slopes for this condition-reactor combination
                perm_test_results[(cond['pH'], cond['CO2_rate'], cond['H2_rate'], cond['Added_HAc'], reactor)] = permuted_slopes

            # Plot data and regression line
            ax = axs[i, j]
            ax.scatter(x_data, y_data, color='blue', label='Data Points')
            ax.plot(x_data, linear_func(x_data, *popt), color='red', label=f'Fit Line (R²={r_squared:.2f})')
            ax.set_title(f'pH={cond["pH"]}, CO2_rate={cond["CO2_rate"]}, H2_rate={cond["H2_rate"]}, Added_HAc={cond["Added_HAc"]}, Reactor={reactor}, Coef_Beta1={a:.2f}, P={p_value:.3f}')
            ax.set_xlabel('Time')
            ax.set_ylabel(product)
            ax.legend()
        except RuntimeError as e:
            logging.error(f"Curve fitting failed for reactor {reactor} in condition {cond.to_dict()}: {e}")

# Show the regression plot
plt.show()

# Convert results to DataFrame and display
results_df = pd.DataFrame(results)
print(results_df)


# Option to visualize histograms of permuted slopes
for key, slopes in perm_test_results.items():
    pH, CO2_rate, H2_rate, Added_HAc, reactor = key
    plt.figure(figsize=(8, 6))
    plt.hist(slopes, bins=30, alpha=0.7, label='Permuted Slopes')
    plt.axvline(x=results_df[(results_df['pH'] == pH) & 
                             (results_df['CO2_rate'] == CO2_rate) & 
                             (results_df['H2_rate'] == H2_rate) & 
                             (results_df['Added_HAc'] == Added_HAc) & 
                             (results_df['Reactor'] == reactor)]['Observed_Slope'].values[0], 
                color='red', linestyle='--', label='Observed Slope')
    plt.title(f'Histogram of Permuted Slopes\n(pH={pH}, CO2_rate={CO2_rate}, H2_rate={H2_rate}, Reactor={reactor})')
    plt.xlabel('Slope')
    plt.ylabel('Frequency')
    plt.legend()
    plt.show()

# Calculate weighted averages for all conditions and reactors
weighted_averages = []
for _, cond in conditions.iterrows():
    for reactor in reactors:
        # Filter the DataFrame for the current condition and reactor
        reactor_df = df[
            (df['pH'] == cond['pH']) &
            (df['CO2_rate'] == cond['CO2_rate']) &
            (df['H2_rate'] == cond['H2_rate']) &
            (df['Added_HAc'] == cond['Added_HAc']) &
            (df['Reactor'] == reactor)
        ]

        if reactor_df.empty or reactor_df['Time'].sum() == 0:
            continue

        try:
            # Calculate weights as each row's Time divided by the sum of all Times in this subset
            weights = reactor_df['Time'] / reactor_df['Time'].sum()
            # Weighted average of the product column, using the computed weights
            weighted_avg = np.sum(weights * reactor_df[product])
            # Store results in a list of dictionaries
            weighted_averages.append({
                'pH': cond['pH'],
                'CO2_rate': cond['CO2_rate'],
                'H2_rate': cond['H2_rate'],
                'Added_HAc': cond['Added_HAc'],
                'Reactor': reactor,
                'Weighted_Average': weighted_avg
            })
        except KeyError as e:
            logging.error(f"Missing key data for condition {cond.to_dict()} and reactor {reactor}: {e}")
        except ZeroDivisionError as e:
            logging.error(f"ZeroDivisionError for condition {cond.to_dict()} and reactor {reactor}: {e}")

# Convert the weighted averages to a DataFrame and display
weighted_averages_df = pd.DataFrame(weighted_averages)
print(weighted_averages_df)
# %%