Skip to content
Snippets Groups Projects
Commit d57d5f65 authored by rink's avatar rink
Browse files

add option to pass in absolute tolerance

parent aa5f8fa6
No related branches found
No related tags found
No related merge requests found
...@@ -8,15 +8,15 @@ def create_mnist_dataset(): ...@@ -8,15 +8,15 @@ def create_mnist_dataset():
return tf.keras.datasets.mnist.load_data() return tf.keras.datasets.mnist.load_data()
def create_toy_regression_dataset(xmin=-10., xmax=10, noise_std=.2): def create_toy_regression_dataset(xmin=-10., xmax=10, num_points=1000, noise_std=.05):
x_trn = np.linspace(-2.0, 2.0, 1000, dtype=np.float32) x_trn = np.linspace(-2.0, 2.0, num_points, dtype=np.float32)
y_trn = np.sin(x_trn) + np.random.normal(0, noise_std, size=1000).astype(np.float32) y_trn = np.sin(x_trn) + np.random.normal(0, noise_std, size=num_points).astype(np.float32)
x_gt = np.linspace(xmin, xmax, 1000, dtype=np.float32) x_gt = np.linspace(xmin, xmax, num_points, dtype=np.float32)
y_gt = np.sin(x_gt) y_gt = np.sin(x_gt)
x_tst = np.linspace(xmin, xmax, 1000, dtype=np.float32) x_tst = np.linspace(xmin, xmax, num_points, dtype=np.float32)
y_tst = np.sin(x_tst) + np.random.normal(0, noise_std, size=1000).astype(np.float32) y_tst = np.sin(x_tst) + np.random.normal(0, noise_std, size=num_points).astype(np.float32)
return x_gt, y_gt, x_trn, y_trn, x_tst, y_tst return x_gt, y_gt, x_trn, y_trn, x_tst, y_tst
...@@ -199,18 +199,18 @@ def predict(model, x, samples=20): ...@@ -199,18 +199,18 @@ def predict(model, x, samples=20):
y_std = np.sqrt(y_variance) y_std = np.sqrt(y_variance)
return y_mean, y_std return y_mean, y_std
def run(): def run(num_points=1000, samples=20):
xmin = -10. xmin = -10.
xmax = 10. xmax = 10.
x_gt, y_gt, x_trn, y_trn, x_tst, y_tst = create_toy_regression_dataset(xmin=xmin, xmax=xmax, noise_std=0.05) x_gt, y_gt, x_trn, y_trn, x_tst, y_tst = create_toy_regression_dataset(xmin=xmin, xmax=xmax, num_points=num_points, noise_std=0.05)
model = define() model = define()
model = train(x_trn[:, np.newaxis], model = train(x_trn[:, np.newaxis],
y_trn[:, np.newaxis], y_trn[:, np.newaxis],
model) model)
yhat_mean, yhat_std = predict(model, x_tst[:, np.newaxis], samples=20) yhat_mean, yhat_std = predict(model, x_tst[:, np.newaxis], samples=samples)
plot_regression_model_analysis(gt=(x_gt, y_gt), plot_regression_model_analysis(gt=(x_gt, y_gt),
trn=(x_trn, y_trn), trn=(x_trn, y_trn),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment