Skip to content
Snippets Groups Projects
Commit a92bca6f authored by tomrink's avatar tomrink
Browse files

snapshot...

parent 08ca488f
No related branches found
No related tags found
No related merge requests found
......@@ -215,9 +215,6 @@ class ESPCN:
# self.X_img = tf.keras.Input(shape=(32, 32, self.n_chans))
self.inputs.append(self.X_img)
self.inputs.append(tf.keras.Input(shape=(None, None, self.n_chans)))
# self.inputs.append(tf.keras.Input(shape=(36, 36, self.n_chans)))
# self.inputs.append(tf.keras.Input(shape=(32, 32, self.n_chans)))
self.DISK_CACHE = False
......@@ -271,7 +268,7 @@ class ESPCN:
data = np.concatenate([data, data_ud, data_lr])
label = np.concatenate([label, label_ud, label_lr])
return data, data, label
return data, label
def get_in_mem_data_batch_train(self, idxs):
return self.get_in_mem_data_batch(idxs, True)
......@@ -290,28 +287,22 @@ class ESPCN:
data = np.transpose(data, axes=(1, 2, 0))
data = np.expand_dims(data, axis=0)
nda = np.zeros([1])
nda[0] = self.flight_level
nda = tf.one_hot(nda, 5).numpy()
nda = np.expand_dims(nda, axis=0)
nda = np.expand_dims(nda, axis=0)
return data, nda
return data
@tf.function(input_signature=[tf.TensorSpec(None, tf.int32)])
def data_function(self, indexes):
out = tf.numpy_function(self.get_in_mem_data_batch_train, [indexes], [tf.float32, tf.float32, tf.float32])
out = tf.numpy_function(self.get_in_mem_data_batch_train, [indexes], [tf.float32, tf.float32])
return out
@tf.function(input_signature=[tf.TensorSpec(None, tf.int32)])
def data_function_test(self, indexes):
out = tf.numpy_function(self.get_in_mem_data_batch_test, [indexes], [tf.float32, tf.float32, tf.float32])
out = tf.numpy_function(self.get_in_mem_data_batch_test, [indexes], [tf.float32, tf.float32])
return out
@tf.function(input_signature=[tf.TensorSpec(None, tf.int32)])
def data_function_evaluate(self, indexes):
# TODO: modify for user specified altitude
out = tf.numpy_function(self.get_in_mem_data_batch_eval, [indexes], [tf.float32, tf.float32])
out = tf.numpy_function(self.get_in_mem_data_batch_eval, [indexes], [tf.float32])
return out
def get_train_dataset(self, indexes):
......@@ -364,28 +355,9 @@ class ESPCN:
print('num test samples: ', tst_idxs.shape[0])
print('setup_pipeline: Done')
def setup_test_pipeline(self, filename_l1b, filename_l2, seed=None, shuffle=False):
if filename_l1b is not None:
self.h5f_l1b_tst = h5py.File(filename_l1b, 'r')
if filename_l2 is not None:
self.h5f_l2_tst = h5py.File(filename_l2, 'r')
if self.h5f_l1b_tst is not None:
h5f = self.h5f_l1b_tst
else:
h5f = self.h5f_l2_tst
time = h5f['time']
tst_idxs = np.arange(time.shape[0])
self.num_data_samples = len(tst_idxs)
if seed is not None:
np.random.seed(seed)
if shuffle:
np.random.shuffle(tst_idxs)
self.get_test_dataset(tst_idxs)
print('num test samples: ', tst_idxs.shape[0])
def setup_test_pipeline(self, filename):
self.test_data_files = [filename]
self.get_test_dataset([0])
print('setup_test_pipeline: Done')
def setup_eval_pipeline(self, data_dct, num_tiles=1):
......@@ -509,8 +481,8 @@ class ESPCN:
@tf.function
def train_step(self, mini_batch):
inputs = [mini_batch[0], mini_batch[1]]
labels = mini_batch[2]
inputs = [mini_batch[0]]
labels = mini_batch[1]
with tf.GradientTape() as tape:
pred = self.model(inputs, training=True)
loss = self.loss(labels, pred)
......@@ -530,8 +502,8 @@ class ESPCN:
@tf.function
def test_step(self, mini_batch):
inputs = [mini_batch[0], mini_batch[1]]
labels = mini_batch[2]
inputs = [mini_batch[0]]
labels = mini_batch[1]
pred = self.model(inputs, training=False)
t_loss = self.loss(labels, pred)
......@@ -547,8 +519,8 @@ class ESPCN:
# self.test_false_pos(labels, pred)
def predict(self, mini_batch):
inputs = [mini_batch[0], mini_batch[1]]
labels = mini_batch[2]
inputs = [mini_batch[0]]
labels = mini_batch[1]
pred = self.model(inputs, training=False)
t_loss = self.loss(labels, pred)
......@@ -557,14 +529,6 @@ class ESPCN:
self.test_loss(t_loss)
self.test_accuracy(labels, pred)
# if NumClasses == 2:
# self.test_auc(labels, pred)
# self.test_recall(labels, pred)
# self.test_precision(labels, pred)
# self.test_true_neg(labels, pred)
# self.test_true_pos(labels, pred)
# self.test_false_neg(labels, pred)
# self.test_false_pos(labels, pred)
def reset_test_metrics(self):
self.test_loss.reset_states()
......@@ -628,8 +592,8 @@ class ESPCN:
proc_batch_cnt = 0
n_samples = 0
for data0, data1, label in self.train_dataset:
trn_ds = tf.data.Dataset.from_tensor_slices((data0, data1, label))
for data, label in self.train_dataset:
trn_ds = tf.data.Dataset.from_tensor_slices((data, label))
trn_ds = trn_ds.batch(BATCH_SIZE)
for mini_batch in trn_ds:
if self.learningRateSchedule is not None:
......@@ -644,8 +608,8 @@ class ESPCN:
tf.summary.scalar('num_epochs', epoch, step=step)
self.reset_test_metrics()
for data0_tst, data1_tst, label_tst in self.test_dataset:
tst_ds = tf.data.Dataset.from_tensor_slices((data0_tst, data1_tst, label_tst))
for data_tst, label_tst in self.test_dataset:
tst_ds = tf.data.Dataset.from_tensor_slices((data_tst, label_tst))
tst_ds = tst_ds.batch(BATCH_SIZE)
for mini_batch_test in tst_ds:
self.test_step(mini_batch_test)
......@@ -676,7 +640,7 @@ class ESPCN:
print('train loss: ', loss.numpy())
proc_batch_cnt += 1
n_samples += data0.shape[0]
n_samples += data.shape[0]
print('proc_batch_cnt: ', proc_batch_cnt, n_samples)
t1 = datetime.datetime.now().timestamp()
......@@ -684,8 +648,8 @@ class ESPCN:
total_time += (t1-t0)
self.reset_test_metrics()
for data0, data1, label in self.test_dataset:
ds = tf.data.Dataset.from_tensor_slices((data0, data1, label))
for data, label in self.test_dataset:
ds = tf.data.Dataset.from_tensor_slices((data, label))
ds = ds.batch(BATCH_SIZE)
for mini_batch in ds:
self.test_step(mini_batch)
......@@ -756,9 +720,8 @@ class ESPCN:
ds = ds.batch(BATCH_SIZE)
for mini_batch_test in ds:
self.predict(mini_batch_test)
f1, mcc = self.get_metrics()
print('loss, acc: ', self.test_loss.result().numpy(), self.test_accuracy.result().numpy(), self.test_recall.result().numpy(),
self.test_precision.result().numpy(), self.test_auc.result().numpy(), f1.numpy(), mcc.numpy())
print('loss, acc: ', self.test_loss.result().numpy(), self.test_accuracy.result().numpy())
labels = np.concatenate(self.test_labels)
self.test_labels = labels
......@@ -807,18 +770,13 @@ class ESPCN:
self.build_evaluation()
self.do_training()
def run_restore(self, filename_l1b, filename_l2, ckpt_dir):
self.setup_test_pipeline(filename_l1b, filename_l2)
def run_restore(self, filename, ckpt_dir):
self.setup_test_pipeline(filename)
self.build_model()
self.build_training()
self.build_evaluation()
self.restore(ckpt_dir)
if self.h5f_l1b_tst is not None:
self.h5f_l1b_tst.close()
if self.h5f_l2_tst is not None:
self.h5f_l2_tst.close()
def run_evaluate(self, filename, ckpt_dir):
data_dct, ll, cc = make_for_full_domain_predict(filename, name_list=self.train_params)
self.setup_eval_pipeline(data_dct, len(ll))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment