import numpy as np
import logging
import datetime
from pathlib import Path
import os
from glob import glob
import xarray as xr

# set up logger
logging.basicConfig(level='DEBUG')

# MODIFY THIS FOR DIFFERENT DATA
data_files = glob('/data/daves/ERA5/grids/era5test.nc') # to grab more netcdf files, use *.nc
dir_out = f'/data/sreiner/AR_grids/binary_era5' # output directory
num_levels = 9 # 200 250 300 400 500 600 700 850 975

# DO NOT MODIFY BELOW HERE

logfile = f'{dir_out}/stats.log'
newrange = [] 
newmin = []
newmax = []

extend_range = True # flag to determine if we need to extend the range

if os.path.exists(logfile):
    with open(logfile, 'r') as f:
        try:
            for i in range(num_levels):
                x = f.readline().split('\t')
                newrange.append(float(x[3]))
                newmin.append(float(x[1]))
                newmax.append(float(x[2]))
            extend_range = False
        except ValueError as e:
            logging.error(e)

logging.debug(f'data files: {data_files}')
for data in data_files:
    logging.debug(f'reading netcdf data: {data}')
    then = datetime.datetime.now()
    ds = xr.open_dataset(data, chunks={'latitude': 'auto', 'longitude': 'auto'})
    levels = ds.level.values
    time = ds.time.values
    cc = ds.cc.values
    r = ds.r.values
    q = ds.q.values
    u = ds.u.values
    v = ds.v.values
    now = datetime.datetime.now()
    logging.debug(f'time to load data: {now - then}')
  
    date = str(time[0])[:10]  # just get date

    img_out = f'{dir_out}/{date}'
    logging.debug(f'output directory: {img_out}')
    Path(img_out).mkdir(parents=True, exist_ok=True)  # create nested directory for output files

    for i, t in enumerate(time):
        for j, plvl in enumerate(levels):
            #date_time = (date_since + datetime.timedelta(minutes=t*60.)).strftime('%Y-%m-%d_%H:%M:%S')
            date_time = str(t)[:19]
            with open(f'{img_out}/{date_time}_u_{plvl}.bin', 'wb') as wb:
                wb.write(u[i][j])

            with open(f'{img_out}/{date_time}_v_{plvl}.bin', 'wb') as wb:
                wb.write(v[i][j])

            with open(f'{img_out}/{date_time}_cc_{plvl}.bin', 'wb') as wb:
                int_arr = (cc[i][j] * 100).astype('uint8')
                wb.write(int_arr)

            with open(f'{img_out}/{date_time}_r_{plvl}.bin', 'wb') as wb:
                int_arr = (r[i][j]).astype('uint8')
                wb.write(int_arr)

            q_new = np.log10(q[i][j])
            if extend_range:
                min_prs = np.nanmin(q_new)
                max_prs = np.nanmax(q_new)
                nrange = max_prs - min_prs
                newrange.append(nrange * 1.2)
                newmin.append(min_prs - .5 * (nrange * 0.2))  # min for log10 extended 20%
                newmax.append(max_prs + .5 * (nrange * .2))  # recalculate max extended 20%
                text_out = f'pressure_lvl_{plvl}\t{newmin[j]:.6f}\t{newmax[j]:.6f}\t{newrange[j]:.6f}\n'
                with open(logfile, 'a') as f:
                    f.write(text_out)

            q_stretch = ((q_new - newmin[j]) / (newrange[j])) * 256
            #q_stretch_int = np.rint(q_stretch) # round to nearest integer
            q_stretch_int = q_stretch.astype('uint8')

            with open(f'{img_out}/{date_time}_q_{plvl}.bin', 'wb') as wb:
                wb.write(q_stretch_int)

        if extend_range:
            extend_range = False