import os
import re
import glob
import shutil
import numpy as np
import pandas as pd
from tqdm import tqdm
import multiprocessing
from pathlib import Path
from natsort import natsorted
from functools import partial
from astropy.table import Table
from concurrent.futures import ProcessPoolExecutor
from ..utils import data_utils
from ..utils import logging_utils
from ..paper.superNNova_plots import datasets_plots
def process_fn(inputs):
fn, fil = inputs
return fn(fil)
def powers_of_two(x):
powers = []
i = 1
while i <= x:
if i & x:
powers.append(i)
i <<= 1
return powers
[docs]def build_traintestval_splits(settings):
"""Build dataset split in the following way
- Downsample each class so that it has the same cardinality as the lowest cardinality class
- Randomly assign lightcurves to a 80/10/10 train test val split (except Out-of-distribution data 1/1/98)
OOD:
Will use the complete sample for testing, does not require settings.
Args:
settings (ExperimentSettings): controls experiment hyperparameters
"""
logging_utils.print_green("Computing splits")
# Read and process files faster with ProcessPoolExecutor
max_workers = multiprocessing.cpu_count() - 2
photo_columns = ["SNID"] + [
f"target_{nb_classes}classes"
# for nb_classes in list(set([2, len(settings.sntypes.keys())]))
for nb_classes in list(
set([2, len(set([k for k in dict(settings.sntypes).values()]))])
)
]
# Load headers
# either in HEAD.FITS or csv format
list_files = natsorted(Path(settings.raw_dir).glob("**/*HEAD*"))
list_fmt = [re.search(r"(FITS|csv)", fil.name).group() for fil in list_files]
list_files = [str(fil) for fil in list_files]
print("List files", list_files)
# use parallelization to speed up processing
if not settings.debug:
process_fn_FITS = partial(
data_utils.process_header_FITS,
settings=settings,
columns=photo_columns + [settings.sntype_var],
)
process_fn_csv = partial(
data_utils.process_header_csv,
settings=settings,
columns=photo_columns + [settings.sntype_var],
)
list_fn = []
for fmt in list_fmt:
if fmt == "csv":
list_fn.append(process_fn_csv)
elif fmt == "FITS":
list_fn.append(process_fn_FITS)
list_pairs = list(zip(list_fn, list_files))
with ProcessPoolExecutor(max_workers=max_workers) as executor:
list_df = executor.map(process_fn, list_pairs)
else:
logging_utils.print_yellow("Beware debugging mode (loop over files)")
list_df = []
for fil, fmat in zip(list_files, list_fmt):
if fmat == "FITS":
list_df.append(
data_utils.process_header_FITS(
fil, settings, columns=photo_columns + [settings.sntype_var]
)
)
else:
list_df.append(
data_utils.process_header_csv(
fil, settings, columns=photo_columns + [settings.sntype_var]
)
)
# Load df_photo
df_photo = pd.concat(list_df)
df_photo["SNID"] = df_photo["SNID"].astype(str).str.strip()
# load FITOPT file on which we will base our splits
df_salt = data_utils.load_fitfile(settings)
if len(df_salt) < 1:
# if no fits file we include all lcs
logging_utils.print_yellow("All lcs used for salt and photometry samples")
df_salt = pd.DataFrame()
df_salt["SNID"] = df_photo["SNID"].str.strip()
df_salt["is_salt"] = 1
# correct format SNID
df_salt["SNID"] = df_salt["SNID"].astype(str).str.strip()
# Check all SNID in df_salt are also in df_photo
try:
assert np.all(df_salt.SNID.isin(df_photo.SNID))
except Exception:
logging_utils.print_red(
" BEWARE! This is not the fits file for this photometry "
)
print(logging_utils.str_to_redstr(" do point at the correct --fits_dir "))
print(
logging_utils.str_to_redstr(
" (cheat: or an empty folder to override use of salt2fits) "
)
)
import sys
sys.exit(1)
# Merge left on df_photo
df = df_photo.merge(df_salt[["SNID", "is_salt"]], on=["SNID"], how="left")
# Some curves are in photo and not in salt, these curves have is_salt = NaN
# We replace the NaN with 0
df["is_salt"] = df["is_salt"].fillna(0).astype(int)
# Save dataset stats
list_stat = []
# Save a dataframe to record train/test/val split for
# binary, ternary and all-classes classification
for dataset in ["saltfit", "photometry"]:
# for nb_classes in list(set([2, len(settings.sntypes.keys())])):
for nb_classes in list(
set([2, len(set([k for k in dict(settings.sntypes).values()]))])
):
logging_utils.print_green(
f"Computing {dataset} splits for {nb_classes}-way classification"
)
# Randomly sample SNIDs such that all class have the same number of occurences
if dataset == "saltfit":
g = df[df.is_salt == 1].groupby(
f"target_{nb_classes}classes", group_keys=False
)
else:
g = df.groupby(f"target_{nb_classes}classes", group_keys=False)
dic_targets = (
g[settings.sntype_var].apply(lambda x: list(np.unique(x))).to_dict()
)
print(f"target {settings.sntype_var}")
# settings.data_types_training = [
# f"{k} {settings.sntypes[v[0]]} {[int(dt) for dt in dic_targets[k]]}"
# for k, v in dic_targets.items()
# ]
settings.data_types_training = [
f"{k} {settings.sntypes[v[0]]} {[int(dt) for dt in dic_targets[k]]}"
if v[0] in settings.sntypes.keys()
else f"{k} nonIa {[int(dt) for dt in dic_targets[k]]}"
for k, v in dic_targets.items()
]
print(settings.data_types_training)
if settings.testing_ids:
if Path(settings.testing_ids).suffix == ".csv":
df_ids_test = pd.read_csv(settings.testing_ids)
try:
ids_test = df_ids_test["SNID"].astype(str).values
except Exception:
logging_utils.print_red(
f"Provide a {settings.testing_ids} with SNID column"
)
raise ValueError
elif Path(settings.testing_ids).suffix == ".npy":
ids_test = np.load(settings.testing_ids)
ids_test = [f"{k}" for k in ids_test]
else:
logging_utils.print_red("Provide a csv or numpy testing_ids file")
raise ValueError
g_wo_test = df[~df.SNID.isin(ids_test)].groupby(
f"target_{nb_classes}classes", group_keys=False
)
g_test = df[df.SNID.isin(ids_test)].groupby(
f"target_{nb_classes}classes", group_keys=False
)
# Line below: we have grouped df by target, we find out which of those
# group has the smallest size with g.size().min(), then we sample randomly
# from this group and reset the index. We then sample with frac=1 to shuffle
# the whole dataset. Otherwise, the classes are sorted and the train/test/val
# splits are incorrect.
if settings.data_testing:
# when just classifying data balancing is not necessary
g = g.apply(lambda x: x).reset_index(drop=True).sample(frac=1)
elif settings.testing_ids:
g_wo_test = (
g_wo_test.apply(lambda x: x).reset_index(drop=True).sample(frac=1)
)
g_test = g_test.apply(lambda x: x).reset_index(drop=True).sample(frac=1)
else:
g = (
g.apply(lambda x: x.sample(g.size().min()))
.reset_index(drop=True)
.sample(frac=1)
)
if settings.testing_ids:
sampled_SNIDs_wo_test = g_wo_test["SNID"].values
n_samples = len(sampled_SNIDs_wo_test)
SNID_train = sampled_SNIDs_wo_test[: int(0.9 * n_samples)]
SNID_val = sampled_SNIDs_wo_test[int(0.9 * n_samples) : int(n_samples)]
sampled_SNIDs_test = g_test["SNID"].values
SNID_test = sampled_SNIDs_test[:]
else:
sampled_SNIDs = g["SNID"].values
n_samples = len(sampled_SNIDs)
# Now create train/test/validation indices
if settings.data_training:
SNID_train = sampled_SNIDs[: int(0.99 * n_samples)]
SNID_val = sampled_SNIDs[
int(0.99 * n_samples) : int(0.995 * n_samples)
]
SNID_test = sampled_SNIDs[int(0.995 * n_samples) :]
elif settings.data_testing:
SNID_test = sampled_SNIDs[:]
# the train and val sets wont be used in this case
SNID_train = [sampled_SNIDs[0]]
SNID_val = [sampled_SNIDs[0]]
else:
SNID_train = sampled_SNIDs[: int(0.8 * n_samples)]
SNID_val = sampled_SNIDs[
int(0.8 * n_samples) : int(0.9 * n_samples)
]
SNID_test = sampled_SNIDs[int(0.9 * n_samples) :]
# Find the indices of our train test val splits
idxs_train = np.where(df.SNID.isin(SNID_train))[0]
idxs_val = np.where(df.SNID.isin(SNID_val))[0]
idxs_test = np.where(df.SNID.isin(SNID_test))[0]
# Create a new column that will state to which data split
# a given SNID will be assigned
# train: 0, val: 1, test:2 others: -1
arr_dataset = -np.ones(len(df)).astype(int)
arr_dataset[idxs_train] = 0
arr_dataset[idxs_val] = 1
arr_dataset[idxs_test] = 2
df[f"dataset_{dataset}_{nb_classes}classes"] = arr_dataset
# Display classes balancing in the dataset, for each split
logging_utils.print_bright("Dataset composition")
for split_name, idxs in zip(
["Training", "Validation", "Test"], [idxs_train, idxs_val, idxs_test]
):
# We count the number of occurence in each class and each split with
# pandas.value_counts
d_occurences = (
df[f"target_{nb_classes}classes"]
.iloc[idxs]
.value_counts()
.sort_values()
.to_dict()
)
d_occurences_SNTYPE = (
df[settings.sntype_var]
.iloc[idxs]
.value_counts()
.sort_values()
.to_dict()
)
total_samples = sum(d_occurences.values())
total_samples_str = logging_utils.str_to_yellowstr(total_samples)
str_ = f"# samples {total_samples_str} "
for c_, n_samples in d_occurences.items():
class_str = logging_utils.str_to_yellowstr(c_)
class_fraction = f"{100 *(n_samples/total_samples):.2g}%"
class_fraction_str = logging_utils.str_to_yellowstr(class_fraction)
str_ += f"Class {class_str}: {class_fraction_str} samples "
list_stat.append(
[
dataset,
nb_classes,
split_name,
total_samples,
d_occurences,
d_occurences_SNTYPE,
]
)
logging_utils.print_green(f"{split_name} set", str_)
# Save to pickle
df.to_pickle(f"{settings.processed_dir}/SNID.pickle")
logging_utils.print_green("Done")
[docs]def process_single_FITS(file_path, settings):
"""
Carry out preprocessing on FITS file and save results to pickle.
Pickle is preferred to csv as it is faster to read and write.
- Join column from header files
- Select columns that will be useful laer on
- Compute SNID to tag each light curve
- Compute delta times between measures
- Filter preprocessing
- Removal of delimiter rows
Args:
file_path (str): path to ``.FITS`` file
settings (ExperimentSettings): controls experiment hyperparameters
"""
# Load the PHOT file
df = data_utils.load_pandas_from_fit(file_path)
if len(df) < 1:
logging_utils.print_red("Do not provide empty photometry file", file_path)
# Last line may be a line with MJD = -777.
# Remove it so that it does not interfere with arr_ID below
if df.MJD.values[-1] == -777.0:
df = df.drop(df.index[-1])
# Keep only columns of interest
keep_col = ["MJD", "FLUXCAL", "FLUXCALERR", "FLT"]
# BAND and FLT are exchangeable
if "FLT" not in df.keys() and "BAND" in df.keys():
df = df.rename(columns={"BAND": "FLT"})
df = (
df[keep_col + [settings.phot_reject]].copy()
if settings.phot_reject
else df[keep_col].copy()
)
# Load the companion HEAD file
header = Table.read(file_path.replace("PHOT", "HEAD"), format="fits")
df_header = header.to_pandas()
try:
df_header["SNID"] = df_header["SNID"].str.decode("utf-8")
except Exception:
df_header["SNID"] = df_header["SNID"].astype(str)
# Keep only columns of interest
# Hack for using the final redshift not the galaxy
if settings.redshift_label != "none":
logging_utils.print_yellow("Changed redshift to", settings.redshift_label)
df_header["HOSTGAL_SPECZ"] = df_header[settings.redshift_label]
df_header["HOSTGAL_SPECZ_ERR"] = df_header[f"{settings.redshift_label}_ERR"]
df_header["SIM_REDSHIFT_CMB"] = df_header[settings.redshift_label]
keep_col_header = [
"SNID",
"PEAKMJD",
"HOSTGAL_PHOTOZ",
"HOSTGAL_PHOTOZ_ERR",
"HOSTGAL_SPECZ",
"HOSTGAL_SPECZ_ERR",
"SIM_REDSHIFT_CMB",
"SIM_PEAKMAG_z",
"SIM_PEAKMAG_g",
"SIM_PEAKMAG_r",
"SIM_PEAKMAG_i",
settings.sntype_var,
]
if settings.photo_window_var not in keep_col_header:
keep_col_header += [settings.photo_window_var]
if settings.additional_train_var:
keep_col_header += list(settings.additional_train_var)
# check if keys are in header
keep_col_header = [k for k in keep_col_header if k in df_header.keys()]
df_header = df_header[keep_col_header].copy()
df_header["SNID"] = df_header["SNID"].astype(str).str.strip()
#############################################
# Photometry window init
#############################################
if settings.photo_window_files:
if Path(settings.photo_window_files[0]).exists():
# load fits file
df_peak = pd.read_csv(
settings.photo_window_files[0],
comment="#",
delimiter=" ",
skipinitialspace=True,
)
if "SNID" not in df_peak.keys():
df_peak["SNID"] = df_peak["CID"].astype(str)
else:
df_peak["SNID"] = df_peak["SNID"].astype(str)
try:
df_peak = df_peak[["SNID", settings.photo_window_var]]
except Exception:
logging_utils.print_red("Provide a correct photo_window variable")
raise Exception
# merge with header
df_header_tmp = pd.merge(df_header, df_peak, on="SNID")
if len(df_header) == len(df_header_tmp):
df_header = df_header_tmp
else:
raise Exception
if len(df_header) < 1:
logging_utils.print_red(
"Provide a matching photo_window_file (not a common SNID found) "
)
raise Exception
else:
if settings.photo_window_files[0] == "HEAD":
# if using a variable from header file
if settings.photo_window_var in df_header.keys():
pass
else:
logging_utils.print_red(
"Provide a valid peak key in header or a photo_window_file"
)
else:
logging_utils.print_red(
"Provide a valid peak key in header or a photo_window_file"
)
#############################################
# Compute SNID for df and join with df_header
#############################################
arr_ID = np.chararray(len(df), itemsize=15)
# New light curves are identified by MJD == -777.0
arr_idx = np.where(df["MJD"].values == -777.0)[0]
arr_idx = np.hstack((np.array([0]), arr_idx, np.array([len(df)])))
# Fill in arr_ID
for counter in range(1, len(arr_idx)):
start, end = arr_idx[counter - 1], arr_idx[counter]
# index starts at zero
arr_ID[start:end] = df_header.SNID.iloc[counter - 1]
df["SNID"] = arr_ID.astype(str)
df["SNID"] = df["SNID"].str.strip()
df = df.set_index("SNID")
df_header["SNID"] = df_header["SNID"].str.strip()
df_header = df_header.set_index("SNID")
# join df and header
df = df.join(df_header).reset_index()
#############################################
# Photometry window & quality (flag) selection
#############################################
# window
if settings.photo_window_files:
df["window_time_cut"] = True
mask = df["MJD"] != -777.00
df["window_delta_time"] = df["MJD"] - df[settings.photo_window_var]
df.loc[mask, "window_time_cut"] = df["window_delta_time"].apply(
lambda x: True
if (x > 0 and x < settings.photo_window_max)
else (True if (x <= 0 and x > settings.photo_window_min) else False)
)
df = df[df["window_time_cut"] is True]
# quality
if settings.phot_reject:
# only valid for powers of two combinations
tmp = len(df.SNID.unique())
tmp2 = len(df)
df["phot_reject"] = df[settings.phot_reject].apply(
lambda x: False
if len(set(settings.phot_reject_list).intersection(set(powers_of_two(x))))
> 0
else True
)
df = df[df["phot_reject"] is True]
if settings.debug:
logging_utils.print_blue("Phot reject", file_path)
logging_utils.print_blue(f"SNID {tmp} to {len(df.SNID.unique())}")
logging_utils.print_blue(f"Phot {tmp2} to {len(df)}")
#############################################
# Miscellaneous data processing
#############################################
df = df[keep_col + keep_col_header].copy()
# filters have a trailing white space which we remove
df.FLT = df.FLT.apply(lambda x: x.rstrip()).values.astype(str)
# keep only filters we are going to use for classification
df = df[df["FLT"].isin(settings.list_filters)]
# Drop the delimiter lines
df = df[df.MJD != -777.000]
# Reset the index (it is no longer continuous after dropping lines)
df.reset_index(inplace=True, drop=True)
# Add delta time
df = data_utils.compute_delta_time(df)
# Remove rows post large delta time in the same light curve(delta_time > 150)
# df = data_utils.remove_data_post_large_delta_time(df)
#############################################
# Add class and dataset information
#############################################
df_SNID = pd.read_pickle(f"{settings.processed_dir}/SNID.pickle")
# Check all SNID in df are in df_SNID
assert np.all(df.SNID.isin(df_SNID.SNID))
# Merge left on df: len(df) will not change and will now include
# relevant columns from df_SNID
merge_columns = ["SNID"]
# for c_ in list(set([2, len(settings.sntypes.keys())])):
distinct_classes = len(set([k for k in dict(settings.sntypes).values()]))
for c_ in list(set([2, distinct_classes])):
merge_columns += [f"target_{c_}classes"]
for dataset in ["photometry", "saltfit"]:
merge_columns += [f"dataset_{dataset}_{c_}classes"]
df = df.merge(df_SNID[merge_columns], on=["SNID"], how="left")
# Save for future use
basename = os.path.basename(file_path)
folder_name = Path(file_path.split(f"{settings.raw_dir}/")[-1]).parent
if folder_name != Path("."):
prefix = str(folder_name).replace("/", "_")
basename = f"{prefix}_{basename}"
df.to_pickle(
f"{settings.preprocessed_dir}/{basename.replace('.FITS', '.pickle').replace('.gz','')}"
)
# getting SNIDs for SNe with Host_spec
host_spe = df[df["HOSTGAL_SPECZ"] > 0]["SNID"].unique().tolist()
return host_spe
[docs]def process_single_csv(file_path, settings):
"""
Carry out preprocessing on csv file and save results to pickle.
Pickle is preferred to csv as it is faster to read and write.
- Compute delta times between measures
- Filter preprocessing
Args:
file_path (str): path to ``.csv`` file
settings (ExperimentSettings): controls experiment hyperparameters
"""
# Load the PHOT file
df = pd.read_csv(file_path)
if len(df) < 1:
logging_utils.print_red("Do not provide empty photometry file", file_path)
raise ValueError
# Keep only columns of interest
keep_col = ["SNID", "MJD", "FLUXCAL", "FLUXCALERR", "FLT"]
df = df[keep_col].copy()
df["SNID"] = df["SNID"].astype(str)
df = df.set_index("SNID")
# Load the companion HEAD file
df_header = pd.read_csv(file_path.replace("PHOT", "HEAD"))
if settings.redshift_label != "none":
logging_utils.print_yellow("Changed redshift to", settings.redshift_label)
df_header["HOSTGAL_SPECZ"] = df_header[settings.redshift_label]
df_header["HOSTGAL_SPECZ_ERR"] = df_header[f"{settings.redshift_label}_ERR"]
df_header["SIM_REDSHIFT_CMB"] = df_header[settings.redshift_label]
# Keep only columns of interest
keep_col_header = [
"SNID",
"PEAKMJD",
"HOSTGAL_PHOTOZ",
"HOSTGAL_PHOTOZ_ERR",
"HOSTGAL_SPECZ",
"HOSTGAL_SPECZ_ERR",
"SIM_REDSHIFT_CMB",
"SIM_PEAKMAG_z",
"SIM_PEAKMAG_g",
"SIM_PEAKMAG_r",
"SIM_PEAKMAG_i",
settings.sntype_var,
]
if settings.photo_window_var not in keep_col_header:
keep_col_header += [settings.photo_window_var]
if settings.additional_train_var:
keep_col_header += list(settings.additional_train_var)
print(f"Adding additional variables to dataset {settings.additional_train_var}")
# check if keys are in header
keep_col_header = [k for k in keep_col_header if k in df_header.keys()]
df_header = df_header[keep_col_header].copy()
df_header["SNID"] = df_header["SNID"].astype(str)
df_header["SNID"] = df_header["SNID"].str.strip()
df_header = df_header.set_index("SNID")
df = df.join(df_header).reset_index()
if settings.photo_window_files:
df["window_time_cut"] = True
mask = df["MJD"] != -777.00
df["window_delta_time"] = df["MJD"] - df[settings.photo_window_var]
df.loc[mask, "window_time_cut"] = df["window_delta_time"].apply(
lambda x: True
if (x > 0 and x < settings.photo_window_max)
else (True if (x <= 0 and x > settings.photo_window_min) else False)
)
df = df[df["window_time_cut"] is True]
# quality
if settings.phot_reject:
tmp = len(df.SNID.unique())
tmp2 = len(df)
df["phot_reject"] = df[settings.phot_reject].apply(
lambda x: False
if len(set(settings.phot_reject_list).intersection(set(powers_of_two(x))))
> 0
else True
)
df = df[df["phot_reject"] is True]
if settings.debug:
logging_utils.print_blue("Phot reject", file_path)
logging_utils.print_blue(f"SNID {tmp} to {len(df.SNID.unique())}")
logging_utils.print_blue(f"Phot {tmp2} to {len(df)}")
#############################################
# Miscellaneous data processing
#############################################
df = df[list(set(keep_col + keep_col_header))].copy()
# filters have a trailing white space which we remove
df.FLT = df.FLT.apply(lambda x: x.rstrip()).values.astype(str)
# keep only filters we are going to use for classification
df = df[df["FLT"].isin(settings.list_filters)]
# Drop the delimiter lines
df = df[df.MJD != -777.000]
# Reset the index (it is no longer continuous after dropping lines)
df.reset_index(inplace=True, drop=True)
# Add delta time
df = data_utils.compute_delta_time(df)
# Remove rows post large delta time in the same light curve(delta_time > 150)
# df = data_utils.remove_data_post_large_delta_time(df)
#############################################
# Add class and dataset information
#############################################
df_SNID = pd.read_pickle(f"{settings.processed_dir}/SNID.pickle")
# Check all SNID in df are in df_SNID
assert np.all(df.SNID.isin(df_SNID.SNID))
# Merge left on df: len(df) will not change and will now include
# relevant columns from df_SNID
merge_columns = ["SNID"]
# for c_ in list(set([2, len(settings.sntypes.keys())])):
distinct_classes = len(set([k for k in dict(settings.sntypes).values()]))
for c_ in list(set([2, distinct_classes])):
merge_columns += [f"target_{c_}classes"]
for dataset in ["photometry", "saltfit"]:
merge_columns += [f"dataset_{dataset}_{c_}classes"]
df = df.merge(df_SNID[merge_columns], on=["SNID"], how="left")
# Save for future use
basename = os.path.basename(file_path)
df.to_pickle(
f"{settings.preprocessed_dir}/{basename.replace('.FITS', '.pickle').replace('.gz','')}"
)
# getting SNIDs for SNe with Host_spec
host_spe = (
df[df["HOSTGAL_SPECZ"] > 0]["SNID"].unique().tolist()
if "HOSTGAL_SPECZ" in df.keys()
else []
)
return host_spe
[docs]def preprocess_data(settings):
"""Preprocess the FITS data
- Use multiprocessing/threading to speed up data processing
- Preprocess every FIT file in the raw data dir
- Also save a DataFrame of Host Spe for publication plots
Args:
settings (ExperimentSettings): controls experiment hyperparameters
"""
# Get the list of FITS files
# Load headers
# either in HEAD.FITS or csv format
list_files = natsorted(Path(settings.raw_dir).glob("**/*PHOT*"))
list_fmt = [re.search(r"(FITS|csv)", fil.name).group() for fil in list_files]
list_files = [str(fil) for fil in list_files]
if not settings.debug:
process_fn_FITS = partial(process_single_FITS, settings=settings)
process_fn_csv = partial(process_single_csv, settings=settings)
list_fn = []
for fmt in list_fmt:
if fmt == "csv":
list_fn.append(process_fn_csv)
elif fmt == "FITS":
list_fn.append(process_fn_FITS)
logging_utils.print_green("List to preprocess ", list_files)
max_workers = multiprocessing.cpu_count() - 2
host_spe_tmp = []
# use parallelization to speed up processing
# Split list files in chunks of size 10 or less
# to get a progress bar and alleviate memory constraints
num_elem = len(list_files)
num_chunks = num_elem // 10 + 1
list_chunks = np.array_split(np.arange(num_elem), num_chunks)
# # Loop over chunks of files
if not settings.debug:
for chunk_idx in tqdm(list_chunks, desc="Preprocess", ncols=100):
# Process each file in the chunk in parallel
with ProcessPoolExecutor(max_workers=max_workers) as executor:
start, end = chunk_idx[0], chunk_idx[-1] + 1
# Need to cast to list because executor returns an iterator
# host_spe_tmp += list(executor.map(parallel_fn, list_files[start:end]))
list_pairs = list(zip(list_fn[start:end], list_files[start:end]))
host_spe_tmp += list(executor.map(process_fn, list_pairs))
else:
logging_utils.print_yellow("Beware debugging mode (loop over files)")
# for debugging only (parallelization needs to be commented)
for i in range(len(list_files)):
out = (
process_single_FITS(list_files[i], settings)
if "FITS" in list_files[i]
else process_single_csv(list_files[i], settings)
)
host_spe_tmp.append(out)
# Save host spe for plotting and performance tests
host_spe = [item for sublist in host_spe_tmp for item in sublist]
pd.DataFrame(host_spe, columns=["SNID"]).to_pickle(
f"{settings.processed_dir}/hostspe_SNID.pickle"
)
logging_utils.print_green("Finished preprocessing")
def pivot_dataframe_single(filename, settings):
df = pd.read_pickle(filename)
df = pivot_dataframe_single_from_df(df, settings)
# Save to pickle
dump_filename = filename.split(".pickle")[0] + "_pivot.pickle"
df.to_pickle(dump_filename)
[docs]def pivot_dataframe_single_from_df(df, settings):
"""
Carry out pivot: we will group time-wise close observations on the same row
and each row in the dataframe will show a value for each of the flux and flux
error column
- All observations withing 8 hours of each other are assigned the same MJD
- Results are cached with pickle
Args:
filename (str): path to a ``.pickle`` file containing pre-processed data
settings (ExperimentSettings): controls experiment hyperparameters
"""
list_filters = settings.list_filters
assert len(list_filters) > 0
arr_MJD = df.MJD.values
arr_delta_time = df.delta_time.values
# Loop over times to create grouped MJD:
# if filters are acquired within less than 0.33 MJD (~8 hours) of each other
# they get assigned the same time
min_dt = 0.33
time_last_change = 0
arr_grouped_MJD = np.zeros_like(arr_MJD)
for i in range(len(df)):
time = arr_MJD[i]
dt = arr_delta_time[i]
# 2 possibilities to update the time
# dt == 0 (it"s a new light curve)
# time - time_last_change > min_dt
if dt == 0 or (time - time_last_change) > min_dt:
arr_grouped_MJD[i] = time
time_last_change = time
else:
arr_grouped_MJD[i] = arr_grouped_MJD[i - 1]
# Add grouped delta time to dataframe
df["grouped_MJD"] = np.array(arr_grouped_MJD)
# Some filters (i, r, g, z) may appear multiple times with the same grouped MJD within same light curve
# When this happens, we select the one with lowest FLUXCALERR
df = df.sort_values("FLUXCALERR").groupby(["SNID", "grouped_MJD", "FLT"]).first()
# We then reset the index
df = df.reset_index()
# Compute PEAKMJDNORM = PEAKMJD in days since the start of the light curve
df["PEAKMJDNORM"] = df["PEAKMJD"] - df["MJD"]
# The correct PEAKMJDNORM is the first one hence the use of first after groupby
df_PEAKMJDNORM = df[["SNID", "PEAKMJDNORM"]].groupby("SNID").first().reset_index()
# Remove PEAKMJDNORM
df = df.drop(labels="PEAKMJDNORM", axis=1)
# Add PEAKMJDNORM back to df with a merge on SNID
df = df.merge(df_PEAKMJDNORM, how="left", on="SNID")
# drop columns that won"t be used onwards
df = df.drop(labels=["MJD", "delta_time"], axis=1)
class_columns = []
# for c_ in list(set([2, len(settings.sntypes.keys())])):
distinct_classes = len(set([k for k in dict(settings.sntypes).values()]))
for c_ in list(set([2, distinct_classes])):
class_columns += [f"target_{c_}classes"]
for dataset in ["photometry", "saltfit"]:
class_columns += [f"dataset_{dataset}_{c_}classes"]
group_features_list = (
[
"SNID",
"grouped_MJD",
"PEAKMJD",
"PEAKMJDNORM",
"SIM_REDSHIFT_CMB",
settings.sntype_var,
"SIM_PEAKMAG_z",
"SIM_PEAKMAG_g",
"SIM_PEAKMAG_r",
"SIM_PEAKMAG_i",
]
+ [k for k in df.keys() if "HOST" in k]
+ class_columns
)
if settings.photo_window_var not in group_features_list:
group_features_list += [settings.photo_window_var]
if settings.additional_train_var:
group_features_list += list(settings.additional_train_var)
# check if keys are in header
group_features_list = [k for k in group_features_list if k in df.keys()]
# Pivot so that for a given MJD, we have info on all available fluxes / error
df = pd.pivot_table(df, index=group_features_list, columns=["FLT"])
# Flatten columns
df.columns = ["_".join(col).strip() for col in df.columns.values]
# Reset index to get grouped_MJD and target as columns
cols_to_reset_list = [c for c in df.index.names if c != "SNID"]
df.reset_index(cols_to_reset_list, inplace=True)
# Rename grouped_MJD to MJD
df.rename(columns={"grouped_MJD": "MJD"}, inplace=True)
# New column to indicate which channel (r,g,z,i) is present
# The column will read ``rg`` if r,g are present; ``rgz`` if r,g and z are present, etc.
# fix missing filters
missing_filters = [k for k in list_filters if f"FLUXCAL_{k}" not in df.columns]
for f in missing_filters:
df[f"FLUXCAL_{f}"] = np.nan
df[f"FLUXCALERR_{f}"] = np.nan
for flt in list_filters:
df[flt] = np.where(df["FLUXCAL_%s" % flt].isnull(), "", flt)
df["FLT"] = df[list_filters[0]]
for flt in list_filters[1:]:
df["FLT"] += df[flt]
# Drop some irrelevant columns
df = df.drop(labels=list_filters, axis=1)
# Finally replace NaN with 0
df = df.fillna(0)
# Add delta_time back. We removed all delta time columns above as they get
# filled with NaN during pivot. It is clearer to recompute delta time once the pivot is complete
df = data_utils.compute_delta_time(df)
# Cast columns to float32, int32 to save space
for c in df.columns:
if df[c].dtype == np.float64:
df[c] = df[c].astype(np.float32)
elif "classes" in c and df[c].dtype == np.int64:
df[c] = df[c].astype(np.int8)
# Add some extra columns from the FITOPT file
df_salt = data_utils.load_fitfile(settings, verbose=False)
if len(df_salt) > 1:
df_salt = df_salt.set_index("SNID")
else:
# if no fits file we populate with dummies
# logging_utils.print_yellow(f"Creating dummy mB,c,x1")
df_salt = pd.DataFrame()
df_salt["SNID"] = np.array(df.index.unique())
df_salt["mB"] = np.zeros(len(df.index.unique()))
df_salt["c"] = np.zeros(len(df.index.unique()))
df_salt["x1"] = np.zeros(len(df.index.unique()))
df_salt = df_salt.set_index("SNID")
df = df.join(df_salt[["mB", "c", "x1"]], how="left")
df.drop(columns="MJD", inplace=True)
return df
[docs]def pivot_dataframe_batch(list_files, settings):
"""
- Use multiprocessing/threading to speed up data processing
- Pivot every file in list_files and cache the result with pickle
Args:
list_files (list): list of ``.pickle`` files containing pre-processed data
settings (ExperimentSettings): controls experiment hyperparameters
"""
# Split list files in chunks of size 10 or less
# to get a progress bar and alleviate memory constraints
num_elem = len(list_files)
num_chunks = num_elem // 10 + 1
list_chunks = np.array_split(np.arange(num_elem), num_chunks)
# Parameters of multiprocessing below
if not settings.debug:
max_workers = multiprocessing.cpu_count() - 2
# use parallelization to speed up processing
# Loop over chunks of files
for chunk_idx in tqdm(list_chunks, desc="Pivoting dataframes", ncols=100):
parallel_fn = partial(pivot_dataframe_single, settings=settings)
# Process each file in the chunk in parallel
with ProcessPoolExecutor(max_workers=max_workers) as executor:
start, end = chunk_idx[0], chunk_idx[-1] + 1
executor.map(parallel_fn, list_files[start:end])
else:
logging_utils.print_yellow("Beware debugging mode (loop over pivot)")
# for debugging only, process one file only
for fil in list_files:
pivot_dataframe_single(fil, settings)
logging_utils.print_green("Finished pivot")
[docs]@logging_utils.timer("Data processing")
def make_dataset(settings):
"""Main function for data processing
- Create the train test val splits
- Preprocess all the FITs data, then pivot
- Save all of the processed data to a single HDF5 database
Args:
settings (ExperimentSettings): controls experiment hyperparameters
"""
# Clean up data folders
if settings.overwrite is True:
for folder in [settings.preprocessed_dir, settings.processed_dir]:
# Dont throw error if folder exists with exist_ok Flag.
for f in glob.glob(f"{folder}/*"):
os.remove(f)
# split dataset in train test and validation
build_traintestval_splits(settings)
# Preprocess dataset
preprocess_data(settings)
# Pivot dataframe
list_files = natsorted(glob.glob(f"{settings.preprocessed_dir}/*PHOT*"))
pivot_dataframe_batch(list_files, settings)
# Aggregate the pivoted dataframe
list_files = natsorted(
glob.glob(os.path.join(settings.preprocessed_dir, "*pivot.pickle*"))
)
logging_utils.print_green("Concatenating pivot")
df = pd.concat([pd.read_pickle(f) for f in list_files], axis=0)
# Save to HDF5
data_utils.save_to_HDF5(settings, df)
# Save plots to visualize the distribution of some of the data features
try:
SNinfo_df = data_utils.load_HDF5_SNinfo(settings)
datasets_plots(SNinfo_df, settings)
except Exception:
logging_utils.print_yellow(
"Warning: can't do data plots if no saltfit for this dataset"
)
# Clean preprocessed directory
if settings.debug:
logging_utils.print_red("Debugging mode, keeping preprocessed data")
else:
shutil.rmtree(settings.preprocessed_dir)
logging_utils.print_green("Finished making dataset")