diff --git a/modules/deeplearning/icing_cnn.py b/modules/deeplearning/icing_cnn.py index 0fcde07636b9673ab57d4a3846da769a5a9f7cef..81e6b716ddef1831b9e4e31ff81ee5f677c946db 100644 --- a/modules/deeplearning/icing_cnn.py +++ b/modules/deeplearning/icing_cnn.py @@ -308,7 +308,7 @@ class IcingIntensityNN: print('num test samples: ', tst_idxs.shape[0]) print('setup_pipeline: Done') - def setup_test_pipeline(self, filename, seed=None): + def setup_test_pipeline(self, filename, seed=None, shuffle=False): self.filename_tst = filename self.h5f_tst = h5py.File(filename, 'r') @@ -317,7 +317,8 @@ class IcingIntensityNN: self.num_data_samples = len(tst_idxs) if seed is not None: np.random.seed(seed) - np.random.shuffle(tst_idxs) + if shuffle: + np.random.shuffle(tst_idxs) self.get_test_dataset(tst_idxs)