diff --git a/modules/deeplearning/unet.py b/modules/deeplearning/unet.py index b1ca9c9e30d0c28d0926bdfdfb7e42b09c6960fb..e16946ef7b6259fcc01c012c34f5f8f0e72ea22c 100644 --- a/modules/deeplearning/unet.py +++ b/modules/deeplearning/unet.py @@ -76,38 +76,6 @@ zero_out_params = ['cld_reff_dcomp', 'cld_opd_dcomp', 'iwc_dcomp', 'lwc_dcomp'] DO_ZERO_OUT = False -def build_residual_block_1x1(input_layer, num_filters, activation, block_name, padding='SAME', drop_rate=0.5, - do_drop_out=True, do_batch_norm=True): - - with tf.name_scope(block_name): - skip = input_layer - if do_drop_out: - input_layer = tf.keras.layers.Dropout(drop_rate)(input_layer) - if do_batch_norm: - input_layer = tf.keras.layers.BatchNormalization()(input_layer) - conv = tf.keras.layers.Conv2D(num_filters, kernel_size=1, strides=1, padding=padding, activation=activation)(input_layer) - print(conv.shape) - - # if do_drop_out: - # conv = tf.keras.layers.Dropout(drop_rate)(conv) - # if do_batch_norm: - # conv = tf.keras.layers.BatchNormalization()(conv) - # conv = tf.keras.layers.Conv2D(num_filters, kernel_size=1, strides=1, padding=padding, activation=activation)(conv) - # print(conv.shape) - - if do_drop_out: - conv = tf.keras.layers.Dropout(drop_rate)(conv) - if do_batch_norm: - conv = tf.keras.layers.BatchNormalization()(conv) - conv = tf.keras.layers.Conv2D(num_filters, kernel_size=1, strides=1, padding=padding, activation=None)(conv) - - conv = conv + skip - conv = tf.keras.layers.LeakyReLU()(conv) - print(conv.shape) - - return conv - - def build_conv2d_block(conv, num_filters, activation, block_name, padding='SAME'): with tf.name_scope(block_name): skip = conv @@ -550,7 +518,7 @@ class UNET: self.get_evaluate_dataset(idxs) - def build_cnn(self): + def build_unet(self): print('build_cnn') # padding = "VALID" padding = "SAME" @@ -665,45 +633,9 @@ class UNET: activation = tf.nn.softmax # For multi-class # Called logits, but these are actually probabilities, see activation - logits = tf.keras.layers.Conv2D(1, kernel_size=1, strides=1, padding=padding, name='probability', activation=activation)(conv) - - print(logits.shape) - - return logits - - def build_fcl(self, input_layer): - print('build fully connected layer') - num_filters = input_layer.shape[3] - - drop_rate = 0.5 + self.logits = tf.keras.layers.Conv2D(1, kernel_size=1, strides=1, padding=padding, name='probability', activation=activation)(conv) - # activation = tf.nn.relu - # activation = tf.nn.elu - activation = tf.nn.leaky_relu - # padding = "VALID" - padding = "SAME" - - conv = build_residual_block_1x1(input_layer, num_filters, activation, 'Residual_Block_1', padding=padding) - - conv = build_residual_block_1x1(conv, num_filters, activation, 'Residual_Block_2', padding=padding) - - conv = build_residual_block_1x1(conv, num_filters, activation, 'Residual_Block_3', padding=padding) - - #conv = build_residual_block_1x1(conv, num_filters, activation, 'Residual_Block_4', padding=padding) - - print(conv.shape) - - if NumClasses == 2: - activation = tf.nn.sigmoid # For binary - else: - activation = tf.nn.softmax # For multi-class - - # Called logits, but these are actually probabilities, see activation - logits = tf.keras.layers.Conv2D(1, kernel_size=1, strides=1, padding=padding, name='probability', activation=activation)(conv) - - print(logits.shape) - - self.logits = logits + print(self.logits.shape) def build_training(self): if NumClasses == 2: @@ -981,11 +913,7 @@ class UNET: f.close() def build_model(self): - cnn = self.build_cnn() - print(cnn.shape, self.inputs[1].shape) - if self.USE_FLIGHT_ALTITUDE: - cnn = tf.keras.layers.concatenate([cnn, self.inputs[1]]) - self.build_fcl(cnn) + self.build_unet() self.model = tf.keras.Model(self.inputs, self.logits) def restore(self, ckpt_dir): @@ -1070,165 +998,6 @@ class UNET: self.do_evaluate(ckpt_dir) -def run_restore_static(filename_l1b, filename_l2, ckpt_dir_s_path, day_night='DAY', use_flight_altitude=False): - ckpt_dir_s = os.listdir(ckpt_dir_s_path) - cm_s = [] - prob_s = [] - labels = None - - for ckpt in ckpt_dir_s: - ckpt_dir = ckpt_dir_s_path + ckpt - if not os.path.isdir(ckpt_dir): - continue - nn = IcingIntensityFCN(day_night=day_night, use_flight_altitude=use_flight_altitude) - nn.run_restore(filename_l1b, filename_l2, ckpt_dir) - cm_s.append(tf.math.confusion_matrix(nn.test_labels.flatten(), nn.test_preds.flatten())) - prob_s.append(nn.test_probs.flatten()) - if labels is None: # These should be the same - labels = nn.test_labels.flatten() - - num = len(cm_s) - cm_avg = cm_s[0] - prob_avg = prob_s[0] - for k in range(num-1): - cm_avg += cm_s[k+1] - prob_avg += prob_s[k+1] - cm_avg /= num - prob_avg /= num - - return labels, prob_avg, cm_avg - - -def run_evaluate_static_avg(data_dct, ll, cc, ckpt_dir_s_path, day_night='DAY', flight_level=4, - use_flight_altitude=False, prob_thresh=0.5, - satellite='GOES16', domain='FD'): - num_elems = len(cc) - num_lines = len(ll) - cc = np.array(cc) - ll = np.array(ll) - - ckpt_dir_s = os.listdir(ckpt_dir_s_path) - - nav = get_navigation(satellite, domain) - - prob_s = [] - for ckpt in ckpt_dir_s: - ckpt_dir = ckpt_dir_s_path + ckpt - if not os.path.isdir(ckpt_dir): - continue - nn = IcingIntensityFCN(day_night=day_night, use_flight_altitude=use_flight_altitude) - nn.flight_level = flight_level - nn.setup_eval_pipeline(data_dct, num_lines * num_elems) - nn.build_model() - nn.build_training() - nn.build_evaluation() - nn.do_evaluate(ckpt_dir) - prob_s.append(nn.test_probs) - - num = len(prob_s) - prob_avg = prob_s[0] - for k in range(num-1): - prob_avg += prob_s[k+1] - prob_avg /= num - probs = prob_avg - - if NumClasses == 2: - preds = np.where(probs > prob_thresh, 1, 0) - else: - preds = np.argmax(probs, axis=1) - preds_2d = preds.reshape((num_lines, num_elems)) - - ll, cc = np.meshgrid(ll, cc, indexing='ij') - cc = cc.flatten() - ll = ll.flatten() - - ice_mask = preds == 1 - ice_cc = cc[ice_mask] - ice_ll = ll[ice_mask] - - ice_lons, ice_lats = nav.lc_to_earth(ice_cc, ice_ll) - - return ice_lons, ice_lats, preds_2d - - -# def run_evaluate_static_fcn(data_dct, ckpt_dir_s_path, day_night='DAY', l1b_or_l2='both', prob_thresh=0.5, -# flight_levels=[0, 1, 2, 3, 4], use_flight_altitude=False): -# -# ckpt_dir_s = os.listdir(ckpt_dir_s_path) -# ckpt_dir = ckpt_dir_s_path + ckpt_dir_s[0] -# -# if not use_flight_altitude: -# flight_levels = [0] -# -# probs_dct = {flvl: None for flvl in flight_levels} -# preds_dct = {flvl: None for flvl in flight_levels} -# -# nn = IcingIntensityFCN(day_night=day_night, l1b_or_l2=l1b_or_l2, use_flight_altitude=use_flight_altitude) -# nn.num_data_samples = 1 -# nn.build_model() -# nn.build_training() -# nn.build_evaluation() -# -# ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=nn.model) -# ckpt_manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep=3) -# ckpt.restore(ckpt_manager.latest_checkpoint) -# -# for flvl in flight_levels: -# nn.flight_level = flvl -# nn.setup_eval_pipeline(data_dct) -# nn.do_evaluate(ckpt_dir) -# -# probs = nn.test_probs -# if NumClasses == 2: -# preds = np.where(probs > prob_thresh, 1, 0) -# else: -# preds = np.argmax(probs, axis=1) -# -# probs_dct[flvl] = probs -# preds_dct[flvl] = preds -# -# return preds_dct, probs_dct - - -def run_evaluate_static_fcn(data_dct, num_tiles, ckpt_dir_s_path, day_night='DAY', l1b_or_l2='both', prob_thresh=0.5, - flight_levels=[0, 1, 2, 3, 4], use_flight_altitude=False): - - ckpt_dir_s = os.listdir(ckpt_dir_s_path) - ckpt_dir = ckpt_dir_s_path + ckpt_dir_s[0] - - if not use_flight_altitude: - flight_levels = [0] - - probs_dct = {flvl: None for flvl in flight_levels} - preds_dct = {flvl: None for flvl in flight_levels} - - nn = IcingIntensityFCN(day_night=day_night, l1b_or_l2=l1b_or_l2, use_flight_altitude=use_flight_altitude) - nn.num_data_samples = num_tiles - nn.build_model() - nn.build_training() - nn.build_evaluation() - - ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=nn.model) - ckpt_manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep=3) - ckpt.restore(ckpt_manager.latest_checkpoint) - - for flvl in flight_levels: - nn.flight_level = flvl - nn.setup_eval_pipeline(data_dct, num_tiles) - nn.do_evaluate(ckpt_dir) - - probs = nn.test_probs - if NumClasses == 2: - preds = np.where(probs > prob_thresh, 1, 0) - else: - preds = np.argmax(probs, axis=1) - - probs_dct[flvl] = probs - preds_dct[flvl] = preds - - return preds_dct, probs_dct - - if __name__ == "__main__": - nn = IcingIntensityFCN() + nn = UNET() nn.run('matchup_filename')