diff --git a/modules/deeplearning/unet.py b/modules/deeplearning/unet.py index bab6c4d57ca1a18561f7ee57b83aebb59f4d93c7..5c66a3cd64cba08afdbad09cc93bac434e113243 100644 --- a/modules/deeplearning/unet.py +++ b/modules/deeplearning/unet.py @@ -2,7 +2,6 @@ import glob import tensorflow as tf from util.setup import logdir, modeldir, cachepath, now, ancillary_path, home_dir from util.util import EarlyStop, normalize, make_for_full_domain_predict -import matplotlib.pyplot as plt import os, datetime import numpy as np @@ -187,102 +186,6 @@ class UNET: tf.debugging.set_log_device_placement(LOG_DEVICE_PLACEMENT) - # Doesn't seem to play well with SLURM - # gpus = tf.config.experimental.list_physical_devices('GPU') - # if gpus: - # try: - # # Currently, memory growth needs to be the same across GPUs - # for gpu in gpus: - # tf.config.experimental.set_memory_growth(gpu, True) - # logical_gpus = tf.config.experimental.list_logical_devices('GPU') - # print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs") - # except RuntimeError as e: - # # Memory growth must be set before GPUs have been initialized - # print(e) - - # def get_in_mem_data_batch(self, idxs, is_training): - # - # # sort these to use as numpy indexing arrays - # nd_idxs = np.array(idxs) - # nd_idxs = np.sort(nd_idxs) - # - # data = [] - # for param in self.train_params: - # nda = self.get_parameter_data(param, nd_idxs, is_training) - # nda = normalize(nda, param, mean_std_dct) - # if DO_ZERO_OUT and is_training: - # try: - # zero_out_params.index(param) - # nda[:,] = 0.0 - # except ValueError: - # pass - # data.append(nda) - # data = np.stack(data) - # data = data.astype(np.float32) - # data = np.transpose(data, axes=(1, 2, 3, 0)) - # - # data_alt = self.get_scalar_data(nd_idxs, is_training) - # - # label = self.get_label_data(nd_idxs, is_training) - # label = np.where(label == -1, 0, label) - # - # # binary, two class - # if NumClasses == 2: - # label = np.where(label != 0, 1, label) - # label = label.reshape((label.shape[0], 1)) - # elif NumClasses == 3: - # label = np.where(np.logical_or(label == 1, label == 2), 1, label) - # label = np.where(np.invert(np.logical_or(label == 0, label == 1)), 2, label) - # label = label.reshape((label.shape[0], 1)) - # - # if is_training and DO_AUGMENT: - # data_ud = np.flip(data, axis=1) - # data_alt_ud = np.copy(data_alt) - # label_ud = np.copy(label) - # - # data_lr = np.flip(data, axis=2) - # data_alt_lr = np.copy(data_alt) - # label_lr = np.copy(label) - # - # data = np.concatenate([data, data_ud, data_lr]) - # data_alt = np.concatenate([data_alt, data_alt_ud, data_alt_lr]) - # label = np.concatenate([label, label_ud, label_lr]) - # - # return data, data_alt, label - - def get_in_mem_data_batch_salt(self, idxs, is_training): - data = [] - for k in idxs: - if is_training: - nda = plt.imread(self.train_data_files[k]) - else: - nda = plt.imread(self.test_data_files[k]) - data.append(nda[0:64, 0:64]) - data = np.stack(data) - data = data.astype(np.float32) - - label = [] - for k in idxs: - if is_training: - nda = plt.imread(self.train_label_files[k]) - else: - nda = plt.imread(self.test_label_files[k]) - label.append(nda[0:64, 0:64]) - label = np.stack(label) - label = label.astype(np.int32) - - if is_training and DO_AUGMENT: - data_ud = np.flip(data, axis=1) - label_ud = np.flip(label, axis=1) - - data_lr = np.flip(data, axis=2) - label_lr = np.flip(label, axis=2) - - data = np.concatenate([data, data_ud, data_lr]) - label = np.concatenate([label, label_ud, label_lr]) - - return data, data, label - def get_in_mem_data_batch(self, idxs, is_training): if is_training: train_data = [] @@ -298,6 +201,7 @@ class UNET: data = np.concatenate(train_data) data = np.expand_dims(data, axis=3) + label = np.concatenate(train_label) label = np.expand_dims(label, axis=3) else: @@ -336,38 +240,6 @@ class UNET: return data, data, label - # def get_parameter_data(self, param, nd_idxs, is_training): - # if is_training: - # if param in self.train_params_l1b: - # h5f = self.h5f_l1b_trn - # else: - # h5f = self.h5f_l2_trn - # else: - # if param in self.train_params_l1b: - # h5f = self.h5f_l1b_tst - # else: - # h5f = self.h5f_l2_tst - # - # nda = h5f[param][nd_idxs,] - # return nda - # - # def get_label_data(self, nd_idxs, is_training): - # # Note: labels will be same for nd_idxs across both L1B and L2 - # if is_training: - # if self.h5f_l1b_trn is not None: - # h5f = self.h5f_l1b_trn - # else: - # h5f = self.h5f_l2_trn - # else: - # if self.h5f_l1b_tst is not None: - # h5f = self.h5f_l1b_tst - # else: - # h5f = self.h5f_l2_tst - # - # label = h5f['icing_intensity'][nd_idxs] - # label = label.astype(np.int32) - # return label - def get_in_mem_data_batch_train(self, idxs): return self.get_in_mem_data_batch(idxs, True) @@ -437,55 +309,6 @@ class UNET: dataset = dataset.map(self.data_function_evaluate, num_parallel_calls=8) self.eval_dataset = dataset - # def setup_pipeline(self, filename_l1b_trn, filename_l1b_tst, filename_l2_trn, filename_l2_tst, trn_idxs=None, tst_idxs=None, seed=None): - # if filename_l1b_trn is not None: - # self.h5f_l1b_trn = h5py.File(filename_l1b_trn, 'r') - # if filename_l1b_tst is not None: - # self.h5f_l1b_tst = h5py.File(filename_l1b_tst, 'r') - # if filename_l2_trn is not None: - # self.h5f_l2_trn = h5py.File(filename_l2_trn, 'r') - # if filename_l2_tst is not None: - # self.h5f_l2_tst = h5py.File(filename_l2_tst, 'r') - # - # if trn_idxs is None: - # # Note: time is same across both L1B and L2 for idxs - # if self.h5f_l1b_trn is not None: - # h5f = self.h5f_l1b_trn - # else: - # h5f = self.h5f_l2_trn - # time = h5f['time'] - # trn_idxs = np.arange(time.shape[0]) - # if seed is not None: - # np.random.seed(seed) - # np.random.shuffle(trn_idxs) - # - # if self.h5f_l1b_tst is not None: - # h5f = self.h5f_l1b_tst - # else: - # h5f = self.h5f_l2_tst - # time = h5f['time'] - # tst_idxs = np.arange(time.shape[0]) - # if seed is not None: - # np.random.seed(seed) - # np.random.shuffle(tst_idxs) - # - # self.num_data_samples = trn_idxs.shape[0] - # - # self.get_train_dataset(trn_idxs) - # self.get_test_dataset(tst_idxs) - # - # print('datetime: ', now) - # print('training and test data: ') - # print(filename_l1b_trn) - # print(filename_l1b_tst) - # print(filename_l2_trn) - # print(filename_l2_tst) - # print('---------------------------') - # print('num train samples: ', self.num_data_samples) - # print('BATCH SIZE: ', BATCH_SIZE) - # print('num test samples: ', tst_idxs.shape[0]) - # print('setup_pipeline: Done') - def setup_pipeline(self, data_nda, label_nda, perc=0.20): num_samples = data_nda.shape[0] @@ -572,47 +395,6 @@ class UNET: self.get_evaluate_dataset(idxs) - def setup_salt_pipeline(self, data_path, label_path, perc=0.15): - data_files = np.array(glob.glob(data_path+'*.png')) - label_files = np.array(glob.glob(label_path+'*.png')) - - num_data_samples = len(data_files) - idxs = np.arange(num_data_samples) - np.random.shuffle(idxs) - num_train = int(num_data_samples*(1-perc)) - self.num_data_samples = num_train - trn_idxs = idxs[0:num_train] - np.sort(trn_idxs) - tst_idxs = idxs[num_train:] - np.sort(tst_idxs) - - self.train_data_files = data_files[trn_idxs] - self.train_label_files = label_files[trn_idxs] - - self.test_data_files = data_files[tst_idxs] - self.test_label_files = label_files[tst_idxs] - - trn_idxs = np.arange(len(trn_idxs)) - tst_idxs = np.arange(len(tst_idxs)) - - np.random.shuffle(trn_idxs) - self.get_train_dataset(trn_idxs) - self.get_test_dataset(tst_idxs) - - f = open(home_dir+'/salt_test_files.pkl', 'wb') - pickle.dump((self.test_data_files, self.test_label_files), f) - f.close() - - def setup_salt_restore(self, test_files='/Users/tomrink/salt_test_files.pkl'): - tup = pickle.load(open(test_files, 'rb')) - - self.test_data_files = tup[0] - self.test_label_files = tup[1] - self.num_data_samples = len(self.test_data_files) - - tst_idxs = np.arange(self.num_data_samples) - self.get_test_dataset(tst_idxs) - def build_unet(self): print('build_cnn') # padding = "VALID" @@ -623,7 +405,6 @@ class UNET: activation = tf.nn.leaky_relu momentum = 0.99 - # num_filters = len(self.train_params) * 4 num_filters = self.n_chans * 4 input_2d = self.inputs[0] @@ -1075,21 +856,7 @@ class UNET: preds = np.argmax(preds, axis=1) self.test_preds = preds - def run(self, filename_l1b_trn, filename_l1b_tst, filename_l2_trn, filename_l2_tst): - self.setup_pipeline(filename_l1b_trn, filename_l1b_tst, filename_l2_trn, filename_l2_tst) - self.build_model() - self.build_training() - self.build_evaluation() - self.do_training() - - def run_salt(self, data_path='/Users/tomrink/data/salt/train/images/', label_path='/Users/tomrink/data/salt/train/masks/'): - self.setup_salt_pipeline(data_path, label_path) - self.build_model() - self.build_training() - self.build_evaluation() - self.do_training() - - def run_test(self, directory): + def run(self, directory): data_files = glob.glob(directory+'mod_res*.npy') label_files = [f.replace('mod', 'img') for f in data_files] self.setup_pipeline_files(data_files, label_files) @@ -1098,14 +865,6 @@ class UNET: self.build_evaluation() self.do_training() - def run_restore_test(self, data_path='/Users/tomrink/data/salt/train/images/', label_path='/Users/tomrink/data/salt/train/masks/', ckpt_dir=None): - #self.setup_salt_pipeline(data_path, label_path) - self.setup_salt_restore() - self.build_model() - self.build_training() - self.build_evaluation() - self.restore(ckpt_dir) - def run_restore(self, filename_l1b, filename_l2, ckpt_dir): self.setup_test_pipeline(filename_l1b, filename_l2) self.build_model()