dataset.py 7.70 KiB
import os
import numpy as np
from absl import logging
from functools import partial
import tensorflow as tf
# import tensorflow_datasets as tfds
from util.util import scale, make_tf_callable_generator
from util.setup import ancillary_path
import pickle
""" Dataset Handlers for ESRGAN """
def scale_down(method="bicubic", dimension=512, size=None, factor=4):
""" Scales down function based on the parameters provided.
Args:
method (default: bicubic): Interpolation method to be used for Scaling down the image.
dimension (default: 256): Dimension of the high resolution counterpart.
size (default: None): [height, width] of the image.
factor (default: 4): Factor by which the model enhances the low resolution image.
Returns:
tf.data.Dataset mappable python function based on the configuration.
"""
if not size:
size = (dimension, dimension)
size_ = {"size": size}
def scale_fn(image, *args, **kwargs):
size = size_["size"]
high_resolution = image
# if not kwargs.get("no_random_crop", None):
# high_resolution = tf.image.random_crop(
# image, [size[0], size[1], image.shape[-1]])
low_resolution = tf.image.resize(
high_resolution,
[size[0] // factor, size[1] // factor],
method=method)
low_resolution = tf.clip_by_value(low_resolution, 0, 255)
high_resolution = tf.clip_by_value(high_resolution, 0, 255)
return low_resolution, high_resolution
scale_fn.size = size_["size"]
return scale_fn
def augment_image(
brightness_delta=0.05,
contrast_factor=[0.7, 1.3],
saturation=[0.6, 1.6]):
""" Helper function used for augmentation of images in the dataset.
Args:
brightness_delta: maximum value for randomly assigning brightness of the image.
contrast_factor: list / tuple of minimum and maximum value of factor to set random contrast.
None, if not to be used.
saturation: list / tuple of minimum and maximum value of factor to set random saturation.
None, if not to be used.
Returns:
tf.data.Dataset mappable function for image augmentation
"""
def augment_fn(data, label, *args, **kwargs):
# Augmenting data (~ 80%)
def augment_steps_fn(data, label):
# Randomly rotating image (~50%)
def rotate_fn(data, label):
times = tf.random.uniform(minval=1, maxval=4, dtype=tf.int32, shape=[])
return (tf.image.rot90(data, times),
tf.image.rot90(label, times))
data, label = tf.cond(
tf.less_equal(tf.random.uniform([]), 0.5),
lambda: rotate_fn(data, label),
lambda: (data, label))
# Randomly flipping image (~50%)
def flip_fn(data, label):
return (tf.image.flip_left_right(data),
tf.image.flip_left_right(label))
data, label = tf.cond(
tf.less_equal(tf.random.uniform([]), 0.5),
lambda: flip_fn(data, label),
lambda: (data, label))
return data, label
# Randomly returning unchanged data (~20%)
return tf.cond(
tf.less_equal(tf.random.uniform([]), 0.2),
lambda: (data, label),
partial(augment_steps_fn, data, label))
return augment_fn
# def reform_dataset(dataset, types, size, num_elems=None):
# """ Helper function to convert the output_dtype of the dataset
# from (tf.float32, tf.uint8) to desired dtype
# Args:
# dataset: Source dataset(image-label dataset) to convert.
# types: tuple / list of target datatype.
# size: [height, width] threshold of the images.
# num_elems: Number of Data points to store
# Returns:
# tf.data.Dataset with the images of dimension >= Args.size and types = Args.types
# """
# _carrier = {"num_elems": num_elems}
#
# def generator_fn():
# for idx, data in enumerate(dataset, 1):
# if _carrier["num_elems"]:
# if not idx % _carrier["num_elems"]:
# raise StopIteration
# if data[0].shape[0] >= size[0] and data[0].shape[1] >= size[1]:
# yield data[0], data[1]
# else:
# continue
# return tf.data.Dataset.from_generator(
# generator_fn, types, (tf.TensorShape([None, None, 3]), tf.TensorShape(None)))
# setup scaling parameters dictionary
mean_std_dct = {}
mean_std_file = ancillary_path+'mean_std_lo_hi_l2.pkl'
f = open(mean_std_file, 'rb')
mean_std_dct_l2 = pickle.load(f)
f.close()
mean_std_file = ancillary_path+'mean_std_lo_hi_l1b.pkl'
f = open(mean_std_file, 'rb')
mean_std_dct_l1b = pickle.load(f)
f.close()
mean_std_dct.update(mean_std_dct_l1b)
mean_std_dct.update(mean_std_dct_l2)
class OpdNpyDataset:
def __init__(self, filenames, hr_size, lr_size, batch_size=128):
self.filenames = filenames
self.num_files = len(filenames)
self.hr_size = hr_size
self.lr_size = lr_size
# def integer_gen(limit):
# n = 0
# while n < limit:
# yield n
# n += 1
#
# num_gen = integer_gen(self.num_files)
# gen = make_tf_callable_generator(num_gen)
# dataset = tf.data.Dataset.from_generator(gen, output_types=tf.int32)
# dataset = dataset.batch(batch_size)
# dataset = dataset.map(self.data_function, num_parallel_calls=8)
# # These execute w/o an iteration on dataset?
# # dataset = dataset.map(scale_down(), num_parallel_calls=1)
# # dataset = dataset.map(augment_image(), num_parallel_calls=1)
#
# dataset = dataset.cache()
# dataset = dataset.prefetch(buffer_size=1)
file_idxs = np.arange(self.num_files)
dataset = tf.data.Dataset.from_tensor_slices(list(file_idxs))
dataset = dataset.shuffle(2000, reshuffle_each_iteration=True)
dataset = dataset.batch(batch_size)
dataset = dataset.map(self.data_function, num_parallel_calls=8)
dataset = dataset.cache(filename='/ships22/cloud/scratch/Satellite_Output/GOES-16/global/NREL_2023/2023_east_cf/cld_opd_train/cld_opd_train')
dataset = dataset.prefetch(buffer_size=1)
self.dataset = dataset
def read_numpy_file_s(self, f_idxs):
opd_s = []
refl_s = []
for fi in f_idxs:
fname = self.filenames[fi]
data = np.load(fname)
refl = data[0, ]
refl = scale(refl, 'refl_0_65um_nom', mean_std_dct)
refl = refl.astype(np.float32)
refl_s.append(refl)
opd = data[1, ]
opd = scale(opd, 'cld_opd_dcomp', mean_std_dct)
opd = opd.astype(np.float32)
opd_s.append(opd)
opd = np.concatenate(opd_s)
refl = np.concatenate(refl_s)
hr_image = np.stack([refl, opd], axis=3)
hr_image = tf.image.crop_to_bounding_box(hr_image, 0, 0, self.hr_size, self.hr_size)
low_resolution = tf.image.resize(hr_image, [self.lr_size, self.lr_size], method='bicubic')
low_resolution = tf.math.multiply(low_resolution, 255.0)
low_resolution = tf.clip_by_value(low_resolution, 0, 255)
hr_image = tf.math.multiply(hr_image[:, :, :, 1], 255.0)
high_resolution = tf.clip_by_value(hr_image, 0, 255)
high_resolution = tf.expand_dims(high_resolution, axis=3)
low_resolution, high_resolution = augment_image()(low_resolution, high_resolution)
return low_resolution, high_resolution
@tf.function(input_signature=[tf.TensorSpec(None, tf.int32)])
def data_function(self, indexes):
out = tf.numpy_function(self.read_numpy_file_s, [indexes], [tf.float32, tf.float32])
return out