From 8fdaec386074626d914776df583b5e949afb5ae9 Mon Sep 17 00:00:00 2001
From: tomrink <rink@ssec.wisc.edu>
Date: Thu, 21 Jan 2021 10:20:29 -0600
Subject: [PATCH] initial commit
---
modules/deeplearning/cloudheight.py | 34 +++++++++++++++++++++--------
1 file changed, 25 insertions(+), 9 deletions(-)
diff --git a/modules/deeplearning/cloudheight.py b/modules/deeplearning/cloudheight.py
index eacd2fc3..c77ac2f5 100644
--- a/modules/deeplearning/cloudheight.py
+++ b/modules/deeplearning/cloudheight.py
@@ -20,6 +20,7 @@ PROC_BATCH_SIZE = 40
NumLabels = 1
BATCH_SIZE = 384
NUM_EPOCHS = 200
+GLOBAL_BATCH_SIZE = BATCH_SIZE * 3
TRACK_MOVING_AVERAGE = False
@@ -261,6 +262,8 @@ class CloudHeightNN:
# Memory growth must be set before GPUs have been initialized
print(e)
+ self.strategy = tf.distribute.MirroredStrategy()
+
def get_in_mem_data_batch(self, time_keys):
images = []
vprof = []
@@ -422,7 +425,7 @@ class CloudHeightNN:
dataset = tf.data.Dataset.from_tensor_slices(time_keys)
dataset = dataset.batch(PROC_BATCH_SIZE)
dataset = dataset.map(self.data_function, num_parallel_calls=8)
- dataset = dataset.shuffle(PROC_BATCH_SIZE*BATCH_SIZE)
+ dataset = dataset.shuffle(PROC_BATCH_SIZE*GLOBAL_BATCH_SIZE)
dataset = dataset.prefetch(buffer_size=1)
self.train_dataset = dataset
@@ -614,7 +617,7 @@ class CloudHeightNN:
self.logits = logits
def build_training(self):
- self.loss = tf.keras.losses.MeanSquaredError()
+ self.loss = tf.keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE)
# decayed_learning_rate = learning_rate * decay_rate ^ (global_step / decay_steps)
initial_learning_rate = 0.0016
@@ -665,6 +668,7 @@ class CloudHeightNN:
with tf.GradientTape() as tape:
pred = self.model(inputs, training=True)
loss = self.loss(labels, pred)
+ loss = tf.nn.compute_average_loss(loss, global_batch_size=GLOBAL_BATCH_SIZE)
total_loss = loss
if len(self.model.losses) > 0:
reg_loss = tf.math.add_n(self.model.losses)
@@ -683,10 +687,20 @@ class CloudHeightNN:
labels = mini_batch[2]
pred = self.model(inputs, training=False)
t_loss = self.loss(labels, pred)
+ t_loss = tf.nn.compute_average_loss(t_loss, global_batch_size=GLOBAL_BATCH_SIZE)
self.test_loss(t_loss)
self.test_accuracy(labels, pred)
+ @tf.function
+ def distributed_train_step(self, dataset_inputs):
+ per_replica_losses = self.strategy.run(self.train_step, args=(dataset_inputs,))
+ return self.strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
+
+ @tf.function
+ def distributed_test_step(self, dataset_inputs):
+ return self.strategy.run(self.test_step, args=(dataset_inputs,))
+
def predict(self, mini_batch):
inputs = [mini_batch[0], mini_batch[1], mini_batch[3]]
labels = mini_batch[2]
@@ -748,10 +762,11 @@ class CloudHeightNN:
for abi, temp, lbfp, sfc in self.train_dataset:
trn_ds = tf.data.Dataset.from_tensor_slices((abi, temp, lbfp, sfc))
- trn_ds = trn_ds.batch(BATCH_SIZE)
- for mini_batch in trn_ds:
+ trn_ds = trn_ds.batch(GLOBAL_BATCH_SIZE)
+ trn_dist_ds = self.strategy.experimental_distribute_dataset(trn_ds)
+ for mini_batch in trn_dist_ds:
if self.learningRateSchedule is not None:
- loss = self.train_step(mini_batch)
+ loss = self.distributed_train_step(mini_batch)
if (step % 100) == 0:
@@ -765,9 +780,10 @@ class CloudHeightNN:
for abi_tst, temp_tst, lbfp_tst, sfc_tst in self.test_dataset:
tst_ds = tf.data.Dataset.from_tensor_slices((abi_tst, temp_tst, lbfp_tst, sfc_tst))
- tst_ds = tst_ds.batch(BATCH_SIZE)
- for mini_batch_test in tst_ds:
- self.test_step(mini_batch_test)
+ tst_ds = tst_ds.batch(GLOBAL_BATCH_SIZE)
+ tst_dist_ds = self.strategy.experimental_distribute_dataset(tst_ds)
+ for mini_batch_test in tst_dist_ds:
+ self.distributed_test_step(mini_batch_test)
with self.writer_valid.as_default():
tf.summary.scalar('loss_val', self.test_loss.result(), step=step)
@@ -841,7 +857,7 @@ class CloudHeightNN:
print('acc_5', self.num_5, self.accuracy_5.result())
def run(self, matchup_dict, train_dict=None, valid_dict=None):
- with tf.device('/device:GPU:'+str(self.gpu_device)):
+ with self.strategy.scope():
self.setup_pipeline(matchup_dict, train_dict=train_dict, valid_test_dict=valid_dict)
self.build_model()
self.build_training()
--
GitLab