Skip to content
Snippets Groups Projects
Commit 3767a477 authored by tomrink's avatar tomrink
Browse files

snapshot...

parent afc6a33e
Branches
No related tags found
No related merge requests found
......@@ -13,7 +13,7 @@ import h5py
LOG_DEVICE_PLACEMENT = False
PROC_BATCH_SIZE = 50
PROC_BATCH_SIZE = 2048
PROC_BATCH_BUFFER_SIZE = 50000
NumClasses = 2
......@@ -212,45 +212,33 @@ class UNET:
# print(e)
def get_in_mem_data_batch(self, idxs, is_training):
np.sort(idxs)
if is_training:
train_data = []
train_label = []
for k in idxs:
f = self.train_data_files[k]
f = self.train_data_files
nda = np.load(f)
train_data.append(nda)
data = nda[idxs, :, :, :]
f = self.train_label_files[k]
f = self.train_label_files
nda = np.load(f)
train_label.append(nda[:, 0, :, :])
data = np.concatenate(train_data)
label = np.concatenate(train_label)
label = nda[idxs, 0, :, :]
label = np.expand_dims(label, axis=3)
else:
test_data = []
test_label = []
for k in idxs:
f = self.test_data_files[k]
f = self.test_data_files
nda = np.load(f)
test_data.append(nda)
data = nda[idxs, :, :, :]
f = self.test_label_files[k]
f = self.test_label_files
nda = np.load(f)
test_label.append(nda[:, 0, :, :])
data = np.concatenate(test_data)
label = np.concatenate(test_label)
label = nda[idxs, 0, :, :]
label = np.expand_dims(label, axis=3)
data = data.astype(np.float32)
label = label.astype(np.float32)
data_norm = []
for idx, param in enumerate(emis_params):
tmp = normalize(data[:, idx, :, :], param, mean_std_dct)
for k, param in enumerate(emis_params):
tmp = normalize(data[:, k, :, :], param, mean_std_dct)
data_norm.append(tmp)
data = np.stack(data_norm, axis=3)
......@@ -365,27 +353,21 @@ class UNET:
print('num test samples: ', tst_idxs.shape[0])
print('setup_pipeline: Done')
def setup_pipeline_files(self, data_files, label_files, perc=0.20):
num_files = len(data_files)
num_test_files = int(num_files * perc)
num_train_files = num_files - num_test_files
num_test_files = 1
num_train_files = 3
def setup_pipeline_files(self, train_data_files, train_label_files, test_data_files, test_label_files):
self.train_data_files = data_files[0:num_train_files]
self.train_label_files = label_files[0:num_train_files]
self.test_data_files = data_files[num_train_files:]
self.test_label_files = label_files[num_train_files:]
self.train_data_files = train_data_files
self.train_label_files = train_label_files
self.test_data_files = test_data_files
self.test_label_files = test_label_files
trn_idxs = np.arange(num_train_files)
trn_idxs = np.arange(10503)
np.random.shuffle(trn_idxs)
tst_idxs = np.arange(num_test_files)
tst_idxs = np.arange(1167)
self.get_train_dataset(trn_idxs)
self.get_test_dataset(tst_idxs)
self.num_data_samples = num_train_files * 1000 # approximately
self.num_data_samples = 10503
print('datetime: ', now)
print('training and test data: ')
......@@ -895,10 +877,11 @@ class UNET:
self.build_evaluation()
self.do_training()
def run_test(self, directory):
data_files = glob.glob(directory+'l1b_*.npy')
label_files = [f.replace('l1b', 'l2') for f in data_files]
self.setup_pipeline_files(data_files, label_files)
# def run_test(self, directory):
# data_files = glob.glob(directory+'l1b_*.npy')
# label_files = [f.replace('l1b', 'l2') for f in data_files]
def run_test(self, train_data_files, train_label_files, vld_data_files, vld_label_files):
self.setup_pipeline_files(train_data_files, train_label_files, vld_data_files, vld_label_files)
self.build_model()
self.build_training()
self.build_evaluation()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment