Skip to content
Snippets Groups Projects
Commit 7f2b45f9 authored by tomrink's avatar tomrink
Browse files

snapshot...

parent f20f00e5
Branches
No related tags found
No related merge requests found
......@@ -18,7 +18,7 @@ PROC_BATCH_SIZE = 2046
PROC_BATCH_BUFFER_SIZE = 50000
NumLabels = 1
BATCH_SIZE = 256
NUM_EPOCHS = 60
NUM_EPOCHS = 80
TRACK_MOVING_AVERAGE = False
......@@ -243,7 +243,7 @@ class IcingIntensityNN:
dataset = dataset.map(self.data_function, num_parallel_calls=8)
self.test_dataset = dataset
def setup_pipeline(self, filename, train_idxs=None, test_idxs=None):
def setup_pipeline(self, filename):
self.filename = filename
self.h5f = h5py.File(filename, 'r')
time = self.h5f['time']
......@@ -592,16 +592,16 @@ class IcingIntensityNN:
self.predict(mini_batch_test)
print('loss, acc: ', self.test_loss.result(), self.test_accuracy.result())
def run(self, filename, filename_l1b=None, train_dict=None, valid_dict=None):
def run(self, filename, filename_l1b=None):
with tf.device('/device:GPU:'+str(self.gpu_device)):
self.setup_pipeline(filename, train_idxs=train_dict, test_idxs=valid_dict)
self.setup_pipeline(filename)
self.build_model()
self.build_training()
self.build_evaluation()
self.do_training()
def run_restore(self, filename, ckpt_dir, train_dict=None, valid_dict=None):
self.setup_pipeline(filename, train_idxs=train_dict, test_idxs=valid_dict)
def run_restore(self, filename, ckpt_dir):
self.setup_pipeline(filename)
self.build_model()
self.build_training()
self.build_evaluation()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment