Skip to content
Snippets Groups Projects
Commit a22815c1 authored by tomrink's avatar tomrink
Browse files

snapshot...

parent 70c47460
Branches
No related tags found
No related merge requests found
......@@ -2,8 +2,10 @@ import gc
import glob
import tensorflow as tf
from util.augment import augment_image
from util.setup import logdir, modeldir, now, ancillary_path
from util.util import EarlyStop, normalize, denormalize, scale, descale, get_grid_values_all, resample_2d_linear, smooth_2d
from util.util import EarlyStop, normalize, denormalize, scale, descale, get_grid_values_all, resample_2d_linear,\
smooth_2d, make_tf_callable_generator
import os, datetime
import numpy as np
import pickle
......@@ -343,28 +345,9 @@ class SRCNN:
label = scale(label, label_param, mean_std_dct)
label = label[:, self.y_128, self.x_128]
label = np.where(np.isnan(label), 0.0, label)
label = np.expand_dims(label, axis=3)
data = data.astype(np.float32)
label = label.astype(np.float32)
if is_training and DO_AUGMENT:
# data_ud = np.flip(data, axis=1)
# label_ud = np.flip(label, axis=1)
#
# data_lr = np.flip(data, axis=2)
# label_lr = np.flip(label, axis=2)
#
# data = np.concatenate([data, data_ud, data_lr])
# label = np.concatenate([label, label_ud, label_lr])
data_rot = np.rot90(data, axes=(1, 2))
label_rot = np.rot90(label, axes=(1, 2))
data = np.concatenate([data, data_rot])
label = np.concatenate([label, label_rot])
return data, label
def get_in_mem_data_batch_train(self, idxs):
......@@ -383,22 +366,35 @@ class SRCNN:
out = tf.numpy_function(self.get_in_mem_data_batch_test, [indexes], [tf.float32, tf.float32])
return out
def get_train_dataset(self, indexes):
indexes = list(indexes)
def get_train_dataset(self, num_files):
def integer_gen(limit):
n = 0
while n < limit:
yield n
n += 1
num_gen = integer_gen(num_files)
gen = make_tf_callable_generator(num_gen)
dataset = tf.data.Dataset.from_tensor_slices(indexes)
dataset = tf.data.Dataset.from_generator(gen, output_types=tf.int32)
dataset = dataset.batch(PROC_BATCH_SIZE)
dataset = dataset.map(self.data_function, num_parallel_calls=8)
dataset = dataset.cache()
if DO_AUGMENT:
dataset = dataset.shuffle(PROC_BATCH_BUFFER_SIZE)
dataset = dataset.map(augment_image(), num_parallel_calls=8)
dataset = dataset.cache()
dataset = dataset.shuffle(PROC_BATCH_BUFFER_SIZE, reshuffle_each_iteration=False)
dataset = dataset.prefetch(buffer_size=1)
self.train_dataset = dataset
def get_test_dataset(self, indexes):
indexes = list(indexes)
def get_test_dataset(self, num_files):
def integer_gen(limit):
n = 0
while n < limit:
yield n
n += 1
num_gen = integer_gen(num_files)
gen = make_tf_callable_generator(num_gen)
dataset = tf.data.Dataset.from_tensor_slices(indexes)
dataset = tf.data.Dataset.from_generator(gen, output_types=tf.int32)
dataset = dataset.batch(PROC_BATCH_SIZE)
dataset = dataset.map(self.data_function_test, num_parallel_calls=8)
dataset = dataset.cache()
......@@ -410,22 +406,17 @@ class SRCNN:
self.test_data_files = test_data_files
self.test_label_files = test_label_files
trn_idxs = np.arange(len(train_data_files))
np.random.shuffle(trn_idxs)
tst_idxs = np.arange(len(test_data_files))
self.get_train_dataset(trn_idxs)
self.get_test_dataset(tst_idxs)
self.get_train_dataset(len(train_data_files))
self.get_test_dataset(len(test_data_files))
self.num_data_samples = num_train_samples # approximately
print('datetime: ', now)
print('training and test data: ')
print('---------------------------')
print('num train samples: ', self.num_data_samples)
print('num train files: ', len(train_data_files))
print('BATCH SIZE: ', BATCH_SIZE)
print('num test samples: ', tst_idxs.shape[0])
print('num test files: ', len(test_data_files))
print('setup_pipeline: Done')
def setup_test_pipeline(self, test_data_files, test_label_files):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment