Skip to content
Snippets Groups Projects
Commit 8ffbc901 authored by tomrink's avatar tomrink
Browse files

minor...

parent 7d48a97e
Branches
No related tags found
No related merge requests found
......@@ -36,12 +36,12 @@ NOISE_TRAINING = True
NOISE_STDDEV = 0.10
DO_AUGMENT = True
img_width = 16
mean_std_file = home_dir+'/viirs_emis_rad_mean_std.pkl'
f = open(mean_std_file, 'rb')
mean_std_dct = pickle.load(f)
f.close()
f_stats = open(mean_std_file, 'rb')
mean_std_dct = pickle.load(f_stats)
f_stats.close()
param = 'M15'
# -- Zero out params (Experimentation Only) ------------
zero_out_params = ['cld_reff_dcomp', 'cld_opd_dcomp', 'iwc_dcomp', 'lwc_dcomp']
......@@ -93,11 +93,6 @@ class UNET:
self.inner_handle = None
self.in_mem_batch = None
self.h5f_l1b_trn = None
self.h5f_l1b_tst = None
self.h5f_l2_trn = None
self.h5f_l2_tst = None
self.logits = None
self.predict_data = None
......@@ -120,12 +115,6 @@ class UNET:
self.OUT_OF_RANGE = False
self.abi = None
self.temp = None
self.wv = None
self.lbfp = None
self.sfc = None
self.in_mem_data_cache = {}
self.in_mem_data_cache_test = {}
......@@ -159,11 +148,6 @@ class UNET:
self.test_data_files = None
self.test_label_files = None
self.train_data_nda = None
self.train_label_nda = None
self.test_data_nda = None
self.test_label_nda = None
# self.n_chans = len(self.train_params)
self.n_chans = 1
if TRIPLET:
......@@ -171,7 +155,6 @@ class UNET:
self.X_img = tf.keras.Input(shape=(None, None, self.n_chans))
self.inputs.append(self.X_img)
# self.inputs.append(tf.keras.Input(shape=(None, None, 5)))
self.inputs.append(tf.keras.Input(shape=(None, None, 1)))
self.flight_level = 0
......@@ -188,45 +171,34 @@ class UNET:
def get_in_mem_data_batch(self, idxs, is_training):
if is_training:
train_data = []
train_label = []
for k in idxs:
f = self.train_data_files[k]
nda = np.load(f)
train_data.append(nda)
f = self.train_label_files[k]
nda = np.load(f)
train_label.append(nda)
data = np.concatenate(train_data)
data = np.expand_dims(data, axis=3)
label = np.concatenate(train_label)
label = np.expand_dims(label, axis=3)
data_files = self.train_data_files
label_files = self.train_label_files
else:
test_data = []
test_label = []
for k in idxs:
f = self.test_data_files[k]
nda = np.load(f)
test_data.append(nda)
data_files = self.test_data_files
label_files = self.test_label_files
data_s = []
label_s = []
for k in idxs:
f = data_files[k]
nda = np.load(f)
data_s.append(nda)
f = self.test_label_files[k]
nda = np.load(f)
test_label.append(nda)
f = label_files[k]
nda = np.load(f)
label_s.append(nda)
data = np.concatenate(test_data)
data = np.expand_dims(data, axis=3)
data = np.concatenate(data_s)
data = np.expand_dims(data, axis=3)
label = np.concatenate(test_label)
label = np.expand_dims(label, axis=3)
label = np.concatenate(label_s)
label = np.expand_dims(label, axis=3)
data = data.astype(np.float32)
label = label.astype(np.float32)
data = normalize(data, 'M15', mean_std_dct)
label = normalize(label, 'M15', mean_std_dct)
data = normalize(data, param, mean_std_dct)
label = normalize(label, param, mean_std_dct)
if is_training and DO_AUGMENT:
data_ud = np.flip(data, axis=1)
......@@ -337,24 +309,21 @@ class UNET:
# print('num test samples: ', tst_idxs.shape[0])
# print('setup_pipeline: Done')
def setup_pipeline(self, data_files, label_files, perc=0.20):
num_files = len(data_files)
num_test_files = int(num_files * perc)
num_train_files = num_files - num_test_files
def setup_pipeline(self, train_data_files, train_label_files, test_data_files, test_label_files, num_train_samples):
self.train_data_files = data_files[0:num_train_files]
self.train_label_files = label_files[0:num_train_files]
self.test_data_files = data_files[num_train_files:]
self.test_label_files = label_files[num_train_files:]
self.train_data_files = train_data_files
self.train_label_files = train_label_files
self.test_data_files = test_data_files
self.test_label_files = test_label_files
trn_idxs = np.arange(num_train_files)
trn_idxs = np.arange(len(train_data_files))
np.random.shuffle(trn_idxs)
tst_idxs = np.arange(num_test_files)
tst_idxs = np.arange(len(train_data_files))
self.get_train_dataset(trn_idxs)
self.get_test_dataset(tst_idxs)
self.num_data_samples = num_train_files * 30 # approximately
self.num_data_samples = num_train_samples # approximately
print('datetime: ', now)
print('training and test data: ')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment