diff --git a/modules/machine_learning/classification.py b/modules/machine_learning/classification.py index 3e30531fe61677ea71b5f54e0168be436d5dea6d..0d4c399330e844497b90b9f11f912865ba633397 100644 --- a/modules/machine_learning/classification.py +++ b/modules/machine_learning/classification.py @@ -57,6 +57,18 @@ def get_csv_as_dataframe(csv_file, reduce_frac=None, random_state=42): if reduce_frac is not None: icing_df = icing_df.sample(axis=0, frac=reduce_frac, random_state=random_state) + # # remove approximately half of rows where column_name equals to column_value + # column_name = 'icing_intensity' + # column_value = -1 + # if column_name in icing_df.columns: + # df_to_reduce = icing_df[icing_df[column_name] == column_value] + # icing_df = icing_df[icing_df[column_name] != column_value] + # + # if reduce_frac is not None: + # df_to_reduce = df_to_reduce.sample(axis=0, frac=0.5, random_state=random_state) + # + # icing_df = pd.concat([icing_df, df_to_reduce]) + return icing_df