Skip to content
Snippets Groups Projects
Commit f60c7853 authored by tomrink's avatar tomrink
Browse files

minor...

parent 9b257830
No related branches found
No related tags found
No related merge requests found
import tensorflow as tf import tensorflow as tf
import tensorflow_addons as tfa import tensorflow_addons as tfa
from util.setup import logdir, modeldir, cachepath from util.setup import logdir, modeldir, cachepath
from util.util import homedir from util.util import homedir, EarlyStop
import os, datetime import os, datetime
import numpy as np import numpy as np
...@@ -22,6 +22,7 @@ BATCH_SIZE = 256 ...@@ -22,6 +22,7 @@ BATCH_SIZE = 256
NUM_EPOCHS = 200 NUM_EPOCHS = 200
TRACK_MOVING_AVERAGE = True TRACK_MOVING_AVERAGE = True
EARLY_STOP = True
TRIPLET = False TRIPLET = False
CONV3D = False CONV3D = False
...@@ -591,6 +592,9 @@ class IcingIntensityNN: ...@@ -591,6 +592,9 @@ class IcingIntensityNN:
step = 0 step = 0
total_time = 0 total_time = 0
if EARLY_STOP:
es = EarlyStop()
for epoch in range(NUM_EPOCHS): for epoch in range(NUM_EPOCHS):
self.train_loss.reset_states() self.train_loss.reset_states()
self.train_accuracy.reset_states() self.train_accuracy.reset_states()
...@@ -666,6 +670,9 @@ class IcingIntensityNN: ...@@ -666,6 +670,9 @@ class IcingIntensityNN:
pickle.dump(self.in_mem_data_cache, f) pickle.dump(self.in_mem_data_cache, f)
f.close() f.close()
if es.check_stop(self.test_loss.result().numpy()):
break
print('total time: ', total_time) print('total time: ', total_time)
self.writer_train.close() self.writer_train.close()
self.writer_valid.close() self.writer_valid.close()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment