diff --git a/modules/deeplearning/cloudheight_mgpus.py b/modules/deeplearning/cloudheight_mgpus.py
index 05bc45c0df8ecbb9922538df759468cfc862f247..ee1e2fb637a20aa9a248c98ce94f6a08f13214f9 100644
--- a/modules/deeplearning/cloudheight_mgpus.py
+++ b/modules/deeplearning/cloudheight_mgpus.py
@@ -817,6 +817,8 @@ class CloudHeightNN:
             print('loss, acc: ', self.test_loss.result(), self.test_accuracy.result())
             ckpt_manager.save()
 
+            self.DISK_CACHE = True
+
             if self.DISK_CACHE and epoch == 0:
                 f = open(cachepath, 'wb')
                 pickle.dump(self.in_mem_data_cache, f)