diff --git a/modules/deeplearning/srcnn_l1b_l2.py b/modules/deeplearning/srcnn_l1b_l2.py index 89df1a27b08c53028cb28ecca874b049399c7683..483bcacd0b9e0a7c2a71ed31ad29092975f6b16c 100644 --- a/modules/deeplearning/srcnn_l1b_l2.py +++ b/modules/deeplearning/srcnn_l1b_l2.py @@ -718,7 +718,7 @@ def analyze(): print(grd_lr.shape) leny, lenx = grd_lr.shape rnd = np.random.normal(loc=0, scale=0.001, size=grd_lr.size) - grd_lr += rnd + grd_lr += rnd.reshape(grd_lr.shape) grd_lr = np.where(grd_lr < 0, 0, grd_lr) grd_lr = np.where(grd_lr > 1, 1, grd_lr)