"""Download ERA5.py

This script can be used to download ERA5 data on model levels using the CDS API.
Please see the installation instructions: https://cds.climate.copernicus.eu/api-how-to
You need to have a valid CDS API key and you need to pip install cdsapi.

For model levels the following variables are downloaded:
  - u, v, q on selected levels
  - tp, e, sp, and tcw at the surface

Modify the settings below, then run with:

python download_era5_ml.py
"""
from pathlib import Path
import calendar
import cdsapi
import pandas as pd
import os
import numpy as np
import xarray as xr
import argparse
import sys

class Unbuffered(object):
   def __init__(self, stream):
       self.stream = stream
   def write(self, data):
       self.stream.write(data)
       self.stream.flush()
   def writelines(self, datas):
       self.stream.writelines(datas)
       self.stream.flush()
   def __getattr__(self, attr):
       return getattr(self.stream, attr)

sys.stdout = Unbuffered(sys.stdout)

def generate_dates(year, month):
    # Gets the number of days in the month and the weekday of the first day of the month
    number_of_days = calendar.monthrange(year, month)[1]
    # Generates a list of dates in the format "year-month-day"
    dates = [f"{year}-{month:02d}-{day:02d}" for day in range(1, number_of_days + 1)]
    return dates


parser = argparse.ArgumentParser()
parser.add_argument("-p",action="store",dest="Path",help="Base folder",required=True, type=str)
parser.add_argument("-m",action="store",dest="Month",help="Number of months",required=False, type=int)
parser.add_argument("-y",action="store",dest="Year",help="Year",required=True, type=int)
parser.add_argument("-t",action="store",dest="Dt",help="Time interval",required=False, type=int)
parser.add_argument("-o",action="store",dest="Out",help="Output folder",required=False, type=str)
opt = parser.parse_args()

target_dir = opt.Path
skip_exist = True

#datelist = pd.date_range("19960101", "19970101")
start_year=opt.Year
end_year=opt.Year
year = opt.Year
if opt.Month is None:
    months = range(1,13)
else:
    months = [opt.Month]
#elif opt.Month < 12:
#    months = [[opt.Year,opt.Month],[opt.Year,opt.Month+1]]
#else:
#    months = [[opt.Year,opt.Month],[opt.Year+1,1]]

years = range(start_year, end_year + 1)

area = [75,-180,-75,180]#None  # None for global, or [N, W, S, E]
grid = [0.5, 0.5]

if opt.Dt is None:
    times = ["%02i:00" %i for i in range(0,24)]
else:
    times = ["%02i:00" %i for i in range(0,24,opt.Dt)]

levels = (
    "20/40/60/80/90/95/100/105/110/115/120/123/125/128/130/131/132/133/134/135/136/137"
)

surface_variables = {
    "tp": "total_precipitation",
    "e": "evaporation",
    "sp": "surface_pressure",
    "tcw": "total_column_water",
}

#c = cdsapi.Client()

for month in months:
    #for date in datelist:
    # Create data directory if it doesn't exist yet
    outfolder = Path(target_dir)
    outfolder.mkdir(exist_ok=True, parents=True)

    # Download surface variables
    #for variable, long_name in surface_variables.items():
    outfile = f"ERA5_single_levels_{year:04d}_{month:02d}_inst.nc"
    if (outfolder / outfile).exists() and skip_exist:
        print(
            f"{outfolder / outfile} already exists, skipping. Set skip_exist = False to force re-download"
        )
    else:
        
        try:
            date_list = generate_dates(year,month)
            print(f"Requesting download of the file for {month}/{year}")
            c = cdsapi.Client()

            c.retrieve(
                "reanalysis-era5-single-levels",
                {
                    "product_type": "reanalysis",
                    "variable": [
                        'surface_pressure', 'total_column_water',
                    ],
                    "date": date_list,
                    "time": times,
                    "area": area,
                    "grid": grid,
                    "data_format": "netcdf",
                    "download_format": "unarchived",
                },
                str(outfolder / outfile),
            )
        except Exception as e:
            print(e)
            print(
                ""
                f"Request failed for {month:02d}/{year:04d}. Proceeding. You can"
                " run the code again and set 'skip_exist' to true to avoid"
                " duplicate downloads."
                ""
            )

    input_file = str(outfolder / outfile)
    print(input_file)
    ds = xr.open_dataset(input_file,engine='netcdf4')

    # Assume time dimension is named 'time' and it's continuous for the whole month
    time_var = ds.valid_time
    dates = np.unique(time_var.dt.floor('D').values)
    date_str = np.datetime_as_string(dates[-1], unit='D')
    var = np.array(list(ds.data_vars))[-1]
    #if os.path.isfile(os.path.join(outfolder, f"ERA5_{date_str}_{var}.nc")):
    #  print(os.path.join(outfolder, f"ERA5_{date_str}_{var}.nc"), ' exists. Skipping...')
    #  ds.close()
    #  continue
    ds.load()

    if opt.Out is None:
        outfolder2=outfolder
    else:
        outfolder2 = Path(opt.Out)
        outfolder2.mkdir(exist_ok=True, parents=True)

    # Split by each day and each variable
    iday = 0
    for date in dates:
        iday = iday+1
        for var in ds.data_vars:  # Loop through each variable in the dataset
            # Select data for the specific day and variable
            date_str = np.datetime_as_string(date, unit='D')
            if os.path.isfile(os.path.join(outfolder2, f"ERA5_{date_str}_{var}.nc")):
                print(os.path.join(outfolder2, f"ERA5_{date_str}_{var}.nc"), ' exists. Skipping...')
                continue

            # Construct the output filename (e.g., 2023-01-01_temperature.nc)
            #date_str = np.datetime_as_string(date, unit='D')
            output_file = os.path.join(outfolder2, f"ERA5_{date_str}_{var}.nc")
            mask = time_var.dt.day == iday

            if opt.Dt is None:
                daily_data = ds[var].sel(valid_time=mask)
            else:
                timenew = np.array(['%i-%02i-%02iT%02i:00:00.000000000' %(year,month,iday,i) for i in range(0,24,opt.Dt)],
                                   dtype='datetime64[ns]')
                daily_data = ds[var].sel(valid_time=timenew)

            # Save the daily variable-specific data to a new NetCDF file
            daily_data.to_netcdf(output_file)
            print(f"Saved {output_file}")

    # Close the dataset
    ds.close()
