import os import numpy as np from absl import logging from functools import partial import tensorflow as tf # import tensorflow_datasets as tfds from util.util import scale, make_tf_callable_generator from util.setup import ancillary_path import pickle """ Dataset Handlers for ESRGAN """ def scale_down(method="bicubic", dimension=512, size=None, factor=4): """ Scales down function based on the parameters provided. Args: method (default: bicubic): Interpolation method to be used for Scaling down the image. dimension (default: 256): Dimension of the high resolution counterpart. size (default: None): [height, width] of the image. factor (default: 4): Factor by which the model enhances the low resolution image. Returns: mappable python function based on the configuration. """ if not size: size = (dimension, dimension) size_ = {"size": size} def scale_fn(image, *args, **kwargs): size = size_["size"] high_resolution = image # if not kwargs.get("no_random_crop", None): # high_resolution = tf.image.random_crop( # image, [size[0], size[1], image.shape[-1]]) low_resolution = tf.image.resize( high_resolution, [size[0] // factor, size[1] // factor], method=method) low_resolution = tf.clip_by_value(low_resolution, 0, 255) high_resolution = tf.clip_by_value(high_resolution, 0, 255) return low_resolution, high_resolution scale_fn.size = size_["size"] return scale_fn def augment_image( brightness_delta=0.05, contrast_factor=[0.7, 1.3], saturation=[0.6, 1.6]): """ Helper function used for augmentation of images in the dataset. Args: brightness_delta: maximum value for randomly assigning brightness of the image. contrast_factor: list / tuple of minimum and maximum value of factor to set random contrast. None, if not to be used. saturation: list / tuple of minimum and maximum value of factor to set random saturation. None, if not to be used. Returns: mappable function for image augmentation """ def augment_fn(data, label, *args, **kwargs): # Augmenting data (~ 80%) def augment_steps_fn(data, label): # Randomly rotating image (~50%) def rotate_fn(data, label): times = tf.random.uniform(minval=1, maxval=4, dtype=tf.int32, shape=[]) return (tf.image.rot90(data, times), tf.image.rot90(label, times)) data, label = tf.cond( tf.less_equal(tf.random.uniform([]), 0.5), lambda: rotate_fn(data, label), lambda: (data, label)) # Randomly flipping image (~50%) def flip_fn(data, label): return (tf.image.flip_left_right(data), tf.image.flip_left_right(label)) data, label = tf.cond( tf.less_equal(tf.random.uniform([]), 0.5), lambda: flip_fn(data, label), lambda: (data, label)) return data, label # Randomly returning unchanged data (~20%) return tf.cond( tf.less_equal(tf.random.uniform([]), 0.2), lambda: (data, label), partial(augment_steps_fn, data, label)) return augment_fn # def reform_dataset(dataset, types, size, num_elems=None): # """ Helper function to convert the output_dtype of the dataset # from (tf.float32, tf.uint8) to desired dtype # Args: # dataset: Source dataset(image-label dataset) to convert. # types: tuple / list of target datatype. # size: [height, width] threshold of the images. # num_elems: Number of Data points to store # Returns: # with the images of dimension >= Args.size and types = Args.types # """ # _carrier = {"num_elems": num_elems} # # def generator_fn(): # for idx, data in enumerate(dataset, 1): # if _carrier["num_elems"]: # if not idx % _carrier["num_elems"]: # raise StopIteration # if data[0].shape[0] >= size[0] and data[0].shape[1] >= size[1]: # yield data[0], data[1] # else: # continue # return # generator_fn, types, (tf.TensorShape([None, None, 3]), tf.TensorShape(None))) # setup scaling parameters dictionary mean_std_dct = {} mean_std_file = ancillary_path+'mean_std_lo_hi_l2.pkl' f = open(mean_std_file, 'rb') mean_std_dct_l2 = pickle.load(f) f.close() mean_std_file = ancillary_path+'mean_std_lo_hi_l1b.pkl' f = open(mean_std_file, 'rb') mean_std_dct_l1b = pickle.load(f) f.close() mean_std_dct.update(mean_std_dct_l1b) mean_std_dct.update(mean_std_dct_l2) class OpdNpyDataset: def __init__(self, filenames, hr_size, lr_size, batch_size=128): self.filenames = filenames self.num_files = len(filenames) self.hr_size = hr_size self.lr_size = lr_size # def integer_gen(limit): # n = 0 # while n < limit: # yield n # n += 1 # # num_gen = integer_gen(self.num_files) # gen = make_tf_callable_generator(num_gen) # dataset =, output_types=tf.int32) # dataset = dataset.batch(batch_size) # dataset =, num_parallel_calls=8) # # These execute w/o an iteration on dataset? # # dataset =, num_parallel_calls=1) # # dataset =, num_parallel_calls=1) # # dataset = dataset.cache() # dataset = dataset.prefetch(buffer_size=1) file_idxs = np.arange(self.num_files) dataset = dataset = dataset.shuffle(2000, reshuffle_each_iteration=True) dataset = dataset.batch(batch_size) dataset =, num_parallel_calls=8) dataset = dataset.cache(filename='/ships22/cloud/scratch/Satellite_Output/GOES-16/global/NREL_2023/2023_east_cf/cld_opd_train/cld_opd_train') dataset = dataset.prefetch(buffer_size=1) self.dataset = dataset def read_numpy_file_s(self, f_idxs): opd_s = [] refl_s = [] for fi in f_idxs: fname = self.filenames[fi] data = np.load(fname) refl = data[:, 0, :, :] refl = scale(refl, 'refl_0_65um_nom', mean_std_dct) refl = refl.astype(np.float32) refl_s.append(refl) opd = data[:, 1, :, :] opd = scale(opd, 'cld_opd_dcomp', mean_std_dct) opd = opd.astype(np.float32) opd_s.append(opd) opd = np.concatenate(opd_s) refl = np.concatenate(refl_s) hr_image = np.stack([refl, opd], axis=3) hr_image = tf.image.crop_to_bounding_box(hr_image, 0, 0, self.hr_size, self.hr_size) low_resolution = tf.image.resize(hr_image, [self.lr_size, self.lr_size], method='bicubic') # low_resolution = tf.image.resize(hr_image, [self.lr_size, self.lr_size], method='nearest') # low_resolution = hr_image[:, ::4, ::4, :] # low_resolution = low_resolution.copy() low_resolution = tf.math.multiply(low_resolution, 255.0) low_resolution = tf.clip_by_value(low_resolution, 0, 255.0) hr_image = tf.math.multiply(hr_image[:, :, :, 1], 255.0) high_resolution = tf.clip_by_value(hr_image, 0, 255.0) high_resolution = tf.expand_dims(high_resolution, axis=3) low_resolution, high_resolution = augment_image()(low_resolution, high_resolution) return low_resolution, high_resolution @tf.function(input_signature=[tf.TensorSpec(None, tf.int32)]) def data_function(self, indexes): out = tf.numpy_function(self.read_numpy_file_s, [indexes], [tf.float32, tf.float32]) return out