Quickstart guide (GitHub)

Welcome to SuperNNova!

This is a quick start guide so you can start testing our framework. If you want to install SuperNNova as a module, please take a look at Quickstart guide (pip).

Installation

Clone the GitHub repository

git clone https://github.com/supernnova/supernnova.git

Setup your environment. 3 options

Please beware that SuperNNova only runs properly in Unix systems (Linux, MacOS).
  1. Create a docker image: Docker .

  2. Create a conda virtual env Environment configuration .

  3. Install packages manually. Inspect conda_env.txt for the list of packages we use.

Usage

For quick tests, a database that contains a limited number of light-curves is provided. It is located in tests/raw. For more information on the available data, check Data walkthrough.

Using command line

Build the database .. code:

an additional argument --fits_dir tests/fits can provide a SALT2 fits file for random forest training and interpretation.

Train an RNN

python run.py --train_rnn --dump_dir tests/dump

With this command you are training and validating our Baseline RNN with the test database. The trained model will be saved in a newly created model folder inside tests/dump/models.

The model folder has been named as follows: vanilla_S_0_CLF_2_R_None_photometry_DF_1.0_N_global_lstm_32x2_0.05_128_True_mean (See below for the naming conventions). This folder’s contents are:

  • saved model (*.pt): PyTorch RNN model.

  • statistics (METRICS*.pickle): pickled Pandas DataFrame with accuracy and other performance statistics for this model.

  • predictions (PRED*.pickle): pickled Pandas DataFrame with the predictions of our model on the test set.

  • figures (train_and_val_*.png): figures showing the evolution of the chosen metric at each training step.

Remember that our data is split in training, validation and test sets.

Plot light-curves and their predictions

python run.py --dump_dir tests/dump --plot_lcs

You can now inspect the test light-curves and their predictions in tests/dump/lightcurves

You have trained, validated and tested your model.

Using Yaml

Build the database .. code:

python run_yml.py configs_yml/default.yml --mode data

Train an RNN

python run_yml.py configs_yml/default.yml --mode train_rnn

Available modes: data,``train_rnn``, validate_rnn, plot_lcs. Currently RF classification is not suppported with the yaml configurations. An example of classification using existing model is in configs_yml/classify.yml.

Reproduce SuperNNova paper

To reproduce the results of the paper please use the branch paper and run:

cd SuperNNova && python run_paper.py --debug --dump_dir tests/dump

--debug will train simplified models with a reduced number of epochs. Remove this flag for full reproducibility. With the --debug flag on, this should take between 15 and 30 minutes on the CPU.

Naming conventions

  • vanilla/variational/bayesian: The type of RNN to be trained. variational and bayesian are bayesian recurrent networks

  • S_0: seed used for training. Default is 0.

  • CLF_2: number of targets to be used in classification. This case uses two classes: type Ia supernovae vs. all others.

  • R_None: host-galaxy redshift provided. Options: zpho (photometric) or zspe (spectroscopic)

  • photometry: data used. In our database we split light-curves that have a succesful SALT2 fit (saltfit) and the complete dataset (photometry).

  • DF_1.0: data fraction used in training. With large datasets it is usefult to test training with a fraction of the available training set. In this case we use the whole dataset (1.0).

  • N_global: normalization used. Default: global.

  • lstm: type of layer used. Default lstm.

  • 32x2: hidden layer dimension x number the layers.

  • 0.05: dropout value.

  • 128: batch size.

  • True: if this model is bidirectional.

  • mean: output option. mean is mean pooling.

The naming convention is defined in SuperNNova/conf.py.