From 81c9a2483f1aee90c53b79693cd54976de64d289 Mon Sep 17 00:00:00 2001
From: tomrink <rink@ssec.wisc.edu>
Date: Thu, 6 May 2021 22:26:45 -0500
Subject: [PATCH] snapshot...

---
 modules/icing/pirep_goes.py | 195 +++++++++++++++++++++++-------------
 1 file changed, 127 insertions(+), 68 deletions(-)

diff --git a/modules/icing/pirep_goes.py b/modules/icing/pirep_goes.py
index 42f9f70a..84985078 100644
--- a/modules/icing/pirep_goes.py
+++ b/modules/icing/pirep_goes.py
@@ -1010,7 +1010,8 @@ def fov_extract(outfile='/home/rink/fovs_l1b_out.h5', train_params=l1b_ds_list,
     h5f_expl.close()
 
 
-def tile_extract(outfile='/home/rink/tiles_l1b_out.h5', train_params=l1b_ds_list, ds_types=l1b_ds_types, augment=False):
+def tile_extract(trnfile='/home/rink/tiles_l1b_train.h5', tstfile='/home/rink/tiles_l1b_test.h5',
+                 train_params=l1b_ds_list, ds_types=l1b_ds_types, augment=False, split=0.2):
     icing_int_s = []
     ice_time_s = []
     no_ice_time_s = []
@@ -1041,36 +1042,11 @@ def tile_extract(outfile='/home/rink/tiles_l1b_out.h5', train_params=l1b_ds_list
             for ds_name in train_params:
                 dat = f[ds_name][i, 12:28, 12:28]
                 icing_data_dct[ds_name].append(dat)
-                if augment:
-                    if icing_int[i] >= 3:
-                        icing_data_dct[ds_name].append(np.fliplr(dat))
-                        icing_data_dct[ds_name].append(np.flipud(dat))
-                        icing_data_dct[ds_name].append(np.rot90(dat))
 
             icing_int_s.append(icing_int[i])
-            if augment:
-                if icing_int[i] >= 3:
-                    icing_int_s.append(icing_int[i])
-                    icing_int_s.append(icing_int[i])
-                    icing_int_s.append(icing_int[i])
-
             ice_time_s.append(times[i])
-            if augment:
-                if icing_int[i] >= 3:
-                    ice_time_s.append(times[i])
-                    ice_time_s.append(times[i])
-                    ice_time_s.append(times[i])
-
             ice_lon_s.append(lons[i])
             ice_lat_s.append(lats[i])
-            if augment:
-                if icing_int[i] >= 3:
-                    ice_lon_s.append(lons[i])
-                    ice_lat_s.append(lats[i])
-                    ice_lon_s.append(lons[i])
-                    ice_lat_s.append(lats[i])
-                    ice_lon_s.append(lons[i])
-                    ice_lat_s.append(lats[i])
 
         print(fname)
 
@@ -1123,6 +1099,85 @@ def tile_extract(outfile='/home/rink/tiles_l1b_out.h5', train_params=l1b_ds_list
     for ds_name in train_params:
         data_dct[ds_name] = np.concatenate([icing_data_dct[ds_name], no_icing_data_dct[ds_name]])
 
+    trn_idxs, tst_idxs = split_data(icing_intensity.shape[0], shuffle=False, split=split)
+
+    trn_data_dct = {}
+    for ds_name in train_params:
+        trn_data_dct[ds_name] = data_dct[ds_name][trn_idxs,]
+    trn_icing_intesity = icing_intensity[trn_idxs,]
+    trn_icing_times = icing_times[trn_idxs,]
+    trn_icing_lons = icing_lons[trn_idxs,]
+    trn_icing_lats = icing_lats[trn_idxs,]
+
+    #  Data augmentation -------------------------------------------------------------
+    trn_data_dct_aug = {[] for ds_name in train_params}
+    trn_icing_intesity_aug = []
+    trn_icing_times_aug = []
+    trn_icing_lons_aug = []
+    trn_icing_lats_aug = []
+    if augment:
+        for k in range(trn_icing_intesity.shape[0]):
+            iceint = trn_icing_intesity[k]
+            icetime = trn_icing_times[k]
+            icelon = trn_icing_lons[k]
+            icelat = trn_icing_lats[k]
+            if iceint >= 3:
+                for ds_name in train_params:
+                    dat = trn_data_dct[ds_name]
+                    trn_data_dct_aug[ds_name].append(np.fliplr(dat[k,]))
+                    trn_data_dct_aug[ds_name].append(np.flipup(dat[k,]))
+                    trn_data_dct_aug[ds_name].append(np.rot90(dat[k,]))
+
+                trn_icing_intesity_aug.append(iceint)
+                trn_icing_intesity_aug.append(iceint)
+                trn_icing_intesity_aug.append(iceint)
+
+                trn_icing_times_aug.append(icetime)
+                trn_icing_times_aug.append(icetime)
+                trn_icing_times_aug.append(icetime)
+
+                trn_icing_lons_aug.append(icelon)
+                trn_icing_lons_aug.append(icelon)
+                trn_icing_lons_aug.append(icelon)
+
+                trn_icing_lats_aug.append(icelat)
+                trn_icing_lats_aug.append(icelat)
+                trn_icing_lats_aug.append(icelat)
+
+    for ds_name in train_params:
+        trn_data_dct_aug[ds_name] = np.stack(trn_data_dct_aug[ds_name])
+    trn_icing_intesity_aug = np.stack(trn_icing_intesity_aug)
+    trn_icing_times_aug = np.stack(trn_icing_intesity_aug)
+    trn_icing_lons_aug = np.stack(trn_icing_lons_aug)
+    trn_icing_lats_aug = np.stack(trn_icing_lats_aug)
+
+    for ds_name in train_params:
+        trn_data_dct[ds_name] = np.concatenate([trn_data_dct[ds_name], trn_data_dct_aug])
+    trn_icing_intensity = np.concatenate([trn_icing_intesity, trn_icing_intesity_aug])
+    trn_icing_times = np.concatenate([trn_icing_times, trn_icing_times_aug])
+    trn_icing_lons = np.concatenate([trn_icing_lons, trn_icing_lons_aug])
+    trn_icing_lats = np.concatenate([trn_icing_lats, trn_icing_lats_aug])
+
+    # do sort
+    ds_indexes = np.argsort(trn_icing_times)
+    for ds_name in train_params:
+        trn_data_dct[ds_name] = trn_data_dct[ds_name][ds_indexes]
+    trn_icing_intensity = trn_icing_intensity[ds_indexes]
+    trn_icing_times = trn_icing_times[ds_indexes]
+    trn_icing_lons = trn_icing_lons[ds_indexes]
+    trn_icing_lats = trn_icing_lats[ds_indexes]
+
+    write_file('/home/rink/tiles_l1b_train.h5', trn_data_dct, trn_icing_intensity, trn_icing_times, trn_icing_lons,
+               trn_icing_lats)
+
+    tst_data_dct = {}
+    for ds_name in train_params:
+        tst_data_dct[ds_name] = data_dct[ds_name][tst_idxs,]
+    tst_icing_intensity = icing_intensity[tst_idxs,]
+    tst_icing_times = icing_times[tst_idxs,]
+    tst_icing_lons = icing_lons[tst_idxs,]
+    tst_icing_lats = icing_lats[tst_idxs,]
+
     # Do shuffle
     # ds_indexes = np.arange(num_ice + num_no_ice)
     # np.random.shuffle(ds_indexes)
@@ -1135,49 +1190,55 @@ def tile_extract(outfile='/home/rink/tiles_l1b_out.h5', train_params=l1b_ds_list
     # icing_lats = icing_lats[ds_indexes]
 
     # do sort
-    ds_indexes = np.argsort(icing_times)
+    ds_indexes = np.argsort(tst_icing_times)
     for ds_name in train_params:
-        data_dct[ds_name] = data_dct[ds_name][ds_indexes]
-    icing_intensity = icing_intensity[ds_indexes]
-    icing_times = icing_times[ds_indexes]
-    icing_lons = icing_lons[ds_indexes]
-    icing_lats = icing_lats[ds_indexes]
-
-    h5f_expl = h5py.File(a_clvr_file, 'r')
-    h5f_out = h5py.File(outfile, 'w')
-
-    for idx, ds_name in enumerate(train_params):
-        dt = ds_types[idx]
-        data = data_dct[ds_name]
-        h5f_out.create_dataset(ds_name, data=data, dtype=dt)
+        tst_data_dct[ds_name] = tst_data_dct[ds_name][ds_indexes]
+    tst_icing_intensity = tst_icing_intensity[ds_indexes]
+    tst_icing_times = tst_icing_times[ds_indexes]
+    tst_icing_lons = tst_icing_lons[ds_indexes]
+    tst_icing_lats = tst_icing_lats[ds_indexes]
 
-    icing_int_ds = h5f_out.create_dataset('icing_intensity', data=icing_intensity, dtype='i4')
-    icing_int_ds.attrs.create('long_name', data='From PIREP. -1:No Icing, 1:Trace, 2:Light, 3:Light Moderate, 4:Moderate, 5:Moderate Severe, 6:Severe')
+    write_file('/home/rink/tiles_l1b_test.h5', tst_data_dct, tst_icing_intensity, tst_icing_times, tst_icing_lons,
+               tst_icing_lats)
 
-    time_ds = h5f_out.create_dataset('time', data=icing_times, dtype='f4')
-    time_ds.attrs.create('units', data='seconds since 1970-1-1 00:00:00')
-    time_ds.attrs.create('long_name', data='PIREP time')
-
-    lon_ds = h5f_out.create_dataset('longitude', data=icing_lons, dtype='f4')
-    lon_ds.attrs.create('units', data='degrees_east')
-    lon_ds.attrs.create('long_name', data='PIREP longitude')
-
-    lat_ds = h5f_out.create_dataset('latitude', data=icing_lats, dtype='f4')
-    lat_ds.attrs.create('units', data='degrees_north')
-    lat_ds.attrs.create('long_name', data='PIREP latitude')
-
-    # copy relevant attributes
-    for ds_name in train_params:
-        h5f_ds = h5f_out[ds_name]
-        h5f_ds.attrs.create('standard_name', data=h5f_expl[ds_name].attrs.get('standard_name'))
-        h5f_ds.attrs.create('long_name', data=h5f_expl[ds_name].attrs.get('long_name'))
-        h5f_ds.attrs.create('units', data=h5f_expl[ds_name].attrs.get('units'))
-        attr = h5f_expl[ds_name].attrs.get('actual_range')
-        if attr is not None:
-            h5f_ds.attrs.create('actual_range', data=attr)
-        attr = h5f_expl[ds_name].attrs.get('flag_values')
-        if attr is not None:
-            h5f_ds.attrs.create('flag_values', data=attr)
+    # h5f_expl = h5py.File(a_clvr_file, 'r')
+    # h5f_out = h5py.File(outfile, 'w')
+    #
+    # for idx, ds_name in enumerate(train_params):
+    #     dt = ds_types[idx]
+    #     data = data_dct[ds_name]
+    #     h5f_out.create_dataset(ds_name, data=data, dtype=dt)
+    #
+    # icing_int_ds = h5f_out.create_dataset('icing_intensity', data=icing_intensity, dtype='i4')
+    # icing_int_ds.attrs.create('long_name', data='From PIREP. -1:No Icing, 1:Trace, 2:Light, 3:Light Moderate, 4:Moderate, 5:Moderate Severe, 6:Severe')
+    #
+    # time_ds = h5f_out.create_dataset('time', data=icing_times, dtype='f4')
+    # time_ds.attrs.create('units', data='seconds since 1970-1-1 00:00:00')
+    # time_ds.attrs.create('long_name', data='PIREP time')
+    #
+    # lon_ds = h5f_out.create_dataset('longitude', data=icing_lons, dtype='f4')
+    # lon_ds.attrs.create('units', data='degrees_east')
+    # lon_ds.attrs.create('long_name', data='PIREP longitude')
+    #
+    # lat_ds = h5f_out.create_dataset('latitude', data=icing_lats, dtype='f4')
+    # lat_ds.attrs.create('units', data='degrees_north')
+    # lat_ds.attrs.create('long_name', data='PIREP latitude')
+    #
+    # # copy relevant attributes
+    # for ds_name in train_params:
+    #     h5f_ds = h5f_out[ds_name]
+    #     h5f_ds.attrs.create('standard_name', data=h5f_expl[ds_name].attrs.get('standard_name'))
+    #     h5f_ds.attrs.create('long_name', data=h5f_expl[ds_name].attrs.get('long_name'))
+    #     h5f_ds.attrs.create('units', data=h5f_expl[ds_name].attrs.get('units'))
+    #     attr = h5f_expl[ds_name].attrs.get('actual_range')
+    #     if attr is not None:
+    #         h5f_ds.attrs.create('actual_range', data=attr)
+    #     attr = h5f_expl[ds_name].attrs.get('flag_values')
+    #     if attr is not None:
+    #         h5f_ds.attrs.create('flag_values', data=attr)
+    #
+    # h5f_out.close()
+    # h5f_expl.close()
 
     # --- close files
     for h5f in h5_s_icing:
@@ -1186,8 +1247,6 @@ def tile_extract(outfile='/home/rink/tiles_l1b_out.h5', train_params=l1b_ds_list
     for h5f in h5_s_no_icing:
         h5f.close()
 
-    h5f_out.close()
-    h5f_expl.close()
 
 
 def write_file(outfile, train_params, data_dct, icing_intensity, icing_times, icing_lons, icing_lats):
-- 
GitLab