Skip to content
Snippets Groups Projects
Commit dd11d98b authored by tomrink's avatar tomrink
Browse files

snapshot...

parent a0449ea5
Branches
No related tags found
No related merge requests found
......@@ -62,11 +62,11 @@ IMG_DEPTH = 1
label_param = 'cloud_probability'
params = ['temp_11_0um_nom', 'refl_0_65um_nom', label_param]
# params_i = ['refl_0_65um_nom', label_param]
params_i = ['refl_0_65um_nom', label_param]
data_params_half = ['temp_11_0um_nom']
data_params_full = ['refl_0_65um_nom']
# label_idx = params_i.index(label_param)
label_idx_i = params_i.index(label_param)
label_idx = params.index(label_param)
print('data_params_half: ', data_params_half)
......@@ -77,8 +77,8 @@ KERNEL_SIZE = 3 # target size: (128, 128)
N = 1
if KERNEL_SIZE == 3:
slc_x = slice(2, N*128 + 4)
slc_y = slice(2, N*128 + 4)
# slc_x = slice(2, N*128 + 4)
# slc_y = slice(2, N*128 + 4)
slc_x_2 = slice(1, N*128 + 6, 2)
slc_y_2 = slice(1, N*128 + 6, 2)
x_2 = np.arange(int((N*128)/2) + 3)
......@@ -89,6 +89,8 @@ if KERNEL_SIZE == 3:
y_k = slice(1, N*128 + 3)
# x_128 = slice(3, N*128 + 3)
# y_128 = slice(3, N*128 + 3)
slc_x = slice(1, 67)
slc_y = slice(1, 67)
x_128 = slice(4, N*128 + 4)
y_128 = slice(4, N*128 + 4)
elif KERNEL_SIZE == 5:
......@@ -315,42 +317,42 @@ class SRCNN:
tf.debugging.set_log_device_placement(LOG_DEVICE_PLACEMENT)
def get_in_mem_data_batch(self, idxs, is_training):
if is_training:
files = self.train_data_files
else:
files = self.test_data_files
data_s = []
for k in idxs:
f = files[k]
try:
nda = np.load(f)
except Exception:
print(f)
continue
data_s.append(nda)
input_data = np.concatenate(data_s)
# input_label = input_data[:, label_idx, :, :]
# if is_training:
# data_files = self.train_data_files
# label_files = self.train_label_files
# files = self.train_data_files
# else:
# data_files = self.test_data_files
# label_files = self.test_label_files
# files = self.test_data_files
#
# data_s = []
# label_s = []
# for k in idxs:
# f = data_files[k]
# f = files[k]
# try:
# nda = np.load(f)
# except Exception:
# print(f)
# continue
# data_s.append(nda)
#
# f = label_files[k]
# nda = np.load(f)
# label_s.append(nda)
# input_data = np.concatenate(data_s)
# input_label = np.concatenate(label_s)
# # input_label = input_data[:, label_idx, :, :]
if is_training:
data_files = self.train_data_files
label_files = self.train_label_files
else:
data_files = self.test_data_files
label_files = self.test_label_files
data_s = []
label_s = []
for k in idxs:
f = data_files[k]
nda = np.load(f)
data_s.append(nda)
f = label_files[k]
nda = np.load(f)
label_s.append(nda)
input_data = np.concatenate(data_s)
input_label = np.concatenate(label_s)
data_norm = []
for param in data_params_half:
......@@ -360,14 +362,15 @@ class SRCNN:
if DO_ESPCN:
tmp = tmp[:, slc_y_2, slc_x_2]
else: # Half res upsampled to full res:
tmp = get_grid_cell_mean(tmp)
tmp = tmp[:, 1:67, 1:67]
tmp = tmp[:, slc_y, slc_x]
tmp = normalize(tmp, param, mean_std_dct)
data_norm.append(tmp)
for param in data_params_full:
idx = params.index(param)
tmp = input_data[:, idx, :, :]
# idx = params.index(param)
# tmp = input_data[:, idx, :, :]
idx = params_i.index(param)
tmp = input_label[:, idx, :, :]
tmp = tmp.copy()
lo, hi, std, avg = get_min_max_std(tmp)
......@@ -376,18 +379,17 @@ class SRCNN:
hi = normalize(hi, param, mean_std_dct)
avg = normalize(avg, param, mean_std_dct)
data_norm.append(lo[:, 1:67, 1:67])
data_norm.append(hi[:, 1:67, 1:67])
data_norm.append(avg[:, 1:67, 1:67])
# data_norm.append(std[:, 0:66, 0:66])
data_norm.append(lo[:, slc_y, slc_x])
data_norm.append(hi[:, slc_y, slc_x])
data_norm.append(avg[:, slc_y, slc_x])
# data_norm.append(std[:, slc_y, slc_x])
# ---------------------------------------------------
tmp = input_data[:, label_idx, :, :]
tmp = tmp.copy()
if DO_ESPCN:
tmp = tmp[:, slc_y_2, slc_x_2]
else: # Half res upsampled to full res:
tmp = get_grid_cell_mean(tmp)
tmp = tmp[:, 1:67, 1:67]
tmp = tmp[:, slc_y, slc_x]
if label_param != 'cloud_probability':
tmp = normalize(tmp, label_param, mean_std_dct)
data_norm.append(tmp)
......@@ -396,7 +398,8 @@ class SRCNN:
data = data.astype(np.float32)
# -----------------------------------------------------
# -----------------------------------------------------
label = input_data[:, label_idx, :, :]
# label = input_data[:, label_idx, :, :]
label = input_label[:, label_idx_i, :, :]
label = label.copy()
label = label[:, y_128, x_128]
if NumClasses == 5:
......@@ -463,13 +466,13 @@ class SRCNN:
self.test_dataset = dataset
def setup_pipeline(self, train_data_files, train_label_files, test_data_files, test_label_files, num_train_samples):
# self.train_data_files = train_data_files
# self.train_label_files = train_label_files
# self.test_data_files = test_data_files
# self.test_label_files = test_label_files
self.train_data_files = train_data_files
self.train_label_files = train_label_files
self.test_data_files = test_data_files
self.test_label_files = test_label_files
# self.train_data_files = train_data_files
# self.test_data_files = test_data_files
trn_idxs = np.arange(len(train_data_files))
np.random.shuffle(trn_idxs)
......@@ -489,8 +492,9 @@ class SRCNN:
print('num test samples: ', tst_idxs.shape[0])
print('setup_pipeline: Done')
def setup_test_pipeline(self, test_data_files):
def setup_test_pipeline(self, test_data_files, test_label_files):
self.test_data_files = test_data_files
self.test_label_files = test_label_files
tst_idxs = np.arange(len(test_data_files))
self.get_test_dataset(tst_idxs)
print('setup_test_pipeline: Done')
......@@ -796,15 +800,15 @@ class SRCNN:
return pred
def run(self, directory, ckpt_dir=None, num_data_samples=50000):
# train_data_files = glob.glob(directory+'train*mres*.npy')
# valid_data_files = glob.glob(directory+'valid*mres*.npy')
# train_label_files = glob.glob(directory+'train*ires*.npy')
# valid_label_files = glob.glob(directory+'valid*ires*.npy')
# self.setup_pipeline(train_data_files, train_label_files, valid_data_files, valid_label_files, num_data_samples)
train_data_files = glob.glob(directory+'train*mres*.npy')
valid_data_files = glob.glob(directory+'valid*mres*.npy')
train_label_files = glob.glob(directory+'train*ires*.npy')
valid_label_files = glob.glob(directory+'valid*ires*.npy')
self.setup_pipeline(train_data_files, train_label_files, valid_data_files, valid_label_files, num_data_samples)
train_data_files = glob.glob(directory+'data_train_*.npy')
valid_data_files = glob.glob(directory+'data_valid_*.npy')
self.setup_pipeline(train_data_files, None, valid_data_files, None, num_data_samples)
# train_data_files = glob.glob(directory+'data_train_*.npy')
# valid_data_files = glob.glob(directory+'data_valid_*.npy')
# self.setup_pipeline(train_data_files, None, valid_data_files, None, num_data_samples)
self.build_model()
self.build_training()
......@@ -812,9 +816,14 @@ class SRCNN:
self.do_training(ckpt_dir=ckpt_dir)
def run_restore(self, directory, ckpt_dir):
valid_data_files = glob.glob(directory + 'data_valid*.npy')
self.num_data_samples = 1000
self.setup_test_pipeline(valid_data_files)
# valid_data_files = glob.glob(directory + 'data_valid*.npy')
# self.setup_test_pipeline(valid_data_files, None)
valid_data_files = glob.glob(directory + 'valid*mres*.npy')
valid_label_files = glob.glob(directory + 'valid*ires*.npy')
self.setup_test_pipeline(valid_data_files, valid_label_files)
self.build_model()
self.build_training()
self.build_evaluation()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment