From 8d0e71839bee50c6ffc43c5ed201183f4bfb87a2 Mon Sep 17 00:00:00 2001
From: tomrink <rink@ssec.wisc.edu>
Date: Mon, 4 Sep 2023 14:12:34 -0500
Subject: [PATCH] snapshot...

---
 .../deeplearning/cloud_fraction_fcn_abi.py    | 67 +++++++++++--------
 1 file changed, 39 insertions(+), 28 deletions(-)

diff --git a/modules/deeplearning/cloud_fraction_fcn_abi.py b/modules/deeplearning/cloud_fraction_fcn_abi.py
index 3f86ac83..c4e619ab 100644
--- a/modules/deeplearning/cloud_fraction_fcn_abi.py
+++ b/modules/deeplearning/cloud_fraction_fcn_abi.py
@@ -1,8 +1,9 @@
 import tensorflow as tf
 
 from util.plot_cm import confusion_matrix_values
+from util.augment import augment_image
 from util.setup_cloud_fraction import logdir, modeldir, now, ancillary_path
-from util.util import EarlyStop, normalize, denormalize, get_grid_values_all
+from util.util import EarlyStop, normalize, denormalize, get_grid_values_all, make_tf_callable_generator
 import glob
 import os, datetime
 import numpy as np
@@ -39,6 +40,9 @@ DO_SMOOTH = False
 SIGMA = 1.0
 DO_ZERO_OUT = False
 
+# CACHE_FILE = '/scratch/long/rink/cld_opd_abi_128x128_cache'
+CACHE_FILE = ''
+
 # setup scaling parameters dictionary
 mean_std_dct = {}
 mean_std_file = ancillary_path+'mean_std_lo_hi_l2.pkl'
@@ -164,11 +168,11 @@ def get_label_data_5cat(grd_k):
         grd_k[:, 0::4, 2::4] + grd_k[:, 1::4, 2::4] + grd_k[:, 2::4, 2::4] + grd_k[:, 3::4, 2::4] + \
         grd_k[:, 0::4, 3::4] + grd_k[:, 1::4, 3::4] + grd_k[:, 2::4, 3::4] + grd_k[:, 3::4, 3::4]
 
-    cat_0 = np.logical_and(s >= 0, s < 2)
-    cat_1 = np.logical_and(s >= 2, s < 6)
+    cat_0 = np.logical_and(s >= 0, s < 1)
+    cat_1 = np.logical_and(s >= 1, s < 6)
     cat_2 = np.logical_and(s >= 6, s < 11)
-    cat_3 = np.logical_and(s >= 11, s < 15)
-    cat_4 = np.logical_and(s >= 15, s <= 16)
+    cat_3 = np.logical_and(s >= 11, s <= 15)
+    cat_4 = np.logical_and(s > 15, s <= 16)
 
     s[cat_0] = 0
     s[cat_1] = 1
@@ -381,24 +385,37 @@ class SRCNN:
         out = tf.numpy_function(self.get_in_mem_data_batch_test, [indexes], [tf.float32, tf.float32])
         return out
 
-    def get_train_dataset(self, indexes):
-        indexes = list(indexes)
+    def get_train_dataset(self, num_files):
+        def integer_gen(limit):
+            n = 0
+            while n < limit:
+                yield n
+                n += 1
+        num_gen = integer_gen(num_files)
+        gen = make_tf_callable_generator(num_gen)
 
-        dataset = tf.data.Dataset.from_tensor_slices(indexes)
+        dataset = tf.data.Dataset.from_generator(gen, output_types=tf.int32)
         dataset = dataset.batch(PROC_BATCH_SIZE)
-        dataset = dataset.map(self.data_function, num_parallel_calls=AUTOTUNE)
-        dataset = dataset.cache()
+        dataset = dataset.map(self.data_function, num_parallel_calls=8)
+        dataset = dataset.cache(filename=CACHE_FILE)
+        dataset = dataset.shuffle(PROC_BATCH_BUFFER_SIZE, reshuffle_each_iteration=True)
         if DO_AUGMENT:
-            dataset = dataset.shuffle(PROC_BATCH_BUFFER_SIZE)
-        dataset = dataset.prefetch(buffer_size=AUTOTUNE)
+            dataset = dataset.map(augment_image(), num_parallel_calls=8)
+        dataset = dataset.prefetch(buffer_size=1)
         self.train_dataset = dataset
 
-    def get_test_dataset(self, indexes):
-        indexes = list(indexes)
+    def get_test_dataset(self, num_files):
+        def integer_gen(limit):
+            n = 0
+            while n < limit:
+                yield n
+                n += 1
+        num_gen = integer_gen(num_files)
+        gen = make_tf_callable_generator(num_gen)
 
-        dataset = tf.data.Dataset.from_tensor_slices(indexes)
+        dataset = tf.data.Dataset.from_generator(gen, output_types=tf.int32)
         dataset = dataset.batch(PROC_BATCH_SIZE)
-        dataset = dataset.map(self.data_function_test, num_parallel_calls=AUTOTUNE)
+        dataset = dataset.map(self.data_function_test, num_parallel_calls=8)
         dataset = dataset.cache()
         self.test_dataset = dataset
 
@@ -408,29 +425,23 @@ class SRCNN:
         self.test_data_files = test_data_files
         self.test_label_files = test_label_files
 
-        trn_idxs = np.arange(len(train_data_files))
-        np.random.shuffle(trn_idxs)
-
-        tst_idxs = np.arange(len(test_data_files))
-
-        self.get_train_dataset(trn_idxs)
-        self.get_test_dataset(tst_idxs)
+        self.get_train_dataset(len(train_data_files))
+        self.get_test_dataset(len(test_data_files))
 
         self.num_data_samples = num_train_samples  # approximately
 
         print('datetime: ', now)
         print('training and test data: ')
         print('---------------------------')
-        print('num train samples: ', self.num_data_samples)
+        print('num train files: ', len(train_data_files))
         print('BATCH SIZE: ', BATCH_SIZE)
-        print('num test samples: ', tst_idxs.shape[0])
+        print('num test files: ', len(test_data_files))
         print('setup_pipeline: Done')
 
     def setup_test_pipeline(self, test_data_files, test_label_files):
         self.test_data_files = test_data_files
         self.test_label_files = test_label_files
-        tst_idxs = np.arange(len(test_data_files))
-        self.get_test_dataset(tst_idxs)
+        self.get_test_dataset(len(test_data_files))
         print('setup_test_pipeline: Done')
 
     def build_srcnn(self, do_drop_out=False, do_batch_norm=False, drop_rate=0.5, factor=2):
@@ -742,7 +753,7 @@ class SRCNN:
         self.num_data_samples = 1000
 
         valid_data_files = glob.glob(directory + 'valid*mres*.npy')
-        valid_label_files = glob.glob(directory + 'valid*ires*.npy')
+        valid_label_files = [f.replace('mres', 'ires') for f in valid_data_files]
         self.setup_test_pipeline(valid_data_files, valid_label_files)
 
         self.build_model()
-- 
GitLab