Validation documentation

Validation RandomForest

supernnova.validation.validate_randomforest.get_predictions(settings, model_file=None)[source]

Test random forest models on independent test set

Features are stored in a .FITRES file found in data_dir Use predefined splits to select test set Save predicted target and probabilities to preds_dir

  • settings (ExperimentSettings) – custom class to hold hyperparameters

  • model_file (str) – path to saved randomforest model

Validation RNN

supernnova.validation.validate_rnn.find_idx(array, value)[source]

Utility to find the index of the element of array that most closely matches value

  • array (np.array) – The array in which to search

  • value (float) – The value for which we are looking for a match


(int) the index of of the element of array that most closely matches value

supernnova.validation.validate_rnn.get_batch_predictions(rnn, X, target)[source]

Utility to obtain predictions for a given batch

  • rnn (torch.nn) – The RNN model

  • X (torch.Tensor) – The batch on which to carry out predictions

  • target (torch.longTensor) – The true class of each element in the batch


Tuple containing

  • arr_preds (np.array): predictions

  • arr_target (np.array): actual targets

supernnova.validation.validate_rnn.get_batch_predictions_MFE(rnn, X, target)[source]

Utility to obtain predictions for a given batch

  • rnn (torch.nn) – The RNN model

  • X (torch.Tensor) – The batch on which to carry out predictions

  • target (torch.longTensor) – The true class of each element in the batch


Tuple containing

  • arr_preds (np.array): predictions

  • arr_target (np.array): actual targets

supernnova.validation.validate_rnn.get_predictions(settings, model_file=None)[source]

Obtain predictions for a given RNN model specified by the settings argument or alternatively, by a model_file

  • Models are benchmarked on the test data set

  • Batch size can be controled to speed up predictions

  • For Bayesian models, multiple predictions are carried to

    obtain a distribution of predictions

  • Predictions are computed for full lightcurves, and around the peak light

  • Predictions are saved to a pickle file (for faster loading)

  • settings (ExperimentSettings) – custom class to hold hyperparameters

  • model_file (str) – Path to saved model weights. Default: None


Test RNN models inference speed

  • Models are benchmarked on the test data set

  • Batch size can be controled to speed up predictions

  • For Bayesian models, multiple predictions are carried to

    obtain a distribution of predictions

  • Results are saved to a .csv for future use


settings (ExperimentSettings) – custom class to hold hyperparameters



Aggregate all pre-computed METRICS files into a single dataframe for analysis

Save a csv dataframe aggregating all the metrics


settings (ExperimentSettings) – custom class to hold hyperparameters

supernnova.validation.metrics.get_metrics_singlemodel(settings, prediction_file=None, model_type='rnn')[source]

Launch computation of all evaluation metrics for a given model, specified by the settings object or by a model file

Save a pickled dataframe (we pickle because we’re saving numpy arrays, which are not easily savable with the to_csv method).

  • settings (ExperimentSettings) – custom class to hold hyperparameters

  • prediction_file (str) – Path to saved predictions. Default: None

  • model_type (str) – Choose rnn or randomforest


(pandas.DataFrame) holds the performance metrics for this dataframe

supernnova.validation.metrics.get_rnn_performance_metrics_singlemodel(settings, df, host_zspe_list)[source]

Compute performance metrics (accuracy, AUC, purity etc) for an RNN model

  • Compute metrics around peak light (i.e. PEAKMJD) and for the full lightcurve.

  • For bayesian models, compute multiple predictions per lightcurve and then take the median

  • settings (ExperimentSettings) – custom class to hold hyperparameters

  • df (pandas.DataFrame) – dataframe containing a model’s predictions

  • host_zspe_list (list) – available host galaxy spectroscopic redshifts


(pandas.DataFrame) holds the performance metrics for this dataframe

supernnova.validation.metrics.get_randomforest_performance_metrics_singlemodel(settings, df, host_zspe_list)[source]

Compute performance metrics (accuracy, AUC, purity etc) for a randomforest model

  • settings (ExperimentSettings) – custom class to hold hyperparameters

  • df (pandas.DataFrame) – dataframe containing a model’s predictions

  • host_zspe_list (list) – available host galaxy spectroscopic redshifts


(pandas.DataFrame) holds the performance metrics for this dataframe


For any lightcurve, compute the standard deviation of the model’s predictions (this is only valid for bayesian models which yield a distribution of predictions).

Then, compute the mean and std dev of this distribution across all lightcurves A higher mean indicates a model which is less confident in its predictions


df (pandas.DataFrame) – dataframe containing a model’s predictions


(pandas.DataFrame) holds the uncertainty metrics for this dataframe

supernnova.validation.metrics.get_entropy_metrics_singlemodel(df, nb_classes)[source]

Compute the entropy of the predictions Low entropy indicates a model that is very confident of its predictions

  • df (pandas.DataFrame) – dataframe containing a model’s predictions

  • nb_classes (int) – the number of classes in the classification task


(pandas.DataFrame) holds the entropy metrics for this dataframe


Compute probability calibration dataframe. If the calibration curve is close to identity, the model is considered well-calibrated.


df (pandas.DataFrame) – dataframe containing a model’s predictions


(pandas.DataFrame) holds the calibration metrics for this dataframe

supernnova.validation.metrics.get_classification_stats_singlemodel(df, nb_classes)[source]

Find out how many lightcurves are classified in each class

  • df (pandas.DataFrame) – dataframe containing a model’s predictions

  • nb_classes (int) – the number of classes in the classification task


(pandas.DataFrame) holds the calibration metrics for this dataframe