Quickstart guide (GitHub)
Welcome to SuperNNova!
This is a quick start guide so you can start testing our framework. If you want to install SuperNNova as a module, please take a look at Quickstart guide (pip).
Installation
Clone the GitHub repository
git clone https://github.com/supernnova/supernnova.git
Setup your environment. 3 options
- Please beware that SuperNNova only runs properly in Unix systems (Linux, MacOS).
Create a docker image: Docker .
Create a conda virtual env Environment configuration .
Install packages manually. Inspect
conda_env.txt
for the list of packages we use.
Usage
For quick tests, a database that contains a limited number of light-curves is provided. It is located in tests/raw
. For more information on the available data, check Data walkthrough.
Using command line
Build the database .. code:
an additional argument --fits_dir tests/fits
can provide a SALT2 fits file for random forest training and interpretation.
Train an RNN
python run.py --train_rnn --dump_dir tests/dump
With this command you are training and validating our Baseline RNN with the test database. The trained model will be saved in a newly created model folder inside tests/dump/models
.
The model folder has been named as follows: vanilla_S_0_CLF_2_R_None_photometry_DF_1.0_N_global_lstm_32x2_0.05_128_True_mean
(See below for the naming conventions). This folder’s contents are:
saved model (
*.pt
): PyTorch RNN model.statistics (
METRICS*.pickle
): pickled Pandas DataFrame with accuracy and other performance statistics for this model.predictions (
PRED*.pickle
): pickled Pandas DataFrame with the predictions of our model on the test set.figures (
train_and_val_*.png
): figures showing the evolution of the chosen metric at each training step.
Remember that our data is split in training, validation and test sets.
Plot light-curves and their predictions
python run.py --dump_dir tests/dump --plot_lcs
You can now inspect the test light-curves and their predictions in tests/dump/lightcurves
You have trained, validated and tested your model.
Using Yaml
Build the database .. code:
python run_yml.py configs_yml/default.yml --mode data
Train an RNN
python run_yml.py configs_yml/default.yml --mode train_rnn
Available modes: data
,``train_rnn``, validate_rnn
, plot_lcs
. Currently RF classification is not suppported with the yaml configurations. An example of classification using existing model is in configs_yml/classify.yml
.
Reproduce SuperNNova paper
To reproduce the results of the paper please use the branch paper
and run:
cd SuperNNova && python run_paper.py --debug --dump_dir tests/dump
--debug
will train simplified models with a reduced number of epochs. Remove this flag for full reproducibility.
With the --debug
flag on, this should take between 15 and 30 minutes on the CPU.
Naming conventions
vanilla/variational/bayesian: The type of RNN to be trained.
variational
andbayesian
are bayesian recurrent networksS_0: seed used for training. Default is 0.
CLF_2: number of targets to be used in classification. This case uses two classes: type Ia supernovae vs. all others.
R_None: host-galaxy redshift provided. Options:
zpho
(photometric) orzspe
(spectroscopic)photometry: data used. In our database we split light-curves that have a succesful SALT2 fit (
saltfit
) and the complete dataset (photometry
).DF_1.0: data fraction used in training. With large datasets it is usefult to test training with a fraction of the available training set. In this case we use the whole dataset (
1.0
).N_global: normalization used. Default:
global
.lstm: type of layer used. Default
lstm
.32x2: hidden layer dimension x number the layers.
0.05: dropout value.
128: batch size.
True: if this model is bidirectional.
mean: output option.
mean
is mean pooling.
The naming convention is defined in SuperNNova/conf.py
.