Skip to content
Snippets Groups Projects
Commit ea399e9b authored by tomrink's avatar tomrink
Browse files

snapshot...

parent 5ed99af6
No related merge requests found
......@@ -210,63 +210,24 @@ class IcingIntensityNN:
label = label.astype(np.int32)
label = np.where(label == -1, 0, label)
# Augmentation, TODO: work into Dataset.map (efficiency)
data_aug = []
label_aug = []
for k in range(label.shape[0]):
if label[k] == 3 or label[k] == 4 or label[k] == 5 or label[k] == 6:
data_aug.append(tf.image.flip_up_down(data[k,]).numpy())
data_aug.append(tf.image.flip_left_right(data[k,]).numpy())
data_aug.append(tf.image.rot90(data[k,]).numpy())
label_aug.append(label[k])
label_aug.append(label[k])
label_aug.append(label[k])
data_aug = np.stack(data_aug)
label_aug = np.stack(label_aug)
data = np.concatenate([data, data_aug])
label = np.concatenate([label, label_aug])
# --------------------------------------------------------
# binary, two class
if NumClasses == 2:
label = np.where(label != 0, 1, label)
label = label.reshape((label.shape[0], 1))
elif NumClasses == 3:
label = np.where(np.logical_or(label == 1, label == 2), 1, label)
label = np.where(np.invert(np.logical_or(label == 0, label == 1)), 2, label)
label = label.reshape((label.shape[0], 1))
if CACHE_DATA_IN_MEM:
self.in_mem_data_cache[key] = (data, label)
return data, label
def get_in_mem_data_batch_test(self, idxs):
key = frozenset(idxs)
if CACHE_DATA_IN_MEM:
tup = self.in_mem_data_cache.get(key)
if tup is not None:
return tup[0], tup[1]
# sort these to use as numpy indexing arrays
nd_idxs = np.array(idxs)
nd_idxs = np.sort(nd_idxs)
data = []
for param in train_params:
nda = self.h5f[param][nd_idxs, ]
nda = normalize(nda, param, mean_std_dct)
data.append(nda)
data = np.stack(data)
data = data.astype(np.float32)
data = np.transpose(data, axes=(1, 2, 3, 0))
label = self.h5f['icing_intensity'][nd_idxs]
label = label.astype(np.int32)
label = np.where(label == -1, 0, label)
# # Augmentation, TODO: work into Dataset.map (efficiency)
# data_aug = []
# label_aug = []
# for k in range(label.shape[0]):
# if label[k] == 3 or label[k] == 4 or label[k] == 5 or label[k] == 6:
# data_aug.append(tf.image.flip_up_down(data[k,]).numpy())
# data_aug.append(tf.image.flip_left_right(data[k,]).numpy())
# data_aug.append(tf.image.rot90(data[k,]).numpy())
# label_aug.append(label[k])
# label_aug.append(label[k])
# label_aug.append(label[k])
#
# data_aug = np.stack(data_aug)
# label_aug = np.stack(label_aug)
#
# data = np.concatenate([data, data_aug])
# label = np.concatenate([label, label_aug])
# # --------------------------------------------------------
# binary, two class
if NumClasses == 2:
......@@ -282,11 +243,6 @@ class IcingIntensityNN:
return data, label
@tf.function(input_signature=[tf.TensorSpec(None, tf.int32)])
def data_function_test(self, indexes):
out = tf.numpy_function(self.get_in_mem_data_batch_test, [indexes], [tf.float32, tf.int32])
return out
@tf.function(input_signature=[tf.TensorSpec(None, tf.int32)])
def data_function(self, indexes):
out = tf.numpy_function(self.get_in_mem_data_batch, [indexes], [tf.float32, tf.int32])
......@@ -307,7 +263,7 @@ class IcingIntensityNN:
dataset = tf.data.Dataset.from_tensor_slices(indexes)
dataset = dataset.batch(PROC_BATCH_SIZE)
dataset = dataset.map(self.data_function_test, num_parallel_calls=8)
dataset = dataset.map(self.data_function, num_parallel_calls=8)
self.test_dataset = dataset
def setup_pipeline(self, filename, trn_idxs=None, tst_idxs=None, seed=None):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment