Skip to content
Snippets Groups Projects
cloudheight.py 31.02 KiB
import tensorflow as tf
from util.setup import logdir, modeldir, cachepath
import subprocess

import os, datetime
import numpy as np
import xarray as xr
import pickle

from deeplearning.amv_raob import get_bounding_gfs_files, convert_file, get_images, get_interpolated_profile, \
    split_matchup, shuffle_dict, get_interpolated_scalar, get_num_samples, get_time_tuple_utc, get_profile

LOG_DEVICE_PLACEMENT = False

CACHE_DATA_IN_MEM = True
CACHE_GFS = True

PROC_BATCH_SIZE = 60
PROC_BATCH_BUFFER_SIZE = 50000
NumLabels = 1
BATCH_SIZE = 256
NUM_EPOCHS = 200


TRACK_MOVING_AVERAGE = False

DAY_NIGHT = 'ANY'

TRIPLET = False
CONV3D = False

abi_2km_channels = ['14', '08', '11', '13', '15', '16']
# abi_2km_channels = ['08', '09', '10']
abi_hkm_channels = []
# abi_channels = abi_2km_channels + abi_hkm_channels
abi_channels = abi_2km_channels

abi_mean = {'08': 236.014, '14': 275.229, '02': 0.049, '11': 273.582, '13': 275.796, '15': 272.928, '16': 260.956, '09': 244.502, '10': 252.375}
abi_std = {'08': 7.598, '14': 20.443, '02': 0.082, '11': 19.539, '13': 20.431, '15': 20.104, '16': 15.720, '09': 9.827, '10': 11.765}
abi_valid_range = {'02': [0.001, 120], '08': [150, 350], '14': [150, 350], '11': [150, 350], '13': [150, 350], '15': [150, 350], '16': [150, 350], '09': [150, 350], '10': [150, 350]}
abi_half_width = {'08': 12, '14': 12, '02': 48, '11': 12, '13': 12, '15': 12, '16': 12, '09': 12, '10': 12}
#abi_half_width = {'08': 6, '14': 6, '02': 24, '11': 6, '13': 6, '15': 6, '16': 6, '09': 6, '10': 6}
#abi_half_width = {'08': 3, '14': 3, '02': 12, '11': 3, '13': 3, '15': 3, '16': 3, '09': 3, '10': 3}
abi_stride = {'08': 1, '14': 1, '02': 4, '11': 1, '13': 1, '15': 1, '16': 1, '09': 1, '10': 1}
img_width = 24
#img_width = 12
#img_width = 6

NUM_VERT_LEVELS = 26
NUM_VERT_PARAMS = 2

gfs_mean_temp = [225.481110,
                 218.950729,
                 215.830338,
                 212.063187,
                 209.348038,
                 208.787033,
                 213.728928,
                 218.298264,
                 223.061020,
                 229.190445,
                 236.095215,
                 242.589493,
                 248.333237,
                 253.357071,
                 257.768646,
                 261.599396,
                 264.793671,
                 267.667603,
                 270.408478,
                 272.841919,
                 274.929138,
                 276.826294,
                 277.786865,
                 278.834198,
                 279.980408,
                 281.308380]
gfs_mean_temp = np.array(gfs_mean_temp)
gfs_mean_temp = np.reshape(gfs_mean_temp, (1, gfs_mean_temp.shape[0]))

gfs_std_temp = [13.037852,
                11.669035,
                10.775956,
                10.428216,
                11.705231,
                12.352798,
                8.892235,
                7.101064,
                8.505628,
                10.815929,
                12.139559,
                12.720000,
                12.929382,
                13.023590,
                13.135534,
                13.543551,
                14.449997,
                15.241049,
                15.638563,
                15.943666,
                16.178715,
                16.458992,
                16.700863,
                17.109579,
                17.630177,
                18.080544]
gfs_std_temp = np.array(gfs_std_temp)
gfs_std_temp = np.reshape(gfs_std_temp, (1, gfs_std_temp.shape[0]))

mean_std_dict = {'temperature': (gfs_mean_temp, gfs_std_temp), 'surface temperature': (279.35, 22.81),
                 'MSL pressure': (1010.64, 13.46), 'tropopause temperature': (208.17, 11.36), 'tropopause pressure': (219.62, 78.79)}

