Skip to content
Snippets Groups Projects
Commit d6807531 authored by tomrink's avatar tomrink
Browse files

minor...

parent 3866d024
No related branches found
No related tags found
No related merge requests found
import tensorflow as tf import tensorflow as tf
import tensorflow_addons as tfa import tensorflow_addons as tfa
from util.setup import logdir, modeldir, cachepath from util.setup import logdir, modeldir, cachepath, now
from util.util import homedir, EarlyStop from util.util import homedir, EarlyStop
from util.geos_nav import GEOSNavigation
from util.plot import make_icing_image
import os, datetime import os, datetime
import numpy as np import numpy as np
import pickle import pickle
import h5py import h5py
from icing.pirep_goes import normalize from icing.pirep_goes import normalize, make_for_full_domain_predict
LOG_DEVICE_PLACEMENT = False LOG_DEVICE_PLACEMENT = False
...@@ -103,6 +105,7 @@ class IcingIntensityNN: ...@@ -103,6 +105,7 @@ class IcingIntensityNN:
self.train_dataset = None self.train_dataset = None
self.inner_train_dataset = None self.inner_train_dataset = None
self.test_dataset = None self.test_dataset = None
self.eval_dataset = None
self.X_img = None self.X_img = None
self.X_prof = None self.X_prof = None
self.X_u = None self.X_u = None
...@@ -172,6 +175,8 @@ class IcingIntensityNN: ...@@ -172,6 +175,8 @@ class IcingIntensityNN:
self.num_data_samples = None self.num_data_samples = None
self.initial_learning_rate = None self.initial_learning_rate = None
self.data_dct = None
n_chans = len(train_params) n_chans = len(train_params)
if TRIPLET: if TRIPLET:
n_chans *= 3 n_chans *= 3
...@@ -253,6 +258,22 @@ class IcingIntensityNN: ...@@ -253,6 +258,22 @@ class IcingIntensityNN:
def get_in_mem_data_batch_test(self, idxs): def get_in_mem_data_batch_test(self, idxs):
return self.get_in_mem_data_batch(idxs, False) return self.get_in_mem_data_batch(idxs, False)
def get_in_mem_data_batch_eval(self, idxs):
# sort these to use as numpy indexing arrays
nd_idxs = np.array(idxs)
nd_idxs = np.sort(nd_idxs)
data = []
for param in train_params:
nda = self.data_dct[param][nd_idxs, ]
nda = normalize(nda, param, mean_std_dct)
data.append(nda)
data = np.stack(data)
data = data.astype(np.float32)
data = np.transpose(data, axes=(1, 2, 3, 0))
return data
@tf.function(input_signature=[tf.TensorSpec(None, tf.int32)]) @tf.function(input_signature=[tf.TensorSpec(None, tf.int32)])
def data_function(self, indexes): def data_function(self, indexes):
out = tf.numpy_function(self.get_in_mem_data_batch_train, [indexes], [tf.float32, tf.int32]) out = tf.numpy_function(self.get_in_mem_data_batch_train, [indexes], [tf.float32, tf.int32])
...@@ -263,6 +284,11 @@ class IcingIntensityNN: ...@@ -263,6 +284,11 @@ class IcingIntensityNN:
out = tf.numpy_function(self.get_in_mem_data_batch_test, [indexes], [tf.float32, tf.int32]) out = tf.numpy_function(self.get_in_mem_data_batch_test, [indexes], [tf.float32, tf.int32])
return out return out
@tf.function(input_signature=[tf.TensorSpec(None, tf.int32)])
def data_function_evaluate(self, indexes):
out = tf.numpy_function(self.get_in_mem_data_batch_eval, [indexes], tf.float32)
return out
def get_train_dataset(self, indexes): def get_train_dataset(self, indexes):
indexes = list(indexes) indexes = list(indexes)
...@@ -283,6 +309,15 @@ class IcingIntensityNN: ...@@ -283,6 +309,15 @@ class IcingIntensityNN:
dataset = dataset.cache() dataset = dataset.cache()
self.test_dataset = dataset self.test_dataset = dataset
def get_evaluate_dataset(self, indexes):
indexes = list(indexes)
dataset = tf.data.Dataset.from_tensor_slices(indexes)
dataset = dataset.batch(PROC_BATCH_SIZE)
dataset = dataset.map(self.data_function_evaluate, num_parallel_calls=8)
dataset = dataset.cache()
self.eval_dataset = dataset
def setup_pipeline(self, filename_trn, filename_tst, trn_idxs=None, tst_idxs=None, seed=None): def setup_pipeline(self, filename_trn, filename_tst, trn_idxs=None, tst_idxs=None, seed=None):
self.filename_trn = filename_trn self.filename_trn = filename_trn
self.h5f_trn = h5py.File(filename_trn, 'r') self.h5f_trn = h5py.File(filename_trn, 'r')
...@@ -330,6 +365,13 @@ class IcingIntensityNN: ...@@ -330,6 +365,13 @@ class IcingIntensityNN:
print('num test samples: ', tst_idxs.shape[0]) print('num test samples: ', tst_idxs.shape[0])
print('setup_test_pipeline: Done') print('setup_test_pipeline: Done')
def setup_eval_pipeline(self, data_dct, num_tiles):
self.data_dct = data_dct
idxs = np.arange(num_tiles)
self.num_data_samples = idxs.shape[0]
self.get_evaluate_dataset(idxs)
def build_1d_cnn(self): def build_1d_cnn(self):
print('build_1d_cnn') print('build_1d_cnn')
# padding = 'VALID' # padding = 'VALID'
...@@ -600,7 +642,14 @@ class IcingIntensityNN: ...@@ -600,7 +642,14 @@ class IcingIntensityNN:
step = 0 step = 0
total_time = 0 total_time = 0
last_test_loss = np.finfo(dtype=np.float).max best_test_loss = np.finfo(dtype=np.float).max
best_test_acc = 0
best_test_recall = 0
best_test_precision = 0
best_test_auc = 0
best_test_f1 = 0
best_test_mcc = 0
if EARLY_STOP: if EARLY_STOP:
es = EarlyStop() es = EarlyStop()
...@@ -682,8 +731,15 @@ class IcingIntensityNN: ...@@ -682,8 +731,15 @@ class IcingIntensityNN:
self.optimizer.assign_average_vars(self.model.trainable_variables) self.optimizer.assign_average_vars(self.model.trainable_variables)
tst_loss = self.test_loss.result().numpy() tst_loss = self.test_loss.result().numpy()
if tst_loss < last_test_loss: if tst_loss < best_test_loss:
last_test_loss = tst_loss best_test_loss = tst_loss
best_test_acc = self.test_accuracy.result().numpy()
best_test_recall = self.test_recall.result().numpy()
best_test_precision = self.test_precision.result().numpy()
best_test_auc = self.test_auc.result().numpy()
best_test_f1 = f1.numpy()
best_test_mcc = mcc.numpy()
ckpt_manager.save() ckpt_manager.save()
if self.DISK_CACHE and epoch == 0: if self.DISK_CACHE and epoch == 0:
...@@ -701,6 +757,10 @@ class IcingIntensityNN: ...@@ -701,6 +757,10 @@ class IcingIntensityNN:
self.h5f_trn.close() self.h5f_trn.close()
self.h5f_tst.close() self.h5f_tst.close()
f = open('/home/rink/best_stats_'+now+'.pkl', 'wb')
pickle.dump((best_test_loss, best_test_acc, best_test_recall, best_test_precision, best_test_auc, best_test_f1, best_test_mcc), f)
f.close()
def build_model(self): def build_model(self):
flat = self.build_cnn() flat = self.build_cnn()
# flat_1d = self.build_1d_cnn() # flat_1d = self.build_1d_cnn()
...@@ -742,6 +802,49 @@ class IcingIntensityNN: ...@@ -742,6 +802,49 @@ class IcingIntensityNN:
self.h5f_tst.close() self.h5f_tst.close()
def do_evaluate(self, ckpt_dir, ll, cc):
ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model)
ckpt_manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep=3)
ckpt.restore(ckpt_manager.latest_checkpoint)
pred_s = []
for data in self.eval_dataset:
ds = tf.data.Dataset.from_tensor_slices(data)
ds = ds.batch(BATCH_SIZE)
for mini_batch in ds:
pred = self.model([mini_batch], training=False)
pred_s.append(pred)
preds = np.concatenate(pred_s)
if NumClasses == 2:
preds = np.where(preds > 0.6, 1, 0)
else:
preds = np.argmax(preds, axis=1)
print(preds.shape[0], np.sum(preds == 1))
preds = preds[:,0]
cc = np.array(cc)
ll = np.array(ll)
ice_mask = preds == 1
print(cc.shape, ll.shape, ice_mask.shape)
ice_cc = cc[ice_mask]
ice_ll = ll[ice_mask]
nav = GEOSNavigation(sub_lon=-75.0, CFAC=5.6E-05, COFF=-0.101332, LFAC=-5.6E-05, LOFF=0.128212, num_elems=2500,
num_lines=1500)
ice_lons = []
ice_lats = []
for k in range(ice_cc.shape[0]):
lon, lat = nav.lc_to_earth(ice_cc[k], ice_ll[k])
ice_lons.append(lon)
ice_lats.append(lat)
return ice_lons, ice_lats
def run(self, filename_trn, filename_tst): def run(self, filename_trn, filename_tst):
with tf.device('/device:GPU:'+str(self.gpu_device)): with tf.device('/device:GPU:'+str(self.gpu_device)):
self.setup_pipeline(filename_trn, filename_tst) self.setup_pipeline(filename_trn, filename_tst)
...@@ -757,6 +860,15 @@ class IcingIntensityNN: ...@@ -757,6 +860,15 @@ class IcingIntensityNN:
self.build_evaluation() self.build_evaluation()
self.restore(ckpt_dir) self.restore(ckpt_dir)
def run_evaluate(self, filename, ckpt_dir):
data_dct, ll, cc = make_for_full_domain_predict(filename)
self.setup_eval_pipeline(data_dct, len(ll))
self.build_model()
self.build_training()
self.build_evaluation()
ice_lons, ice_lats = self.do_evaluate(ckpt_dir, ll, cc)
return filename, ice_lons, ice_lats
if __name__ == "__main__": if __name__ == "__main__":
nn = IcingIntensityNN() nn = IcingIntensityNN()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment