Skip to content
Snippets Groups Projects
Commit 005381d9 authored by tomrink's avatar tomrink
Browse files

make 3 class now

parent 40120ee4
No related branches found
No related tags found
No related merge requests found
...@@ -15,7 +15,7 @@ LOG_DEVICE_PLACEMENT = False ...@@ -15,7 +15,7 @@ LOG_DEVICE_PLACEMENT = False
PROC_BATCH_SIZE = 4 PROC_BATCH_SIZE = 4
PROC_BATCH_BUFFER_SIZE = 5000 PROC_BATCH_BUFFER_SIZE = 5000
NumClasses = 5 NumClasses = 3
if NumClasses == 2: if NumClasses == 2:
NumLogits = 1 NumLogits = 1
else: else:
...@@ -172,6 +172,7 @@ def get_label_data(grd_k): ...@@ -172,6 +172,7 @@ def get_label_data(grd_k):
def get_label_data_5cat(grd_k): def get_label_data_5cat(grd_k):
grd_k = np.where(np.isnan(grd_k), 0, grd_k) grd_k = np.where(np.isnan(grd_k), 0, grd_k)
# grd_u = np.where(np.logical_and(grd_k > 0.45, grd_k < 0.55), 1, 0)
grd_k = np.where(grd_k < 0.5, 0, 1) grd_k = np.where(grd_k < 0.5, 0, 1)
a = grd_k[:, 0::2, 0::2] a = grd_k[:, 0::2, 0::2]
...@@ -192,8 +193,17 @@ def get_label_data_5cat(grd_k): ...@@ -192,8 +193,17 @@ def get_label_data_5cat(grd_k):
s[cat_3] = 3 s[cat_3] = 3
s[cat_4] = 4 s[cat_4] = 4
# a = grd_u[:, 0::2, 0::2]
# b = grd_u[:, 1::2, 0::2]
# c = grd_u[:, 0::2, 1::2]
# d = grd_u[:, 1::2, 1::2]
# s_u = a + b + c + d
# cat_u = (s_u == 4)
# s[cat_u] = 5
return s return s
class SRCNN: class SRCNN:
def __init__(self): def __init__(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment