diff --git a/modules/deeplearning/cloudheight.py b/modules/deeplearning/cloudheight.py
index ceede4815878cda6531113f801b2594c1a6431dc..f6694ebd46df8350b0deeb3a3fe0dfd8059f2ce9 100644
--- a/modules/deeplearning/cloudheight.py
+++ b/modules/deeplearning/cloudheight.py
@@ -797,15 +797,15 @@ class CloudHeightNN:
             print('loss, acc: ', self.test_loss.result(), self.test_accuracy.result())
             ckpt_manager.save()
 
+            if DISK_CACHE and epoch == 0:
+                f = open(cachepath, 'wb')
+                pickle.dump(self.in_mem_data_cache, f)
+                f.close()
+
         print('total time: ', total_time)
         self.writer_train.close()
         self.writer_valid.close()
 
-        if DISK_CACHE:
-            f = open(cachepath, 'wb')
-            pickle.dump(self.in_mem_data_cache, f)
-            f.close()
-
     def build_model(self):
         flat = self.build_cnn()
         flat_1d = self.build_1d_cnn()