Skip to content
Snippets Groups Projects
Commit c287379f authored by tomrink's avatar tomrink
Browse files

snapshot...

parent b30ed759
Branches
No related tags found
No related merge requests found
...@@ -3,7 +3,7 @@ import tensorflow as tf ...@@ -3,7 +3,7 @@ import tensorflow as tf
from util.plot_cm import confusion_matrix_values from util.plot_cm import confusion_matrix_values
from util.augment import augment_image from util.augment import augment_image
from util.setup_cloud_fraction import logdir, modeldir, now, ancillary_path from util.setup_cloud_fraction import logdir, modeldir, now, ancillary_path
from util.util import EarlyStop, normalize, denormalize, scale, descale, get_grid_values_all from util.util import EarlyStop, normalize, denormalize, scale, descale, get_grid_values_all, make_tf_callable_generator
import glob import glob
import os, datetime import os, datetime
import numpy as np import numpy as np
...@@ -364,24 +364,37 @@ class SRCNN: ...@@ -364,24 +364,37 @@ class SRCNN:
out = tf.numpy_function(self.get_in_mem_data_batch_test, [indexes], [tf.float32, tf.float32]) out = tf.numpy_function(self.get_in_mem_data_batch_test, [indexes], [tf.float32, tf.float32])
return out return out
def get_train_dataset(self, indexes): def get_train_dataset(self, num_files):
indexes = list(indexes) def integer_gen(limit):
n = 0
while n < limit:
yield n
n += 1
num_gen = integer_gen(num_files)
gen = make_tf_callable_generator(num_gen)
dataset = tf.data.Dataset.from_tensor_slices(indexes) dataset = tf.data.Dataset.from_generator(gen, output_types=tf.int32)
dataset = dataset.batch(PROC_BATCH_SIZE) dataset = dataset.batch(PROC_BATCH_SIZE)
dataset = dataset.map(self.data_function, num_parallel_calls=AUTOTUNE) dataset = dataset.map(self.data_function, num_parallel_calls=8)
dataset = dataset.cache() dataset = dataset.cache()
dataset = dataset.shuffle(PROC_BATCH_BUFFER_SIZE, reshuffle_each_iteration=True)
if DO_AUGMENT: if DO_AUGMENT:
dataset = dataset.shuffle(PROC_BATCH_BUFFER_SIZE) dataset = dataset.map(augment_image(), num_parallel_calls=8)
dataset = dataset.prefetch(buffer_size=AUTOTUNE) dataset = dataset.prefetch(buffer_size=1)
self.train_dataset = dataset self.train_dataset = dataset
def get_test_dataset(self, indexes): def get_test_dataset(self, num_files):
indexes = list(indexes) def integer_gen(limit):
n = 0
while n < limit:
yield n
n += 1
num_gen = integer_gen(num_files)
gen = make_tf_callable_generator(num_gen)
dataset = tf.data.Dataset.from_tensor_slices(indexes) dataset = tf.data.Dataset.from_generator(gen, output_types=tf.int32)
dataset = dataset.batch(PROC_BATCH_SIZE) dataset = dataset.batch(PROC_BATCH_SIZE)
dataset = dataset.map(self.data_function_test, num_parallel_calls=AUTOTUNE) dataset = dataset.map(self.data_function_test, num_parallel_calls=8)
dataset = dataset.cache() dataset = dataset.cache()
self.test_dataset = dataset self.test_dataset = dataset
...@@ -391,29 +404,23 @@ class SRCNN: ...@@ -391,29 +404,23 @@ class SRCNN:
self.test_data_files = test_data_files self.test_data_files = test_data_files
self.test_label_files = test_label_files self.test_label_files = test_label_files
trn_idxs = np.arange(len(train_data_files)) self.get_train_dataset(len(train_data_files))
np.random.shuffle(trn_idxs) self.get_test_dataset(len(test_data_files))
tst_idxs = np.arange(len(test_data_files))
self.get_train_dataset(trn_idxs)
self.get_test_dataset(tst_idxs)
self.num_data_samples = num_train_samples # approximately self.num_data_samples = num_train_samples # approximately
print('datetime: ', now) print('datetime: ', now)
print('training and test data: ') print('training and test data: ')
print('---------------------------') print('---------------------------')
print('num train samples: ', self.num_data_samples) print('num train files: ', len(train_data_files))
print('BATCH SIZE: ', BATCH_SIZE) print('BATCH SIZE: ', BATCH_SIZE)
print('num test samples: ', tst_idxs.shape[0]) print('num test files: ', len(test_data_files))
print('setup_pipeline: Done') print('setup_pipeline: Done')
def setup_test_pipeline(self, test_data_files, test_label_files): def setup_test_pipeline(self, test_data_files, test_label_files):
self.test_data_files = test_data_files self.test_data_files = test_data_files
self.test_label_files = test_label_files self.test_label_files = test_label_files
tst_idxs = np.arange(len(test_data_files)) self.get_test_dataset(len(test_data_files))
self.get_test_dataset(tst_idxs)
print('setup_test_pipeline: Done') print('setup_test_pipeline: Done')
def build_srcnn(self, do_drop_out=False, do_batch_norm=False, drop_rate=0.5, factor=2): def build_srcnn(self, do_drop_out=False, do_batch_norm=False, drop_rate=0.5, factor=2):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment