import zarr
import cfgrib
import numpy as np
import xarray as xr
import gcsfs
import calendar
import metview as mv
import time
import os
import tempfile
import argparse


def __init__(self, stream):
   self.stream = stream
def write(self, data):
   self.stream.write(data)
   self.stream.flush()
def writelines(self, datas):
   self.stream.writelines(datas)
   self.stream.flush()
def __getattr__(self, attr):
   return getattr(self.stream, attr)



def generate_dates(year, month):
  # Gets the number of days in the month and the weekday of the first day of the month
  number_of_days = calendar.monthrange(year, month)[1]
  # Generates a list of dates in the format "year-month-day"
  dates = [f"{year}-{month:02d}-{day:02d}" for day in range(1, number_of_days + 1)]
  return dates



def attribute_fix(ds):
  """Fix attributes in an xarray Dataset for compatibility with ecCodes.

  In some cases, shortNames in the ecCodes table can cause ambiguity in string matching. 
  To eliminate this ambiguity, this function updates the Dataset attributes to make use
  of the paramId, which serves as a consistent source-of-truth.

  Args:
    ds (xarray.Dataset): The Dataset to fix attributes for.

  Returns:
    xarray.Dataset: The updated Dataset with fixed attributes.
  """
  for var in ds:
    attrs = ds[var].attrs
    result = attrs.pop('GRIB_cfName', None)
    result = attrs.pop('GRIB_cfVarName', None)
    result = attrs.pop('GRIB_shortName', None)
    ds[var].attrs.update(attrs)
  return ds

def fix_time_problem(surface_slice):
  surface_slice = surface_slice.assign_coords(
      time=(surface_slice["time"].astype("int64") // 1_000_000_000)
  )
  surface_slice["time"].attrs = {}

  surface_slice = surface_slice.assign_coords(
      valid_time=(surface_slice["valid_time"].astype("int64") // 1_000_000_000)
  )
  surface_slice["valid_time"].attrs = {}

  # Fix 'step' to be integer hours
  if np.issubdtype(surface_slice["step"].dtype, np.timedelta64):
      surface_slice = surface_slice.assign_coords(
          step=(surface_slice["step"] / np.timedelta64(1, "h")).astype("int32")
      )
  surface_slice["step"].attrs = {}
  return surface_slice

def fix_coords(q_ll):
  q_ll = q_ll.reset_coords("valid_time",drop=True)
  q_ll = q_ll.reset_coords("step",drop=True)
  q_ll = q_ll.rename({"hybrid": "model_level", "time" : "valid_time"})
  return q_ll

def compute_ll_dataset(basepath, ml_wind, ml_moisture, levels, datestring, dt=1):
  """Merges and processes AR model-level ERA5 datasets into CO ML Gaussian Grid and regular lat-lon grid datasets for a single hour.

  This function takes a datestamp and the three AR model-level ERA5 
  datasets to produce a unified model-level dataset on a Gaussian Grid
  and a regular lat-lon grid for a single hour.
   
  Args:
    ml_surface (xarray.Dataset): Model-level surface dataset.
    ml_wind (xarray.Dataset): Model-level wind dataset.
    ml_moisture (xarray.Dataset): Model-level moisture dataset.
    datestring (str): Datestamp for the dataset in ISO format (e.g., "2023-09-11").

  Returns:
    Tuple[xarray.Dataset, xarray.Dataset]: A tuple containing the Gaussian Grid dataset and the regular lat-lon grid dataset.

  This function takes three AR model-level ERA5 datasets and a datestamp to produce unified model-level datasets on Gaussian Grid and regular lat-lon grid for a single hour.

  Note:
    - This function relies on the Metview library for data processing.
    - Temporary files may be created during processing, and it's important to delete unnecessary variables to free up disk space.

  Example:
    >>> import xarray as xr
    >>> datestring = "2023-09-11"
    >>> ml_surface = xr.Dataset(...)  # Replace with actual data
    >>> ml_wind = xr.Dataset(...)     # Replace with actual data
    >>> ml_moisture = xr.Dataset(...) # Replace with actual data
    >>> gg_dataset, ll_dataset = compute_gg_ll_dataset(ml_surface, ml_wind, ml_moisture, datestring)
    >>> print(gg_dataset)
    >>> print(ll_dataset)
  """
  import metview as mv
  t0 = time.time()
  times = [datestring+"T%02i:00:00" %i for i in range(0,24,dt)]
  # surface_slice = ml_surface.sel(time=slice(datestring,datestring)).compute()
  wind_slice = ml_wind.sel(time=times,hybrid=levels).compute()
  moisture_slice = ml_moisture.sel(time=times,hybrid=levels).compute()
  t1 = time.time()
  print('Download: ', t1-t0)
  t0 = time.time()
  # surface_slice = fix_time_problem(surface_slice)
  # wind_slice = fix_time_problem(wind_slice)
  # moisture_slice = fix_time_problem(moisture_slice)

  # Fix the fieldsets
  # WARNING:  Metview creates temporary files for each fieldset variable.  This
  # can use a lot of local disk space.  It's important to delete fieldset
  # variables once they are no longer necessary to free up disk space.
  wind_fieldset = mv.dataset_to_fieldset(attribute_fix(wind_slice).squeeze())
  # surface_fieldset = mv.dataset_to_fieldset(attribute_fix(surface_slice).squeeze())
  if (os.path.isdir(basepath+'ERA5_'+datestring+'_ml_q.zarr')==False):
    moist_fieldset = mv.dataset_to_fieldset(attribute_fix(moisture_slice).squeeze())

    q_gg = moist_fieldset.select(shortName='q')
    q_fll = mv.read(data=q_gg, grid=[0.5,0.5])

    q_ll = q_fll.to_dataset().sel(latitude=slice(75,-75))
    q_ll = fix_coords(q_ll)
    q_ll.to_zarr(basepath+'ERA5_'+datestring+'_ml_q.zarr')
    del q_gg, q_fll, q_ll, moist_fieldset
  else:
    print(basepath+'ERA5_'+datestring+'_ml_q.zarr already exists. Skipping...')


  # Compute the u/v wind elements from d/vo and then translate to a Gaussian grid
  # d_gg = wind_fieldset.select(shortName='d')
  # vo_gg = wind_fieldset.select(shortName='vo')
  # uv_grid = mv.sh_to_grid(d_gg,vo_gg,grid=[0.5,0.5])
  # u_fll = uv_grid.select(shortName='u')
  # v_fll = uv_grid.select(shortName='v')

  uv_wind_spectral = mv.uvwind(data=wind_fieldset,truncation=639)
  #uv_wind_gg = mv.regrid(data=uv_wind_spectral,grid=[0.5,0.5],wind_processing="uv_to_uv")
  uv_wind_gg = mv.read(data=uv_wind_spectral,grid='N320')
  uv_wind_ll = mv.read(data=uv_wind_gg, grid=[0.5,0.5])
  del uv_wind_spectral
  del wind_fieldset

  u_fll = uv_wind_ll.select(shortName='u')
  v_fll = uv_wind_ll.select(shortName='v')
  u_ll = u_fll.to_dataset().sel(latitude=slice(75,-75))
  u_ll = fix_coords(u_ll)
  u_ll.to_zarr(basepath+'ERA5_'+datestring+'_ml_u.zarr')
  del u_ll,u_fll
  v_ll = v_fll.to_dataset().sel(latitude=slice(75,-75))
  v_ll = fix_coords(v_ll)
  v_ll.to_zarr(basepath+'ERA5_'+datestring+'_ml_v.zarr')
  del v_ll,v_fll

  del uv_wind_gg, uv_wind_ll


def make_field(basepath,year,month=None,dt=1):
  if month is None:
    months = range(1,13)
  elif month < 12:
    months = [[year,month],[year,month+1]]
  else:
    months = [[year,month],[year+1,1]]

  # surface_variables = {
  #   "tp": "total_precipitation",
  #   "e": "evaporation",
  #   "sp": "surface_pressure",
  #   "tcw": "total_column_water",
  # }

  levels = [20,40,60,80,90,95,100,105,110,115,120,123,125,128,130,131,132,133,134,135,136,137]

  for year, month in months:
    dates = generate_dates(year,month)
    
    #q
    print('Opening datasets')
    ml_moisture_all = xr.open_zarr("gs://gcp-public-data-arco-era5/co/model-level-moisture.zarr-v2",consolidated=True,)
    ml_moisture = xr.Dataset(data_vars=dict(q=ml_moisture_all.q),attrs=ml_moisture_all.attrs,coords=ml_moisture_all.coords)
    ml_wind_all = xr.open_zarr("gs://gcp-public-data-arco-era5/co/model-level-wind.zarr-v2",consolidated=True,)
    ml_wind = xr.Dataset(data_vars=dict(d=ml_wind_all.d, vo=ml_wind_all.vo),attrs=ml_wind_all.attrs,coords=ml_wind_all.coords)
    # ml_surface = xr.open_zarr("gs://gcp-public-data-arco-era5/co/single-level-reanalysis.zarr-v2",consolidated=True,)
    # ml_surface_2 = xr.open_zarr("gs://gcp-public-data-arco-era5/co/single-level-forecast.zarr-v2",consolidated=True,)

    for i in range(len(dates)):
      print(dates[i],basepath+'ERA5_'+dates[i]+'_ml_v.zarr')
      if (os.path.isdir(basepath+'ERA5_'+dates[i]+'_ml_v.zarr')):
        print(basepath+'ERA5_'+dates[i]+'_ml_v.zarr already exists. Skipping...')
        continue

      #for hour in range(24):
      t0=time.time()
      compute_ll_dataset(basepath, ml_wind, ml_moisture, levels, dates[i], dt)#+'T%02i' %hour)
      t1 = time.time()

      print(t1-t0)


parser = argparse.ArgumentParser()
parser.add_argument("-p",action="store",dest="Path",help="Base folder",required=True, type=str)
parser.add_argument("-m",action="store",dest="Month",help="Number of months",required=False, type=int)
parser.add_argument("-y",action="store",dest="Year",help="Year",required=True, type=int)
parser.add_argument("-t",action="store",dest="dt",help="Time interval",required=False,type=int)
opt = parser.parse_args()
if opt.dt is None:
    opt.dt=1
print('Make field...')
make_field(opt.Path, opt.Year, opt.Month, opt.dt)
#u

#v

#tp

#e

#sp

#tcw
