diff --git a/modules/deeplearning/icing_cnn.py b/modules/deeplearning/icing_cnn.py index 8a93537a2c854e25f0d88cb6c7f93d797e5037aa..a5ce63e6fd744660781df4ff0e38719766b558d0 100644 --- a/modules/deeplearning/icing_cnn.py +++ b/modules/deeplearning/icing_cnn.py @@ -1,7 +1,7 @@ import tensorflow as tf import tensorflow_addons as tfa from util.setup import logdir, modeldir, cachepath -from util.util import homedir +from util.util import homedir, EarlyStop import os, datetime import numpy as np @@ -22,6 +22,7 @@ BATCH_SIZE = 256 NUM_EPOCHS = 200 TRACK_MOVING_AVERAGE = True +EARLY_STOP = True TRIPLET = False CONV3D = False @@ -591,6 +592,9 @@ class IcingIntensityNN: step = 0 total_time = 0 + if EARLY_STOP: + es = EarlyStop() + for epoch in range(NUM_EPOCHS): self.train_loss.reset_states() self.train_accuracy.reset_states() @@ -666,6 +670,9 @@ class IcingIntensityNN: pickle.dump(self.in_mem_data_cache, f) f.close() + if es.check_stop(self.test_loss.result().numpy()): + break + print('total time: ', total_time) self.writer_train.close() self.writer_valid.close()