diff --git a/modules/deeplearning/icing_cnn.py b/modules/deeplearning/icing_cnn.py index 08a57c1061a5f43c6de9982b55f3033decfa9382..678016013fc255894eacc6b2828a697bbdc5ba9c 100644 --- a/modules/deeplearning/icing_cnn.py +++ b/modules/deeplearning/icing_cnn.py @@ -294,10 +294,14 @@ class IcingIntensityNN: if trn_idxs is None: time = self.h5f_trn['time'] trn_idxs = np.arange(time.shape[0]) + if seed is not None: + np.random.seed(seed) np.random.shuffle(trn_idxs) time = self.h5f_tst['time'] tst_idxs = np.arange(time.shape[0]) + if seed is not None: + np.random.seed(seed) np.random.shuffle(tst_idxs) self.num_data_samples = trn_idxs.shape[0]