Skip to content
Snippets Groups Projects
Commit e9e080f3 authored by tomrink's avatar tomrink
Browse files

snapshot...

parent 7f14d0c8
Branches
No related tags found
No related merge requests found
...@@ -2,6 +2,7 @@ import tensorflow as tf ...@@ -2,6 +2,7 @@ import tensorflow as tf
from util.setup import logdir, modeldir, cachepath, now, ancillary_path, home_dir from util.setup import logdir, modeldir, cachepath, now, ancillary_path, home_dir
from util.util import EarlyStop, normalize, make_for_full_domain_predict, get_training_parameters from util.util import EarlyStop, normalize, make_for_full_domain_predict, get_training_parameters
from util.geos_nav import get_navigation from util.geos_nav import get_navigation
from util.augment import augment_image
import os, datetime import os, datetime
import numpy as np import numpy as np
...@@ -315,26 +316,26 @@ class IcingIntensityFCN: ...@@ -315,26 +316,26 @@ class IcingIntensityFCN:
label = np.where(np.invert(np.logical_or(label == 0, label == 1)), 2, label) label = np.where(np.invert(np.logical_or(label == 0, label == 1)), 2, label)
label = label.reshape((label.shape[0], 1)) label = label.reshape((label.shape[0], 1))
if is_training and DO_AUGMENT: # if is_training and DO_AUGMENT:
data_ud = np.flip(data, axis=1) # data_ud = np.flip(data, axis=1)
data_alt_ud = np.copy(data_alt) # data_alt_ud = np.copy(data_alt)
label_ud = np.copy(label) # label_ud = np.copy(label)
#
data_lr = np.flip(data, axis=2) # data_lr = np.flip(data, axis=2)
data_alt_lr = np.copy(data_alt) # data_alt_lr = np.copy(data_alt)
label_lr = np.copy(label) # label_lr = np.copy(label)
#
data_r1 = np.rot90(data, k=1, axes=(1, 2)) # data_r1 = np.rot90(data, k=1, axes=(1, 2))
data_alt_r1 = np.copy(data_alt) # data_alt_r1 = np.copy(data_alt)
label_r1 = np.copy(label) # label_r1 = np.copy(label)
#
data_r2 = np.rot90(data, k=1, axes=(1, 2)) # data_r2 = np.rot90(data, k=1, axes=(1, 2))
data_alt_r2 = np.copy(data_alt) # data_alt_r2 = np.copy(data_alt)
label_r2 = np.copy(label) # label_r2 = np.copy(label)
#
data = np.concatenate([data, data_ud, data_lr, data_r1, data_r2]) # data = np.concatenate([data, data_ud, data_lr, data_r1, data_r2])
data_alt = np.concatenate([data_alt, data_alt_ud, data_alt_lr, data_alt_r1, data_alt_r2]) # data_alt = np.concatenate([data_alt, data_alt_ud, data_alt_lr, data_alt_r1, data_alt_r2])
label = np.concatenate([label, label_ud, label_lr, label_r1, label_r2]) # label = np.concatenate([label, label_ud, label_lr, label_r1, label_r2])
return data, data_alt, label return data, data_alt, label
...@@ -475,8 +476,9 @@ class IcingIntensityFCN: ...@@ -475,8 +476,9 @@ class IcingIntensityFCN:
dataset = dataset.batch(PROC_BATCH_SIZE) dataset = dataset.batch(PROC_BATCH_SIZE)
dataset = dataset.map(self.data_function, num_parallel_calls=8) dataset = dataset.map(self.data_function, num_parallel_calls=8)
dataset = dataset.cache() dataset = dataset.cache()
dataset = dataset.shuffle(PROC_BATCH_BUFFER_SIZE, reshuffle_each_iteration=True)
if DO_AUGMENT: if DO_AUGMENT:
dataset = dataset.shuffle(PROC_BATCH_BUFFER_SIZE) dataset = dataset.map(augment_image(), num_parallel_calls=8)
dataset = dataset.prefetch(buffer_size=1) dataset = dataset.prefetch(buffer_size=1)
self.train_dataset = dataset self.train_dataset = dataset
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment