diff --git a/modules/deeplearning/icing_cnn.py b/modules/deeplearning/icing_cnn.py index 8aea3c05978b262d75230eb133772ab14d6758c8..fc0ad7d87113f817f1ac66fba96ccbe14660cf06 100644 --- a/modules/deeplearning/icing_cnn.py +++ b/modules/deeplearning/icing_cnn.py @@ -584,8 +584,8 @@ class IcingIntensityNN: self.test_loss.reset_states() self.test_accuracy.reset_states() - for abi_tst, temp_tst, lbfp_tst in self.test_dataset: - ds = tf.data.Dataset.from_tensor_slices((abi_tst, temp_tst, lbfp_tst)) + for data0, label in self.test_dataset: + ds = tf.data.Dataset.from_tensor_slices((data0, label)) ds = ds.batch(BATCH_SIZE) for mini_batch_test in ds: self.predict(mini_batch_test) @@ -599,8 +599,8 @@ class IcingIntensityNN: self.build_evaluation() self.do_training() - def run_restore(self, matchup_dict, ckpt_dir): - self.setup_pipeline(None, None, matchup_dict) + def run_restore(self, filename, ckpt_dir, train_dict=None, valid_dict=None): + self.setup_pipeline(filename, train_idxs=train_dict, test_idxs=valid_dict) self.build_model() self.build_training() self.build_evaluation()