Skip to content
Snippets Groups Projects
Commit 0814369f authored by tomrink's avatar tomrink
Browse files

snapshot...

parent b8583bc0
No related branches found
No related tags found
No related merge requests found
......@@ -360,9 +360,8 @@ class ESPCN:
self.get_test_dataset([0])
print('setup_test_pipeline: Done')
def setup_eval_pipeline(self, data_dct, num_tiles=1):
self.data_dct = data_dct
idxs = np.arange(num_tiles)
def setup_eval_pipeline(self, filename):
idxs = [0]
self.num_data_samples = idxs.shape[0]
self.get_evaluate_dataset(idxs)
......@@ -398,26 +397,26 @@ class ESPCN:
conv = tf.keras.layers.BatchNormalization()(conv)
print(conv.shape)
conv = tf.keras.layers.Conv2D(num_filters, kernel_size=3, strides=1, padding=padding, activation=None)(conv)
conv = tf.keras.layers.BatchNormalization()(conv)
print(conv.shape)
conv = conv + skip
conv = tf.keras.layers.LeakyReLU()(conv)
print(conv.shape)
# conv = tf.keras.layers.Conv2D(num_filters, kernel_size=3, strides=1, padding=padding, activation=None)(conv)
# conv = tf.keras.layers.BatchNormalization()(conv)
# print(conv.shape)
#
# conv = conv + skip
# conv = tf.keras.layers.LeakyReLU()(conv)
# print(conv.shape)
conv = tf.keras.layers.Conv2D(num_filters/2, kernel_size=3, strides=1, padding=padding, activation=activation)(conv)
conv = tf.keras.layers.Conv2D(num_filters // 2, kernel_size=3, strides=1, padding=padding, activation=activation)(conv)
conv = tf.keras.layers.BatchNormalization()(conv)
print(conv.shape)
conv = tf.keras.layers.Conv2D(num_filters/2, kernel_size=3, strides=1, padding=padding, activation=activation)(conv)
conv = tf.keras.layers.Conv2D(num_filters // 2, kernel_size=3, strides=1, padding=padding, activation=activation)(conv)
conv = tf.keras.layers.BatchNormalization()(conv)
print(conv.shape)
# conv = tf.keras.layers.Conv2D(4, kernel_size=3, strides=1, padding=padding, activation=activation)(conv)
# print(conv.shape)
conv = tf.keras.layers.Conv2DTranspose(4, kernel_size=3, strides=2, padding=padding, activation=activation)(conv)
conv = tf.keras.layers.Conv2DTranspose(num_filters // 4, kernel_size=3, strides=2, padding=padding, activation=activation)(conv)
print(conv.shape)
self.logits = tf.keras.layers.Conv2D(1, kernel_size=1, strides=1, padding=padding, name='probability', activation=tf.nn.sigmoid)(conv)
......@@ -741,26 +740,21 @@ class ESPCN:
self.test_preds = preds
def do_evaluate(self, prob_thresh=0.5):
def do_evaluate(self, nda_lr, ckpt_dir, prob_thresh=0.5):
self.reset_test_metrics()
ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model)
ckpt_manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep=3)
pred_s = []
ckpt.restore(ckpt_manager.latest_checkpoint)
for data in self.eval_dataset:
print(data[0].shape, data[1].shape)
pred = self.model([data])
print(pred.shape, np.histogram(pred.numpy()))
data = normalize(nda_lr, data_param, mean_std_dct)
data = np.expand_dims(data, axis=0)
data = np.expand_dims(data, axis=3)
preds = np.concatenate(pred_s)
preds = preds[:,0]
self.test_probs = preds
self.reset_test_metrics()
if NumClasses == 2:
preds = np.where(preds > prob_thresh, 1, 0)
else:
preds = np.argmax(preds, axis=1)
self.test_preds = preds
pred = self.model([data])
self.test_probs = pred
def run(self, directory):
train_data_files = glob.glob(directory+'data_train*.npy')
......@@ -783,8 +777,7 @@ class ESPCN:
self.restore(ckpt_dir)
def run_evaluate(self, filename, ckpt_dir):
data_dct, ll, cc = make_for_full_domain_predict(filename, name_list=self.train_params)
self.setup_eval_pipeline(data_dct, len(ll))
self.setup_eval_pipeline(filename)
self.build_model()
self.build_training()
self.build_evaluation()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment