import datetime, os
from datetime import timezone
import glob
import numpy as np
import re
from netCDF4 import Dataset
from pathlib import Path

from util.util import GenericException
from util.lon_lat_grid import LonLatGrid
from util.geos_nav import GEOSNavigation
from util.util import homedir


def get_parameters_clavrx(filename=homedir+'data/clavrx/clavrx_OR_ABI-L1b-RadF-M6C01_G16_s20192930000343.level2.nc'):
    rg = Dataset(filename, 'r')
    var_s = rg.variables
    var_names = list(var_s.keys())
    var_names_2d = []

    for str in var_names:
        v = var_s[str]
        if len(v.shape) == 2:
            if not (str.find('latitude') != -1 or str.find('longitude') != -1):
                var_names_2d.append(str)

    rg.close()

    return var_names_2d


def get_parameters_fmwk_amvs(filename='/ships19/cloud/scratch/4TH_AMV_INTERCOMPARISON/FMWK2_AMV/GOES16_ABI_2KM_FD_2019293_0020_34_WINDS_AMV_EN-14CT.nc'):
    rg = Dataset(filename, 'r')
    var_s = rg.variables
    var_names = list(var_s.keys())
    var_names_to_keep = []

    lon = var_s['Longitude']
    num_amvs = lon.shape[0]

    for str in var_names:
        v = var_s[str]
        if len(v.shape) == 1 and v.shape[0] == num_amvs:
            var_names_to_keep.append(str)

    rg.close()

    return var_names_to_keep


class Files:
    def __init__(self, files_path, file_time_span, pattern):
        if os.path.isdir(files_path):
            self.flist = []
            for path in Path(files_path).rglob(pattern):
                self.flist.append(path)
        else:
            self.flist = glob.glob(files_path + pattern)
        if len(self.flist) == 0:
            raise GenericException('no matching files found in: ' + files_path + pattern)

        self.ftimes = []
        self.span_seconds = datetime.timedelta(minutes=file_time_span).seconds

        for pname in self.flist:
            dto = self.get_datetime(pname)
            dto_start = dto
            dto_end = dto + datetime.timedelta(minutes=file_time_span)
            self.ftimes.append((dto_start.timestamp(), dto_end.timestamp()))

        self.ftimes = np.array(self.ftimes)
        self.flist = np.array(self.flist)
        sidxs = np.argsort(self.ftimes[:, 0])
        self.ftimes = self.ftimes[sidxs, :]
        self.flist = self.flist[sidxs]

    def get_datetime(self, pathname):
        pass

    def get_file_containing_time(self, timestamp):
        k = -1
        for i in range(self.ftimes.shape[0]):
            if (timestamp >= self.ftimes[i, 0]) and (timestamp < self.ftimes[i, 1]):
                k = i
                break
        if k < 0:
            return None, None, None

        return self.flist[k], self.ftimes[k, 0], k

    def get_file(self, timestamp, window=None):
        if window is None:
            window = self.span_seconds
        else:
            window = datetime.timedelta(minutes=window).seconds
        diff = self.ftimes[:, 0] - timestamp
        midx = np.argmin(np.abs(diff))
        if np.abs(self.ftimes[midx, 0] - timestamp) < window:
            return self.flist[midx], self.ftimes[midx, 0], midx
        else:
            return None, None, None

    def get_parameters(self):
        pass


class GOESL1B(Files):
    def __init__(self, files_path, band='14'):
        super().__init__(files_path, 10, 'OR_ABI-L1b-RadC-M*C'+band+'*.nc')

    def get_datetime(self, pathname):
        filename = os.path.split(pathname)[1]
        so = re.search('_s\\d{11}', filename)
        dt_str = so.group()
        dto = datetime.datetime.strptime(dt_str, '_s%Y%j%H%M').replace(tzinfo=timezone.utc)
        return dto


# GOES-16, CONUS. TODO: Generalize to G-16 and FD, MESO
class CLAVRx(Files):
    def __init__(self, files_path):
        super().__init__(files_path, 10, 'clavrx_OR_ABI-L1b*.level2.nc')
        self.params = get_parameters_clavrx()

    def get_datetime(self, pathname):
        filename = os.path.split(pathname)[1]
        so = re.search('_s\\d{11}', filename)
        dt_str = so.group()
        dto = datetime.datetime.strptime(dt_str, '_s%Y%j%H%M').replace(tzinfo=timezone.utc)
        return dto

    def get_parameters(self):
        return self.params

    def get_navigation(self, h5f):
        return GEOSNavigation(sub_lon=-75.0, CFAC=5.6E-05, COFF=-0.101332, LFAC=-5.6E-05, LOFF=0.128212, num_elems=2500, num_lines=1500)


class CLAVRx_VIIRS(Files):
    def __init__(self, files_path):
        super().__init__(files_path, 5, 'clavrx*viirs*level2.h5')
        self.params = get_parameters_clavrx()

    def get_datetime(self, pathname):
        filename = os.path.split(pathname)[1]
        toks = filename.split('.')
        # so = re.search('\\d{11}', toks[4])
        # dt_str = so.group()
        dt_str = toks[1] + toks[2]
        dto = datetime.datetime.strptime(dt_str, 'A%Y%j%H%M').replace(tzinfo=timezone.utc)
        return dto

    def get_parameters(self):
        return self.params

    def get_navigation(self, h5f):
        lons = h5f['longitude'][:, :]
        lats = h5f['latitude'][:, :]
        lons = np.where(lons < 0, lons + 360, lons)
        ll_grid = LonLatGrid(lons, lats, reduce=2)
        return ll_grid


class RAOBfiles(Files):
    def __init__(self, files_path, file_time_span=10):
        super().__init__(files_path, file_time_span, 'raob_soundings*.cdf')

    def get_datetime(self, pathname):
        filename = os.path.split(pathname)[1]
        dt_str = (((filename.split('raob_soundings'))[1]).split('.'))[0]
        dto = datetime.datetime.strptime(dt_str, '%Y%m%d_%H%M').replace(tzinfo=timezone.utc)
        return dto


class GFSfiles(Files):
    def __init__(self, files_path):
        super().__init__(files_path, 10, 'gfs*.h5')

    def get_datetime(self, pathname):
        filename = os.path.split(pathname)[1]
        dto = datetime.datetime.strptime(filename, 'gfs.%y%m%d%H_F012.h5').replace(tzinfo=timezone.utc)
        dto += datetime.timedelta(hours=12)
        return dto


class AMVFiles:

    def __init__(self, files_path, file_time_span, pattern, band='14', elem_name=None, line_name=None, lat_name=None,
                 lon_name=None, params=None, out_params=None, meta_dict=None, qc_params=None):

        self.flist = glob.glob(files_path + pattern)
        if len(self.flist) == 0:
            raise GenericException('no matching files found in: ' + files_path)

        self.band = band
        self.ftimes = []
        self.span_seconds = datetime.timedelta(minutes=file_time_span).seconds

        for pname in self.flist:
            dto = self.get_datetime(pname)
            dto_start = dto
            dto_end = dto + datetime.timedelta(minutes=file_time_span)
            self.ftimes.append((dto_start.timestamp(), dto_end.timestamp()))

        self.ftimes = np.array(self.ftimes)
        self.flist = np.array(self.flist)
        sidxs = np.argsort(self.ftimes[:, 0])
        self.ftimes = self.ftimes[sidxs, :]
        self.flist = self.flist[sidxs]

        self.elem_name = elem_name
        self.line_name = line_name
        self.lon_name = lon_name
        self.lat_name = lat_name
        self.params = params
        self.out_params = params
        if out_params is not None:
            self.out_params = out_params
        self.meta_dict = meta_dict
        self.qc_params = qc_params

    def get_datetime(self, pathname):
        pass

    def get_navigation(self):
        pass

    def get_file_containing_time(self, timestamp):
        k = -1
        for i in range(self.ftimes.shape[0]):
            if (timestamp >= self.ftimes[i, 0]) and (timestamp < self.ftimes[i, 1]):
                k = i
                break
        if k < 0:
            return None, None, None

        return self.flist[k], self.ftimes[k, 0], k

    def get_file(self, timestamp, window=None):
        if window is None:
            window = self.span_seconds
        else:
            window = datetime.timedelta(minutes=window).seconds
        diff = self.ftimes[:, 0] - timestamp
        midx = np.argmin(np.abs(diff))
        if np.abs(self.ftimes[midx, 0] - timestamp) < window:
            return self.flist[midx], self.ftimes[midx, 0], midx
        else:
            return None, None, None

    def get_parameters(self):
        return self.params

    def get_out_parameters(self):
        return self.out_params

    def get_meta_dict(self):
        return self.meta_dict

    def get_qc_params(self):
        return self.qc_params

    def filter(self, qc_param):
        pass


class Framework(AMVFiles):
    def __init__(self, files_path, file_time_span, band='14'):
        elem_name = 'Element'
        line_name = 'Line'
        lon_name = 'Longitude'
        lat_name = 'Latitude'
        params = get_parameters_fmwk_amvs()
        params.remove(elem_name)
        params.remove(line_name)
        params.remove(lon_name)
        params.remove(lat_name)

        out_params = ['Lon', 'Lat', 'Element', 'Line', 'pressure', 'wind_speed', 'wind_direction']
        meta_dict = {'Lon': ('degrees east', 'f4'), 'Lat': ('degrees north', 'f4'), 'Element': (None, 'i4'),
                     'Line': (None, 'i4'), 'pressure': ('hPa', 'f4'), 'wind_speed': ('m s-1', 'f4'),
                     'wind_direction': ('degrees', 'f4')}
        qc_param = 'QI'

        super().__init__(files_path, file_time_span, '*WINDS_AMV_EN-'+band+'*.nc', band=band, elem_name=elem_name, params=params,
                         line_name=line_name, lat_name=lat_name, lon_name=lon_name, out_params=out_params, meta_dict=meta_dict, qc_params=qc_param)

    def get_navigation(self):
        return GEOSNavigation(sub_lon=-75.0)

    def get_datetime(self, pathname):
        fname = os.path.split(pathname)[1]
        toks = fname.split('_')
        dstr = toks[4]
        tstr = toks[5]
        dtstr = dstr + tstr
        dto = datetime.datetime.strptime(dtstr, '%Y%j%H%M').replace(tzinfo=timezone.utc)

        return dto

    def filter(self, qc_param):
        good = qc_param > 50
        return good


class FrameworkCloudHeight(AMVFiles):
    def __init__(self, files_path, file_time_span):
        elem_name = 'Element'
        line_name = 'Line'
        lon_name = 'Longitude'
        lat_name = 'Latitude'

        out_params = ['CldTopPres', 'CldTopHght', 'CldOptDpth']
        params = ['CldTopPres', 'CldTopHght', 'CldOptDpth']
        meta_dict = {'CldTopPres': ('hPa', 'f4'), 'CldTopHght': ('km', 'f4'), 'CldOptDpth': ('km', 'f4')}

        super().__init__(files_path, file_time_span, '*_CLOUD_HEIGHT_EN'+'*.nc', band=None, elem_name=elem_name, params=params,
                         line_name=line_name, lat_name=lat_name, lon_name=lon_name, out_params=out_params, meta_dict=meta_dict)

    def get_navigation(self):
        return GEOSNavigation(sub_lon=-75.0)

    def get_datetime(self, pathname):
        fname = os.path.split(pathname)[1]
        toks = fname.split('_')
        dstr = toks[4]
        tstr = toks[5]
        dtstr = dstr + tstr
        dto = datetime.datetime.strptime(dtstr, '%Y%j%H%M').replace(tzinfo=timezone.utc)

        return dto


class FrameworkCloudPhase(AMVFiles):
    def __init__(self, files_path, file_time_span):
        elem_name = 'Element'
        line_name = 'Line'
        lon_name = 'Longitude'
        lat_name = 'Latitude'

        out_params = ['CloudPhase', 'CloudType']
        params = ['CloudPhase', 'CloudType']
        meta_dict = {'CloudPhase': (None, 'i1'), 'CloudType': (None, 'i1')}

        super().__init__(files_path, file_time_span, '*_CLOUD_PHASE_EN'+'*.nc', band=None, elem_name=elem_name, params=params,
                         line_name=line_name, lat_name=lat_name, lon_name=lon_name, out_params=out_params, meta_dict=meta_dict)

    def get_navigation(self):
        return GEOSNavigation(sub_lon=-75.0)

    def get_datetime(self, pathname):
        fname = os.path.split(pathname)[1]
        toks = fname.split('_')
        dstr = toks[4]
        tstr = toks[5]
        dtstr = dstr + tstr
        dto = datetime.datetime.strptime(dtstr, '%Y%j%H%M').replace(tzinfo=timezone.utc)

        return dto


class OpsCloudPhase(AMVFiles):
    def __init__(self, files_path, file_time_span):
        elem_name = None
        line_name = None
        lon_name = None
        lat_name = None

        out_params = ['Phase']
        params = ['Phase']
        meta_dict = {'Phase': (None, 'i1')}

        super().__init__(files_path, file_time_span, 'OR_ABI-L2-ACTPF'+'*.nc', band=None, elem_name=elem_name, params=params,
                         line_name=line_name, lat_name=lat_name, lon_name=lon_name, out_params=out_params, meta_dict=meta_dict)

    def get_navigation(self):
        return GEOSNavigation(sub_lon=-75.0)

    def get_datetime(self, pathname):
        fname = os.path.split(pathname)[1]
        toks = fname.split('_')
        dtstr = toks[3]
        dtstr = dtstr[:-3]
        dto = datetime.datetime.strptime(dtstr, 's%Y%j%H%M').replace(tzinfo=timezone.utc)
        return dto


class OPS(AMVFiles):
    def __init__(self, files_path, file_time_span, band='14'):
        elem_name = None
        line_name = None
        lon_name = 'lon'
        lat_name = 'lat'

        out_params = ['Lon', 'Lat', 'Element', 'Line', 'pressure', 'wind_speed', 'wind_direction']
        params = ['pressure', 'wind_speed', 'wind_direction']
        meta_dict = {'Lon': ('degrees east', 'f4'), 'Lat': ('degrees north', 'f4'), 'Element': (None, 'i4'), 'Line': (None, 'i4'),
                          'pressure': ('hPa', 'f4'), 'wind_speed': ('m s-1', 'f4'), 'wind_direction': ('degrees', 'f4')}

        super().__init__(files_path, file_time_span, 'OR_ABI-L2-DMWF*'+'C'+band+'*.nc', band=band, elem_name=elem_name, params=params,
                         line_name=line_name, lat_name=lat_name, lon_name=lon_name, out_params=out_params, meta_dict=meta_dict)

    def get_navigation(self):
        # return GEOSNavigation(sub_lon=-75.2) ?
        return GEOSNavigation(sub_lon=-75.0)

    def get_datetime(self, pathname):
        fname = os.path.split(pathname)[1]
        toks = fname.split('_')
        dtstr = toks[3]
        dtstr = dtstr[:-3]
        dto = datetime.datetime.strptime(dtstr, 's%Y%j%H%M').replace(tzinfo=timezone.utc)

        return dto


class CarrStereo(AMVFiles):
    def __init__(self, files_path, file_time_span, band='14'):
        elem_name = 'Element'
        line_name = 'Line'
        lon_name = 'Lon'
        lat_name = 'Lat'

        out_params = ['Lon', 'Lat', 'Element', 'Line', 'V_3D_u', 'V_3D_v', 'H_3D', 'pres', 'Fcst_Spd', 'Fcst_Dir', 'SatZen',
                           'InversionFlag', 'CloudPhase', 'CloudType']

        params = ['V_3D', 'H_3D', 'pres', 'Fcst_Spd', 'Fcst_Dir', 'SatZen',
                       'InversionFlag', 'CloudPhase', 'CloudType']

        meta_dict = {'H_3D': ('m', 'f4'), 'pres': ('hPa', 'f4'), 'Fcst_Spd': ('m s-1', 'f4'),
                          'Fcst_Dir': ('degree', 'f4'),
                          'SatZen': ('degree', 'f4'), 'InversionFlag': (None, 'u1'),
                          'CloudPhase': (None, 'u1'), 'CloudType': (None, 'u1'),
                          'V_3D_u': ('m s-1', 'f4'), 'V_3D_v': ('m s-1', 'f4'), 'Lon': ('degrees east', 'f4'),
                          'Lat': ('degrees north', 'f4'), 'Element': (None, 'i4'), 'Line': (None, 'i4')}

        super().__init__(files_path, file_time_span, 'tdw_qc_GOES*'+'ch_'+band+'.nc', band=band, elem_name=elem_name, params=params,
                         line_name=line_name, lat_name=lat_name, lon_name=lon_name, out_params=out_params, meta_dict=meta_dict)

    def get_navigation(self):
        return GEOSNavigation(sub_lon=-137.0)

    def get_datetime(self, pathname):
        fname = os.path.split(pathname)[1]
        toks = fname.split('_')
        dtstr = toks[3]
        dto = datetime.datetime.strptime(dtstr, '%Y%j.%H%M.ch').replace(tzinfo=timezone.utc)

        return dto


def get_datasource(files_path, source, file_time_span=10, band='14'):
    if source == 'OPS':
        return OPS(files_path, file_time_span, band=band)
    elif source == 'FMWK':
        return Framework(files_path, file_time_span, band=band)
    elif source == 'CARR':
        return CarrStereo(files_path, file_time_span, band=band)
    elif source == 'FMWK_CLD_HGT':
        return FrameworkCloudHeight(files_path, file_time_span)
    elif source == 'FMWK_CLD_PHASE':
        return FrameworkCloudPhase(files_path, file_time_span)
    elif source == 'OPS_CLD_PHASE':
        return OpsCloudPhase(files_path, file_time_span)
    elif source == 'GFS':
        return GFSfiles(files_path)
    elif source == 'RAOB':
        return RAOBfiles(files_path, file_time_span)
    else:
        raise GenericException('Unknown data source type')