From a22815c11fc2d3d63ce51085f9b70b187e98b7c4 Mon Sep 17 00:00:00 2001 From: tomrink <rink@ssec.wisc.edu> Date: Thu, 3 Aug 2023 09:41:46 -0500 Subject: [PATCH] snapshot... --- modules/deeplearning/cloud_opd_srcnn_abi.py | 65 +++++++++------------ 1 file changed, 28 insertions(+), 37 deletions(-) diff --git a/modules/deeplearning/cloud_opd_srcnn_abi.py b/modules/deeplearning/cloud_opd_srcnn_abi.py index 4815fced..58b91b00 100644 --- a/modules/deeplearning/cloud_opd_srcnn_abi.py +++ b/modules/deeplearning/cloud_opd_srcnn_abi.py @@ -2,8 +2,10 @@ import gc import glob import tensorflow as tf +from util.augment import augment_image from util.setup import logdir, modeldir, now, ancillary_path -from util.util import EarlyStop, normalize, denormalize, scale, descale, get_grid_values_all, resample_2d_linear, smooth_2d +from util.util import EarlyStop, normalize, denormalize, scale, descale, get_grid_values_all, resample_2d_linear,\ + smooth_2d, make_tf_callable_generator import os, datetime import numpy as np import pickle @@ -343,28 +345,9 @@ class SRCNN: label = scale(label, label_param, mean_std_dct) label = label[:, self.y_128, self.x_128] - label = np.where(np.isnan(label), 0.0, label) label = np.expand_dims(label, axis=3) - - data = data.astype(np.float32) label = label.astype(np.float32) - if is_training and DO_AUGMENT: - # data_ud = np.flip(data, axis=1) - # label_ud = np.flip(label, axis=1) - # - # data_lr = np.flip(data, axis=2) - # label_lr = np.flip(label, axis=2) - # - # data = np.concatenate([data, data_ud, data_lr]) - # label = np.concatenate([label, label_ud, label_lr]) - - data_rot = np.rot90(data, axes=(1, 2)) - label_rot = np.rot90(label, axes=(1, 2)) - - data = np.concatenate([data, data_rot]) - label = np.concatenate([label, label_rot]) - return data, label def get_in_mem_data_batch_train(self, idxs): @@ -383,22 +366,35 @@ class SRCNN: out = tf.numpy_function(self.get_in_mem_data_batch_test, [indexes], [tf.float32, tf.float32]) return out - def get_train_dataset(self, indexes): - indexes = list(indexes) + def get_train_dataset(self, num_files): + def integer_gen(limit): + n = 0 + while n < limit: + yield n + n += 1 + num_gen = integer_gen(num_files) + gen = make_tf_callable_generator(num_gen) - dataset = tf.data.Dataset.from_tensor_slices(indexes) + dataset = tf.data.Dataset.from_generator(gen, output_types=tf.int32) dataset = dataset.batch(PROC_BATCH_SIZE) dataset = dataset.map(self.data_function, num_parallel_calls=8) - dataset = dataset.cache() if DO_AUGMENT: - dataset = dataset.shuffle(PROC_BATCH_BUFFER_SIZE) + dataset = dataset.map(augment_image(), num_parallel_calls=8) + dataset = dataset.cache() + dataset = dataset.shuffle(PROC_BATCH_BUFFER_SIZE, reshuffle_each_iteration=False) dataset = dataset.prefetch(buffer_size=1) self.train_dataset = dataset - def get_test_dataset(self, indexes): - indexes = list(indexes) + def get_test_dataset(self, num_files): + def integer_gen(limit): + n = 0 + while n < limit: + yield n + n += 1 + num_gen = integer_gen(num_files) + gen = make_tf_callable_generator(num_gen) - dataset = tf.data.Dataset.from_tensor_slices(indexes) + dataset = tf.data.Dataset.from_generator(gen, output_types=tf.int32) dataset = dataset.batch(PROC_BATCH_SIZE) dataset = dataset.map(self.data_function_test, num_parallel_calls=8) dataset = dataset.cache() @@ -410,22 +406,17 @@ class SRCNN: self.test_data_files = test_data_files self.test_label_files = test_label_files - trn_idxs = np.arange(len(train_data_files)) - np.random.shuffle(trn_idxs) - - tst_idxs = np.arange(len(test_data_files)) - - self.get_train_dataset(trn_idxs) - self.get_test_dataset(tst_idxs) + self.get_train_dataset(len(train_data_files)) + self.get_test_dataset(len(test_data_files)) self.num_data_samples = num_train_samples # approximately print('datetime: ', now) print('training and test data: ') print('---------------------------') - print('num train samples: ', self.num_data_samples) + print('num train files: ', len(train_data_files)) print('BATCH SIZE: ', BATCH_SIZE) - print('num test samples: ', tst_idxs.shape[0]) + print('num test files: ', len(test_data_files)) print('setup_pipeline: Done') def setup_test_pipeline(self, test_data_files, test_label_files): -- GitLab