Skip to content
Snippets Groups Projects
Commit e9e080f3 authored by tomrink's avatar tomrink
Browse files

snapshot...

parent 7f14d0c8
No related branches found
No related tags found
No related merge requests found
...@@ -2,6 +2,7 @@ import tensorflow as tf ...@@ -2,6 +2,7 @@ import tensorflow as tf
from util.setup import logdir, modeldir, cachepath, now, ancillary_path, home_dir from util.setup import logdir, modeldir, cachepath, now, ancillary_path, home_dir
from util.util import EarlyStop, normalize, make_for_full_domain_predict, get_training_parameters from util.util import EarlyStop, normalize, make_for_full_domain_predict, get_training_parameters
from util.geos_nav import get_navigation from util.geos_nav import get_navigation
from util.augment import augment_image
import os, datetime import os, datetime
import numpy as np import numpy as np
...@@ -315,26 +316,26 @@ class IcingIntensityFCN: ...@@ -315,26 +316,26 @@ class IcingIntensityFCN:
label = np.where(np.invert(np.logical_or(label == 0, label == 1)), 2, label) label = np.where(np.invert(np.logical_or(label == 0, label == 1)), 2, label)
label = label.reshape((label.shape[0], 1)) label = label.reshape((label.shape[0], 1))
if is_training and DO_AUGMENT: # if is_training and DO_AUGMENT:
data_ud = np.flip(data, axis=1) # data_ud = np.flip(data, axis=1)
data_alt_ud = np.copy(data_alt) # data_alt_ud = np.copy(data_alt)
label_ud = np.copy(label) # label_ud = np.copy(label)
#
data_lr = np.flip(data, axis=2) # data_lr = np.flip(data, axis=2)
data_alt_lr = np.copy(data_alt) # data_alt_lr = np.copy(data_alt)
label_lr = np.copy(label) # label_lr = np.copy(label)
#
data_r1 = np.rot90(data, k=1, axes=(1, 2)) # data_r1 = np.rot90(data, k=1, axes=(1, 2))
data_alt_r1 = np.copy(data_alt) # data_alt_r1 = np.copy(data_alt)
label_r1 = np.copy(label) # label_r1 = np.copy(label)
#
data_r2 = np.rot90(data, k=1, axes=(1, 2)) # data_r2 = np.rot90(data, k=1, axes=(1, 2))
data_alt_r2 = np.copy(data_alt) # data_alt_r2 = np.copy(data_alt)
label_r2 = np.copy(label) # label_r2 = np.copy(label)
#
data = np.concatenate([data, data_ud, data_lr, data_r1, data_r2]) # data = np.concatenate([data, data_ud, data_lr, data_r1, data_r2])
data_alt = np.concatenate([data_alt, data_alt_ud, data_alt_lr, data_alt_r1, data_alt_r2]) # data_alt = np.concatenate([data_alt, data_alt_ud, data_alt_lr, data_alt_r1, data_alt_r2])
label = np.concatenate([label, label_ud, label_lr, label_r1, label_r2]) # label = np.concatenate([label, label_ud, label_lr, label_r1, label_r2])
return data, data_alt, label return data, data_alt, label
...@@ -475,8 +476,9 @@ class IcingIntensityFCN: ...@@ -475,8 +476,9 @@ class IcingIntensityFCN:
dataset = dataset.batch(PROC_BATCH_SIZE) dataset = dataset.batch(PROC_BATCH_SIZE)
dataset = dataset.map(self.data_function, num_parallel_calls=8) dataset = dataset.map(self.data_function, num_parallel_calls=8)
dataset = dataset.cache() dataset = dataset.cache()
dataset = dataset.shuffle(PROC_BATCH_BUFFER_SIZE, reshuffle_each_iteration=True)
if DO_AUGMENT: if DO_AUGMENT:
dataset = dataset.shuffle(PROC_BATCH_BUFFER_SIZE) dataset = dataset.map(augment_image(), num_parallel_calls=8)
dataset = dataset.prefetch(buffer_size=1) dataset = dataset.prefetch(buffer_size=1)
self.train_dataset = dataset self.train_dataset = dataset
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment