diff --git a/modules/deeplearning/mc_dropout_regression.py b/modules/deeplearning/mc_dropout_regression.py index fbf1873fb0bf2aeb208d01dad1baa5709c46aacb..924f57f9d16dbf7a2cf7dccc7079e7ebe7fc4ba8 100644 --- a/modules/deeplearning/mc_dropout_regression.py +++ b/modules/deeplearning/mc_dropout_regression.py @@ -8,15 +8,15 @@ def create_mnist_dataset(): return tf.keras.datasets.mnist.load_data() -def create_toy_regression_dataset(xmin=-10., xmax=10, noise_std=.2): - x_trn = np.linspace(-2.0, 2.0, 1000, dtype=np.float32) - y_trn = np.sin(x_trn) + np.random.normal(0, noise_std, size=1000).astype(np.float32) +def create_toy_regression_dataset(xmin=-10., xmax=10, num_points=1000, noise_std=.05): + x_trn = np.linspace(-2.0, 2.0, num_points, dtype=np.float32) + y_trn = np.sin(x_trn) + np.random.normal(0, noise_std, size=num_points).astype(np.float32) - x_gt = np.linspace(xmin, xmax, 1000, dtype=np.float32) + x_gt = np.linspace(xmin, xmax, num_points, dtype=np.float32) y_gt = np.sin(x_gt) - x_tst = np.linspace(xmin, xmax, 1000, dtype=np.float32) - y_tst = np.sin(x_tst) + np.random.normal(0, noise_std, size=1000).astype(np.float32) + x_tst = np.linspace(xmin, xmax, num_points, dtype=np.float32) + y_tst = np.sin(x_tst) + np.random.normal(0, noise_std, size=num_points).astype(np.float32) return x_gt, y_gt, x_trn, y_trn, x_tst, y_tst @@ -199,18 +199,18 @@ def predict(model, x, samples=20): y_std = np.sqrt(y_variance) return y_mean, y_std -def run(): +def run(num_points=1000, samples=20): xmin = -10. xmax = 10. - x_gt, y_gt, x_trn, y_trn, x_tst, y_tst = create_toy_regression_dataset(xmin=xmin, xmax=xmax, noise_std=0.05) + x_gt, y_gt, x_trn, y_trn, x_tst, y_tst = create_toy_regression_dataset(xmin=xmin, xmax=xmax, num_points=num_points, noise_std=0.05) model = define() model = train(x_trn[:, np.newaxis], y_trn[:, np.newaxis], model) - yhat_mean, yhat_std = predict(model, x_tst[:, np.newaxis], samples=20) + yhat_mean, yhat_std = predict(model, x_tst[:, np.newaxis], samples=samples) plot_regression_model_analysis(gt=(x_gt, y_gt), trn=(x_trn, y_trn),