Skip to content
Snippets Groups Projects
Commit abe9ab43 authored by tomrink's avatar tomrink
Browse files

minor...

parent f8243d43
Branches
No related tags found
No related merge requests found
......@@ -297,6 +297,21 @@ class IcingIntensityNN:
print('num test samples: ', tst_idxs.shape[0])
print('setup_pipeline: Done')
def setup_test_pipeline(self, filename, seed=None):
self.filename_tst = filename
self.h5f_tst = h5py.File(filename, 'r')
time = self.h5f_tst['time']
tst_idxs = np.arange(time.shape[0])
if seed is not None:
np.random.seed(seed)
np.random.shuffle(tst_idxs)
self.get_test_dataset(tst_idxs)
print('num test samples: ', tst_idxs.shape[0])
print('setup_test_pipeline: Done')
def build_1d_cnn(self):
print('build_1d_cnn')
# padding = 'VALID'
......@@ -667,7 +682,7 @@ class IcingIntensityNN:
self.test_labels = labels
self.test_preds = preds
def run(self, filename_trn, filename_tst, filename_l1b=None):
def run(self, filename_trn, filename_tst):
with tf.device('/device:GPU:'+str(self.gpu_device)):
self.setup_pipeline(filename_trn, filename_tst)
self.build_model()
......@@ -675,8 +690,8 @@ class IcingIntensityNN:
self.build_evaluation()
self.do_training()
def run_restore(self, filename, ckpt_dir):
self.setup_pipeline(filename)
def run_restore(self, filename_tst, ckpt_dir):
self.setup_test_pipeline(filename_tst)
self.build_model()
self.build_training()
self.build_evaluation()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment