diff --git a/modules/machine_learning/classification.py b/modules/machine_learning/classification.py index 814af787a6a38b53110691c41780326f9c6cc097..3e30531fe61677ea71b5f54e0168be436d5dea6d 100644 --- a/modules/machine_learning/classification.py +++ b/modules/machine_learning/classification.py @@ -50,13 +50,13 @@ def plot_confusion_matrix(cm, classes, plt.xlabel('Predicted label') -def get_csv_as_dataframe(csv_file, reduce_frac=None): +def get_csv_as_dataframe(csv_file, reduce_frac=None, random_state=42): icing_df = pd.read_csv(csv_file) + # Random selection of reduce_frac of the rows if reduce_frac is not None: - icing_df = icing_df.sample(frac=reduce_frac) - print(icing_df.describe()) - print(icing_df.shape) + icing_df = icing_df.sample(axis=0, frac=reduce_frac, random_state=random_state) + return icing_df @@ -72,9 +72,8 @@ def get_feature_target_data(data_frame, standardize=True): # Remove rows with NaN values # icing_df = icing_df.dropna() - print('num obs, features: ', icing_df.shape) - x = np.asarray(icing_df[params]) + print('num obs, features: ', x.shape) if standardize: x = preprocessing.StandardScaler().fit(x).transform(x)