valid_range_dict = {'temperature': (150, 350), 'surface temperature': (150, 350), 'MSL pressure': (800, 1050),
                    'tropopause temperature': (150, 250), 'tropopause pressure': (100, 500)}


def build_residual_block(input, drop_rate, num_neurons, activation, block_name, doDropout=True, doBatchNorm=True):
    with tf.name_scope(block_name):
        if doDropout:
            fc = tf.keras.layers.Dropout(drop_rate)(input)
            fc = tf.keras.layers.Dense(num_neurons, activation=activation)(fc)
        else:
            fc = tf.keras.layers.Dense(num_neurons, activation=activation)(input)
        if doBatchNorm:
            fc = tf.keras.layers.BatchNormalization()(fc)
        print(fc.shape)
        fc_skip = fc

        if doDropout:
            fc = tf.keras.layers.Dropout(drop_rate)(fc)
        fc = tf.keras.layers.Dense(num_neurons, activation=activation)(fc)
        if doBatchNorm:
            fc = tf.keras.layers.BatchNormalization()(fc)
        print(fc.shape)

        if doDropout:
            fc = tf.keras.layers.Dropout(drop_rate)(fc)
        fc = tf.keras.layers.Dense(num_neurons, activation=activation)(fc)
        if doBatchNorm:
            fc = tf.keras.layers.BatchNormalization()(fc)
        print(fc.shape)

        if doDropout:
            fc = tf.keras.layers.Dropout(drop_rate)(fc)
        fc = tf.keras.layers.Dense(num_neurons, activation=None)(fc)
        if doBatchNorm:
            fc = tf.keras.layers.BatchNormalization()(fc)

        fc = fc + fc_skip
        fc = tf.keras.layers.LeakyReLU()(fc)
        print(fc.shape)

    return fc


