diff --git a/modules/deeplearning/quantile_regression.py b/modules/deeplearning/quantile_regression.py index cb038e960458ad87f1958acc8d2fd94d1a5314b3..f5ff789398b3ca750abab9ce94a8daac011f65c1 100644 --- a/modules/deeplearning/quantile_regression.py +++ b/modules/deeplearning/quantile_regression.py @@ -22,7 +22,7 @@ def true_func(x): # Y = 2 + 1.5 * X + epsilon # Linear relationship with variance increasing # Y = 1 + np.exp(X / 4) + epsilon # Y = 1 + np.exp(X / 4) + np.sin((2*np.pi/5)*X) - return 1 + np.exp(x / 4) + 3*np.sin((2*np.pi/5)*x) + return 1 + np.exp(x / 3) + 4*np.sin((2*np.pi/5)*x) + (0.5 + x/5)*np.cos((2*np.pi/2)*x) # Generate synthetic dataset def make_data(num_points=1000): @@ -85,7 +85,7 @@ def run(num_points=1000, num_plot_pts=200): # Plot the results plt.figure(figsize=(8, 6)) - plt.scatter(X_test[::4, 0], Y_test[::4, 0], alpha=0.3, label="Test Data") + # plt.scatter(X_test[::4, 0], Y_test[::4, 0], alpha=0.3, label="Test Data") plt.plot(X_range, predictions[0.05], label="Quantile 0.05", color='red') plt.plot(X_range, predictions[0.5], label="Quantile 0.5 (Median)", color='green') plt.plot(X_range, predictions[0.95], label="Quantile 0.95", color='blue')