From b048d747a6b3ddd430492fed2f972070ed26cfed Mon Sep 17 00:00:00 2001 From: rink <rink@ssec.wisc.edu> Date: Wed, 2 Apr 2025 10:44:50 -0500 Subject: [PATCH] add option to pass in absolute tolerance --- modules/deeplearning/quantile_regression.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/modules/deeplearning/quantile_regression.py b/modules/deeplearning/quantile_regression.py index 93e0b888..14838bbb 100644 --- a/modules/deeplearning/quantile_regression.py +++ b/modules/deeplearning/quantile_regression.py @@ -25,10 +25,12 @@ def make_data(num_points=1000): # epsilon = np.random.normal(0, X/2, size=(num_points, 1)) # Noise increasing with X epsilon = np.random.normal(0, 0.5 + X/6, size=(num_points, 1)) # Y = 2 + 1.5 * X + epsilon # Linear relationship with variance increasing - Y = 1 + np.exp(X / 4) + epsilon + # Y = 1 + np.exp(X / 4) + epsilon + Y = 1 + np.exp(X / 4) + np.sin((2*np.pi/5)*X) + Y_eps = Y + epsilon # Split into training and test sets - X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.2, random_state=42) + X_train, X_test, Y_train, Y_test = train_test_split(X, Y_eps, test_size=0.2, random_state=42) return X_train, X_test, Y_train, Y_test, X, Y @@ -78,7 +80,7 @@ def build_mse_model(): model.compile(optimizer='adam', loss=tf.keras.losses.MeanSquaredError()) return model -def run(num_points=1000): +def run(num_points=1000, num_plot_pts=200): # Define quantiles quantiles = [0.1, 0.5, 0.9] models = {} @@ -93,7 +95,7 @@ def run(num_points=1000): models[q].fit(X_train, Y_train, epochs=100, batch_size=32, verbose=0) # Generate test data predictions - X_range = np.linspace(X.min(), X.max(), 100).reshape(-1, 1) + X_range = np.linspace(X.min(), X.max(), num_plot_pts).reshape(-1, 1) predictions = {q: models[q].predict(X_range) for q in quantiles} model = build_mae_model() @@ -119,7 +121,7 @@ def run(num_points=1000): plt.plot(X_range, predictions[0.9], label="Quantile 0.9", color='blue') plt.plot(X_range, mae_predictions, label="MAE", color='magenta') plt.plot(X_range, mse_predictions, label="MSE", color='cyan') - plt.plot(X_range, bulk_predictions, label="Bulk Quantile Model (Wimmers)", color='black') + plt.plot(X_range, bulk_predictions, label="Bulk Quantile Model (Wimmers)", color='orange') plt.xlabel("X") plt.ylabel("Y") plt.legend() -- GitLab