Skip to content
Snippets Groups Projects
Commit 88ea29a8 authored by tomrink's avatar tomrink
Browse files

minor

parent e5d98d99
No related branches found
No related tags found
No related merge requests found
......@@ -2,7 +2,6 @@ import glob
import tensorflow as tf
from util.setup import logdir, modeldir, cachepath, now, ancillary_path, home_dir
from util.util import EarlyStop, normalize, make_for_full_domain_predict
import matplotlib.pyplot as plt
import os, datetime
import numpy as np
......@@ -187,102 +186,6 @@ class UNET:
tf.debugging.set_log_device_placement(LOG_DEVICE_PLACEMENT)
# Doesn't seem to play well with SLURM
# gpus = tf.config.experimental.list_physical_devices('GPU')
# if gpus:
# try:
# # Currently, memory growth needs to be the same across GPUs
# for gpu in gpus:
# tf.config.experimental.set_memory_growth(gpu, True)
# logical_gpus = tf.config.experimental.list_logical_devices('GPU')
# print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
# except RuntimeError as e:
# # Memory growth must be set before GPUs have been initialized
# print(e)
# def get_in_mem_data_batch(self, idxs, is_training):
#
# # sort these to use as numpy indexing arrays
# nd_idxs = np.array(idxs)
# nd_idxs = np.sort(nd_idxs)
#
# data = []
# for param in self.train_params:
# nda = self.get_parameter_data(param, nd_idxs, is_training)
# nda = normalize(nda, param, mean_std_dct)
# if DO_ZERO_OUT and is_training:
# try:
# zero_out_params.index(param)
# nda[:,] = 0.0
# except ValueError:
# pass
# data.append(nda)
# data = np.stack(data)
# data = data.astype(np.float32)
# data = np.transpose(data, axes=(1, 2, 3, 0))
#
# data_alt = self.get_scalar_data(nd_idxs, is_training)
#
# label = self.get_label_data(nd_idxs, is_training)
# label = np.where(label == -1, 0, label)
#
# # binary, two class
# if NumClasses == 2:
# label = np.where(label != 0, 1, label)
# label = label.reshape((label.shape[0], 1))
# elif NumClasses == 3:
# label = np.where(np.logical_or(label == 1, label == 2), 1, label)
# label = np.where(np.invert(np.logical_or(label == 0, label == 1)), 2, label)
# label = label.reshape((label.shape[0], 1))
#
# if is_training and DO_AUGMENT:
# data_ud = np.flip(data, axis=1)
# data_alt_ud = np.copy(data_alt)
# label_ud = np.copy(label)
#
# data_lr = np.flip(data, axis=2)
# data_alt_lr = np.copy(data_alt)
# label_lr = np.copy(label)
#
# data = np.concatenate([data, data_ud, data_lr])
# data_alt = np.concatenate([data_alt, data_alt_ud, data_alt_lr])
# label = np.concatenate([label, label_ud, label_lr])
#
# return data, data_alt, label
def get_in_mem_data_batch_salt(self, idxs, is_training):
data = []
for k in idxs:
if is_training:
nda = plt.imread(self.train_data_files[k])
else:
nda = plt.imread(self.test_data_files[k])
data.append(nda[0:64, 0:64])
data = np.stack(data)
data = data.astype(np.float32)
label = []
for k in idxs:
if is_training:
nda = plt.imread(self.train_label_files[k])
else:
nda = plt.imread(self.test_label_files[k])
label.append(nda[0:64, 0:64])
label = np.stack(label)
label = label.astype(np.int32)
if is_training and DO_AUGMENT:
data_ud = np.flip(data, axis=1)
label_ud = np.flip(label, axis=1)
data_lr = np.flip(data, axis=2)
label_lr = np.flip(label, axis=2)
data = np.concatenate([data, data_ud, data_lr])
label = np.concatenate([label, label_ud, label_lr])
return data, data, label
def get_in_mem_data_batch(self, idxs, is_training):
if is_training:
train_data = []
......@@ -298,6 +201,7 @@ class UNET:
data = np.concatenate(train_data)
data = np.expand_dims(data, axis=3)
label = np.concatenate(train_label)
label = np.expand_dims(label, axis=3)
else:
......@@ -336,38 +240,6 @@ class UNET:
return data, data, label
# def get_parameter_data(self, param, nd_idxs, is_training):
# if is_training:
# if param in self.train_params_l1b:
# h5f = self.h5f_l1b_trn
# else:
# h5f = self.h5f_l2_trn
# else:
# if param in self.train_params_l1b:
# h5f = self.h5f_l1b_tst
# else:
# h5f = self.h5f_l2_tst
#
# nda = h5f[param][nd_idxs,]
# return nda
#
# def get_label_data(self, nd_idxs, is_training):
# # Note: labels will be same for nd_idxs across both L1B and L2
# if is_training:
# if self.h5f_l1b_trn is not None:
# h5f = self.h5f_l1b_trn
# else:
# h5f = self.h5f_l2_trn
# else:
# if self.h5f_l1b_tst is not None:
# h5f = self.h5f_l1b_tst
# else:
# h5f = self.h5f_l2_tst
#
# label = h5f['icing_intensity'][nd_idxs]
# label = label.astype(np.int32)
# return label
def get_in_mem_data_batch_train(self, idxs):
return self.get_in_mem_data_batch(idxs, True)
......@@ -437,55 +309,6 @@ class UNET:
dataset = dataset.map(self.data_function_evaluate, num_parallel_calls=8)
self.eval_dataset = dataset
# def setup_pipeline(self, filename_l1b_trn, filename_l1b_tst, filename_l2_trn, filename_l2_tst, trn_idxs=None, tst_idxs=None, seed=None):
# if filename_l1b_trn is not None:
# self.h5f_l1b_trn = h5py.File(filename_l1b_trn, 'r')
# if filename_l1b_tst is not None:
# self.h5f_l1b_tst = h5py.File(filename_l1b_tst, 'r')
# if filename_l2_trn is not None:
# self.h5f_l2_trn = h5py.File(filename_l2_trn, 'r')
# if filename_l2_tst is not None:
# self.h5f_l2_tst = h5py.File(filename_l2_tst, 'r')
#
# if trn_idxs is None:
# # Note: time is same across both L1B and L2 for idxs
# if self.h5f_l1b_trn is not None:
# h5f = self.h5f_l1b_trn
# else:
# h5f = self.h5f_l2_trn
# time = h5f['time']
# trn_idxs = np.arange(time.shape[0])
# if seed is not None:
# np.random.seed(seed)
# np.random.shuffle(trn_idxs)
#
# if self.h5f_l1b_tst is not None:
# h5f = self.h5f_l1b_tst
# else:
# h5f = self.h5f_l2_tst
# time = h5f['time']
# tst_idxs = np.arange(time.shape[0])
# if seed is not None:
# np.random.seed(seed)
# np.random.shuffle(tst_idxs)
#
# self.num_data_samples = trn_idxs.shape[0]
#
# self.get_train_dataset(trn_idxs)
# self.get_test_dataset(tst_idxs)
#
# print('datetime: ', now)
# print('training and test data: ')
# print(filename_l1b_trn)
# print(filename_l1b_tst)
# print(filename_l2_trn)
# print(filename_l2_tst)
# print('---------------------------')
# print('num train samples: ', self.num_data_samples)
# print('BATCH SIZE: ', BATCH_SIZE)
# print('num test samples: ', tst_idxs.shape[0])
# print('setup_pipeline: Done')
def setup_pipeline(self, data_nda, label_nda, perc=0.20):
num_samples = data_nda.shape[0]
......@@ -572,47 +395,6 @@ class UNET:
self.get_evaluate_dataset(idxs)
def setup_salt_pipeline(self, data_path, label_path, perc=0.15):
data_files = np.array(glob.glob(data_path+'*.png'))
label_files = np.array(glob.glob(label_path+'*.png'))
num_data_samples = len(data_files)
idxs = np.arange(num_data_samples)
np.random.shuffle(idxs)
num_train = int(num_data_samples*(1-perc))
self.num_data_samples = num_train
trn_idxs = idxs[0:num_train]
np.sort(trn_idxs)
tst_idxs = idxs[num_train:]
np.sort(tst_idxs)
self.train_data_files = data_files[trn_idxs]
self.train_label_files = label_files[trn_idxs]
self.test_data_files = data_files[tst_idxs]
self.test_label_files = label_files[tst_idxs]
trn_idxs = np.arange(len(trn_idxs))
tst_idxs = np.arange(len(tst_idxs))
np.random.shuffle(trn_idxs)
self.get_train_dataset(trn_idxs)
self.get_test_dataset(tst_idxs)
f = open(home_dir+'/salt_test_files.pkl', 'wb')
pickle.dump((self.test_data_files, self.test_label_files), f)
f.close()
def setup_salt_restore(self, test_files='/Users/tomrink/salt_test_files.pkl'):
tup = pickle.load(open(test_files, 'rb'))
self.test_data_files = tup[0]
self.test_label_files = tup[1]
self.num_data_samples = len(self.test_data_files)
tst_idxs = np.arange(self.num_data_samples)
self.get_test_dataset(tst_idxs)
def build_unet(self):
print('build_cnn')
# padding = "VALID"
......@@ -623,7 +405,6 @@ class UNET:
activation = tf.nn.leaky_relu
momentum = 0.99
# num_filters = len(self.train_params) * 4
num_filters = self.n_chans * 4
input_2d = self.inputs[0]
......@@ -1075,21 +856,7 @@ class UNET:
preds = np.argmax(preds, axis=1)
self.test_preds = preds
def run(self, filename_l1b_trn, filename_l1b_tst, filename_l2_trn, filename_l2_tst):
self.setup_pipeline(filename_l1b_trn, filename_l1b_tst, filename_l2_trn, filename_l2_tst)
self.build_model()
self.build_training()
self.build_evaluation()
self.do_training()
def run_salt(self, data_path='/Users/tomrink/data/salt/train/images/', label_path='/Users/tomrink/data/salt/train/masks/'):
self.setup_salt_pipeline(data_path, label_path)
self.build_model()
self.build_training()
self.build_evaluation()
self.do_training()
def run_test(self, directory):
def run(self, directory):
data_files = glob.glob(directory+'mod_res*.npy')
label_files = [f.replace('mod', 'img') for f in data_files]
self.setup_pipeline_files(data_files, label_files)
......@@ -1098,14 +865,6 @@ class UNET:
self.build_evaluation()
self.do_training()
def run_restore_test(self, data_path='/Users/tomrink/data/salt/train/images/', label_path='/Users/tomrink/data/salt/train/masks/', ckpt_dir=None):
#self.setup_salt_pipeline(data_path, label_path)
self.setup_salt_restore()
self.build_model()
self.build_training()
self.build_evaluation()
self.restore(ckpt_dir)
def run_restore(self, filename_l1b, filename_l2, ckpt_dir):
self.setup_test_pipeline(filename_l1b, filename_l2)
self.build_model()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment