diff --git a/modules/deeplearning/quantile_regression.py b/modules/deeplearning/quantile_regression.py index 93e0b888e3107234ea0d85f796c4f0ba55166eef..14838bbbe2e01e1af3ce5b9c9c02d162d3262040 100644 --- a/modules/deeplearning/quantile_regression.py +++ b/modules/deeplearning/quantile_regression.py @@ -25,10 +25,12 @@ def make_data(num_points=1000): # epsilon = np.random.normal(0, X/2, size=(num_points, 1)) # Noise increasing with X epsilon = np.random.normal(0, 0.5 + X/6, size=(num_points, 1)) # Y = 2 + 1.5 * X + epsilon # Linear relationship with variance increasing - Y = 1 + np.exp(X / 4) + epsilon + # Y = 1 + np.exp(X / 4) + epsilon + Y = 1 + np.exp(X / 4) + np.sin((2*np.pi/5)*X) + Y_eps = Y + epsilon # Split into training and test sets - X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.2, random_state=42) + X_train, X_test, Y_train, Y_test = train_test_split(X, Y_eps, test_size=0.2, random_state=42) return X_train, X_test, Y_train, Y_test, X, Y @@ -78,7 +80,7 @@ def build_mse_model(): model.compile(optimizer='adam', loss=tf.keras.losses.MeanSquaredError()) return model -def run(num_points=1000): +def run(num_points=1000, num_plot_pts=200): # Define quantiles quantiles = [0.1, 0.5, 0.9] models = {} @@ -93,7 +95,7 @@ def run(num_points=1000): models[q].fit(X_train, Y_train, epochs=100, batch_size=32, verbose=0) # Generate test data predictions - X_range = np.linspace(X.min(), X.max(), 100).reshape(-1, 1) + X_range = np.linspace(X.min(), X.max(), num_plot_pts).reshape(-1, 1) predictions = {q: models[q].predict(X_range) for q in quantiles} model = build_mae_model() @@ -119,7 +121,7 @@ def run(num_points=1000): plt.plot(X_range, predictions[0.9], label="Quantile 0.9", color='blue') plt.plot(X_range, mae_predictions, label="MAE", color='magenta') plt.plot(X_range, mse_predictions, label="MSE", color='cyan') - plt.plot(X_range, bulk_predictions, label="Bulk Quantile Model (Wimmers)", color='black') + plt.plot(X_range, bulk_predictions, label="Bulk Quantile Model (Wimmers)", color='orange') plt.xlabel("X") plt.ylabel("Y") plt.legend()