Skip to content
Snippets Groups Projects
Commit 99cc43de authored by tomrink's avatar tomrink
Browse files

add method to scale data

parent 512c1dc0
No related branches found
No related tags found
No related merge requests found
import glob import glob
import tensorflow as tf import tensorflow as tf
from util.setup import logdir, modeldir, cachepath, now, ancillary_path, home_dir from util.setup import logdir, modeldir, cachepath, now, ancillary_path, home_dir
from util.util import EarlyStop, normalize, make_for_full_domain_predict from util.util import EarlyStop, normalize, scale, make_for_full_domain_predict
import os, datetime import os, datetime
import numpy as np import numpy as np
import pickle import pickle
...@@ -26,7 +26,7 @@ BATCH_SIZE = 128 ...@@ -26,7 +26,7 @@ BATCH_SIZE = 128
NUM_EPOCHS = 60 NUM_EPOCHS = 60
TRACK_MOVING_AVERAGE = False TRACK_MOVING_AVERAGE = False
EARLY_STOP = True EARLY_STOP = False
TRIPLET = False TRIPLET = False
CONV3D = False CONV3D = False
...@@ -54,11 +54,14 @@ mean_std_dct.update(mean_std_dct_l2) ...@@ -54,11 +54,14 @@ mean_std_dct.update(mean_std_dct_l2)
emis_params = ['temp_10_4um_nom', 'temp_11_0um_nom', 'temp_12_0um_nom', 'temp_13_3um_nom', 'temp_3_9um_nom', emis_params = ['temp_10_4um_nom', 'temp_11_0um_nom', 'temp_12_0um_nom', 'temp_13_3um_nom', 'temp_3_9um_nom',
'temp_6_7um_nom'] 'temp_6_7um_nom']
l2_params = ['cloud_fraction', 'cld_temp_acha', 'cld_press_acha']
# -- Zero out params (Experimentation Only) ------------ # -- Zero out params (Experimentation Only) ------------
zero_out_params = ['cld_reff_dcomp', 'cld_opd_dcomp', 'iwc_dcomp', 'lwc_dcomp'] zero_out_params = ['cld_reff_dcomp', 'cld_opd_dcomp', 'iwc_dcomp', 'lwc_dcomp']
DO_ZERO_OUT = False DO_ZERO_OUT = False
label_param = l2_params[1]
def build_conv2d_block(conv, num_filters, activation, block_name, padding='SAME'): def build_conv2d_block(conv, num_filters, activation, block_name, padding='SAME'):
with tf.name_scope(block_name): with tf.name_scope(block_name):
...@@ -240,7 +243,7 @@ class UNET: ...@@ -240,7 +243,7 @@ class UNET:
data_norm.append(tmp) data_norm.append(tmp)
data = np.stack(data_norm, axis=3) data = np.stack(data_norm, axis=3)
# label = normalize(label, 'M15', mean_std_dct) label = scale(label, label_param, mean_std_dct)
if is_training and DO_AUGMENT: if is_training and DO_AUGMENT:
data_ud = np.flip(data, axis=1) data_ud = np.flip(data, axis=1)
...@@ -358,14 +361,14 @@ class UNET: ...@@ -358,14 +361,14 @@ class UNET:
self.test_data_files = test_data_files self.test_data_files = test_data_files
self.test_label_files = test_label_files self.test_label_files = test_label_files
trn_idxs = np.arange(10503) trn_idxs = np.arange(20925)
np.random.shuffle(trn_idxs) np.random.shuffle(trn_idxs)
tst_idxs = np.arange(1167) tst_idxs = np.arange(2325)
self.get_train_dataset(trn_idxs) self.get_train_dataset(trn_idxs)
self.get_test_dataset(tst_idxs) self.get_test_dataset(tst_idxs)
self.num_data_samples = 10503 self.num_data_samples = 20925
print('datetime: ', now) print('datetime: ', now)
print('training and test data: ') print('training and test data: ')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment