Skip to content
Snippets Groups Projects
Commit a92bca6f authored by tomrink's avatar tomrink
Browse files

snapshot...

parent 08ca488f
Branches
No related tags found
No related merge requests found
...@@ -215,9 +215,6 @@ class ESPCN: ...@@ -215,9 +215,6 @@ class ESPCN:
# self.X_img = tf.keras.Input(shape=(32, 32, self.n_chans)) # self.X_img = tf.keras.Input(shape=(32, 32, self.n_chans))
self.inputs.append(self.X_img) self.inputs.append(self.X_img)
self.inputs.append(tf.keras.Input(shape=(None, None, self.n_chans)))
# self.inputs.append(tf.keras.Input(shape=(36, 36, self.n_chans)))
# self.inputs.append(tf.keras.Input(shape=(32, 32, self.n_chans)))
self.DISK_CACHE = False self.DISK_CACHE = False
...@@ -271,7 +268,7 @@ class ESPCN: ...@@ -271,7 +268,7 @@ class ESPCN:
data = np.concatenate([data, data_ud, data_lr]) data = np.concatenate([data, data_ud, data_lr])
label = np.concatenate([label, label_ud, label_lr]) label = np.concatenate([label, label_ud, label_lr])
return data, data, label return data, label
def get_in_mem_data_batch_train(self, idxs): def get_in_mem_data_batch_train(self, idxs):
return self.get_in_mem_data_batch(idxs, True) return self.get_in_mem_data_batch(idxs, True)
...@@ -290,28 +287,22 @@ class ESPCN: ...@@ -290,28 +287,22 @@ class ESPCN:
data = np.transpose(data, axes=(1, 2, 0)) data = np.transpose(data, axes=(1, 2, 0))
data = np.expand_dims(data, axis=0) data = np.expand_dims(data, axis=0)
nda = np.zeros([1]) return data
nda[0] = self.flight_level
nda = tf.one_hot(nda, 5).numpy()
nda = np.expand_dims(nda, axis=0)
nda = np.expand_dims(nda, axis=0)
return data, nda
@tf.function(input_signature=[tf.TensorSpec(None, tf.int32)]) @tf.function(input_signature=[tf.TensorSpec(None, tf.int32)])
def data_function(self, indexes): def data_function(self, indexes):
out = tf.numpy_function(self.get_in_mem_data_batch_train, [indexes], [tf.float32, tf.float32, tf.float32]) out = tf.numpy_function(self.get_in_mem_data_batch_train, [indexes], [tf.float32, tf.float32])
return out return out
@tf.function(input_signature=[tf.TensorSpec(None, tf.int32)]) @tf.function(input_signature=[tf.TensorSpec(None, tf.int32)])
def data_function_test(self, indexes): def data_function_test(self, indexes):
out = tf.numpy_function(self.get_in_mem_data_batch_test, [indexes], [tf.float32, tf.float32, tf.float32]) out = tf.numpy_function(self.get_in_mem_data_batch_test, [indexes], [tf.float32, tf.float32])
return out return out
@tf.function(input_signature=[tf.TensorSpec(None, tf.int32)]) @tf.function(input_signature=[tf.TensorSpec(None, tf.int32)])
def data_function_evaluate(self, indexes): def data_function_evaluate(self, indexes):
# TODO: modify for user specified altitude # TODO: modify for user specified altitude
out = tf.numpy_function(self.get_in_mem_data_batch_eval, [indexes], [tf.float32, tf.float32]) out = tf.numpy_function(self.get_in_mem_data_batch_eval, [indexes], [tf.float32])
return out return out
def get_train_dataset(self, indexes): def get_train_dataset(self, indexes):
...@@ -364,28 +355,9 @@ class ESPCN: ...@@ -364,28 +355,9 @@ class ESPCN:
print('num test samples: ', tst_idxs.shape[0]) print('num test samples: ', tst_idxs.shape[0])
print('setup_pipeline: Done') print('setup_pipeline: Done')
def setup_test_pipeline(self, filename_l1b, filename_l2, seed=None, shuffle=False): def setup_test_pipeline(self, filename):
self.test_data_files = [filename]
if filename_l1b is not None: self.get_test_dataset([0])
self.h5f_l1b_tst = h5py.File(filename_l1b, 'r')
if filename_l2 is not None:
self.h5f_l2_tst = h5py.File(filename_l2, 'r')
if self.h5f_l1b_tst is not None:
h5f = self.h5f_l1b_tst
else:
h5f = self.h5f_l2_tst
time = h5f['time']
tst_idxs = np.arange(time.shape[0])
self.num_data_samples = len(tst_idxs)
if seed is not None:
np.random.seed(seed)
if shuffle:
np.random.shuffle(tst_idxs)
self.get_test_dataset(tst_idxs)
print('num test samples: ', tst_idxs.shape[0])
print('setup_test_pipeline: Done') print('setup_test_pipeline: Done')
def setup_eval_pipeline(self, data_dct, num_tiles=1): def setup_eval_pipeline(self, data_dct, num_tiles=1):
...@@ -509,8 +481,8 @@ class ESPCN: ...@@ -509,8 +481,8 @@ class ESPCN:
@tf.function @tf.function
def train_step(self, mini_batch): def train_step(self, mini_batch):
inputs = [mini_batch[0], mini_batch[1]] inputs = [mini_batch[0]]
labels = mini_batch[2] labels = mini_batch[1]
with tf.GradientTape() as tape: with tf.GradientTape() as tape:
pred = self.model(inputs, training=True) pred = self.model(inputs, training=True)
loss = self.loss(labels, pred) loss = self.loss(labels, pred)
...@@ -530,8 +502,8 @@ class ESPCN: ...@@ -530,8 +502,8 @@ class ESPCN:
@tf.function @tf.function
def test_step(self, mini_batch): def test_step(self, mini_batch):
inputs = [mini_batch[0], mini_batch[1]] inputs = [mini_batch[0]]
labels = mini_batch[2] labels = mini_batch[1]
pred = self.model(inputs, training=False) pred = self.model(inputs, training=False)
t_loss = self.loss(labels, pred) t_loss = self.loss(labels, pred)
...@@ -547,8 +519,8 @@ class ESPCN: ...@@ -547,8 +519,8 @@ class ESPCN:
# self.test_false_pos(labels, pred) # self.test_false_pos(labels, pred)
def predict(self, mini_batch): def predict(self, mini_batch):
inputs = [mini_batch[0], mini_batch[1]] inputs = [mini_batch[0]]
labels = mini_batch[2] labels = mini_batch[1]
pred = self.model(inputs, training=False) pred = self.model(inputs, training=False)
t_loss = self.loss(labels, pred) t_loss = self.loss(labels, pred)
...@@ -557,14 +529,6 @@ class ESPCN: ...@@ -557,14 +529,6 @@ class ESPCN:
self.test_loss(t_loss) self.test_loss(t_loss)
self.test_accuracy(labels, pred) self.test_accuracy(labels, pred)
# if NumClasses == 2:
# self.test_auc(labels, pred)
# self.test_recall(labels, pred)
# self.test_precision(labels, pred)
# self.test_true_neg(labels, pred)
# self.test_true_pos(labels, pred)
# self.test_false_neg(labels, pred)
# self.test_false_pos(labels, pred)
def reset_test_metrics(self): def reset_test_metrics(self):
self.test_loss.reset_states() self.test_loss.reset_states()
...@@ -628,8 +592,8 @@ class ESPCN: ...@@ -628,8 +592,8 @@ class ESPCN:
proc_batch_cnt = 0 proc_batch_cnt = 0
n_samples = 0 n_samples = 0
for data0, data1, label in self.train_dataset: for data, label in self.train_dataset:
trn_ds = tf.data.Dataset.from_tensor_slices((data0, data1, label)) trn_ds = tf.data.Dataset.from_tensor_slices((data, label))
trn_ds = trn_ds.batch(BATCH_SIZE) trn_ds = trn_ds.batch(BATCH_SIZE)
for mini_batch in trn_ds: for mini_batch in trn_ds:
if self.learningRateSchedule is not None: if self.learningRateSchedule is not None:
...@@ -644,8 +608,8 @@ class ESPCN: ...@@ -644,8 +608,8 @@ class ESPCN:
tf.summary.scalar('num_epochs', epoch, step=step) tf.summary.scalar('num_epochs', epoch, step=step)
self.reset_test_metrics() self.reset_test_metrics()
for data0_tst, data1_tst, label_tst in self.test_dataset: for data_tst, label_tst in self.test_dataset:
tst_ds = tf.data.Dataset.from_tensor_slices((data0_tst, data1_tst, label_tst)) tst_ds = tf.data.Dataset.from_tensor_slices((data_tst, label_tst))
tst_ds = tst_ds.batch(BATCH_SIZE) tst_ds = tst_ds.batch(BATCH_SIZE)
for mini_batch_test in tst_ds: for mini_batch_test in tst_ds:
self.test_step(mini_batch_test) self.test_step(mini_batch_test)
...@@ -676,7 +640,7 @@ class ESPCN: ...@@ -676,7 +640,7 @@ class ESPCN:
print('train loss: ', loss.numpy()) print('train loss: ', loss.numpy())
proc_batch_cnt += 1 proc_batch_cnt += 1
n_samples += data0.shape[0] n_samples += data.shape[0]
print('proc_batch_cnt: ', proc_batch_cnt, n_samples) print('proc_batch_cnt: ', proc_batch_cnt, n_samples)
t1 = datetime.datetime.now().timestamp() t1 = datetime.datetime.now().timestamp()
...@@ -684,8 +648,8 @@ class ESPCN: ...@@ -684,8 +648,8 @@ class ESPCN:
total_time += (t1-t0) total_time += (t1-t0)
self.reset_test_metrics() self.reset_test_metrics()
for data0, data1, label in self.test_dataset: for data, label in self.test_dataset:
ds = tf.data.Dataset.from_tensor_slices((data0, data1, label)) ds = tf.data.Dataset.from_tensor_slices((data, label))
ds = ds.batch(BATCH_SIZE) ds = ds.batch(BATCH_SIZE)
for mini_batch in ds: for mini_batch in ds:
self.test_step(mini_batch) self.test_step(mini_batch)
...@@ -756,9 +720,8 @@ class ESPCN: ...@@ -756,9 +720,8 @@ class ESPCN:
ds = ds.batch(BATCH_SIZE) ds = ds.batch(BATCH_SIZE)
for mini_batch_test in ds: for mini_batch_test in ds:
self.predict(mini_batch_test) self.predict(mini_batch_test)
f1, mcc = self.get_metrics()
print('loss, acc: ', self.test_loss.result().numpy(), self.test_accuracy.result().numpy(), self.test_recall.result().numpy(), print('loss, acc: ', self.test_loss.result().numpy(), self.test_accuracy.result().numpy())
self.test_precision.result().numpy(), self.test_auc.result().numpy(), f1.numpy(), mcc.numpy())
labels = np.concatenate(self.test_labels) labels = np.concatenate(self.test_labels)
self.test_labels = labels self.test_labels = labels
...@@ -807,18 +770,13 @@ class ESPCN: ...@@ -807,18 +770,13 @@ class ESPCN:
self.build_evaluation() self.build_evaluation()
self.do_training() self.do_training()
def run_restore(self, filename_l1b, filename_l2, ckpt_dir): def run_restore(self, filename, ckpt_dir):
self.setup_test_pipeline(filename_l1b, filename_l2) self.setup_test_pipeline(filename)
self.build_model() self.build_model()
self.build_training() self.build_training()
self.build_evaluation() self.build_evaluation()
self.restore(ckpt_dir) self.restore(ckpt_dir)
if self.h5f_l1b_tst is not None:
self.h5f_l1b_tst.close()
if self.h5f_l2_tst is not None:
self.h5f_l2_tst.close()
def run_evaluate(self, filename, ckpt_dir): def run_evaluate(self, filename, ckpt_dir):
data_dct, ll, cc = make_for_full_domain_predict(filename, name_list=self.train_params) data_dct, ll, cc = make_for_full_domain_predict(filename, name_list=self.train_params)
self.setup_eval_pipeline(data_dct, len(ll)) self.setup_eval_pipeline(data_dct, len(ll))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment