diff --git a/modules/deeplearning/srcnn.py b/modules/deeplearning/srcnn.py index a30a2c3b8118a587d14656633a1d1ebb5e775129..9d1858d2e24707838adeac6b5df6f30888be7a76 100644 --- a/modules/deeplearning/srcnn.py +++ b/modules/deeplearning/srcnn.py @@ -644,7 +644,6 @@ class SRCNN: self.restore(ckpt_dir) def run_evaluate(self, nda_lr, param, ckpt_dir): - # self.setup_eval_pipeline(filename) self.num_data_samples = 80000 self.build_model() self.build_training() @@ -661,10 +660,17 @@ def prepare(param_idx=1, filename='/Users/tomrink/data_valid_40.npy'): return nda_lr -def run_evaluate_static(nda_lr, param='temp_11_0um_nom', ckpt_dir='/Users/tomrink/tf_model_sres/run-20220805173619/'): +def run_evaluate_static(in_file, out_file, param='temp_11_0um_nom', ckpt_dir='/Users/tomrink/tf_model_sres/run-20220805173619/'): + nda = np.load(in_file) + nda = nda[:, data_idx, 2:133:2, 2:133:2] + nda = np.expand_dims(nda, axis=3) + nn = SRCNN() - out_sr = nn.run_evaluate(nda_lr, param, ckpt_dir) - return out_sr + out_sr = nn.run_evaluate(nda, param, ckpt_dir) + if out_file is not None: + np.save(out_file, out_sr) + else: + return out_sr if __name__ == "__main__":