Skip to content
Snippets Groups Projects
Commit 3cfb756b authored by tomrink's avatar tomrink
Browse files

snapshot...

parent a0ed1dbf
No related branches found
No related tags found
No related merge requests found
...@@ -15,7 +15,7 @@ LOG_DEVICE_PLACEMENT = False ...@@ -15,7 +15,7 @@ LOG_DEVICE_PLACEMENT = False
CACHE_DATA_IN_MEM = True CACHE_DATA_IN_MEM = True
PROC_BATCH_SIZE = 2046 PROC_BATCH_SIZE = 4096
PROC_BATCH_BUFFER_SIZE = 50000 PROC_BATCH_BUFFER_SIZE = 50000
NumClasses = 3 NumClasses = 3
NumLogits = 1 NumLogits = 1
...@@ -210,6 +210,25 @@ class IcingIntensityNN: ...@@ -210,6 +210,25 @@ class IcingIntensityNN:
label = label.astype(np.int32) label = label.astype(np.int32)
label = np.where(label == -1, 0, label) label = np.where(label == -1, 0, label)
# Augmentation, TODO: work into Dataset.map (efficiency)
data_aug = []
label_aug = []
for k in range(label.shape[0]):
if label[k] == 3 or label[k] == 4 or label[k] == 5 or label[k] == 6:
data_aug.append(tf.image.flip_up_down(data[k,]).numpy())
data_aug.append(tf.image.flip_left_right(data[k,]).numpy())
data_aug.append(tf.image.rot90(data[k,]).numpy())
label_aug.append(label[k])
label_aug.append(label[k])
label_aug.append(label[k])
data_aug = np.stack(data_aug)
label_aug = np.stack(label_aug)
data = np.concatenate([data, data_aug])
label = np.concatenate([label, label_aug])
# --------------------------------------------------------
# binary, two class # binary, two class
if NumClasses == 2: if NumClasses == 2:
label = np.where(label != 0, 1, label) label = np.where(label != 0, 1, label)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment