Skip to content
Snippets Groups Projects
Commit 4ae756f8 authored by tomrink's avatar tomrink
Browse files

snapshot...

parent 17877b40
Branches
No related tags found
No related merge requests found
......@@ -110,19 +110,6 @@ if DO_ESPCN:
y_128 = slice(2, 130)
# gpus = tf.config.list_physical_devices('GPU')
# if gpus:
# try:
# # Currently, memory growth needs to be the same across GPUs
# for gpu in gpus:
# tf.config.experimental.set_memory_growth(gpu, True)
# logical_gpus = tf.config.list_logical_devices('GPU')
# print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
# except RuntimeError as e:
# # Memory growth must be set before GPUs have been initialized
# print(e)
def build_residual_conv2d_block(conv, num_filters, block_name, activation=tf.nn.relu, padding='SAME',
kernel_initializer='he_uniform', scale=None, kernel_size=3,
do_drop_out=True, drop_rate=0.5, do_batch_norm=True):
......@@ -211,6 +198,30 @@ def get_label_data(grd_k):
return s
def get_label_data_5cat(grd_k):
grd_k = np.where(np.isnan(grd_k), 0, grd_k)
grd_k = np.where(grd_k < 0.5, 0, 1)
a = grd_k[:, 0::2, 0::2]
b = grd_k[:, 1::2, 0::2]
c = grd_k[:, 0::2, 1::2]
d = grd_k[:, 1::2, 1::2]
s = a + b + c + d
cat_0 = (s == 0)
cat_1 = (s == 1)
cat_2 = (s == 2)
cat_3 = (s == 3)
cat_4 = (s == 4)
s[cat_0] = 0
s[cat_1] = 1
s[cat_2] = 2
s[cat_3] = 3
s[cat_4] = 4
return s
class SRCNN:
def __init__(self):
......
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment