package DistanceMeasure;

import java.util.Arrays;
import java.lang.Math;

public class EmpiricalCDF {
	/* The x values for ECDF. The x values are common to both
	 * ECDFs so that the two can be easily compared
	 * An x value is computed at any change in y in either ECDF
	 */
	protected double[][] xi;
	/* The y values of the first ECDF
	 * The first ECDF will be the experimental data ECDF
	 */
	protected double[][] y1;
	/* The y values for the second ECDF (the simulation output) */
	protected double[][] y2;
	/* The original experiment (ground truth) data points */
	protected double[][] expSamples;
	
	/**
	 * The number of rows in data1 and data2 must match, i.e. they must have the same timepoints
	 * The number of columns (second dimension) is the number of samples
	 * There is an xi point created (and a corresponding y) for each sample from both data sets
	 * @param data1 each row represents a time point and is the experimental data
	 * @param data2
	 * Creates the two ECDFs for comparison
	 * data 1 must be the ground truth (experimental values) else the least squares distance is incorrect
	 */
	public EmpiricalCDF(double[][] data1, double[][] data2) {
		
		this.xi = new double[data1.length][data1[0].length +data2[0].length];
		this.y1 = new double[data1.length][data1[0].length +data2[0].length];
		this.y2 = new double[data1.length][data1[0].length +data2[0].length];
		this.expSamples = data1;
		
		// Sort experimental array to make checking if it contains an element faster
		for (int i = 0; i < expSamples.length; i++) {
			Arrays.sort(expSamples[i]);
		}
		
		// The xi points are the complete set of all data points, so concatenate the arrays
		this.xi = HelperFunctions.arrayConcat(data1, data2);
		
		//System.out.println("Unsorted xi:");
		//System.out.println(Arrays.toString(xi[0]));
		
		// To calc the ECDF the xi need to be in ascending order
		// Sort each row of xi
		for (int i = 0; i < xi.length; i++) {
			Arrays.sort(xi[i]);
		}
		
		
		
		//System.out.println("Sorted xi:");
		//System.out.println(Arrays.toString(xi[0]));
		
		int divisor1 = data1[0].length;
		int divisor2 = data2[0].length;
		
		
		for (int i = 0; i < xi.length; i++) {
			for (int n = 0; n < xi[0].length; n++) {
				y1[i][n] = HelperFunctions.ecdfSum(data1[i], xi[i][n]) / divisor1;
				y2[i][n] = HelperFunctions.ecdfSum(data2[i], xi[i][n]) / divisor2;
			}
		}
		
		//System.out.println("y1:");
		//System.out.println(Arrays.deepToString(y1));
		//System.out.println("y2:");
		//System.out.println(Arrays.deepToString(y2));
		
	}
	
	public double[][] getXvals() {
		return xi;
	}
	
	public double[][] getFirstCDF() {
		return y1;
	}
	
	public double[][] getSecondCDF() {
		return y2;
	}
	
	/*
	 * Below methods are all related to getting a distance measure from the ECDFs
	 */
	
	/**
	 * 
	 * @return areas - returns an array of the area between the ECDF at each time point
	 */
	public double[] getAreas() {
		double width;
		double height;
		double[] areas = new double[xi.length];
		// loop through each ECDF 
		for (int i = 0; i < xi.length; i++) {
			// set the area initially to zero and take a cumulative sum
			areas[i] = 0;
			// loop through each data point in ecdf and take right Riemann sums
			for (int j = 0; j <xi[i].length - 1; j++) {
				width = xi[i][j+1] - xi[i][j];
				height = Math.abs(y1[i][j] - y2[i][j]);
				areas[i] += width*height;
			}
		}
		return areas;
	}
	
	/**
	 * Returns the mean area over all time points
	 * Does not have any weightings
	 * @return areaMean
	 */
	/*
	public double getAreaMean() {
		double[] areas = getAreas();
		// Mean area calculated by the sum of all areas divided by the number of areas (number of timepoints)
		double areaMean = HelperFunctions.arraySum(areas) / areas.length;
		return areaMean;
	}
	*/
	/**
	 * The Kolmogorov-Smirnov distance
	 * @return an array where each element is the KS distance for the corresponding ECDF at timepoint
	 */
	public double[] ksDistance() {
		double[] ks = new double[xi.length];
		double currDist;
		for (int i = 0; i < xi.length; i++) {
			ks[i] = 0;
			for (int j = 0; j < xi[i].length; j++) {
				currDist = Math.abs(y1[i][j] - y2[i][j]);
				if (currDist > ks[i]) {
					ks[i] = currDist;
				}
			}
		}
		return ks;
	}
	
