import re
import h5py
import torch
import numpy as np
import pandas as pd
import torch.nn as nn
from tqdm import tqdm
from pathlib import Path
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from ..utils import data_utils as du
from ..utils import logging_utils as lu
from ..utils import training_utils as tu
from ..utils.visualization_utils import FILTER_COLORS, ALL_COLORS, LINE_STYLE
plt.switch_backend("agg")
def get_predictions(settings, dict_rnn, X, target, OOD=None):
list_data = [(X.copy(), target)]
_, X_tensor, *_ = tu.get_data_batch(list_data, [0], settings, OOD=OOD)
if settings.use_cuda:
X_tensor.cuda()
d_pred = {key: {"prob": []} for key in dict_rnn}
# Apply rnn to obtain prediction
for model_type, rnn in dict_rnn.items():
n = settings.num_inference_samples if "variational" in model_type else 1
new_size = (X_tensor.size(0), n, X_tensor.size(2))
if "bayesian" in model_type:
# Loop over num samples to obtain predictions
list_out = [
rnn(X_tensor.expand(new_size))
for i in range(settings.num_inference_samples)
]
out = torch.cat(list_out, dim=0)
# Apply softmax to obtain a proba
pred_proba = nn.functional.softmax(out, dim=-1).data.cpu().numpy()
else:
out = rnn(X_tensor.expand(new_size))
# Apply softmax to obtain a proba
pred_proba = nn.functional.softmax(out, dim=-1).data.cpu().numpy()
# Add to buffer list
d_pred[model_type]["prob"].append(pred_proba)
# Stack
for key in dict_rnn.keys():
arr_proba = np.stack(d_pred[key]["prob"], axis=0)
d_pred[key]["prob"] = arr_proba # arr_prob is (T, num_samples, 2)
return d_pred, X_tensor.squeeze().detach().cpu().numpy()
def plot_distributions(settings, list_d_plot):
plt.figure(figsize=(20, 30))
gs = gridspec.GridSpec(8, 2, hspace=0.3, wspace=0.1)
for i in range(len(list_d_plot)):
d_plot = list_d_plot[i]
SNID = d_plot["SNID"]
OOD = d_plot["OOD"]
target = d_plot["target"]
redshift = d_plot["redshift"]
peak_MJD = d_plot["peak_MJD"]
d_pred = d_plot["d_pred"]
# Plot the lightcurve
ax = plt.subplot(gs[2 * i])
for flt in settings.list_filters:
flt_time = d_plot[flt]["MJD"]
# Only plot a time series if it's non empty
if len(flt_time) > 0:
flux = d_plot[flt]["FLUXCAL"]
fluxerr = d_plot[flt]["FLUXCALERR"]
ax.errorbar(
flt_time,
flux,
yerr=fluxerr,
fmt="o",
label=f"Filter {flt}",
color=FILTER_COLORS[flt],
)
ax.set_ylabel("FLUXCAL", fontsize=24)
if i == 7:
ax.set_xlabel("days", fontsize=24)
ylim = ax.get_ylim()
plt.yticks(fontsize=14)
plt.xticks(fontsize=14)
SNtype = du.sntype_decoded(target, settings)
if OOD is not None:
ax.set_title(f"OOD {OOD} ID: {SNID}", fontsize=24)
else:
ax.set_title(
SNtype + f" (ID: {SNID}, redshift: {redshift:.3g})", fontsize=24
)
# Add PEAKMJD
ax.plot([peak_MJD, peak_MJD], ylim, "k--", label="Peak MJD")
# Plot the classifications
ax = plt.subplot(gs[2 * i + 1])
ax.set_xlim(-0.1, 1.1)
for idx, key in enumerate(d_pred.keys()):
arr_prob = np.squeeze(d_pred[key]["prob"])
all_probs = np.ravel(arr_prob)
_, bin_edges = np.histogram(all_probs, bins=25)
for class_prob in range(settings.nb_classes):
color = ALL_COLORS[class_prob + idx * settings.nb_classes]
label = du.sntype_decoded(class_prob, settings)
linestyle = LINE_STYLE[class_prob]
if len(d_pred) > 1:
label += f" {key}"
ax.hist(
arr_prob[:, class_prob],
color=color,
histtype="step",
linestyle=linestyle,
linewidth=2,
label=label,
bins=bin_edges,
)
plt.yticks(fontsize=14)
plt.xticks(fontsize=14)
ax.set_yscale("log")
if i == len(list_d_plot) - 1:
ax.legend(
bbox_to_anchor=(-2.8, 5.8),
loc=2,
borderaxespad=0.0,
fontsize=16,
ncol=(settings.nb_classes // 2) + 1,
framealpha=0,
)
if 2 * i + 1 == 15:
ax.set_xlabel("classification probability", fontsize=24)
plt.subplots_adjust(
left=0.08, right=0.99, bottom=0.03, top=0.98, wspace=0.0, hspace=0.02
)
if len([settings.model_files]) == 1:
parent_dir = Path(settings.model_files[0]).parent.name
fig_path = f"{settings.lightcurves_dir}/{parent_dir}/prediction_distribution"
fig_name = f"{parent_dir}.png"
else:
fig_path = f"{settings.lightcurves_dir}/{settings.pytorch_model_name}/prediction_distribution"
fig_name = f"{settings.pytorch_model_name}.png"
Path(fig_path).mkdir(parents=True, exist_ok=True)
# plt.tight_layout()
plt.savefig(Path(fig_path) / fig_name)
plt.clf()
plt.close()
[docs]def plot_prediction_distribution(settings):
"""Load model corresponding to settings
or (if specified) load a list of models.
Args:
settings: (ExperimentSettings) custom class to hold hyperparameters
int (nb_lcs): number of light-curves to plot, default is 1
"""
# No plot for ternary tasks
if settings.nb_classes == 3:
return
settings.random_length = False
settings.random_redshift = False
# Load the test data
list_data_test = tu.load_HDF5(settings, test=True)
# Load features list
file_name = f"{settings.processed_dir}/database.h5"
with h5py.File(file_name, "r") as hf:
features = hf["features"][settings.idx_features]
# Load RNN model
dict_rnn = {}
if settings.model_files is None:
settings.model_files = [f"{settings.rnn_dir}/{settings.pytorch_model_name}.pt"]
else:
assert (
len(settings.model_files) == 1
), "Only one model file allowed at a time for these plots"
# Check that the settings match the model file
base_files = [Path(f).name for f in settings.model_files]
classes = [int(re.search(r"(?<=CLF\_)\d+(?=\_)", f).group()) for f in base_files]
redshifts = [re.search(r"(?<=R\_)[A-Za-z]+(?=\_)", f).group() for f in base_files]
nb_classes, redshift = classes[0], redshifts[0]
assert settings.nb_classes == nb_classes, lu.str_to_redstr(
"Incompatible nb_classes between CLI and model files"
)
assert str(settings.redshift) == redshift, lu.str_to_redstr(
"Incompatible redshift between CLI and model files"
)
for model_file in settings.model_files:
if "variational" in model_file:
settings.model = "variational"
if "vanilla" in model_file:
settings.model = "vanilla"
if "bayesian" in model_file:
settings.model = "bayesian"
rnn = tu.get_model(settings, len(settings.training_features))
rnn_state = torch.load(model_file, map_location=lambda storage, loc: storage)
rnn.load_state_dict(rnn_state)
rnn.to(settings.device)
rnn.eval()
name = (
f"{settings.model} photometry"
if "photometry" in model_file
else f"{settings.model} salt"
)
dict_rnn[name] = rnn
# lOad SN info
SNinfo_df = du.load_HDF5_SNinfo(settings)
targets = np.array([o[1] for o in list_data_test])
if settings.nb_classes == 2:
# 2 lightcurves of each type
idxs_keep = (
np.where(targets == 0)[0][:2].tolist()
+ np.where(targets == 1)[0][:2].tolist()
)
elif settings.nb_classes == 7:
# Ia, Ib, Ic, IIp
idxs_keep = [
np.where(targets == list(settings.sntypes.values()).index(i))[0][0]
for i in ["Ia", "Ib", "Ic", "IIP"]
]
# Carry out 8 plots: 4 real light curves + 4 OOD
list_OOD_types = ["random", "sin", "reverse", "shuffle"]
list_data_test = [(list_data_test[i], None) for i in idxs_keep]
list_data_test += [(o[0], ood) for (o, ood) in zip(list_data_test, list_OOD_types)]
list_d_plot = []
# Loop over data to plot prediction
for ((X, target, SNID, _, X_ori), OOD) in tqdm(list_data_test, ncols=100):
redshift = SNinfo_df[SNinfo_df["SNID"] == SNID]["SIM_REDSHIFT_CMB"].values[0]
peak_MJD = SNinfo_df[SNinfo_df["SNID"] == SNID]["PEAKMJDNORM"].values[0]
# Prepare plotting data in a dict
d_plot = {
flt: {"FLUXCAL": [], "FLUXCALERR": [], "MJD": []}
for flt in settings.list_filters
}
with torch.no_grad():
d_pred, X_normed = get_predictions(settings, dict_rnn, X, target, OOD=OOD)
# X here has been normalized. We unnormalize X
X_unnormed = tu.unnormalize_arr(X_normed, settings)
# Check we do recover X_ori when OOD is None
if OOD is None:
assert np.all(np.isclose(np.ravel(X_ori), np.ravel(X_unnormed), atol=1e-2))
# TODO: IMPROVE
df_temp = pd.DataFrame(data=X_unnormed, columns=features)
arr_time = np.cumsum(df_temp.delta_time.values)
for flt in settings.list_filters:
non_zero = np.where(
~np.isclose(df_temp[f"FLUXCAL_{flt}"].values, 0, atol=1e-2)
)[0]
d_plot[flt]["FLUXCAL"] = df_temp[f"FLUXCAL_{flt}"].values[non_zero]
d_plot[flt]["FLUXCALERR"] = df_temp[f"FLUXCALERR_{flt}"].values[non_zero]
d_plot[flt]["MJD"] = arr_time[non_zero]
d_plot["redshift"] = redshift
d_plot["peak_MJD"] = peak_MJD
d_plot["SNID"] = SNID
d_plot["OOD"] = OOD
d_plot["target"] = target
d_plot["d_pred"] = d_pred
list_d_plot.append(d_plot)
plot_distributions(settings, list_d_plot)
lu.print_green("Finished plotting lightcurves and predictions ")