diff --git a/modules/util/viirs_l1b_l2.py b/modules/util/viirs_l1b_l2.py index 6d978bb8ba739a9f7ff98c27d5e9cc932d456b8b..880842799ee2680674e2b5d19c181eb3fe1d54fe 100644 --- a/modules/util/viirs_l1b_l2.py +++ b/modules/util/viirs_l1b_l2.py @@ -32,7 +32,7 @@ def run_all(directory, out_directory): cnt = 10 total_num_train_samples = 0 - total_num_test_samples = 0 + total_num_valid_samples = 0 for p in os.scandir(directory): if not p.is_dir(): @@ -105,8 +105,9 @@ def run_all(directory, out_directory): [data_valid_tiles.append(data_tiles[k]) for k in range(n_vld)] [data_train_tiles.append(data_tiles[k]) for k in range(n_vld, num)] + f_cnt += 1 if f_cnt == 10: - cnt += 1 + f_cnt = 0 #label_valid = np.stack(label_valid_tiles) #label_train = np.stack(label_train_tiles) @@ -128,40 +129,11 @@ def run_all(directory, out_directory): print(' file # done: ', cnt) print('num_train_samples, num_valid_samples: ', num_train_samples, num_valid_samples) total_num_train_samples += num_train_samples - total_num_test_samples += num_valid_samples - - f_cnt = 0 - else: - f_cnt += 1 - - # if len(label_train_tiles) == 0 or len(data_train_tiles) == 0: - # continue - # if len(label_train_tiles) != len(data_train_tiles): - # print('weirdness') - # continue + total_num_valid_samples += num_valid_samples - if len(data_train_tiles) == 0: - continue - - #label_valid = np.stack(label_valid_tiles) - #label_train = np.stack(label_train_tiles) - data_valid = np.stack(data_valid_tiles) - data_train = np.stack(data_train_tiles) - - cnt += 1 - np.save(out_directory+'data_train_' + str(cnt), data_train) - np.save(out_directory+'data_valid_' + str(cnt), data_valid) - #np.save(out_directory+'label_train_' + str(cnt), label_train) - #np.save(out_directory+'label_valid_' + str(cnt), label_valid) - - num_train_samples = data_train.shape[0] - num_valid_samples = data_valid.shape[0] - - print('num_train_samples, num_valid_samples: ', num_train_samples, num_valid_samples) - total_num_train_samples += num_train_samples - total_num_test_samples += num_valid_samples + cnt += 1 - print('total_num_train_samples, total_num_valid_samples: ', num_train_samples, num_valid_samples) + print('total_num_train_samples, total_num_valid_samples: ', total_num_train_samples, total_num_valid_samples) def run(data_h5f, label_h5f, data_tiles, label_tiles, mod_tile_width=64, kernel_size=9):