import torch
import json
import numpy as np
import torch.nn as nn
from tqdm import tqdm
from time import time
from pathlib import Path
from ..utils import training_utils as tu
from ..utils import logging_utils as lu
[docs]def get_lr(settings):
"""Select optimal starting learning rate when training with a 1-cycle policy
Args:
settings (ExperimentSettings): controls experiment hyperparameters
"""
# Data
list_data_train, list_data_val = tu.load_HDF5(settings, test=False)
num_elem = len(list_data_train)
num_batches = num_elem // min(num_elem // 2, settings.batch_size)
list_batches = np.array_split(np.arange(num_elem), num_batches)
np.random.shuffle(list_batches)
lr_init_value = 1e-8
lr = float(lr_init_value)
lr_final_value = 10.0
beta = 0.98
avg_loss = 0.0
best_loss = 0.0
batch_num = 0
list_losses = []
list_lr = []
mult = (lr_final_value / lr_init_value) ** (1 / num_batches)
settings.learning_rate = lr_init_value
# Model specification
rnn = tu.get_model(settings, len(settings.training_features))
criterion = nn.CrossEntropyLoss()
optimizer = tu.get_optimizer(settings, rnn)
# Prepare for GPU if required
if settings.use_cuda:
rnn.cuda()
criterion.cuda()
for batch_idxs in tqdm(list_batches, ncols=100):
batch_num += 1
# Sample a batch in packed sequence form
packed, _, target_tensor, idxs_rev_sort = tu.get_data_batch(
list_data_train, batch_idxs, settings
)
# Train step : forward backward pass
loss = tu.train_step(
settings,
rnn,
packed,
target_tensor,
criterion,
optimizer,
target_tensor.size(0),
len(list_batches),
)
loss = loss.detach().cpu().numpy().item()
# Compute the smoothed loss
avg_loss = beta * avg_loss + (1 - beta) * loss
smoothed_loss = avg_loss / (1 - beta**batch_num)
# Stop if the loss is exploding
if batch_num > 1 and smoothed_loss > 4 * best_loss:
break
# Record the best loss
if smoothed_loss < best_loss or batch_num == 1:
best_loss = smoothed_loss
# Store the values
list_losses.append(smoothed_loss)
list_lr.append(lr)
# Update the lr for the next step
lr *= mult
# Set learning rate
for param_group in optimizer.param_groups:
param_group["lr"] = lr
idx_min = np.argmin(list_losses)
print("Min loss", list_losses[idx_min], "LR", list_lr[idx_min])
return list_lr[idx_min]
[docs]def train_cyclic(settings):
"""Train RNN models with a 1-cycle policy
Args:
settings (ExperimentSettings): controls experiment hyperparameters
"""
# save training data config
save_normalizations(settings)
max_learning_rate = get_lr(settings) / 10
min_learning_rate = max_learning_rate / 10
settings.learning_rate = min_learning_rate
print("Setting learning rate to", min_learning_rate)
def one_cycle_sched(epoch, minv, maxv, phases):
if epoch <= phases[0]:
out = minv + (maxv - minv) / (phases[0]) * epoch
elif phases[0] < epoch <= phases[1]:
increment = (minv - maxv) / (phases[1] - phases[0])
out = maxv + increment * (epoch - phases[0])
else:
increment = (minv / 100 - minv) / (phases[2] - phases[1])
out = minv + increment * (epoch - phases[1])
return out
# Data
list_data_train, list_data_val = tu.load_HDF5(settings, test=False)
# Model specification
rnn = tu.get_model(settings, len(settings.training_features))
criterion = nn.CrossEntropyLoss()
optimizer = tu.get_optimizer(settings, rnn)
# Prepare for GPU if required
if settings.use_cuda:
rnn.cuda()
criterion.cuda()
# Keep track of losses for plotting
loss_str = ""
d_monitor_train = {"loss": [], "AUC": [], "Acc": [], "epoch": []}
d_monitor_val = {"loss": [], "AUC": [], "Acc": [], "epoch": []}
if "bayesian" in settings.pytorch_model_name:
d_monitor_train["KL"] = []
d_monitor_val["KL"] = []
lu.print_green("Starting training")
best_loss = float("inf")
settings.cyclic_phases
training_start_time = time()
for epoch in tqdm(range(settings.cyclic_phases[-1]), desc="Training", ncols=100):
desc = f"Epoch: {epoch} -- {loss_str}"
num_elem = len(list_data_train)
num_batches = num_elem // min(num_elem // 2, settings.batch_size)
list_batches = np.array_split(np.arange(num_elem), num_batches)
np.random.shuffle(list_batches)
for batch_idxs in tqdm(
list_batches,
desc=desc,
ncols=100,
bar_format="{desc} |{bar}| {n_fmt}/{total_fmt} {rate_fmt}{postfix}",
):
# Sample a batch in packed sequence form
packed, _, target_tensor, idxs_rev_sort = tu.get_data_batch(
list_data_train, batch_idxs, settings
)
# Train step : forward backward pass
tu.train_step(
settings,
rnn,
packed,
target_tensor,
criterion,
optimizer,
target_tensor.size(0),
len(list_batches),
)
for param_group in optimizer.param_groups:
param_group["lr"] = one_cycle_sched(
epoch, min_learning_rate, max_learning_rate, settings.cyclic_phases
)
if (epoch + 1) % settings.monitor_interval == 0:
# Get metrics (subsample training set to same size as validation set for speed)
d_losses_train = tu.get_evaluation_metrics(
settings, list_data_train, rnn, sample_size=len(list_data_val)
)
d_losses_val = tu.get_evaluation_metrics(
settings, list_data_val, rnn, sample_size=None
)
# Add current loss avg to list of losses
for key in d_losses_train.keys():
d_monitor_train[key].append(d_losses_train[key])
d_monitor_val[key].append(d_losses_val[key])
d_monitor_train["epoch"].append(epoch + 1)
d_monitor_val["epoch"].append(epoch + 1)
# Prepare loss_str to update progress bar
loss_str = tu.get_loss_string(d_losses_train, d_losses_val)
tu.plot_loss(d_monitor_train, d_monitor_val, epoch, settings)
if d_monitor_val["loss"][-1] < best_loss:
best_loss = d_monitor_val["loss"][-1]
torch.save(
rnn.state_dict(),
f"{settings.rnn_dir}/{settings.pytorch_model_name}.pt",
)
training_time = time() - training_start_time
lu.print_green("Finished training")
tu.save_training_results(settings, d_monitor_val, training_time)
[docs]def save_normalizations(settings):
"""Save normalization used for training
Saves a json file with the normalization used for each feature
Arguments:
settings (ExperimentSettings): controls experiment hyperparameters
"""
dic_norm = {}
for i, f in enumerate(settings.training_features_to_normalize):
dic_norm[f] = {}
for j, w in enumerate(["min", "mean", "std"]):
dic_norm[f][w] = float(settings.arr_norm[i, j])
fname = f"{Path(settings.rnn_dir)}/data_norm.json"
with open(fname, "w") as f:
json.dump(dic_norm, f, indent=4, sort_keys=True)
[docs]def train(settings):
"""Train RNN models with a decay on plateau policy
Args:
settings (ExperimentSettings): controls experiment hyperparameters
"""
# save training data config
save_normalizations(settings)
# Data
list_data_train, list_data_val = tu.load_HDF5(settings, test=False)
# Model specification
rnn = tu.get_model(settings, len(settings.training_features))
criterion = nn.CrossEntropyLoss()
optimizer = tu.get_optimizer(settings, rnn)
# Prepare for GPU if required
if settings.use_cuda:
rnn.cuda()
criterion.cuda()
# Keep track of losses for plotting
loss_str = ""
d_monitor_train = {"loss": [], "AUC": [], "Acc": [], "epoch": []}
d_monitor_val = {"loss": [], "AUC": [], "Acc": [], "epoch": []}
if "bayesian" in settings.pytorch_model_name:
d_monitor_train["KL"] = []
d_monitor_val["KL"] = []
lu.print_green("Starting training")
plateau_accuracy = tu.StopOnPlateau(reduce_lr_on_plateau=True)
best_loss = float("inf")
training_start_time = time()
for epoch in tqdm(range(settings.nb_epoch), desc="Training", ncols=100):
desc = f"Epoch: {epoch} -- {loss_str}"
num_elem = len(list_data_train)
num_batches = num_elem // min(num_elem // 2, settings.batch_size)
list_batches = np.array_split(np.arange(num_elem), num_batches)
np.random.shuffle(list_batches)
for batch_idxs in tqdm(
list_batches,
desc=desc,
ncols=100,
bar_format="{desc} |{bar}| {n_fmt}/{total_fmt} {rate_fmt}{postfix}",
):
# Sample a batch in packed sequence form
packed, _, target_tensor, idxs_rev_sort = tu.get_data_batch(
list_data_train, batch_idxs, settings
)
# Exception for multiclass
if settings.nb_classes <= int(target_tensor.max()):
print("")
lu.print_red(
"All sntypes where not defined during database creation (multiclass fails)"
)
raise ValueError
# Train step : forward backward pass
tu.train_step(
settings,
rnn,
packed,
target_tensor,
criterion,
optimizer,
target_tensor.size(0),
len(list_batches),
)
if (epoch + 1) % settings.monitor_interval == 0:
# Get metrics (subsample training set to same size as validation set for speed)
d_losses_train = tu.get_evaluation_metrics(
settings, list_data_train, rnn, sample_size=len(list_data_val)
)
d_losses_val = tu.get_evaluation_metrics(
settings, list_data_val, rnn, sample_size=None
)
end_condition = plateau_accuracy.step(d_losses_val["Acc"], optimizer)
if end_condition is True:
break
# Add current loss avg to list of losses
for key in d_losses_train.keys():
d_monitor_train[key].append(d_losses_train[key])
d_monitor_val[key].append(d_losses_val[key])
d_monitor_train["epoch"].append(epoch + 1)
d_monitor_val["epoch"].append(epoch + 1)
# Prepare loss_str to update progress bar
loss_str = tu.get_loss_string(d_losses_train, d_losses_val)
tu.plot_loss(d_monitor_train, d_monitor_val, epoch, settings)
if d_monitor_val["loss"][-1] < best_loss:
best_loss = d_monitor_val["loss"][-1]
torch.save(
rnn.state_dict(),
f"{settings.rnn_dir}/{settings.pytorch_model_name}.pt",
)
lu.print_green("Finished training")
training_time = time() - training_start_time
tu.save_training_results(settings, d_monitor_val, training_time)