Skip to content
Snippets Groups Projects
Commit d6807531 authored by tomrink's avatar tomrink
Browse files

minor...

parent 3866d024
No related branches found
No related tags found
No related merge requests found
import tensorflow as tf
import tensorflow_addons as tfa
from util.setup import logdir, modeldir, cachepath
from util.setup import logdir, modeldir, cachepath, now
from util.util import homedir, EarlyStop
from util.geos_nav import GEOSNavigation
from util.plot import make_icing_image
import os, datetime
import numpy as np
import pickle
import h5py
from icing.pirep_goes import normalize
from icing.pirep_goes import normalize, make_for_full_domain_predict
LOG_DEVICE_PLACEMENT = False
......@@ -103,6 +105,7 @@ class IcingIntensityNN:
self.train_dataset = None
self.inner_train_dataset = None
self.test_dataset = None
self.eval_dataset = None
self.X_img = None
self.X_prof = None
self.X_u = None
......@@ -172,6 +175,8 @@ class IcingIntensityNN:
self.num_data_samples = None
self.initial_learning_rate = None
self.data_dct = None
n_chans = len(train_params)
if TRIPLET:
n_chans *= 3
......@@ -253,6 +258,22 @@ class IcingIntensityNN:
def get_in_mem_data_batch_test(self, idxs):
return self.get_in_mem_data_batch(idxs, False)
def get_in_mem_data_batch_eval(self, idxs):
# sort these to use as numpy indexing arrays
nd_idxs = np.array(idxs)
nd_idxs = np.sort(nd_idxs)
data = []
for param in train_params:
nda = self.data_dct[param][nd_idxs, ]
nda = normalize(nda, param, mean_std_dct)
data.append(nda)
data = np.stack(data)
data = data.astype(np.float32)
data = np.transpose(data, axes=(1, 2, 3, 0))
return data
@tf.function(input_signature=[tf.TensorSpec(None, tf.int32)])
def data_function(self, indexes):
out = tf.numpy_function(self.get_in_mem_data_batch_train, [indexes], [tf.float32, tf.int32])
......@@ -263,6 +284,11 @@ class IcingIntensityNN:
out = tf.numpy_function(self.get_in_mem_data_batch_test, [indexes], [tf.float32, tf.int32])
return out
@tf.function(input_signature=[tf.TensorSpec(None, tf.int32)])
def data_function_evaluate(self, indexes):
out = tf.numpy_function(self.get_in_mem_data_batch_eval, [indexes], tf.float32)
return out
def get_train_dataset(self, indexes):
indexes = list(indexes)
......@@ -283,6 +309,15 @@ class IcingIntensityNN:
dataset = dataset.cache()
self.test_dataset = dataset
def get_evaluate_dataset(self, indexes):
indexes = list(indexes)
dataset = tf.data.Dataset.from_tensor_slices(indexes)
dataset = dataset.batch(PROC_BATCH_SIZE)
dataset = dataset.map(self.data_function_evaluate, num_parallel_calls=8)
dataset = dataset.cache()
self.eval_dataset = dataset
def setup_pipeline(self, filename_trn, filename_tst, trn_idxs=None, tst_idxs=None, seed=None):
self.filename_trn = filename_trn
self.h5f_trn = h5py.File(filename_trn, 'r')
......@@ -330,6 +365,13 @@ class IcingIntensityNN:
print('num test samples: ', tst_idxs.shape[0])
print('setup_test_pipeline: Done')
def setup_eval_pipeline(self, data_dct, num_tiles):
self.data_dct = data_dct
idxs = np.arange(num_tiles)
self.num_data_samples = idxs.shape[0]
self.get_evaluate_dataset(idxs)
def build_1d_cnn(self):
print('build_1d_cnn')
# padding = 'VALID'
......@@ -600,7 +642,14 @@ class IcingIntensityNN:
step = 0
total_time = 0
last_test_loss = np.finfo(dtype=np.float).max
best_test_loss = np.finfo(dtype=np.float).max
best_test_acc = 0
best_test_recall = 0
best_test_precision = 0
best_test_auc = 0
best_test_f1 = 0
best_test_mcc = 0
if EARLY_STOP:
es = EarlyStop()
......@@ -682,8 +731,15 @@ class IcingIntensityNN:
self.optimizer.assign_average_vars(self.model.trainable_variables)
tst_loss = self.test_loss.result().numpy()
if tst_loss < last_test_loss:
last_test_loss = tst_loss
if tst_loss < best_test_loss:
best_test_loss = tst_loss
best_test_acc = self.test_accuracy.result().numpy()
best_test_recall = self.test_recall.result().numpy()
best_test_precision = self.test_precision.result().numpy()
best_test_auc = self.test_auc.result().numpy()
best_test_f1 = f1.numpy()
best_test_mcc = mcc.numpy()
ckpt_manager.save()
if self.DISK_CACHE and epoch == 0:
......@@ -701,6 +757,10 @@ class IcingIntensityNN:
self.h5f_trn.close()
self.h5f_tst.close()
f = open('/home/rink/best_stats_'+now+'.pkl', 'wb')
pickle.dump((best_test_loss, best_test_acc, best_test_recall, best_test_precision, best_test_auc, best_test_f1, best_test_mcc), f)
f.close()
def build_model(self):
flat = self.build_cnn()
# flat_1d = self.build_1d_cnn()
......@@ -742,6 +802,49 @@ class IcingIntensityNN:
self.h5f_tst.close()
def do_evaluate(self, ckpt_dir, ll, cc):
ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model)
ckpt_manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep=3)
ckpt.restore(ckpt_manager.latest_checkpoint)
pred_s = []
for data in self.eval_dataset:
ds = tf.data.Dataset.from_tensor_slices(data)
ds = ds.batch(BATCH_SIZE)
for mini_batch in ds:
pred = self.model([mini_batch], training=False)
pred_s.append(pred)
preds = np.concatenate(pred_s)
if NumClasses == 2:
preds = np.where(preds > 0.6, 1, 0)
else:
preds = np.argmax(preds, axis=1)
print(preds.shape[0], np.sum(preds == 1))
preds = preds[:,0]
cc = np.array(cc)
ll = np.array(ll)
ice_mask = preds == 1
print(cc.shape, ll.shape, ice_mask.shape)
ice_cc = cc[ice_mask]
ice_ll = ll[ice_mask]
nav = GEOSNavigation(sub_lon=-75.0, CFAC=5.6E-05, COFF=-0.101332, LFAC=-5.6E-05, LOFF=0.128212, num_elems=2500,
num_lines=1500)
ice_lons = []
ice_lats = []
for k in range(ice_cc.shape[0]):
lon, lat = nav.lc_to_earth(ice_cc[k], ice_ll[k])
ice_lons.append(lon)
ice_lats.append(lat)
return ice_lons, ice_lats
def run(self, filename_trn, filename_tst):
with tf.device('/device:GPU:'+str(self.gpu_device)):
self.setup_pipeline(filename_trn, filename_tst)
......@@ -757,6 +860,15 @@ class IcingIntensityNN:
self.build_evaluation()
self.restore(ckpt_dir)
def run_evaluate(self, filename, ckpt_dir):
data_dct, ll, cc = make_for_full_domain_predict(filename)
self.setup_eval_pipeline(data_dct, len(ll))
self.build_model()
self.build_training()
self.build_evaluation()
ice_lons, ice_lats = self.do_evaluate(ckpt_dir, ll, cc)
return filename, ice_lons, ice_lats
if __name__ == "__main__":
nn = IcingIntensityNN()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment