diff --git a/modules/machine_learning/classification.py b/modules/machine_learning/classification.py index c588ceb29cc9c9d41c78e29a1740024ef74d76d5..79813e32c46a9641a78e0d8ddd21c9a18d8c2eaa 100644 --- a/modules/machine_learning/classification.py +++ b/modules/machine_learning/classification.py @@ -67,12 +67,11 @@ def plot_confusion_matrix(cm, classes, plt.xlabel('Predicted label') -def get_csv_as_dataframe(csv_file, reduce_frac=None, random_state=42): +def get_csv_as_dataframe(csv_file, reduce_frac=1.0, random_state=42): icing_df = pd.read_csv(csv_file) # Random selection of reduce_frac of the rows - if reduce_frac is not None: - icing_df = icing_df.sample(axis=0, frac=reduce_frac, random_state=random_state) + icing_df = icing_df.sample(axis=0, frac=reduce_frac, random_state=random_state) # # remove approximately half of rows where column_name equals to column_value # column_name = 'icing_intensity'