From c1572fe94a5e7cdee1eea35b29e76cbde8cf4c3a Mon Sep 17 00:00:00 2001 From: tomrink <rink@ssec.wisc.edu> Date: Thu, 25 Apr 2024 14:03:13 -0500 Subject: [PATCH] snapshot... --- modules/machine_learning/classification.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/modules/machine_learning/classification.py b/modules/machine_learning/classification.py index 814af787..3e30531f 100644 --- a/modules/machine_learning/classification.py +++ b/modules/machine_learning/classification.py @@ -50,13 +50,13 @@ def plot_confusion_matrix(cm, classes, plt.xlabel('Predicted label') -def get_csv_as_dataframe(csv_file, reduce_frac=None): +def get_csv_as_dataframe(csv_file, reduce_frac=None, random_state=42): icing_df = pd.read_csv(csv_file) + # Random selection of reduce_frac of the rows if reduce_frac is not None: - icing_df = icing_df.sample(frac=reduce_frac) - print(icing_df.describe()) - print(icing_df.shape) + icing_df = icing_df.sample(axis=0, frac=reduce_frac, random_state=random_state) + return icing_df @@ -72,9 +72,8 @@ def get_feature_target_data(data_frame, standardize=True): # Remove rows with NaN values # icing_df = icing_df.dropna() - print('num obs, features: ', icing_df.shape) - x = np.asarray(icing_df[params]) + print('num obs, features: ', x.shape) if standardize: x = preprocessing.StandardScaler().fit(x).transform(x) -- GitLab