Skip to content
Snippets Groups Projects
Commit 7c404597 authored by tomrink's avatar tomrink
Browse files

snapshot...

parent cd722805
Branches
No related tags found
No related merge requests found
......@@ -11,6 +11,8 @@ import h5py
import xarray as xr
import gc
AUTOTUNE = tf.data.AUTOTUNE
LOG_DEVICE_PLACEMENT = False
PROC_BATCH_SIZE = 4
......@@ -30,7 +32,7 @@ EARLY_STOP = True
NOISE_TRAINING = False
NOISE_STDDEV = 0.01
DO_AUGMENT = True
DO_AUGMENT = False
DO_SMOOTH = False
SIGMA = 1.0
......@@ -319,9 +321,12 @@ class SRCNN:
data_norm = []
for param in data_params_half:
idx = params.index(param)
tmp = input_data[:, idx, :, :]
tmp = tmp.copy()
# If next 2 uncommented, take out get_grid_cell_mean
# idx = params.index(param)
# tmp = input_data[:, idx, :, :]
idx = params_i.index(param)
tmp = input_label[:, idx, :, :]
tmp = get_grid_cell_mean(tmp)
tmp = tmp[:, slc_y, slc_x]
tmp = normalize(tmp, param, mean_std_dct)
data_norm.append(tmp)
......@@ -329,7 +334,6 @@ class SRCNN:
for param in data_params_full:
idx = params_i.index(param)
tmp = input_label[:, idx, :, :]
tmp = tmp.copy()
lo, hi, std, avg = get_min_max_std(tmp)
lo = normalize(lo, param, mean_std_dct)
......@@ -340,8 +344,10 @@ class SRCNN:
data_norm.append(hi[:, slc_y, slc_x])
data_norm.append(avg[:, slc_y, slc_x])
# ---------------------------------------------------
tmp = input_data[:, label_idx, :, :]
tmp = tmp.copy()
# If next uncommented, take out get_grid_cell_mean
# tmp = input_data[:, label_idx, :, :]
tmp = input_label[:, label_idx_i, :, :]
tmp = get_grid_cell_mean(tmp)
tmp = tmp[:, slc_y, slc_x]
data_norm.append(tmp)
# ---------
......@@ -351,7 +357,6 @@ class SRCNN:
# -----------------------------------------------------
# -----------------------------------------------------
label = input_label[:, label_idx_i, :, :]
label = label.copy()
label = label[:, y_128, x_128]
if NumClasses == 5:
label = get_label_data_5cat(label)
......@@ -397,11 +402,11 @@ class SRCNN:
dataset = tf.data.Dataset.from_tensor_slices(indexes)
dataset = dataset.batch(PROC_BATCH_SIZE)
dataset = dataset.map(self.data_function, num_parallel_calls=8)
dataset = dataset.map(self.data_function, num_parallel_calls=AUTOTUNE)
dataset = dataset.cache()
if DO_AUGMENT:
dataset = dataset.shuffle(PROC_BATCH_BUFFER_SIZE)
dataset = dataset.prefetch(buffer_size=1)
dataset = dataset.prefetch(buffer_size=AUTOTUNE)
self.train_dataset = dataset
def get_test_dataset(self, indexes):
......@@ -409,7 +414,7 @@ class SRCNN:
dataset = tf.data.Dataset.from_tensor_slices(indexes)
dataset = dataset.batch(PROC_BATCH_SIZE)
dataset = dataset.map(self.data_function_test, num_parallel_calls=8)
dataset = dataset.map(self.data_function_test, num_parallel_calls=AUTOTUNE)
dataset = dataset.cache()
self.test_dataset = dataset
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment