diff --git a/modules/deeplearning/srcnn_l1b_l2.py b/modules/deeplearning/srcnn_l1b_l2.py index 5e123688ab9d25a4a9ad0ae555dfd25244b48b9b..72ad40aa3066e34a346c2980ca58797c473797e5 100644 --- a/modules/deeplearning/srcnn_l1b_l2.py +++ b/modules/deeplearning/srcnn_l1b_l2.py @@ -738,18 +738,18 @@ class SRCNN: self.reset_test_metrics() - for data in self.eval_dataset: - pred = self.model([data], training=False) - pred = pred.numpy() - if label_param != 'cloud_probability': - pred = denormalize(pred, label_param, mean_std_dct) - print(pred.min(), pred.max()) - - # pred = self.model([data], training=False) - # self.test_probs = pred - # pred = pred.numpy() - # if label_param != 'cloud_probability': - # pred = denormalize(pred, label_param, mean_std_dct) + # for data in self.eval_dataset: + # pred = self.model([data], training=False) + # pred = pred.numpy() + # if label_param != 'cloud_probability': + # pred = denormalize(pred, label_param, mean_std_dct) + # print(pred.min(), pred.max()) + + pred = self.model([data], training=False) + self.test_probs = pred + pred = pred.numpy() + if label_param != 'cloud_probability': + pred = denormalize(pred, label_param, mean_std_dct) return pred