class CloudHeightNN:
    
    def __init__(self, gpu_device=0, datapath=None):
        self.train_data = None
        self.train_label = None
        self.test_data = None
        self.test_label = None
        self.test_data_denorm = None
        
        self.train_dataset = None
        self.inner_train_dataset = None
        self.test_dataset = None
        self.X_img = None
        self.X_prof = None
        self.X_u = None
        self.X_v = None
        self.X_sfc = None
        self.inputs = []
        self.y = None
        self.handle = None
        self.inner_handle = None
        self.in_mem_batch = None
        self.matchup_dict = None

        self.logits = None

        self.predict_data = None
        self.predict_dataset = None
        self.mean_list = None
        self.std_list = None
        
        self.training_op = None
        self.correct = None
        self.accuracy = None
        self.loss = None
        self.pred_class = None
        self.gpu_device = gpu_device
        self.variable_averages = None

        self.global_step = None

        self.writer_train = None
        self.writer_valid = None

        self.OUT_OF_RANGE = False

        self.abi = None
        self.temp = None
        self.wv = None
        self.lbfp = None
        self.sfc = None

        self.in_mem_data_cache = {}

        self.model = None
        self.optimizer = None
        self.train_loss = None
        self.train_accuracy = None
        self.test_loss = None
        self.test_accuracy = None

        self.accuracy_0 = None
        self.accuracy_1 = None
        self.accuracy_2 = None
        self.accuracy_3 = None
        self.accuracy_4 = None
        self.accuracy_5 = None

        self.num_0 = 0
        self.num_1 = 0
        self.num_2 = 0
        self.num_3 = 0
        self.num_4 = 0
        self.num_5 = 0

        self.learningRateSchedule = None
        self.num_data_samples = None
        self.initial_learning_rate = None

        n_chans = len(abi_channels)
        if TRIPLET:
            n_chans *= 3
        self.X_img = tf.keras.Input(shape=(img_width, img_width, n_chans))
        self.X_prof = tf.keras.Input(shape=(NUM_VERT_LEVELS, NUM_VERT_PARAMS))
        self.X_sfc = tf.keras.Input(shape=2)

        self.inputs.append(self.X_img)
        self.inputs.append(self.X_prof)
        self.inputs.append(self.X_sfc)

        self.DISK_CACHE = True

        if datapath is not None:
            self.DISK_CACHE = False
            f = open(datapath, 'rb')
            self.in_mem_data_cache = pickle.load(f)
            f.close()

        tf.debugging.set_log_device_placement(LOG_DEVICE_PLACEMENT)

        gpus = tf.config.experimental.list_physical_devices('GPU')
        if gpus:
            try:
                # Currently, memory growth needs to be the same across GPUs
                for gpu in gpus:
                    tf.config.experimental.set_memory_growth(gpu, True)
                logical_gpus = tf.config.experimental.list_logical_devices('GPU')
                print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
            except RuntimeError as e:
                # Memory growth must be set before GPUs have been initialized
                print(e)

    def get_in_mem_data_batch(self, time_keys):
        images = []
        vprof = []
        label = []
        sfc = []

        for key in time_keys:
            if CACHE_DATA_IN_MEM:
                tup = self.in_mem_data_cache.get(key)
                if tup is not None:
                    images.append(tup[0])
                    vprof.append(tup[1])
                    label.append(tup[2])
                    sfc.append(tup[3])
                    continue

            obs = self.matchup_dict.get(key)
            if obs is None:
                print('no entry for: ', key)
            timestamp = obs[0][0]
            print('not found in cache, processing key: ', key, get_time_tuple_utc(timestamp)[0])

            gfs_0, time_0, gfs_1, time_1 = get_bounding_gfs_files(timestamp)
            if (gfs_0 is None) and (gfs_1 is None):
                print('no GFS for: ', get_time_tuple_utc(timestamp)[0])
                continue
            try:
                gfs_0 = convert_file(gfs_0)
                if gfs_1 is not None:
                    gfs_1 = convert_file(gfs_1)
            except Exception as exc:
                print(get_time_tuple_utc(timestamp)[0])
                print(exc)
                continue

            ds_1 = None
            try:
                ds_0 = xr.open_dataset(gfs_0)
                if gfs_1 is not None:
                    ds_1 = xr.open_dataset(gfs_1)
            except Exception as exc:
                print(exc)
                continue

            lons = obs[:, 2]
            lats = obs[:, 1]

            half_width = [abi_half_width.get(ch) for ch in abi_2km_channels]
            strides = [abi_stride.get(ch) for ch in abi_2km_channels]

            img_a_s, img_a_s_l, img_a_s_r, idxs_a = get_images(lons, lats, timestamp, abi_2km_channels, half_width, strides, do_norm=True, daynight=DAY_NIGHT)
            if idxs_a.size == 0:
                print('no images for: ', timestamp)
                continue

            idxs_b = None
            if len(abi_hkm_channels) > 0:
                half_width = [abi_half_width.get(ch) for ch in abi_hkm_channels]
                strides = [abi_stride.get(ch) for ch in abi_hkm_channels]

                img_b_s, img_b_s_l, img_b_s_r, idxs_b = get_images(lons, lats, timestamp, abi_hkm_channels, half_width, strides, do_norm=True, daynight=DAY_NIGHT)
                if idxs_b.size == 0:
                    print('no hkm images for: ', timestamp)
                    continue

            if idxs_b is None:
                common_idxs = idxs_a
                img_a_s = img_a_s[:, common_idxs, :, :]
                img_s = img_a_s
                if TRIPLET:
                    img_a_s_l = img_a_s_l[:, common_idxs, :, :]
                    img_a_s_r = img_a_s_r[:, common_idxs, :, :]
                    img_s_l = img_a_s_l
                    img_s_r = img_a_s_r
            else:
                common_idxs = np.intersect1d(idxs_a, idxs_b)
                img_a_s = img_a_s[:, common_idxs, :, :]
                img_b_s = img_b_s[:, common_idxs, :, :]
                img_s = np.vstack([img_a_s, img_b_s])
                # TODO: Triplet support

            lons = lons[common_idxs]
            lats = lats[common_idxs]

            if ds_1 is not None:
                ndb = get_interpolated_profile(ds_0, ds_1, time_0, time_1, 'temperature', timestamp, lons, lats, do_norm=True)
            else:
                ndb = get_profile(ds_0, 'temperature', lons, lats, do_norm=True)
            if ndb is None:
                continue

            if ds_1 is not None:
                ndf = get_interpolated_profile(ds_0, ds_1, time_0, time_1, 'rh', timestamp, lons, lats, do_norm=False)
            else:
                ndf = get_profile(ds_0, 'rh', lons, lats, do_norm=False)
            if ndf is None:
                continue
            ndf /= 100.0
            ndb = np.stack((ndb, ndf), axis=2)

            #ndd = get_interpolated_scalar(ds_0, ds_1, time_0, time_1, 'MSL pressure', timestamp, lons, lats, do_norm=False)
            #ndd /= 1000.0

            #nde = get_interpolated_scalar(ds_0, ds_1, time_0, time_1, 'surface temperature', timestamp, lons, lats, do_norm=True)

            # label/truth
            # Level of best fit (LBF)
            ndc = obs[common_idxs, 3]
            # AMV Predicted
            # ndc = obs[common_idxs, 4]
            ndc /= 1000.0

            nda = np.transpose(img_s, axes=[1, 2, 3, 0])
            if TRIPLET or CONV3D:
                nda_l = np.transpose(img_s_l, axes=[1, 2, 3, 0])
                nda_r = np.transpose(img_s_r, axes=[1, 2, 3, 0])
                if CONV3D:
                    nda = np.stack((nda_l, nda, nda_r), axis=4)
                    nda = np.transpose(nda, axes=[0, 1, 2, 4, 3])
                else:
                    nda = np.concatenate([nda, nda_l, nda_r], axis=3)

            images.append(nda)
            vprof.append(ndb)
            label.append(ndc)
            # nds = np.stack([ndd, nde], axis=1)
            nds = np.zeros((len(lons), 2))
            sfc.append(nds)

            if not CACHE_GFS:
                subprocess.call(['rm', gfs_0, gfs_1])

            if CACHE_DATA_IN_MEM:
                self.in_mem_data_cache[key] = (nda, ndb, ndc, nds)

            ds_0.close()
            if ds_1 is not None:
               ds_1.close()

        images = np.concatenate(images)

        label = np.concatenate(label)
        label = np.reshape(label, (label.shape[0], 1))

        vprof = np.concatenate(vprof)

        sfc = np.concatenate(sfc)

        return images, vprof, label, sfc

    @tf.function(input_signature=[tf.TensorSpec(None, tf.int32)])
    def data_function(self, input):
        out = tf.numpy_function(self.get_in_mem_data_batch, [input], [tf.float32, tf.float64, tf.float64, tf.float64])
        return out

    def get_train_dataset(self, time_keys):
        time_keys = list(time_keys)

        dataset = tf.data.Dataset.from_tensor_slices(time_keys)
        dataset = dataset.batch(PROC_BATCH_SIZE)
        dataset = dataset.map(self.data_function, num_parallel_calls=8)
        dataset = dataset.shuffle(PROC_BATCH_BUFFER_SIZE)
        dataset = dataset.prefetch(buffer_size=1)
        self.train_dataset = dataset

    def get_test_dataset(self, time_keys):
        time_keys = list(time_keys)

        dataset = tf.data.Dataset.from_tensor_slices(time_keys)
        dataset = dataset.batch(PROC_BATCH_SIZE)
        dataset = dataset.map(self.data_function, num_parallel_calls=8)
        self.test_dataset = dataset

    def setup_pipeline(self, matchup_dict, train_dict=None, valid_test_dict=None):
        self.matchup_dict = matchup_dict

        if train_dict is None:
            if valid_test_dict is not None:
                self.matchup_dict = valid_test_dict
                valid_keys = list(valid_test_dict.keys())
                self.get_test_dataset(valid_keys)
                self.num_data_samples = get_num_samples(valid_test_dict, valid_keys)
                print('num test samples: ', self.num_data_samples)
                print('setup_pipeline: Done')
                return

            train_dict, valid_test_dict = split_matchup(matchup_dict, perc=0.10)

        train_dict = shuffle_dict(train_dict)
        train_keys = list(train_dict.keys())

        self.get_train_dataset(train_keys)

        self.num_data_samples = get_num_samples(train_dict, train_keys)
        print('num data samples: ', self.num_data_samples)
        print('BATCH SIZE: ', BATCH_SIZE)

        valid_keys = list(valid_test_dict.keys())
        self.get_test_dataset(valid_keys)
        print('num test samples: ', get_num_samples(valid_test_dict, valid_keys))

        print('setup_pipeline: Done')

    def build_1d_cnn(self):
        print('build_1d_cnn')
        # padding = 'VALID'
        padding = 'SAME'

        # activation = tf.nn.relu
        # activation = tf.nn.elu
        activation = tf.nn.leaky_relu

        num_filters = 6

        conv = tf.keras.layers.Conv1D(num_filters, 5, strides=1, padding=padding)(self.inputs[1])
        conv = tf.keras.layers.MaxPool1D(padding=padding)(conv)
        print(conv)

        num_filters *= 2
        conv = tf.keras.layers.Conv1D(num_filters, 3, strides=1, padding=padding)(conv)
        conv = tf.keras.layers.MaxPool1D(padding=padding)(conv)
        print(conv)

        num_filters *= 2
        conv = tf.keras.layers.Conv1D(num_filters, 3, strides=1, padding=padding)(conv)
        conv = tf.keras.layers.MaxPool1D(padding=padding)(conv)
        print(conv)

        num_filters *= 2
        conv = tf.keras.layers.Conv1D(num_filters, 3, strides=1, padding=padding)(conv)
        conv = tf.keras.layers.MaxPool1D(padding=padding)(conv)
        print(conv)

        flat = tf.keras.layers.Flatten()(conv)
        print(flat)

        return flat

    def build_cnn(self):
        print('build_cnn')
        # padding = "VALID"
        padding = "SAME"

        # activation = tf.nn.relu
        # activation = tf.nn.elu
        activation = tf.nn.leaky_relu
        momentum = 0.99

        num_filters = 8

        conv = tf.keras.layers.Conv2D(num_filters, 5, strides=[1, 1], padding=padding, activation=activation)(self.inputs[0])
        conv = tf.keras.layers.MaxPool2D(padding=padding)(conv)
        conv = tf.keras.layers.BatchNormalization()(conv)
        print(conv.shape)

        num_filters *= 2
        conv = tf.keras.layers.Conv2D(num_filters, 3, strides=[1, 1], padding=padding, activation=activation)(conv)
        conv = tf.keras.layers.MaxPool2D(padding=padding)(conv)
        conv = tf.keras.layers.BatchNormalization()(conv)
        print(conv.shape)

        num_filters *= 2
        conv = tf.keras.layers.Conv2D(num_filters, 3, strides=[1, 1], padding=padding, activation=activation)(conv)
        conv = tf.keras.layers.MaxPool2D(padding=padding)(conv)
        conv = tf.keras.layers.BatchNormalization()(conv)
        print(conv.shape)

        num_filters *= 2
        conv = tf.keras.layers.Conv2D(num_filters, 3, strides=[1, 1], padding=padding, activation=activation)(conv)
        conv = tf.keras.layers.MaxPool2D(padding=padding)(conv)
        conv = tf.keras.layers.BatchNormalization()(conv)
        print(conv.shape)

        num_filters *= 2
        conv = tf.keras.layers.Conv2D(num_filters, 3, strides=[1, 1], padding=padding, activation=activation)(conv)
        conv = tf.keras.layers.MaxPool2D(padding=padding)(conv)
        conv = tf.keras.layers.BatchNormalization()(conv)
        print(conv.shape)

        flat = tf.keras.layers.Flatten()(conv)

        return flat

    def build_anc_dnn(self):
        print('build_anc_dnn')
        drop_rate = 0.5

        # activation = tf.nn.relu
        # activation = tf.nn.elu
        activation = tf.nn.leaky_relu
        momentum = 0.99

        n_hidden = self.X_sfc.shape[1]

        with tf.name_scope("Residual_Block_6"):
            fc = tf.keras.layers.Dropout(drop_rate)(self.inputs[2])
            fc = tf.keras.layers.Dense(4*n_hidden, activation=activation)(fc)
            fc = tf.keras.layers.BatchNormalization()(fc)
            print(fc.shape)
            fc_skip = fc

            fc = tf.keras.layers.Dropout(drop_rate)(fc)
            fc = tf.keras.layers.Dense(4*n_hidden, activation=activation)(fc)
            fc = tf.keras.layers.BatchNormalization()(fc)
            print(fc.shape)

            fc = tf.keras.layers.Dropout(drop_rate)(fc)
            fc = tf.keras.layers.Dense(4*n_hidden, activation=None)(fc)
            fc = tf.keras.layers.BatchNormalization()(fc)
            fc = fc + fc_skip
            fc = tf.keras.layers.LeakyReLU()(fc)
            print(fc.shape)

        return fc

    def build_dnn(self, input_layer=None):
        print('build fully connected layer')
        drop_rate = 0.5

        # activation = tf.nn.relu
        # activation = tf.nn.elu
        activation = tf.nn.leaky_relu
        momentum = 0.99
        
        if input_layer is not None:
            flat = input_layer
            n_hidden = input_layer.shape[1]
        else:
            flat = self.X_img
            n_hidden = self.X_img.shape[1]

        fac = 1

        fc = build_residual_block(flat, drop_rate, fac*n_hidden, activation, 'Residual_Block_1')

        fc = build_residual_block(fc, drop_rate, fac*n_hidden, activation, 'Residual_Block_2')

        fc = build_residual_block(fc, drop_rate, fac*n_hidden, activation, 'Residual_Block_3')

        fc = build_residual_block(fc, drop_rate, fac*n_hidden, activation, 'Residual_Block_4')

        fc = build_residual_block(fc, drop_rate, fac*n_hidden, activation, 'Residual_Block_5')

        fc = tf.keras.layers.Dense(n_hidden, activation=activation)(fc)
        fc = tf.keras.layers.BatchNormalization()(fc)
        print(fc.shape)

        logits = tf.keras.layers.Dense(NumLabels)(fc)
        print(logits.shape)
        
        self.logits = logits

    def build_training(self):
        self.loss = tf.keras.losses.MeanSquaredError()

        # decayed_learning_rate = learning_rate * decay_rate ^ (global_step / decay_steps)
        initial_learning_rate = 0.0016
        decay_rate = 0.95
        steps_per_epoch = int(self.num_data_samples/BATCH_SIZE)  # one epoch
        # decay_steps = int(steps_per_epoch / 2)
        decay_steps = 2 * steps_per_epoch
        print('initial rate, decay rate, steps/epoch, decay steps: ', initial_learning_rate, decay_rate, steps_per_epoch, decay_steps)

        self.learningRateSchedule = tf.keras.optimizers.schedules.ExponentialDecay(initial_learning_rate, decay_steps, decay_rate)

        optimizer = tf.keras.optimizers.Adam(learning_rate=self.learningRateSchedule)

        if TRACK_MOVING_AVERAGE:
            ema = tf.train.ExponentialMovingAverage(decay=0.999)

            with tf.control_dependencies([optimizer]):
                optimizer = ema.apply(self.model.trainable_variables)

        self.optimizer = optimizer
        self.initial_learning_rate = initial_learning_rate

    def build_evaluation(self):
        self.train_accuracy = tf.keras.metrics.MeanAbsoluteError(name='train_accuracy')
        self.test_accuracy = tf.keras.metrics.MeanAbsoluteError(name='test_accuracy')
        self.train_loss = tf.keras.metrics.Mean(name='train_loss')
        self.test_loss = tf.keras.metrics.Mean(name='test_loss')

        self.accuracy_0 = tf.keras.metrics.MeanAbsoluteError(name='acc_0')
        self.accuracy_1 = tf.keras.metrics.MeanAbsoluteError(name='acc_1')
        self.accuracy_2 = tf.keras.metrics.MeanAbsoluteError(name='acc_2')
        self.accuracy_3 = tf.keras.metrics.MeanAbsoluteError(name='acc_3')
        self.accuracy_4 = tf.keras.metrics.MeanAbsoluteError(name='acc_4')
        self.accuracy_5 = tf.keras.metrics.MeanAbsoluteError(name='acc_5')

    def build_predict(self):
        _, pred = tf.nn.top_k(self.logits)
        self.pred_class = pred

        if TRACK_MOVING_AVERAGE:
            self.variable_averages = tf.train.ExponentialMovingAverage(0.999, self.global_step)
            self.variable_averages.apply(self.model.trainable_variables)

    @tf.function
    def train_step(self, mini_batch):
        inputs = [mini_batch[0], mini_batch[1], mini_batch[3]]
        labels = mini_batch[2]
        with tf.GradientTape() as tape:
            pred = self.model(inputs, training=True)
            loss = self.loss(labels, pred)
            total_loss = loss
            if len(self.model.losses) > 0:
                reg_loss = tf.math.add_n(self.model.losses)
                total_loss = loss + reg_loss
        gradients = tape.gradient(total_loss, self.model.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))

        self.train_loss(loss)
        self.train_accuracy(labels, pred)

        return loss

    @tf.function
    def test_step(self, mini_batch):
        inputs = [mini_batch[0], mini_batch[1], mini_batch[3]]
        labels = mini_batch[2]
        pred = self.model(inputs, training=False)
        t_loss = self.loss(labels, pred)

        self.test_loss(t_loss)
        self.test_accuracy(labels, pred)

    def predict(self, mini_batch):
        inputs = [mini_batch[0], mini_batch[1], mini_batch[3]]
        labels = mini_batch[2]
        pred = self.model(inputs, training=False)
        t_loss = self.loss(labels, pred)

        self.test_loss(t_loss)
        self.test_accuracy(labels, pred)

        m = np.logical_and(labels >= 0.01, labels < 0.2)
        self.num_0 += np.sum(m)
        self.accuracy_0(labels[m], pred[m])

        m = np.logical_and(labels >= 0.2, labels < 0.4)
        self.num_1 += np.sum(m)
        self.accuracy_1(labels[m], pred[m])

        m = np.logical_and(labels >= 0.4, labels < 0.6)
        self.num_2 += np.sum(m)
        self.accuracy_2(labels[m], pred[m])

        m = np.logical_and(labels >= 0.6, labels < 0.8)
        self.num_3 += np.sum(m)
        self.accuracy_3(labels[m], pred[m])

        m = np.logical_and(labels >= 0.8, labels < 1.15)
        self.num_4 += np.sum(m)
        self.accuracy_4(labels[m], pred[m])

        m = np.logical_and(labels >= 0.01, labels < 0.5)
        self.num_5 += np.sum(m)
        self.accuracy_5(labels[m], pred[m])

    def do_training(self, ckpt_dir=None):

        if ckpt_dir is None:
            if not os.path.exists(modeldir):
                os.mkdir(modeldir)
            ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model)
            ckpt_manager = tf.train.CheckpointManager(ckpt, modeldir, max_to_keep=3)
        else:
            ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model)
            ckpt_manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep=3)

        self.writer_train = tf.summary.create_file_writer(os.path.join(logdir, 'plot_train'))
        self.writer_valid = tf.summary.create_file_writer(os.path.join(logdir, 'plot_valid'))

        step = 0
        total_time = 0

        for epoch in range(NUM_EPOCHS):
            self.train_loss.reset_states()
            self.train_accuracy.reset_states()

            t0 = datetime.datetime.now().timestamp()

            proc_batch_cnt = 0
            n_samples = 0

            for abi, temp, lbfp, sfc in self.train_dataset:
                trn_ds = tf.data.Dataset.from_tensor_slices((abi, temp, lbfp, sfc))
                trn_ds = trn_ds.batch(BATCH_SIZE)
                for mini_batch in trn_ds:
                    if self.learningRateSchedule is not None:
                        loss = self.train_step(mini_batch)

                    if (step % 100) == 0:

                        with self.writer_train.as_default():
                            tf.summary.scalar('loss_trn', loss.numpy(), step=step)
                            tf.summary.scalar('num_train_steps', step, step=step)
                            tf.summary.scalar('num_epochs', epoch, step=step)

                        self.test_loss.reset_states()
                        self.test_accuracy.reset_states()

                        for abi_tst, temp_tst, lbfp_tst, sfc_tst in self.test_dataset:
                            tst_ds = tf.data.Dataset.from_tensor_slices((abi_tst, temp_tst, lbfp_tst, sfc_tst))
                            tst_ds = tst_ds.batch(BATCH_SIZE)
                            for mini_batch_test in tst_ds:
                                self.test_step(mini_batch_test)

                        with self.writer_valid.as_default():
                            tf.summary.scalar('loss_val', self.test_loss.result(), step=step)
                            tf.summary.scalar('acc_val', self.test_accuracy.result(), step=step)
                            tf.summary.scalar('num_train_steps', step, step=step)
                            tf.summary.scalar('num_epochs', epoch, step=step)

                        print('****** test loss, acc: ', self.test_loss.result(), self.test_accuracy.result())

                    step += 1
                    print('train loss: ', loss.numpy())

                proc_batch_cnt += 1
                n_samples += abi.shape[0]
                print('proc_batch_cnt: ', proc_batch_cnt, n_samples)

            t1 = datetime.datetime.now().timestamp()
            print('End of Epoch: ', epoch+1, 'elapsed time: ', (t1-t0))
            total_time += (t1-t0)

            self.test_loss.reset_states()
            self.test_accuracy.reset_states()
            for abi, temp, lbfp, sfc in self.test_dataset:
                ds = tf.data.Dataset.from_tensor_slices((abi, temp, lbfp, sfc))
                ds = ds.batch(BATCH_SIZE)
                for mini_batch in ds:
                    self.test_step(mini_batch)

            print('loss, acc: ', self.test_loss.result(), self.test_accuracy.result())
            ckpt_manager.save()

            if self.DISK_CACHE and epoch == 0:
                f = open(cachepath, 'wb')
                pickle.dump(self.in_mem_data_cache, f)
                f.close()

        print('total time: ', total_time)
        self.writer_train.close()
        self.writer_valid.close()

    def build_model(self):
        flat = self.build_cnn()
        flat_1d = self.build_1d_cnn()
        # flat_anc = self.build_anc_dnn()
        # flat = tf.keras.layers.concatenate([flat, flat_1d, flat_anc])
        flat = tf.keras.layers.concatenate([flat, flat_1d])
        self.build_dnn(flat)
        self.model = tf.keras.Model(self.inputs, self.logits)

    def restore(self, ckpt_dir):

        ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model)
        ckpt_manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep=3)

        ckpt.restore(ckpt_manager.latest_checkpoint)

        self.test_loss.reset_states()
        self.test_accuracy.reset_states()

        for abi_tst, temp_tst, lbfp_tst, sfc_tst in self.test_dataset:
            ds = tf.data.Dataset.from_tensor_slices((abi_tst, temp_tst, lbfp_tst, sfc_tst))
            ds = ds.batch(BATCH_SIZE)
            for mini_batch_test in ds:
                self.predict(mini_batch_test)
        print('loss, acc: ', self.test_loss.result(), self.test_accuracy.result())
        print('acc_0', self.num_0, self.accuracy_0.result())
        print('acc_1', self.num_1, self.accuracy_1.result())
        print('acc_2', self.num_2, self.accuracy_2.result())
        print('acc_3', self.num_3, self.accuracy_3.result())
        print('acc_4', self.num_4, self.accuracy_4.result())
        print('acc_5', self.num_5, self.accuracy_5.result())

    def run(self, matchup_dict, train_dict=None, valid_dict=None):
        with tf.device('/device:GPU:'+str(self.gpu_device)):
            self.setup_pipeline(matchup_dict, train_dict=train_dict, valid_test_dict=valid_dict)
            self.build_model()
            self.build_training()
            self.build_evaluation()
            self.do_training()

    def run_restore(self, matchup_dict, ckpt_dir):
        self.setup_pipeline(None, None, matchup_dict)
        self.build_model()
        self.build_training()
        self.build_evaluation()
        self.restore(ckpt_dir)


if __name__ == "__main__":
    nn = CloudHeightNN()
    nn.run('matchup_filename')