Skip to content
Snippets Groups Projects
Commit 9ceda9ef authored by tomrink's avatar tomrink
Browse files

snapshot...

parent 6282f56e
Branches
No related tags found
No related merge requests found
......@@ -62,10 +62,11 @@ IMG_DEPTH = 1
label_param = 'cloud_probability'
params = ['temp_11_0um_nom', 'refl_0_65um_nom', label_param]
params_i = ['refl_0_65um_nom', label_param]
data_params_half = ['temp_11_0um_nom']
data_params_full = ['refl_0_65um_nom']
label_idx = params.index(label_param)
label_idx = params_i.index(label_param)
# label_idx = 0
print('data_params_half: ', data_params_half)
......@@ -350,7 +351,6 @@ class SRCNN:
label_s.append(nda)
input_data = np.concatenate(data_s)
input_label = np.concatenate(label_s)
input_label = input_label[:, 0, :, :]
data_norm = []
for param in data_params_half:
......@@ -366,8 +366,9 @@ class SRCNN:
data_norm.append(tmp)
for param in data_params_full:
idx = params.index(param)
tmp = input_data[:, idx, :, :]
idx = params_i.index(param)
# tmp = input_data[:, idx, :, :]
tmp = input_label[:, idx, :, :]
tmp = tmp.copy()
lo, hi, std, avg = get_min_max_std(tmp)
......@@ -381,7 +382,8 @@ class SRCNN:
data_norm.append(avg[:, 0:66, 0:66])
# data_norm.append(std[:, 0:66, 0:66])
# ---------------------------------------------------
tmp = input_data[:, label_idx, :, :]
# tmp = input_data[:, label_idx, :, :]
tmp = input_data[:, 2, :, :]
tmp = tmp.copy()
tmp = np.where(np.isnan(tmp), 0, tmp)
if DO_ESPCN:
......@@ -399,7 +401,7 @@ class SRCNN:
data = data.astype(np.float32)
# -----------------------------------------------------
# -----------------------------------------------------
label = input_label
label = input_label[:, label_idx, :, :]
label = label.copy()
label = label[:, y_128, x_128]
if NumClasses == 5:
......@@ -799,10 +801,10 @@ class SRCNN:
return pred
def run(self, directory, ckpt_dir=None, num_data_samples=50000):
train_data_files = glob.glob(directory+'data_train*.npy')
valid_data_files = glob.glob(directory+'data_valid*.npy')
train_label_files = glob.glob(directory+'label_train*.npy')
valid_label_files = glob.glob(directory+'label_valid*.npy')
train_data_files = glob.glob(directory+'train_mres_*.npy')
valid_data_files = glob.glob(directory+'valid_mres*.npy')
train_label_files = glob.glob(directory+'train_ires*.npy')
valid_label_files = glob.glob(directory+'valid_ires_*.npy')
self.setup_pipeline(train_data_files, train_label_files, valid_data_files, valid_label_files, num_data_samples)
# train_data_files = glob.glob(directory+'data_train_*.npy')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment