Skip to content
Snippets Groups Projects
Commit 38fc2d56 authored by tomrink's avatar tomrink
Browse files

snapshot...

parent 1917df70
No related branches found
No related tags found
No related merge requests found
...@@ -67,6 +67,7 @@ data_params_half = ['temp_11_0um_nom'] ...@@ -67,6 +67,7 @@ data_params_half = ['temp_11_0um_nom']
data_params_full = ['refl_0_65um_nom'] data_params_full = ['refl_0_65um_nom']
label_idx = params.index(label_param) label_idx = params.index(label_param)
# label_idx = 0
print('data_params_half: ', data_params_half) print('data_params_half: ', data_params_half)
print('data_params_full: ', data_params_full) print('data_params_full: ', data_params_full)
...@@ -274,14 +275,14 @@ class SRCNN: ...@@ -274,14 +275,14 @@ class SRCNN:
self.OUT_OF_RANGE = False self.OUT_OF_RANGE = False
self.abi = None # self.abi = None
self.temp = None # self.temp = None
self.wv = None # self.wv = None
self.lbfp = None # self.lbfp = None
self.sfc = None # self.sfc = None
self.in_mem_data_cache = {} # self.in_mem_data_cache = {}
self.in_mem_data_cache_test = {} # self.in_mem_data_cache_test = {}
self.model = None self.model = None
self.optimizer = None self.optimizer = None
...@@ -313,10 +314,10 @@ class SRCNN: ...@@ -313,10 +314,10 @@ class SRCNN:
self.test_data_files = None self.test_data_files = None
self.test_label_files = None self.test_label_files = None
self.train_data_nda = None # self.train_data_nda = None
self.train_label_nda = None # self.train_label_nda = None
self.test_data_nda = None # self.test_data_nda = None
self.test_label_nda = None # self.test_label_nda = None
# self.n_chans = len(data_params_half) + len(data_params_full) + 1 # self.n_chans = len(data_params_half) + len(data_params_full) + 1
self.n_chans = 5 self.n_chans = 5
...@@ -343,6 +344,27 @@ class SRCNN: ...@@ -343,6 +344,27 @@ class SRCNN:
continue continue
data_s.append(nda) data_s.append(nda)
input_data = np.concatenate(data_s) input_data = np.concatenate(data_s)
input_label = input_data[:, label_idx, :, :]
# if is_training:
# data_files = self.train_data_files
# label_files = self.train_label_files
# else:
# data_files = self.test_data_files
# label_files = self.test_label_files
#
# data_s = []
# label_s = []
# for k in idxs:
# f = data_files[k]
# nda = np.load(f)
# data_s.append(nda)
#
# f = label_files[k]
# nda = np.load(f)
# label_s.append(nda)
# input_data = np.concatenate(data_s)
# input_label = np.concatenate(label_s)
DO_ADD_NOISE = False DO_ADD_NOISE = False
if is_training and NOISE_TRAINING: if is_training and NOISE_TRAINING:
...@@ -379,7 +401,7 @@ class SRCNN: ...@@ -379,7 +401,7 @@ class SRCNN:
data_norm.append(avg[:, 0:66, 0:66]) data_norm.append(avg[:, 0:66, 0:66])
# data_norm.append(std[:, 0:66, 0:66]) # data_norm.append(std[:, 0:66, 0:66])
# --------------------------------------------------- # ---------------------------------------------------
tmp = input_data[:, label_idx, :, :] tmp = input_label
tmp = tmp.copy() tmp = tmp.copy()
tmp = np.where(np.isnan(tmp), 0, tmp) tmp = np.where(np.isnan(tmp), 0, tmp)
if DO_ESPCN: if DO_ESPCN:
...@@ -403,7 +425,7 @@ class SRCNN: ...@@ -403,7 +425,7 @@ class SRCNN:
data = data.astype(np.float32) data = data.astype(np.float32)
# ----------------------------------------------------- # -----------------------------------------------------
# ----------------------------------------------------- # -----------------------------------------------------
label = input_data[:, label_idx, :, :] label = input_label
label = label.copy() label = label.copy()
label = label[:, y_128, x_128] label = label[:, y_128, x_128]
label = get_label_data(label) label = get_label_data(label)
...@@ -466,13 +488,18 @@ class SRCNN: ...@@ -466,13 +488,18 @@ class SRCNN:
dataset = dataset.cache() dataset = dataset.cache()
self.test_dataset = dataset self.test_dataset = dataset
def setup_pipeline(self, train_data_files, test_data_files, num_train_samples): def setup_pipeline(self, train_data_files, train_label_files, test_data_files, test_label_files, num_train_samples):
# self.train_data_files = train_data_files
# self.train_label_files = train_label_files
# self.test_data_files = test_data_files
# self.test_label_files = test_label_files
self.train_data_files = train_data_files self.train_data_files = train_data_files
self.test_data_files = test_data_files self.test_data_files = test_data_files
trn_idxs = np.arange(len(train_data_files)) trn_idxs = np.arange(len(train_data_files))
np.random.shuffle(trn_idxs) np.random.shuffle(trn_idxs)
tst_idxs = np.arange(len(test_data_files)) tst_idxs = np.arange(len(test_data_files))
self.get_train_dataset(trn_idxs) self.get_train_dataset(trn_idxs)
...@@ -795,10 +822,16 @@ class SRCNN: ...@@ -795,10 +822,16 @@ class SRCNN:
return pred return pred
def run(self, directory, ckpt_dir=None, num_data_samples=50000): def run(self, directory, ckpt_dir=None, num_data_samples=50000):
# train_data_files = glob.glob(directory+'data_train*.npy')
# valid_data_files = glob.glob(directory+'data_valid*.npy')
# train_label_files = glob.glob(directory+'label_train*.npy')
# valid_label_files = glob.glob(directory+'label_valid*.npy')
# self.setup_pipeline(train_data_files, train_label_files, valid_data_files, valid_label_files, num_data_samples)
train_data_files = glob.glob(directory+'data_train_*.npy') train_data_files = glob.glob(directory+'data_train_*.npy')
valid_data_files = glob.glob(directory+'data_valid_*.npy') valid_data_files = glob.glob(directory+'data_valid_*.npy')
self.setup_pipeline(train_data_files, None, valid_data_files, None, num_data_samples)
self.setup_pipeline(train_data_files, valid_data_files, num_data_samples)
self.build_model() self.build_model()
self.build_training() self.build_training()
self.build_evaluation() self.build_evaluation()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment