Skip to content
Snippets Groups Projects
Commit 3767a477 authored by tomrink's avatar tomrink
Browse files

snapshot...

parent afc6a33e
No related branches found
No related tags found
No related merge requests found
......@@ -13,7 +13,7 @@ import h5py
LOG_DEVICE_PLACEMENT = False
PROC_BATCH_SIZE = 50
PROC_BATCH_SIZE = 2048
PROC_BATCH_BUFFER_SIZE = 50000
NumClasses = 2
......@@ -212,45 +212,33 @@ class UNET:
# print(e)
def get_in_mem_data_batch(self, idxs, is_training):
if is_training:
train_data = []
train_label = []
for k in idxs:
f = self.train_data_files[k]
nda = np.load(f)
train_data.append(nda)
f = self.train_label_files[k]
nda = np.load(f)
train_label.append(nda[:, 0, :, :])
np.sort(idxs)
data = np.concatenate(train_data)
if is_training:
f = self.train_data_files
nda = np.load(f)
data = nda[idxs, :, :, :]
label = np.concatenate(train_label)
f = self.train_label_files
nda = np.load(f)
label = nda[idxs, 0, :, :]
label = np.expand_dims(label, axis=3)
else:
test_data = []
test_label = []
for k in idxs:
f = self.test_data_files[k]
nda = np.load(f)
test_data.append(nda)
f = self.test_data_files
nda = np.load(f)
data = nda[idxs, :, :, :]
f = self.test_label_files[k]
nda = np.load(f)
test_label.append(nda[:, 0, :, :])
data = np.concatenate(test_data)
label = np.concatenate(test_label)
f = self.test_label_files
nda = np.load(f)
label = nda[idxs, 0, :, :]
label = np.expand_dims(label, axis=3)
data = data.astype(np.float32)
label = label.astype(np.float32)
data_norm = []
for idx, param in enumerate(emis_params):
tmp = normalize(data[:, idx, :, :], param, mean_std_dct)
for k, param in enumerate(emis_params):
tmp = normalize(data[:, k, :, :], param, mean_std_dct)
data_norm.append(tmp)
data = np.stack(data_norm, axis=3)
......@@ -365,27 +353,21 @@ class UNET:
print('num test samples: ', tst_idxs.shape[0])
print('setup_pipeline: Done')
def setup_pipeline_files(self, data_files, label_files, perc=0.20):
num_files = len(data_files)
num_test_files = int(num_files * perc)
num_train_files = num_files - num_test_files
num_test_files = 1
num_train_files = 3
def setup_pipeline_files(self, train_data_files, train_label_files, test_data_files, test_label_files):
self.train_data_files = data_files[0:num_train_files]
self.train_label_files = label_files[0:num_train_files]
self.test_data_files = data_files[num_train_files:]
self.test_label_files = label_files[num_train_files:]
self.train_data_files = train_data_files
self.train_label_files = train_label_files
self.test_data_files = test_data_files
self.test_label_files = test_label_files
trn_idxs = np.arange(num_train_files)
trn_idxs = np.arange(10503)
np.random.shuffle(trn_idxs)
tst_idxs = np.arange(num_test_files)
tst_idxs = np.arange(1167)
self.get_train_dataset(trn_idxs)
self.get_test_dataset(tst_idxs)
self.num_data_samples = num_train_files * 1000 # approximately
self.num_data_samples = 10503
print('datetime: ', now)
print('training and test data: ')
......@@ -895,10 +877,11 @@ class UNET:
self.build_evaluation()
self.do_training()
def run_test(self, directory):
data_files = glob.glob(directory+'l1b_*.npy')
label_files = [f.replace('l1b', 'l2') for f in data_files]
self.setup_pipeline_files(data_files, label_files)
# def run_test(self, directory):
# data_files = glob.glob(directory+'l1b_*.npy')
# label_files = [f.replace('l1b', 'l2') for f in data_files]
def run_test(self, train_data_files, train_label_files, vld_data_files, vld_label_files):
self.setup_pipeline_files(train_data_files, train_label_files, vld_data_files, vld_label_files)
self.build_model()
self.build_training()
self.build_evaluation()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment