Source code for supernnova.paper.superNNova_thread

import os
import pandas as pd
from pathlib import Path
import supernnova.conf as conf
from . import superNNova_plots as sp
from . import superNNova_metrics as sm
from ..utils import logging_utils as lu
from ..visualization import early_prediction

"""
Obtaining metrics and plots for SuperNNova paper

Selection of models is hard coded

Code is far from optimized
"""

"""
Best performing algorithms in SuperNNova
"""
Base = (
    "DES_vanilla_CLF_2_R_None_photometry_DF_1.0_N_global_lstm_32x2_0.05_128_True_mean_C"
)
Var = "DES_variational_CLF_2_R_None_photometry_DF_1.0_N_global_lstm_32x2_0.01_128_True_mean_C_WD_1e-07"
BBB = "DES_bayesian_CLF_2_R_None_photometry_DF_1.0_N_global_lstm_32x2_0.05_128_True_mean_Bayes_0.75_-1.0_-7.0_4.0_3.0_-0.5_-0.1_3.0_2.0"  # _KL_0.1"
RF = "DES_randomforest_CLF_2_R_None_saltfit_DF_1.0_N_global"
list_models = [RF, Base, Var, BBB]
list_models_rnn = [Base, Var, BBB]

# useful formats
Base_salt = Base.replace("photometry", "saltfit")


[docs]def SuperNNova_stats_and_plots(settings): """Reproduce stats and plots used for SuperNNova paper. BEWARE: Selection is hardcoded Args: settings (ExperimentSettings): custom class to hold hyperparameters """ # Load summary statistics df_stats = pd.read_csv(Path(settings.stats_dir) / "summary_stats.csv") # Create latex tables # sm.create_accuracy_latex_tables(df_stats, settings) # Rest of stats and plots in paper # can be ran in debug mode: only printing model names # or in no plot mode: only printing stats SuperNNova_stats_and_plots_thread(df_stats, settings, plots=True, debug=False)
[docs]def SuperNNova_stats_and_plots_thread(df, settings, plots=True, debug=False): """Stats quoted in paper which are not in the latex tables and plots Args: df (pandas.DataFrame) : summary statistics df settings (ExperimentSettings): custom class to hold hyperparameters plots (Boolean optional): make pltos or only printout stats debug (Boolean optional): only print tasks Returns: printout: stats as organized in paper figures (png) : figures for paper at settings.dump_dir/figures/ lightcurves (png): lightcurves used on paper at settings.dump_dir/lightcurves/modelname.*png """ """ Ordered as in paper """ pd.set_option("max_colwidth", 1000) print(lu.str_to_greenstr("STATISTICS USED IN SUPERNNOVA")) # Baseline experiments baseline(df, settings, plots, debug) # Bayesian experiments df_delta, df_delta_ood = sm.get_delta_metrics(df, settings) bayesian(df, df_delta, df_delta_ood, settings, plots, debug) # Towards statistical analyses/cosmology towards_cosmo(df, df_delta, df_delta_ood, settings, plots, debug)
[docs]def baseline(df, settings, plots, debug): """ Baseline RNN """ # 0. Figure example if plots: print( lu.str_to_yellowstr("Plotting candidates for Baseline binary (Figure 2.)") ) model_file = f"{settings.models_dir}/{Base.replace('DES_vanilla_','vanilla_S_0_')}/{Base.replace('DES_vanilla_','vanilla_S_0_')}.pt" if os.path.exists(model_file): if debug: print(model_file) else: model_settings = conf.get_settings_from_dump(model_file) early_prediction.make_early_prediction( model_settings, nb_lcs=20, do_gifs=True ) else: print(lu.str_to_redstr(f"File not found {model_file}")) # 1. Hyper-parameters # saltfit, DF 0.2 sel_criteria = ["DES_vanilla_CLF_2_R_None_saltfit_DF_0.2"] print(lu.str_to_bluestr(f"Hyperparameters {sel_criteria}")) if not debug: sm.get_metric_ranges(df, sel_criteria) # 2. Normalization # saltfit, DF 0.5 sel_criteria = Base_salt.replace("DF_1.0", "DF_0.5").split("global") print(lu.str_to_bluestr(f"Normalization {sel_criteria}")) if not debug: sm.get_metric_ranges(df, sel_criteria) # 3. Comparing with other methods print(lu.str_to_bluestr("Other methods:")) # Figure: accuracy vs. number of SNe if plots: print(lu.str_to_yellowstr("Plotting accuracy vs. SNe (Figure 3.)")) if not debug: sp.performance_plots(settings) # baseline best, saltfit if debug: print(Base_salt) print(Base_salt.replace("DF_1.0", "DF_0.05")) print(Base_salt.replace("DF_1.0", "DF_0.05").replace("None", "zpho")) else: sm.acc_auc_df(df, [Base_salt], data="saltfit") # RF and baseline comparisson with Charnock Moss sm.acc_auc_df( df, [ RF, Base_salt.replace("DF_1.0", "DF_0.05"), Base_salt.replace("DF_1.0", "DF_0.05").replace("None", "zpho"), ], ) # 4. Redshift, contamination # baseline saltfit, 1.0, all redshifts sel_criteria = Base_salt.split("None") print("salt") if debug: print(sel_criteria, Base.split("None")) if not debug: sm.print_contamination(df, sel_criteria, settings, data="saltfit") print("photometry") if not debug: sm.print_contamination(df, Base.split("None"), settings, data="photometry") # Multiclass # Plotting Confusion Matrix for just one seed if plots: print( lu.str_to_yellowstr( "Plotting confusion matrix for multiclass classification (Figure 4.)" ) ) for target in [3, 7]: settings.prediction_files = [ settings.models_dir + "/" + model.strip("DES_").replace("CLF_2", f"S_0_CLF_{target}") + "/" + model.replace("DES_", "PRED_DES_").replace( "CLF_2", f"S_0_CLF_{target}" ) + ".pickle" for model in [Base, Var, BBB] ] if debug: print(settings.prediction_files) else: sp.science_plots(settings, onlycnf=True)
# Uncomment to see some examples of seven-way classification # print(lu.str_to_yellowstr("Plotting candidates for multiclass classification")) # model_file = f"{settings.models_dir}/{Base.replace('DES_vanilla_','vanilla_S_0_').replace('CLF_2','CLF_7')}/{Base.replace('DES_vanilla_','vanilla_S_0_').replace('CLF_2','CLF_7')}.pt" # if os.path.exists(model_file): # model_settings = conf.get_settings_from_dump(model_file) # early_prediction.make_early_prediction(model_settings, nb_lcs=20) # else: # print(lu.str_to_redstr(f"File not found {model_file}"))
[docs]def bayesian(df, df_delta, df_delta_ood, settings, plots, debug): """ Bayesian RNNs: BBB and Variational """ # 2. Variational hyper-parameters sel_criteria = ["DES_variational_CLF_2_R_None_saltfit_DF_0.2_N_global_lstm_32x2_"] print(lu.str_to_bluestr(f"Hyperparameters {sel_criteria}")) if not debug: sm.get_metric_ranges(df, sel_criteria) # 3. BBB hyper-parameters sel_criteria = ["DES_bayesian_CLF_2_R_None_saltfit_DF_0.2_N_global_lstm_32x2_"] print(lu.str_to_bluestr(f"Hyperparameters {sel_criteria}")) if not debug: sm.get_metric_ranges(df, sel_criteria) # 2 and 3. Best models print(lu.str_to_bluestr("Best performing Bayesian accuracies")) if debug: print( [ Var, BBB, Var.replace("None", "zpho"), BBB.replace("None", "zpho"), Var.replace("None", "zspe"), BBB.replace("None", "zspe"), ] ) else: sm.acc_auc_df( df, [ Var, BBB, Var.replace("None", "zpho"), BBB.replace("None", "zpho"), Var.replace("None", "zspe"), BBB.replace("None", "zspe"), ], ) # contamination # baseline saltfit, 1.0, all redshifts for model in [Var, BBB]: sel_criteria = model.split("None") if debug: print("contamination") print(sel_criteria) else: sm.print_contamination(df, sel_criteria, settings, data="photometry") # 4. Uncertainties print(lu.str_to_bluestr("Best performing Bayesian uncertainties")) print("Epistemic behaviour") for model in [Var, BBB]: m_right = model.replace("photometry", "saltfit") m_left = model.replace("photometry", "saltfit").replace("DF_1.0", "DF_0.5") print("salt", m_left, m_right) df_sel = df_delta[ (df_delta["model_name_left"] == m_left) & (df_delta["model_name_right"] == m_right) ] if not debug: sm.nice_df_print( df_sel, keys=[ "all_accuracy_mean_delta", "mean_all_class0_std_dev_mean_delta", "all_entropy_mean_delta", ], ) m_right = model m_left = model.replace("DF_1.0", "DF_0.43") print("complete", m_left, m_right) df_sel = df_delta[ (df_delta["model_name_left"] == m_left) & (df_delta["model_name_right"] == m_right) ] if not debug: sm.nice_df_print( df_sel, keys=[ "all_accuracy_mean_delta", "mean_all_class0_std_dev_mean_delta", "all_entropy_mean_delta", ], ) print("Uncertainty size") df_sel = df[df["model_name_noseed"].isin([Var, BBB])] df_sel = df_sel.round(4) if not debug: sm.nice_df_print( df_sel, keys=["mean_all_class0_std_dev_mean", "mean_all_class0_std_dev_std"] ) if plots: print( lu.str_to_yellowstr( "Plotting candidates for multiclass classification (Fig. 5)" ) ) for model in [Var, BBB]: model_file = ( f"{settings.models_dir}/" + model.replace("DES_", "").replace("CLF_2", "S_0_CLF_7") + "/" + model.replace("DES_", "").replace("CLF_2", "S_0_CLF_7") + ".pt" ) if os.path.exists(model_file): if debug: print(model_file) else: model_settings = conf.get_settings_from_dump(model_file) early_prediction.make_early_prediction( model_settings, nb_lcs=20, do_gifs=True ) else: print(lu.str_to_redstr(f"File not found {model_file}")) print(lu.str_to_yellowstr("Adding gifs for binary classification")) for model in [Var, BBB]: model_file = ( f"{settings.models_dir}/" + model.strip("DES_").replace("CLF_2", "S_0_CLF_2") + "/" + model.strip("DES_").replace("CLF_2", "S_0_CLF_2") + ".pt" ) if os.path.exists(model_file): if debug: print(model_file) else: model_settings = conf.get_settings_from_dump(model_file) early_prediction.make_early_prediction( model_settings, nb_lcs=10, do_gifs=True )
[docs]def towards_cosmo(df, df_delta, df_delta_ood, settings, plots, debug): """ Towards cosmology """ # 1. Calibration print(lu.str_to_bluestr("Calibration")) df_sel = df.copy() df_sel = df_sel.round(4) # rf can't be done with photometry df_sel = df_sel[ ( df_sel["model_name_noseed"].isin( [item.replace("photometry", "saltfit") for item in list_models] ) ) & (df_sel["source_data"] == "saltfit") ] print("saltfit") if not debug: sm.nice_df_print( df_sel, keys=[ "model_name_noseed", "calibration_dispersion_mean", "calibration_dispersion_std", ], ) # without rf print("photometry") df_sel = df.copy() df_sel = df_sel[ (df_sel["model_name_noseed"].isin([Base, Var, BBB])) & (df_sel["source_data"] == "photometry") ] if not debug: sm.nice_df_print( df_sel, keys=[ "model_name_noseed", "calibration_dispersion_mean", "calibration_dispersion_std", ], ) # Calibration vs. training set size # using salt print(lu.str_to_bluestr("Calibration vs. data set size")) print("Baseline") sel_criteria = Base_salt.split("DF_1.0") if debug: print(sel_criteria) else: sm.get_metric_ranges( df, sel_criteria, metric="calibration_dispersion", round_output=5 ) # Calibration vs. dataset nature sel_criteria = [ "DES_vanilla_CLF_2_R_None_saltfit_DF_1.0_N_global_lstm_32x2_0.05_128_True_mean_C" ] if debug: print(sel_criteria) else: sm.get_metric_ranges( df, sel_criteria, metric="calibration_dispersion", round_output=5 ) # Calibration figure if plots: print(lu.str_to_yellowstr("Plotting reliability diagram (Figure 6)")) tmp_pred_files = settings.prediction_files settings.prediction_files = [ settings.models_dir + "/" + model.strip("DES_") .replace("CLF_2", "S_0_CLF_2") .replace("photometry", "saltfit") + "/" + model.replace("DES_", "PRED_DES_") .replace("CLF_2", "S_0_CLF_2") .replace("photometry", "saltfit") + ".pickle" for model in [RF, Base, Var, BBB] ] if debug: print(settings.prediction_files) else: sp.plot_calibration(settings) settings.prediction_files = tmp_pred_files # 2. Representativeness print(lu.str_to_bluestr("Representativeness")) for model in [Base, Var, BBB]: m_left = model.replace("photometry", "saltfit") m_right = model.replace("DF_1.0", "DF_0.43") print(m_left, m_right) df_sel = df_delta[ (df_delta["model_name_left"] == m_left) & (df_delta["model_name_right"] == m_right) ] if not debug: sm.nice_df_print( df_sel, keys=[ "all_accuracy_mean_delta", "mean_all_class0_std_dev_mean_delta", "all_entropy_mean_delta", ], ) # 3. OOD print(lu.str_to_bluestr("Out-of-distribution light-curves")) # OOD type assignement figure if plots: print(lu.str_to_yellowstr("Plotting OOD classification percentages (Figure 8)")) if not debug: sp.create_OOD_classification_plots(df, list_models_rnn, settings) # Get entropy print("binary") df_sel = df_delta_ood[df_delta_ood["model_name_noseed"].isin(list_models_rnn)] if not debug: sm.nice_df_print(df_sel) print("ternary") list_models_sel = [item.replace("CLF_2", "CLF_3") for item in list_models_rnn] df_sel = df_delta_ood[df_delta_ood["model_name_noseed"].isin(list_models_sel)] if not debug: sm.nice_df_print(df_sel) print("seven-way") list_models_sel = [item.replace("CLF_2", "CLF_7") for item in list_models_rnn] df_sel = df_delta_ood[df_delta_ood["model_name_noseed"].isin(list_models_sel)] if not debug: sm.nice_df_print(df_sel) if plots: print( lu.str_to_yellowstr( "Plotting OOD candidates with seven-way classification (Figure 9)" ) ) model_file = ( f"{settings.models_dir}/" + Var.replace("DES_variational_", "variational_S_0_").replace( "CLF_2", "CLF_7" ) + "/" + Var.replace("DES_variational_", "variational_S_0_").replace( "CLF_2", "CLF_7" ) + ".pt" ) if os.path.exists(model_file): if debug: print(model_file) else: model_settings = conf.get_settings_from_dump(model_file) early_prediction.make_early_prediction(model_settings, nb_lcs=20) else: print(lu.str_to_redstr(f"File not found {model_file}")) # 4. Cosmology print(lu.str_to_bluestr("SNe Ia for cosmology")) # Plotting Hubble residuals adn other science plots for just one seed if plots: print(lu.str_to_yellowstr("Plotting Hubble residuals (Figures 10 and 11)")) tmp_pred_files = settings.prediction_files settings.prediction_files = [ settings.models_dir + "/" + model.strip("DES_").replace("CLF_2", "S_0_CLF_2") + "/" + model.replace("DES_", "PRED_DES_").replace("CLF_2", "S_0_CLF_2") + ".pickle" for model in [Base, Var, BBB] ] if debug: print(settings.prediction_files) else: sp.science_plots(settings) settings.prediction_files = tmp_pred_files