diff --git a/modules/deeplearning/srcnn_cld_frac.py b/modules/deeplearning/srcnn_cld_frac.py
index 57b9f948db4dc97d78b19385f1bf3bb8ef14c9dc..39d509728732e529eb78e946409e6ff46013bd6f 100644
--- a/modules/deeplearning/srcnn_cld_frac.py
+++ b/modules/deeplearning/srcnn_cld_frac.py
@@ -179,13 +179,20 @@ def get_grid_cell_mean(grd_k):
     return s
 
 
-def get_label_data(grd_k):
-    grd_k = np.where(np.isnan(grd_k), 0, grd_k)
-    grd_k = np.where((grd_k >= 0.0) & (grd_k < 0.3), 0, grd_k)
-    grd_k = np.where((grd_k >= 0.3) & (grd_k < 0.7), 1, grd_k)
-    grd_k = np.where((grd_k >= 0.7) & (grd_k <= 1.0), 2, grd_k)
-
-    return grd_k
+def get_label_data(grd):
+    blen, ylen, xlen = grd.shape
+    grd = grd.flatten()
+    grd = np.where(np.isnan(grd), 0, grd)
+    cat_0 = np.logical_and(grd >= 0.0, grd < 0.3)
+    cat_1 = np.logical_and(grd >= 0.3, grd < 0.7)
+    cat_2 = np.logical_and(grd >= 0.7, grd <= 1.0)
+
+    grd[cat_0] = 0
+    grd[cat_1] = 1
+    grd[cat_2] = 2
+
+    grd = grd.reshape((blen, ylen, xlen))
+    return grd
 
 
 class SRCNN: