Skip to content
Snippets Groups Projects
Commit cc67961c authored by tomrink's avatar tomrink
Browse files

snapshot...

parent 5b47bb27
Branches
No related tags found
No related merge requests found
......@@ -184,25 +184,21 @@ def upsample_nearest(tmp):
def get_label_data(grd_k):
num, leny, lenx = grd_k.shape
grd_k = np.where(np.isnan(grd_k), 0, grd_k)
grd_k = np.where(grd_k < 0.5, 0, 1)
grd_down_2x = []
for t in range(num):
a = grd_k[t, 0::2, 0::2]
b = grd_k[t, 1::2, 0::2]
c = grd_k[t, 0::2, 1::2]
d = grd_k[t, 1::2, 1::2]
s_t = a + b + c + d
s_t = np.where(s_t == 0, 0, s_t)
s_t = np.where(s_t == 1, 1, s_t)
s_t = np.where(s_t == 2, 1, s_t)
s_t = np.where(s_t == 3, 1, s_t)
s_t = np.where(s_t == 4, 2, s_t)
grd_down_2x.append(s_t)
a = grd_k[:, 0::2, 0::2]
b = grd_k[:, 1::2, 0::2]
c = grd_k[:, 0::2, 1::2]
d = grd_k[:, 1::2, 1::2]
s_t = a + b + c + d
s_t = np.where(s_t == 0, 0, s_t)
s_t = np.where(s_t == 1, 1, s_t)
s_t = np.where(s_t == 2, 1, s_t)
s_t = np.where(s_t == 3, 1, s_t)
s_t = np.where(s_t == 4, 2, s_t)
return np.stack(grd_down_2x)
return s_t
class SRCNN:
......@@ -519,9 +515,14 @@ class SRCNN:
conv = conv_b
print(conv.shape)
if NumClasses == 2:
final_activation = tf.nn.sigmoid # For binary
else:
final_activation = tf.nn.softmax # For multi-class
if not DO_ESPCN:
# This is effectively a Dense layer
self.logits = tf.keras.layers.Conv2D(NumLogits, kernel_size=1, strides=1, padding=padding, name='regression')(conv)
self.logits = tf.keras.layers.Conv2D(NumLogits, kernel_size=1, strides=1, padding=padding, activation=final_activation)(conv)
else:
conv = tf.keras.layers.Conv2D(num_filters * (factor ** 2), 3, padding=padding, activation=activation)(conv)
print(conv.shape)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment