From c1572fe94a5e7cdee1eea35b29e76cbde8cf4c3a Mon Sep 17 00:00:00 2001
From: tomrink <rink@ssec.wisc.edu>
Date: Thu, 25 Apr 2024 14:03:13 -0500
Subject: [PATCH] snapshot...

---
 modules/machine_learning/classification.py | 11 +++++------
 1 file changed, 5 insertions(+), 6 deletions(-)

diff --git a/modules/machine_learning/classification.py b/modules/machine_learning/classification.py
index 814af787..3e30531f 100644
--- a/modules/machine_learning/classification.py
+++ b/modules/machine_learning/classification.py
@@ -50,13 +50,13 @@ def plot_confusion_matrix(cm, classes,
     plt.xlabel('Predicted label')
 
 
-def get_csv_as_dataframe(csv_file, reduce_frac=None):
+def get_csv_as_dataframe(csv_file, reduce_frac=None, random_state=42):
     icing_df = pd.read_csv(csv_file)
+
     # Random selection of reduce_frac of the rows
     if reduce_frac is not None:
-        icing_df = icing_df.sample(frac=reduce_frac)
-    print(icing_df.describe())
-    print(icing_df.shape)
+        icing_df = icing_df.sample(axis=0, frac=reduce_frac, random_state=random_state)
+
     return icing_df
 
 
@@ -72,9 +72,8 @@ def get_feature_target_data(data_frame, standardize=True):
     # Remove rows with NaN values
     # icing_df = icing_df.dropna()
 
-    print('num obs, features: ', icing_df.shape)
-
     x = np.asarray(icing_df[params])
+    print('num obs, features: ', x.shape)
     if standardize:
         x = preprocessing.StandardScaler().fit(x).transform(x)
 
-- 
GitLab