	/**
	 * Takes the ks distance at each time point and returns the mean
	 * @return mean ks distance over time points
	 */
	/*
	public double ksDistanceMean() {
		double[] dists = ksDistance();
		double ksMean = HelperFunctions.arraySum(dists) / dists.length;
		return ksMean;
	}
	*/
	/**
	 * Calculates area between ECDF at each time point but using the square of the height
	 * for each Riemann sum
	 * Essentially a least squares equivalent of the area
	 * @return
	 */
	public double[] squareAreas() {
		double width;
		double height;
		double[] areas = new double[xi.length];
		// loop through each ECDF 
		for (int i = 0; i < xi.length; i++) {
			// set the area initially to zero and take a cumulative sum
			areas[i] = 0;
			// loop through each data point in ecdf and take right Riemann sums
			for (int j = 0; j <xi[i].length - 1; j++) {
				width = xi[i][j+1] - xi[i][j];
				height = Math.abs(y1[i][j] - y2[i][j]);
				height = height * height;
				areas[i] += width*height;
			}
		}
		return areas;
	}
	
	/**
	 * Calculates the Wasserstein2 (w2) distance and returns a vector with the distance at each timepoint
	 * @return
	 */
	public double[] wasserstein2() {
		double[] distance = new double[xi.length];
		distance = squareAreas();
		// Square root each element to have the Wasserstein distance at each timepoint
		for (int i = 0; i < distance.length; i++) {
			distance[i] = Math.sqrt(distance[i]);
		}
		return distance;
	}
	
	/**
	 * Calculates the mean of the square areas
	 * @return
	 */
	/*
	public double squareAreaMean() {
		double[] areas = squareAreas();
		// Mean area calculated by the sum of all areas divided by the number of areas (number of timepoints)
		double areaMean = HelperFunctions.arraySum(areas) / areas.length;
		return areaMean;
	}
	*/
	/**
	 * Calculates the signed area between ECDFs
	 * The signed area is positive if ECDF2 is above ECDF1 and negative otherwise
	 * @return
	 */
	public double[] signedArea() {
		double width;
		double height;
		double[] areas = new double[xi.length];
		// loop through each ECDF 
		for (int i = 0; i < xi.length; i++) {
			// set the area initially to zero and take a cumulative sum
			areas[i] = 0;
			// loop through each data point in ecdf and take right Riemann sums
			for (int j = 0; j <xi[i].length - 1; j++) {
				width = xi[i][j+1] - xi[i][j];
				height = y2[i][j] - y1[i][j];
				areas[i] += width*height;
			}
		}
		return areas;
	}
	
	/**
	 * A quadratic distance method based on the Cramer von Mises distance
	 * @return the cramer von Mises distance
	 * IMPORTANT - this distance may need work for if there a repeated samples.
	 * i.e. if the ordered samples have repeats. Nothing is wrong with the calculation but it
	 * may not much the definition unless this is taken into account
	 */
	
	public double[] quadraticDistance() {
		double[] distance = new double[xi.length];
		double diff;
		
		// loop through each ECDF
		for (int i = 0; i < xi.length; i++) {
			distance[i] = 0; // accumulate the sum
			// loop through each data point and add teh square to the distance
			for (int j = 0; j < xi[i].length; j++) {
				diff =  y2[i][j] - y1[i][j];
				distance[i] += diff*diff; // Add the square of the difference in y values at the point				
			}
			
			// Weight the distance based on number of points at each timepoint
			// This is the 1/(N+M) factor seen in Cramer von Mises
			// Will weight just on the average distance
			distance[i] = distance[i] / xi[i].length;
		}
		return distance;
	}
	
	/*
	 * Calculates teh kuiper distance
	 */
	
	public double[] kuiperDistance() {
		double[] kuiper = new double[xi.length];
		double currDist;
		double Dplus = Double.NEGATIVE_INFINITY;
		double Dminus = Double.POSITIVE_INFINITY;
		for (int i = 0; i < xi.length; i++) {
			kuiper[i] = 0;
			for (int j = 0; j < xi[i].length; j++) {
				currDist = y1[i][j] - y2[i][j];
				if (currDist > Dplus) {
					Dplus = currDist;
				}
				if (currDist < Dminus) {
					Dminus = currDist;
				}
			}
			kuiper[i] = Math.abs(Dplus) + Math.abs(Dminus);
		}
		return kuiper;
	}
	
