diff --git a/modules/deeplearning/srcnn_cld_frac.py b/modules/deeplearning/srcnn_cld_frac.py
index 527a607ddca1ef365e0eabd6b3242465104eb853..b3b7ae3f0e6097e8f6598f923881d4a84ffc71fd 100644
--- a/modules/deeplearning/srcnn_cld_frac.py
+++ b/modules/deeplearning/srcnn_cld_frac.py
@@ -185,26 +185,24 @@ def upsample_nearest(tmp):
def get_label_data(grd_k):
num, leny, lenx = grd_k.shape
- leny_d2x = int(leny / 2)
- lenx_d2x = int(lenx / 2)
- grd_down_2x = np.zeros((num, leny_d2x, lenx_d2x))
+ grd_k = np.where(np.isnan(grd_k), 0, grd_k)
+ grd_k = np.where(grd_k < 0.5, 0, 1)
+
+ grd_down_2x = []
for t in range(num):
- for j in range(leny_d2x):
- for i in range(lenx_d2x):
- cell = grd_k[t, j:j + 2, i:i + 2]
- if np.sum(np.isnan(cell)) == 0:
- cell = np.where(cell < 0.5, 0, 1)
- cnt = np.sum(cell)
- if cnt == 0:
- grd_down_2x[t, j, i] = 0
- elif cnt == 4:
- grd_down_2x[t, j, i] = 2
- else:
- grd_down_2x[t, j, i] = 1
- else:
- grd_down_2x[t, j, i] = 0
-
- return grd_down_2x
+ a = grd_k[t, 0::2, 0::2]
+ b = grd_k[t, 1::2, 0::2]
+ c = grd_k[t, 0::2, 1::2]
+ d = grd_k[t, 1::2, 1::2]
+ s_t = a + b + c + d
+ s_t = np.where(s_t == 0, 0, s_t)
+ s_t = np.where(s_t == 1, 1, s_t)
+ s_t = np.where(s_t == 2, 1, s_t)
+ s_t = np.where(s_t == 3, 1, s_t)
+ s_t = np.where(s_t == 4, 2, s_t)
+ grd_down_2x.append(s_t)
+
+ return np.stack(grd_down_2x)
class SRCNN: