Skip to content
Snippets Groups Projects
Commit 3d1692f8 authored by tomrink's avatar tomrink
Browse files

snapshot...

parent e1240674
No related branches found
No related tags found
No related merge requests found
...@@ -13,18 +13,17 @@ from util.plot_cm import plot_confusion_matrix ...@@ -13,18 +13,17 @@ from util.plot_cm import plot_confusion_matrix
LOG_DEVICE_PLACEMENT = False LOG_DEVICE_PLACEMENT = False
CACHE_DATA_IN_MEM = True CACHE_DATA_IN_MEM = False
PROC_BATCH_SIZE = 4096 PROC_BATCH_SIZE = 4096
PROC_BATCH_BUFFER_SIZE = 50000 PROC_BATCH_BUFFER_SIZE = 50000
NumClasses = 3 NumClasses = 2
NumLogits = 1 NumLogits = 1
BATCH_SIZE = 256 BATCH_SIZE = 256
NUM_EPOCHS = 50 NUM_EPOCHS = 200
TRACK_MOVING_AVERAGE = False TRACK_MOVING_AVERAGE = False
TRIPLET = False TRIPLET = False
CONV3D = False CONV3D = False
...@@ -187,7 +186,11 @@ class IcingIntensityNN: ...@@ -187,7 +186,11 @@ class IcingIntensityNN:
# Memory growth must be set before GPUs have been initialized # Memory growth must be set before GPUs have been initialized
print(e) print(e)
def get_in_mem_data_batch(self, idxs): def get_in_mem_data_batch(self, idxs, is_training):
h5f = self.h5f_trn
if not is_training:
h5f = self.h5f_tst
key = frozenset(idxs) key = frozenset(idxs)
if CACHE_DATA_IN_MEM: if CACHE_DATA_IN_MEM:
...@@ -201,14 +204,14 @@ class IcingIntensityNN: ...@@ -201,14 +204,14 @@ class IcingIntensityNN:
data = [] data = []
for param in train_params: for param in train_params:
nda = self.h5f[param][nd_idxs, ] nda = h5f[param][nd_idxs, ]
nda = normalize(nda, param, mean_std_dct) nda = normalize(nda, param, mean_std_dct)
data.append(nda) data.append(nda)
data = np.stack(data) data = np.stack(data)
data = data.astype(np.float32) data = data.astype(np.float32)
data = np.transpose(data, axes=(1, 2, 3, 0)) data = np.transpose(data, axes=(1, 2, 3, 0))
label = self.h5f['icing_intensity'][nd_idxs] label = h5f['icing_intensity'][nd_idxs]
label = label.astype(np.int32) label = label.astype(np.int32)
label = np.where(label == -1, 0, label) label = np.where(label == -1, 0, label)
...@@ -256,7 +259,8 @@ class IcingIntensityNN: ...@@ -256,7 +259,8 @@ class IcingIntensityNN:
dataset = tf.data.Dataset.from_tensor_slices(indexes) dataset = tf.data.Dataset.from_tensor_slices(indexes)
dataset = dataset.batch(PROC_BATCH_SIZE) dataset = dataset.batch(PROC_BATCH_SIZE)
dataset = dataset.map(self.data_function, num_parallel_calls=8) dataset = dataset.map(self.data_function, num_parallel_calls=8)
dataset = dataset.shuffle(PROC_BATCH_BUFFER_SIZE) dataset = dataset.cache()
# dataset = dataset.shuffle(PROC_BATCH_BUFFER_SIZE)
dataset = dataset.prefetch(buffer_size=1) dataset = dataset.prefetch(buffer_size=1)
self.train_dataset = dataset self.train_dataset = dataset
...@@ -266,6 +270,7 @@ class IcingIntensityNN: ...@@ -266,6 +270,7 @@ class IcingIntensityNN:
dataset = tf.data.Dataset.from_tensor_slices(indexes) dataset = tf.data.Dataset.from_tensor_slices(indexes)
dataset = dataset.batch(PROC_BATCH_SIZE) dataset = dataset.batch(PROC_BATCH_SIZE)
dataset = dataset.map(self.data_function, num_parallel_calls=8) dataset = dataset.map(self.data_function, num_parallel_calls=8)
dataset = dataset.cache()
self.test_dataset = dataset self.test_dataset = dataset
def setup_pipeline(self, filename_trn, filename_tst, trn_idxs=None, tst_idxs=None, seed=None): def setup_pipeline(self, filename_trn, filename_tst, trn_idxs=None, tst_idxs=None, seed=None):
...@@ -621,6 +626,9 @@ class IcingIntensityNN: ...@@ -621,6 +626,9 @@ class IcingIntensityNN:
self.writer_train.close() self.writer_train.close()
self.writer_valid.close() self.writer_valid.close()
self.h5f_trn.close()
self.h5f_tst.close()
def build_model(self): def build_model(self):
flat = self.build_cnn() flat = self.build_cnn()
# flat_1d = self.build_1d_cnn() # flat_1d = self.build_1d_cnn()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment