From b048d747a6b3ddd430492fed2f972070ed26cfed Mon Sep 17 00:00:00 2001
From: rink <rink@ssec.wisc.edu>
Date: Wed, 2 Apr 2025 10:44:50 -0500
Subject: [PATCH] add option to pass in absolute tolerance

---
 modules/deeplearning/quantile_regression.py | 12 +++++++-----
 1 file changed, 7 insertions(+), 5 deletions(-)

diff --git a/modules/deeplearning/quantile_regression.py b/modules/deeplearning/quantile_regression.py
index 93e0b888..14838bbb 100644
--- a/modules/deeplearning/quantile_regression.py
+++ b/modules/deeplearning/quantile_regression.py
@@ -25,10 +25,12 @@ def make_data(num_points=1000):
     # epsilon = np.random.normal(0, X/2, size=(num_points, 1))  # Noise increasing with X
     epsilon = np.random.normal(0, 0.5 + X/6, size=(num_points, 1))
     # Y = 2 + 1.5 * X + epsilon  # Linear relationship with variance increasing
-    Y = 1 + np.exp(X / 4) + epsilon
+    # Y = 1 + np.exp(X / 4) + epsilon
+    Y = 1 + np.exp(X / 4) + np.sin((2*np.pi/5)*X)
+    Y_eps = Y + epsilon
 
     # Split into training and test sets
-    X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.2, random_state=42)
+    X_train, X_test, Y_train, Y_test = train_test_split(X, Y_eps, test_size=0.2, random_state=42)
     return X_train, X_test, Y_train, Y_test, X, Y
 
 
@@ -78,7 +80,7 @@ def build_mse_model():
     model.compile(optimizer='adam', loss=tf.keras.losses.MeanSquaredError())
     return model
 
-def run(num_points=1000):
+def run(num_points=1000, num_plot_pts=200):
     # Define quantiles
     quantiles = [0.1, 0.5, 0.9]
     models = {}
@@ -93,7 +95,7 @@ def run(num_points=1000):
         models[q].fit(X_train, Y_train, epochs=100, batch_size=32, verbose=0)
 
     # Generate test data predictions
-    X_range = np.linspace(X.min(), X.max(), 100).reshape(-1, 1)
+    X_range = np.linspace(X.min(), X.max(), num_plot_pts).reshape(-1, 1)
     predictions = {q: models[q].predict(X_range) for q in quantiles}
 
     model = build_mae_model()
@@ -119,7 +121,7 @@ def run(num_points=1000):
     plt.plot(X_range, predictions[0.9], label="Quantile 0.9", color='blue')
     plt.plot(X_range, mae_predictions, label="MAE", color='magenta')
     plt.plot(X_range, mse_predictions, label="MSE", color='cyan')
-    plt.plot(X_range, bulk_predictions, label="Bulk Quantile Model (Wimmers)", color='black')
+    plt.plot(X_range, bulk_predictions, label="Bulk Quantile Model (Wimmers)", color='orange')
     plt.xlabel("X")
     plt.ylabel("Y")
     plt.legend()
-- 
GitLab