From 7786ba8a13f29327309d92153c26b35aecf80def Mon Sep 17 00:00:00 2001
From: tomrink <rink@ssec.wisc.edu>
Date: Sun, 3 Apr 2022 19:09:10 -0500
Subject: [PATCH] add train valid test output

---
 modules/icing/pirep_goes.py | 25 +++++++++++++++++++++++--
 1 file changed, 23 insertions(+), 2 deletions(-)

diff --git a/modules/icing/pirep_goes.py b/modules/icing/pirep_goes.py
index 0ece9432..16986170 100644
--- a/modules/icing/pirep_goes.py
+++ b/modules/icing/pirep_goes.py
@@ -1230,7 +1230,7 @@ def fov_extract(icing_files, no_icing_files, trnfile='/home/rink/fovs_l1b_train.
         h5f.close()
 
 
-def tile_extract(icing_files, no_icing_files, trnfile='/home/rink/tiles_train.h5', tstfile='/home/rink/tiles_test.h5', L1B_or_L2='L1B',
+def tile_extract(icing_files, no_icing_files, trnfile='/home/rink/tiles_train.h5', vldfile='/home/rink/tiles_valid.h5', tstfile='/home/rink/tiles_test.h5', L1B_or_L2='L1B',
                  cld_mask_name='cloud_mask', augment=False, do_split=True):
     # 16x16
     n_a, n_b = 12, 28
@@ -1361,7 +1361,7 @@ def tile_extract(icing_files, no_icing_files, trnfile='/home/rink/tiles_train.h5
     icing_alt = icing_alt[ds_indexes]
 
     if do_split:
-        trn_idxs, tst_idxs = split_data(icing_times)
+        trn_idxs, vld_idxs, tst_idxs = split_data(icing_times)
     else:
         trn_idxs = np.arange(icing_intensity.shape[0])
         tst_idxs = None
@@ -1468,6 +1468,27 @@ def tile_extract(icing_files, no_icing_files, trnfile='/home/rink/tiles_train.h5
 
         write_file(tstfile, params, param_types, tst_data_dct, tst_icing_intensity, tst_icing_times, tst_icing_lons, tst_icing_lats, tst_icing_alt)
 
+        vld_data_dct = {}
+        for ds_name in params:
+            vld_data_dct[ds_name] = data_dct[ds_name][vld_idxs,]
+        vld_icing_intensity = icing_intensity[vld_idxs,]
+        vld_icing_times = icing_times[vld_idxs,]
+        vld_icing_lons = icing_lons[vld_idxs,]
+        vld_icing_lats = icing_lats[vld_idxs,]
+        vld_icing_alt = icing_alt[vld_idxs,]
+
+        # do sort
+        ds_indexes = np.argsort(vld_icing_times)
+        for ds_name in params:
+            vld_data_dct[ds_name] = vld_data_dct[ds_name][ds_indexes]
+        vld_icing_intensity = vld_icing_intensity[ds_indexes]
+        vld_icing_times = vld_icing_times[ds_indexes]
+        vld_icing_lons = vld_icing_lons[ds_indexes]
+        vld_icing_lats = vld_icing_lats[ds_indexes]
+        vld_icing_alt = vld_icing_alt[ds_indexes]
+
+        write_file(vldfile, params, param_types, vld_data_dct, vld_icing_intensity, vld_icing_times, vld_icing_lons, vld_icing_lats, vld_icing_alt)
+
     # --- close files
     for h5f in h5_s_icing:
         h5f.close()
-- 
GitLab