Skip to content
Snippets Groups Projects
Commit 7c404597 authored by tomrink's avatar tomrink
Browse files

snapshot...

parent cd722805
No related branches found
No related tags found
No related merge requests found
......@@ -11,6 +11,8 @@ import h5py
import xarray as xr
import gc
AUTOTUNE = tf.data.AUTOTUNE
LOG_DEVICE_PLACEMENT = False
PROC_BATCH_SIZE = 4
......@@ -30,7 +32,7 @@ EARLY_STOP = True
NOISE_TRAINING = False
NOISE_STDDEV = 0.01
DO_AUGMENT = True
DO_AUGMENT = False
DO_SMOOTH = False
SIGMA = 1.0
......@@ -319,9 +321,12 @@ class SRCNN:
data_norm = []
for param in data_params_half:
idx = params.index(param)
tmp = input_data[:, idx, :, :]
tmp = tmp.copy()
# If next 2 uncommented, take out get_grid_cell_mean
# idx = params.index(param)
# tmp = input_data[:, idx, :, :]
idx = params_i.index(param)
tmp = input_label[:, idx, :, :]
tmp = get_grid_cell_mean(tmp)
tmp = tmp[:, slc_y, slc_x]
tmp = normalize(tmp, param, mean_std_dct)
data_norm.append(tmp)
......@@ -329,7 +334,6 @@ class SRCNN:
for param in data_params_full:
idx = params_i.index(param)
tmp = input_label[:, idx, :, :]
tmp = tmp.copy()
lo, hi, std, avg = get_min_max_std(tmp)
lo = normalize(lo, param, mean_std_dct)
......@@ -340,8 +344,10 @@ class SRCNN:
data_norm.append(hi[:, slc_y, slc_x])
data_norm.append(avg[:, slc_y, slc_x])
# ---------------------------------------------------
tmp = input_data[:, label_idx, :, :]
tmp = tmp.copy()
# If next uncommented, take out get_grid_cell_mean
# tmp = input_data[:, label_idx, :, :]
tmp = input_label[:, label_idx_i, :, :]
tmp = get_grid_cell_mean(tmp)
tmp = tmp[:, slc_y, slc_x]
data_norm.append(tmp)
# ---------
......@@ -351,7 +357,6 @@ class SRCNN:
# -----------------------------------------------------
# -----------------------------------------------------
label = input_label[:, label_idx_i, :, :]
label = label.copy()
label = label[:, y_128, x_128]
if NumClasses == 5:
label = get_label_data_5cat(label)
......@@ -397,11 +402,11 @@ class SRCNN:
dataset = tf.data.Dataset.from_tensor_slices(indexes)
dataset = dataset.batch(PROC_BATCH_SIZE)
dataset = dataset.map(self.data_function, num_parallel_calls=8)
dataset = dataset.map(self.data_function, num_parallel_calls=AUTOTUNE)
dataset = dataset.cache()
if DO_AUGMENT:
dataset = dataset.shuffle(PROC_BATCH_BUFFER_SIZE)
dataset = dataset.prefetch(buffer_size=1)
dataset = dataset.prefetch(buffer_size=AUTOTUNE)
self.train_dataset = dataset
def get_test_dataset(self, indexes):
......@@ -409,7 +414,7 @@ class SRCNN:
dataset = tf.data.Dataset.from_tensor_slices(indexes)
dataset = dataset.batch(PROC_BATCH_SIZE)
dataset = dataset.map(self.data_function_test, num_parallel_calls=8)
dataset = dataset.map(self.data_function_test, num_parallel_calls=AUTOTUNE)
dataset = dataset.cache()
self.test_dataset = dataset
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment