Source code for supernnova.training.train_randomforest

import numpy as np
from sklearn.ensemble import RandomForestClassifier

from ..utils import data_utils as du
from ..utils import training_utils as tu
from ..utils import logging_utils


[docs]def train(settings): """Train random forest models - Features are stored in a .FITRES file found in settings.processed_dir. - Carry out a train/val split based on predefined splits - Compute evaluation metrics on train and validation set - Save trained models to settings.models_dir Args: settings: (ExperimentSettings) custom class to hold hyperparameters """ assert settings.source_data == "saltfit", logging_utils.str_to_redstr( "Only salfit is a valid data source for random forest" ) assert settings.nb_classes == 2, logging_utils.str_to_redstr( "Binary classification is the only task allowed for RandomForest" ) clf = RandomForestClassifier( bootstrap=settings.bootstrap, min_samples_leaf=settings.min_samples_leaf, n_estimators=settings.n_estimators, min_samples_split=settings.min_samples_split, criterion=settings.criterion, max_features=settings.max_features, max_depth=settings.max_depth, random_state=settings.seed, n_jobs=-1, ) ################### # Data ################### # load data df_data = du.load_fitfile(settings) SNID = df_data["SNID"].values # Load training and validation SNID to make our train/val split sn_df = du.load_HDF5_SNinfo(settings) df_train = sn_df[sn_df["dataset_photometry_2classes"] == 0] n_train = len(df_train) df_train = df_train[: int(settings.data_fraction * n_train)] df_val = sn_df[sn_df["dataset_photometry_2classes"] == 1] n_val = len(df_val) df_val = df_val[: int(settings.data_fraction * n_val)] # Compute train/val split index idx_train = np.where(np.in1d(SNID, df_train.SNID.values))[0] idx_val = np.where(np.in1d(SNID, df_val.SNID.values))[0] # Add redshift features if required by settings df_data = du.add_redshift_features(settings, df_data) ################### # Training ################### # Prepare features and target X = df_data[settings.randomforest_features].values y = df_data["target_2classes"].values logging_utils.print_green("Features", ",".join(settings.randomforest_features)) # Apply train/val split X_train, X_val = X[idx_train], X[idx_val] y_train, y_val = y[idx_train], y[idx_val] # Fit and evaluate model clf = tu.train_and_evaluate_randomforest_model(clf, X_train, y_train, X_val, y_val) # save the model to disk tu.save_randomforest_model(settings, clf)