import os
import h5py
import numpy as np
import pandas as pd
from tqdm import tqdm
from astropy.table import Table
from collections import namedtuple
from . import logging_utils
OFFSETS = [-2, -1, 0, 1, 2]
OOD_TYPES = ["random", "reverse", "shuffle", "sin"]
OFFSETS_STR = ["-7", "-2", "-1", "", "+1", "+2", "+30"]
OFFSETS_VAL = []
for v in OFFSETS_STR:
if v == "":
OFFSETS_VAL.append(0)
else:
OFFSETS_VAL.append(int(v.replace("+", "")))
LogStandardized = namedtuple("LogStandardized", ["arr_min", "arr_mean", "arr_std"])
[docs]def load_pandas_from_fit(fit_file_path, columns=None, multidim="drop"):
"""Load a FIT file and cast it to a PANDAS dataframe.
FITS tables can contain vector / array-valued columns (``TDIM > 1``).
``astropy.Table.to_pandas`` raises ``ValueError`` on such columns because
pandas DataFrames cannot natively hold n-dim arrays in a column. This
helper handles them so the pipeline does not crash on otherwise valid
FITS files.
Args:
fit_file_path (str): path to FIT file
columns (list, optional): if given, only these columns are kept
before the pandas conversion. Acts as a whitelist that
avoids loading multi-D columns we don't care about into
pandas in the first place. Names that are not present in
the FITS table are silently ignored.
multidim (str): strategy for multi-dimensional columns that
survive the ``columns`` filter. One of:
* ``"drop"`` (default): remove them with a yellow warning
listing which columns were skipped.
* ``"error"``: raise ``ValueError`` so the caller is forced
to deal with them explicitly.
Returns:
(pandas.DataFrame) load dataframe from FIT file
"""
dat = Table.read(fit_file_path, format="fits")
if columns is not None:
keep = [c for c in columns if c in dat.colnames]
dat = dat[keep]
bad = [c for c in dat.colnames if dat[c].ndim > 1]
if bad:
if multidim == "drop":
logging_utils.print_yellow(
f"Skipping multi-dim FITS cols in {fit_file_path}: {bad}"
)
dat.remove_columns(bad)
elif multidim == "error":
raise ValueError(
f"Multi-dim FITS columns {bad} cannot be cast to pandas; "
f"use multidim='drop' or pass a `columns` whitelist"
)
else:
raise ValueError(
f"Unknown multidim strategy {multidim!r}; "
f"expected 'drop' or 'error'"
)
df = dat.to_pandas()
return df
[docs]def sntype_decoded(target, settings, simplify=False):
"""Match the target class (integer in {0, ..., 6} to the name
of the class, i.e. something like "SN Ia" or "SN CC"
Args:
target (int): specifies the classification target
settings (ExperimentSettings): custom class to hold hyperparameters
simplify (Boolean): if True do not show all classes
Returns:
(str) the name of the class
"""
target_class = getattr(settings, "target_sntype", "Ia")
# Build ordered unique class list consistent with tag_type:
# target_sntype first, then remaining classes in insertion order
used = set()
unique_classes = [
x
for x in settings.sntypes.values()
if x not in used and (used.add(x) or True)
]
if target_class in unique_classes:
unique_classes.remove(target_class)
unique_classes.insert(0, target_class)
if settings.nb_classes > 2:
if target < len(unique_classes):
SNtype = unique_classes[target]
else:
SNtype = f"class_{target}"
else:
if target == 0:
if target_class in unique_classes:
SNtype = f"SN {target_class}"
else:
SNtype = f"SN {unique_classes[0]}"
else:
non_target = [k for k in unique_classes if k != target_class]
SNtype = f"SN {'|'.join(non_target)}"
if simplify:
if target_class in unique_classes:
SNtype = f"non SN {target_class}"
else:
SNtype = f"non SN {unique_classes[0]}"
return SNtype
[docs]def tag_type(df, settings, type_column="TYPE"):
"""Create classes based on a type columns
Depending on the number of classes (2 or all), we create distinct
target columns
Args:
df (pandas.DataFrame): the input dataframe
settings (ExperimentSettings): controls experiment hyperparameters
type_column (str): the type column in df
Returns:
(pandas.DataFrame) the dataframe, with new target columns
"""
# SNTYPE checks
if type_column not in df.keys():
if settings.data_testing:
df[settings.sntype_var] = np.ones(len(df)).astype(int)
else:
logging_utils.print_red(
"Please provide SNTYPE with data (else use data_testing option)"
)
raise Exception
# Auto-detect types in data not present in sntypes and assign as contaminant
df[type_column] = df[type_column].astype(str)
all_types_in_data = set(df[type_column].unique())
missing_types = [t for t in all_types_in_data if t not in settings.sntypes]
if len(missing_types) > 0:
logging_utils.print_yellow(
"Missing sntypes",
f"{missing_types} assigned to 'contaminant' class",
)
for mtyp in missing_types:
settings.sntypes[mtyp] = "contaminant"
# 2 classes: target_sntype vs rest
target_class = getattr(settings, "target_sntype", "Ia")
list_types = list(set([x for x in settings.sntypes.values()]))
if target_class in list_types:
keys_target = [
key for (key, value) in settings.sntypes.items() if value == target_class
]
df["target_2classes"] = df[type_column].apply(
lambda x: 0 if x in keys_target else 1
)
else:
# Fallback: first key in sntypes = class 0
logging_utils.print_yellow(
"target_sntype",
f"'{target_class}' not found in sntypes values {list_types}, "
f"using first entry as class 0",
)
first_key = list(settings.sntypes.keys())[0]
df["target_2classes"] = df[type_column].apply(
lambda x: 0 if x == first_key else 1
)
# All classes — ensure target_sntype is class 0 for consistency with binary
used = set()
unique_classes = [
x for x in settings.sntypes.values() if x not in used and (used.add(x) or True)
]
if target_class in unique_classes:
unique_classes.remove(target_class)
unique_classes.insert(0, target_class)
classes_to_use = dict([(y, x) for x, y in enumerate(unique_classes)])
map_keys_to_classes = {}
for k, v in settings.sntypes.items():
map_keys_to_classes[k] = classes_to_use[v]
df[f"target_{len(unique_classes)}classes"] = df[type_column].apply(
lambda x: map_keys_to_classes[x]
)
return df
[docs]def load_fitfile(settings, verbose=True):
"""Load the FITOPT file as a pandas dataframe
Pickle it for future use (it is faster to load as a pickled dataframe)
Args:
settings (ExperimentSettings): controls experiment hyperparameters
verbose (bool): whether to display logging message. Default: ``True``
Returns:
(pandas.DataFrame) dataframe with FITOPT data
"""
if verbose:
logging_utils.print_green("Loading FITRES file")
if os.access(f"{settings.preprocessed_dir}/FITOPT000.FITRES.pickle", os.R_OK):
df = pd.read_pickle(f"{settings.preprocessed_dir}/FITOPT000.FITRES.pickle")
if verbose:
print(f"Loaded {settings.preprocessed_dir}/FITOPT000.FITRES.pickle")
elif os.access(f"{settings.fits_dir}/FITOPT000.FITRES", os.R_OK) or os.access(
f"{settings.fits_dir}/FITOPT000.FITRES.gz", os.R_OK
):
fit_name = (
f"{settings.fits_dir}/FITOPT000.FITRES"
if os.access(f"{settings.fits_dir}/FITOPT000.FITRES", os.R_OK)
else f"{settings.fits_dir}/FITOPT000.FITRES.gz"
)
df = pd.read_csv(
fit_name, index_col=False, comment="#", delimiter=" ", skipinitialspace=True
)
df = tag_type(df, settings)
# Rename CID to SNID
# SNID is CID in FITOPT000.FITRES
if "SNID" not in df.keys():
df = df.rename(columns={"CID": "SNID"})
# Save to pickle for later use and fast reload
df.to_pickle(f"{settings.preprocessed_dir}/FITOPT000.FITRES.pickle", protocol=4)
if verbose:
print(f"Loaded {fit_name}")
else:
# returning empty df
df = pd.DataFrame()
if verbose:
logging_utils.print_yellow("Warning: No FITRES file to load")
return df
# TODO: double-check the following function that been commented out before deleting. It seems only related to randomforest.
# def add_redshift_features(settings, df):
# """Add redshift features to pandas dataframe.
# Args:
# settings (ExperimentSettings): controls experiment hyperparameters
# df (str): pandas DataFrame with FIT data
# Returns:
# (pandas.DataFrame) the dataframe, possibly with added redshift features
# """
# # check if we use host redshift as feature
# host_features = [f for f in settings.randomforest_features if "HOST" in f]
# use_redshift = len(host_features) > 0
# if use_redshift > 0:
# logging_utils.print_green("Adding redshift features...")
# columns_to_read = ["SNID"] + host_features
# # reading from batch pickles
# list_files = natsorted(glob.glob(f"{settings.preprocessed_dir}/*_PHOT.pickle"))
# # Check file with redshift features exist
# error_msg = "Preprocessed_file not found. Call python run.py --data"
# assert os.path.isfile(list_files[0]), error_msg
# extra_info_df = pd.concat(
# [pd.read_pickle(f)[columns_to_read] for f in list_files]
# )
# # In extra_info_df, there are many SNID duplicates as each row corresponds to a time step in a given curve
# # We use groupby + first to only select the first row of each lightcurve
# # Then we can merge knowing there won't be SNID duplicates in extra_info_df
# extra_info_df = (
# extra_info_df.groupby("SNID")[host_features].first().reset_index()
# )
# # Add redshift info to df
# df = df.merge(extra_info_df, how="left", on="SNID")
# return df
[docs]def compute_delta_time(df):
"""Compute the delta time between two consecutive observations
Args:
df (pandas.DataFrame): dataframe holding lightcurve data
Returns:
(pandas.DataFrame) dataframe holding lightcurve data with delta_time features
"""
# in case the photometyr is not time sorted
df = df.sort_values(["SNID", "MJD"])
df["delta_time"] = df["MJD"].diff()
# Fill the first row with 0 to replace NaN
df.delta_time = df.delta_time.fillna(0)
try:
IDs = df.SNID.values
# Deal with the case where lightcrv_ID is the index
except AttributeError:
assert df.index.name == "SNID"
IDs = df.index.values
# Find idxs of rows where a new light curve start then zero delta_time
idxs = np.where(IDs[:-1] != IDs[1:])[0] + 1
arr_delta_time = df.delta_time.values
arr_delta_time[idxs] = 0
df["delta_time"] = arr_delta_time
return df
[docs]def remove_data_post_large_delta_time(df):
"""
Remove rows in the same light curve after a gap > 150 days
Reason: If no signal has been saved in a time frame of 150 days,
it is unlikely there is much left afterwards
Args:
df (pandas.DataFrame): dataframe holding lightcurve data
Returns:
(pandas.DataFrame) dataframe where large delta time rows have been removed
"""
# Identify indices where delta time is large
list_to_remove = []
idx_high = np.where(df.delta_time.values > 150)[0]
# Identify the lightcurve ID where this happens
IDs = df.SNID.values
# Loop over indices and remove row if they belong to the same light curve
for idx in idx_high:
ID = IDs[idx]
same_lc = True
while same_lc:
list_to_remove.append(idx)
idx += 1
new_ID = IDs[idx]
same_lc = new_ID == ID
df = df.drop(list_to_remove)
# Reset index to account for dropped rows
df = df.reset_index(drop=True)
return df
[docs]def load_HDF5_SNinfo(settings):
"""Load physical information related to the created database of lightcurves
Args:
settings (ExperimentSettings): controls experiment hyperparameters
Returns:
(pandas.DataFrame) dataframe holding physics information about the dataset
"""
file_name = f"{settings.processed_dir}/database.h5"
dict_SNinfo = {}
with h5py.File(file_name, "r") as hf:
columns_to_keep = ["SNID", settings.sntype_var, "mB", "c", "x1"]
columns_to_keep += [c for c in hf.keys() if "SIM_" in c]
columns_to_keep += [c for c in hf.keys() if "dataset_" in c]
columns_to_keep += [c for c in hf.keys() if "PEAK" in c]
for key in columns_to_keep:
dict_SNinfo[key] = hf[key][:]
df_SNinfo = pd.DataFrame(dict_SNinfo)
# bytes
if isinstance(df_SNinfo["SNID"].values[0], bytes):
df_SNinfo["SNID"] = df_SNinfo["SNID"].str.decode("utf8")
return df_SNinfo
[docs]def log_standardization(arr):
"""Normalization strategy for the fluxes and fluxes error
- Log transform the data
- Mean and std dev normalization
Args:
arr (np.array): data to normalize
Returns:
(LogStandardized) namedtuple holding normalization data
"""
arr_min = np.min(arr)
arr_log = np.log(-arr_min + arr + 1e-5)
arr_mean = arr_log.mean()
arr_std = arr_log.std()
if arr_min < -2000:
logging_utils.print_yellow(
f"Warning: extreme data values {arr_min}",
"clipping normalization min to -2000",
)
arr_min = -2000
return LogStandardized(arr_min=arr_min, arr_mean=arr_mean, arr_std=arr_std)
[docs]def save_to_HDF5(settings, df):
"""Saved processed dataframe to HDF5
Args:
settings (ExperimentSettings): controls experiment hyperparameters
df (pandas.DataFrame): dataframe holding processed data
"""
# One hot encode filter information and Normalize features
list_training_features = [f"FLUXCAL_{f}" for f in settings.list_filters]
list_training_features += [f"FLUXCALERR_{f}" for f in settings.list_filters]
list_training_features += [
"delta_time",
"HOSTGAL_PHOTOZ",
"HOSTGAL_PHOTOZ_ERR",
"HOSTGAL_SPECZ",
"HOSTGAL_SPECZ_ERR",
]
if settings.additional_train_var:
list_training_features += list(settings.additional_train_var)
list_misc_features = [
"PEAKMJD",
settings.sntype_var,
"mB",
"c",
"x1",
"SIM_REDSHIFT_CMB",
"SIM_PEAKMAG_z",
"SIM_PEAKMAG_g",
"SIM_PEAKMAG_r",
"SIM_PEAKMAG_i",
]
if settings.photo_window_var not in list_misc_features:
list_misc_features += settings.photo_window_var
list_misc_features = [k for k in list_misc_features if k in df.keys()]
assert df.index.name == "SNID", "Must set SNID as index"
# Get the list of lightcurve IDs
ID = df.index.values
# Find out when ID changes => find start and end idx of each lightcurve
idx_change = np.where(ID[1:] != ID[:-1])[0] + 1
idx_change = np.hstack(([0], idx_change, [len(df)]))
list_start_end = [(s, e) for s, e in zip(idx_change[:-1], idx_change[1:])]
# N.B. We could use df.loc[SNID], more elegant but much slower
# Filter list start end so we get only light curves with at least 3 points
# except when creating testing data (we want to classify all lcs even w. 1-2 epochs)
if not settings.data_testing:
list_start_end = list(filter(lambda x: x[1] - x[0] >= 3, list_start_end))
# Shuffle
np.random.shuffle(list_start_end)
# Save hdf5 file
with h5py.File(settings.hdf5_file_name, "w") as hf:
n_samples = len(list_start_end)
used = set()
unique_classes = [
x
for x in settings.sntypes.values()
if x not in used and (used.add(x) or True)
]
list_classes = list(set([2, len(unique_classes)]))
list_names = ["target", "dataset_photometry", "dataset_saltfit"]
# These arrays can be filled in one shot
start_idxs = [i[0] for i in list_start_end]
shuffled_ID = ID[start_idxs]
hf.create_dataset("SNID", data=shuffled_ID, dtype=h5py.special_dtype(vlen=str))
df_SNID = pd.DataFrame(shuffled_ID, columns=["SNID"])
logging_utils.print_green("Saving misc features")
for feat in list_misc_features:
if feat == settings.sntype_var:
dtype = np.dtype("int32")
else:
dtype = np.dtype("float32")
hf.create_dataset(feat, data=df[feat].values[start_idxs], dtype=dtype)
df = df.drop(columns=feat)
logging_utils.print_green("Saving class")
for c_ in list_classes:
for name in list_names:
field_name = f"{name}_{c_}classes"
hf.create_dataset(
field_name,
data=df[field_name].values[start_idxs],
dtype=np.dtype("int8"),
)
df = df.drop(columns=field_name)
df["time"] = df[["delta_time"]].groupby(df.index).cumsum()
df = df.reset_index()
logging_utils.print_green("Saving unique nights")
# Compute how many unique nights of data taking existed around PEAKMJD
for offset, suffix in zip(OFFSETS, OFFSETS_STR):
new_column = f"PEAKMJD{suffix}_unique_nights"
df_nights = (
df[df["time"] < df["PEAKMJDNORM"] + offset][["PEAKMJDNORM", "SNID"]]
.groupby("SNID")
.count()
.astype(np.uint8)
.rename(columns={"PEAKMJDNORM": new_column})
.reset_index()
)
hf.create_dataset(
new_column,
data=df_SNID.merge(df_nights, on="SNID", how="left")[new_column].values,
dtype=np.dtype("uint8"),
)
logging_utils.print_green("Saving filter occurences")
# Compute how many occurences of a specific filter around PEAKMJD
for flt in settings.list_filters:
# Check presence / absence of the filter at all time steps
df[f"has_{flt}"] = df.FLT.str.contains(flt).astype(np.uint8)
for offset, suffix in zip(OFFSETS, OFFSETS_STR):
new_column = f"PEAKMJD{suffix}_num_{flt}"
df_flt = (
df[df["time"] < df["PEAKMJDNORM"] + offset][[f"has_{flt}", "SNID"]]
.groupby("SNID")
.sum()
.astype(np.uint8)
.rename(columns={f"has_{flt}": new_column})
.reset_index()
)
hf.create_dataset(
new_column,
data=df_SNID.merge(df_flt, on="SNID", how="left")[
new_column
].values,
dtype=np.dtype("uint8"),
)
df = df.drop(columns=f"has_{flt}")
# FInally save PEAKMJDNORM
hf.create_dataset(
"PEAKMJDNORM",
data=df["PEAKMJDNORM"].values[start_idxs],
dtype=np.dtype("float32"),
)
cols_to_drop = [
k
for k in ["time", "SNID", "PEAKMJDNORM", settings.photo_window_var]
if k in df.keys()
]
df = df.drop(columns=list(set(cols_to_drop)))
########################
# Normalize per feature
########################
logging_utils.print_green("Compute normalizations")
gnorm = hf.create_group("normalizations")
# using normalization per feature
for feat in settings.training_features_to_normalize:
# Log transform plus mean subtraction and standard dev subtraction
log_standardized = log_standardization(df[feat].values)
# Store normalization parameters
gnorm.create_dataset(f"{feat}/min", data=log_standardized.arr_min)
gnorm.create_dataset(f"{feat}/mean", data=log_standardized.arr_mean)
gnorm.create_dataset(f"{feat}/std", data=log_standardized.arr_std)
#####################################
# Normalize flux and fluxerr globally
#####################################
logging_utils.print_green("Compute global normalizations")
gnorm = hf.create_group("normalizations_global")
################
# FLUX features
#################
flux_features = [f"FLUXCAL_{f}" for f in settings.list_filters]
flux_log_standardized = log_standardization(df[flux_features].values)
# Store normalization parameters
gnorm.create_dataset("FLUXCAL/min", data=flux_log_standardized.arr_min)
gnorm.create_dataset("FLUXCAL/mean", data=flux_log_standardized.arr_mean)
gnorm.create_dataset("FLUXCAL/std", data=flux_log_standardized.arr_std)
###################
# FLUXERR features
###################
fluxerr_features = [f"FLUXCALERR_{f}" for f in settings.list_filters]
fluxerr_log_standardized = log_standardization(df[fluxerr_features].values)
# Store normalization parameters
gnorm.create_dataset("FLUXCALERR/min", data=fluxerr_log_standardized.arr_min)
gnorm.create_dataset("FLUXCALERR/mean", data=fluxerr_log_standardized.arr_mean)
gnorm.create_dataset("FLUXCALERR/std", data=fluxerr_log_standardized.arr_std)
####################################
# Save the rest of the data to hdf5
####################################
logging_utils.print_green("Save non-data features to HDF5")
# This type allows one to store flat arrays of variable
# length inside an HDF5 group
data_type = h5py.special_dtype(vlen=np.dtype("float32"))
hf.create_dataset("data", (n_samples,), dtype=data_type)
# If header does not have HOST info fill with empty arrays
list_to_fill = [
k for k in list_training_features if k not in df.columns.values.tolist()
]
if len([k for k in list_to_fill if "HOST" not in k]) > 0:
logging_utils.print_red("missing information in input")
raise AttributeError
for key in list_to_fill:
df[key] = np.zeros(len(df))
logging_utils.print_green("Fit onehot on FLT")
assert sorted(df.columns.values.tolist()) == sorted(
list_training_features + ["FLT"]
)
# Fit a one hot encoder for FLT
# to have the same onehot for all datasets
# tmp = pd.Series(settings.list_filters_combination).append(df["FLT"])
tmp = pd.concat(
[pd.Series(settings.list_filters_combination), df["FLT"]]
) # TODO: NEED TO TEST THIS LINE
tmp_onehot = pd.get_dummies(tmp)
# this is ok since it goes by length not by index (which I never reset)
FLT_onehot = tmp_onehot[len(settings.list_filters_combination) :]
df = pd.concat([df[list_training_features], FLT_onehot], axis=1)
# store feature names
list_training_features = df.columns.values.tolist()
hf.create_dataset(
"features",
(len(list_training_features),),
dtype=h5py.special_dtype(vlen=str),
)
hf["features"][:] = list_training_features
logging_utils.print_green("Saved features:", ",".join(list_training_features))
# Save training features to hdf5
logging_utils.print_green("Save data features to HDF5")
arr_feat = df[list_training_features].values
hf["data"].attrs["n_features"] = len(list_training_features)
for idx, idx_pair in enumerate(
tqdm(list_start_end, desc="Filling hdf5", ncols=100)
):
arr = arr_feat[idx_pair[0] : idx_pair[1]]
hf["data"][idx] = np.ravel(arr)
# save data types for training
try:
hf["data_types_training"] = np.asarray(settings.data_types_training).astype(
np.dtype("S100")
)
except Exception:
hf["data_types_training"] = f"{settings.data_types_training}"