	/**
	 * An extension of the Cramer von Mises, the Anderson Darling statistic weights the tails of the distributions
	 * @return
	 */
	
	public double[] AndersonDarling() {
		double[] distance = new double[xi.length];
		double diff;
		double Hx = 0; // The combined sample ECDF
		double numSamples;
		// loop through each ECDF
		for (int i = 0; i < xi.length; i++) {
			distance[i] = 0; // accumulate the sum
			numSamples = xi[i].length;
			// loop through each data point and add teh square to the distance
			Hx = 0;
			for (int j = 0; j < xi[i].length - 1; j++) {
				diff =  y2[i][j] - y1[i][j];
				// value of combined sample ecdf
				// H_N(x) = (N*F_N(x) + M*G_M(x)) / (N+M)
				// Hx = (y1[i][j] * y1[i].length + y2[i][j] * y2[i].length) / xi[i].length; 
				Hx += 1 / numSamples;
				distance[i] += diff*diff / (Hx * (1 - Hx)); // Add the square of the difference in y values at the point				
			}
			
			// Weight the distance based on number of points at each timepoint
			// This is the 1/(N+M) factor seen in Cramer von Mises
			// Will weight just on the average distance
			distance[i] = distance[i] / xi[i].length;
		}
		return distance;
	}
	
	/**
	 * The DTS statistic defined here - https://github.com/cdowd/twosamples
	 * It is the area between ECDFs but weighting tails in a similar manner to Anderson Darling
	 * @return
	 */
	
	public double[] DTS_twoSample() {
		double width;
		double height;
		double[] areas = new double[xi.length];
		double numSamples;
		double Hx = 0;
		// loop through each ECDF 
		for (int i = 0; i < xi.length; i++) {
			numSamples = xi[i].length;
			Hx = 0;
			// set the area initially to zero and take a cumulative sum
			areas[i] = 0;
			// loop through each data point in ecdf and take right Riemann sums
			for (int j = 0; j <xi[i].length - 1; j++) {
				width = xi[i][j+1] - xi[i][j];
				height = Math.abs(y1[i][j] - y2[i][j]);
				Hx += 1 / numSamples;
				areas[i] += width * height / (Hx * (1 - Hx));
			}
		}
		return areas;
	}
	
	// Below here change for loops to start from zero
	public double[] unsignedAreaWeighted() {
		double[] distance = new double[xi.length];
		double weight;
		distance = getAreas();
		
		// Weight by the range of the combined sample
		for (int i = 0; i < distance.length; i++) {
			weight = xi[i][xi[i].length - 1] - xi[i][0];
			distance[i] = distance[i] / weight;
		}
		return distance;
	}
	
	public double[] signedAreaWeighted() {
		double[] distance = new double[xi.length];
		double weight;
		distance = signedArea();
		
		// Weight by the range of the combined sample
		for (int i = 0; i < distance.length; i++) {
			weight = xi[i][xi[i].length - 1] - xi[i][0];
			distance[i] = distance[i] / weight;
		}
		return distance;
	}
	
	/**
	 * The L2 area but weighted by the range of the data at each time point. 
	 * This makes the distances comparable against different time points
	 * @return
	 */
	public double[] L2Weighted() {
		double[] distance = new double[xi.length];
		double weight;
		distance = squareAreas();
		
		// Weight by the range of the combined sample
		for (int i = 0; i < distance.length; i++) {
			weight = xi[i][xi[i].length - 1] - xi[i][0];
			distance[i] = distance[i] / weight;
		}
		return distance;
	}
	
	/*
	public double signedAreaMean() {
		double[] areas = signedArea();
		// Mean area calculated by the sum of all areas divided by the number of areas (number of timepoints)
		double areaMean = HelperFunctions.arraySum(areas) / areas.length;
		return areaMean;
	}
	
	public double areaSupremum() {
		double[] areas = getAreas();
		double areaSup = HelperFunctions.max(areas);
		return areaSup;
	}
	
	public double squareAreaSupremum() {
		double[] areas = squareAreas();
		double areaSup = HelperFunctions.max(areas);
		return areaSup;
	}
	
	public double signedAreaSupremum() {
		double[] areas = signedArea();
		double areaSup = HelperFunctions.max(areas);
		return areaSup;
	}
	
	public double ksSupremum() {
		double[] dists = ksDistance();
		double ksSup = HelperFunctions.max(dists);
		return ksSup;
	}
	*/
}