import torch
[docs]class VariationalRNN(torch.nn.Module):
def __init__(self, input_size, settings):
super(VariationalRNN, self).__init__()
# Params
self.input_size = input_size
self.layer_type = settings.layer_type
self.output_size = settings.nb_classes
self.hidden_size = settings.hidden_dim
self.num_layers = settings.num_layers
self.bidirectional = settings.bidirectional
self.dropout = settings.dropout
self.use_cuda = settings.use_cuda
self.rnn_output_option = settings.rnn_output_option
bidirectional_factor = 2 if self.bidirectional is True else 1
last_input_size = (
self.hidden_size * bidirectional_factor
if self.rnn_output_option == "mean"
else self.hidden_size * bidirectional_factor * self.num_layers
)
# Need to create recurrent layers one by one because we use variational dropout
self.rnn_layers = []
for i in range(self.num_layers):
if i == 0:
input_size = self.input_size
else:
previous_layer = getattr(self, f"rnn_layer{i - 1}")
input_size = previous_layer.module.hidden_size * bidirectional_factor
# Create recurrent layer
layer = getattr(torch.nn, self.layer_type.upper())(
input_size,
self.hidden_size,
num_layers=1,
dropout=0.0, # no dropout: we later create a specific layer for that
bidirectional=self.bidirectional,
)
# Apply weight drop
layer = WeightDrop(layer, ["weight_hh_l0"], dropout=self.dropout)
# Set layer as attribute
setattr(self, f"rnn_layer{i}", layer)
self.rnn_layers.append(layer)
self.recurrent_dropout_layer = VariationalRecurrentDropout()
self.output_dropout_layer = VariationalDropout()
self.output_layer = torch.nn.Linear(last_input_size, self.output_size)
def apply_recurrent_variational_dropout(
self, x, dropout_value, mean_field_inference=False
):
# apply dropout to input
if isinstance(x, torch.nn.utils.rnn.PackedSequence):
# Unpack
x, lengths = torch.nn.utils.rnn.pad_packed_sequence(x)
# Apply locked dropout
x = self.recurrent_dropout_layer(
x, dropout_value, mean_field_inference=mean_field_inference
)
# Repack
x = torch.nn.utils.rnn.pack_padded_sequence(x, lengths)
else:
x = self.recurrent_dropout_layer(
x, dropout_value, mean_field_inference=mean_field_inference
)
return x
[docs] def forward(self, x, mean_field_inference=False):
# apply variational dropout to input
x = self.apply_recurrent_variational_dropout(
x, self.dropout, mean_field_inference
)
list_hidden = []
for layer_idx, rnn_layer in enumerate(self.rnn_layers):
x, hidden = rnn_layer(x, mean_field_inference=mean_field_inference)
list_hidden.append(hidden)
# Apply Variational dropout between recurrent layers
if layer_idx != len(self.rnn_layers) - 1:
x = self.apply_recurrent_variational_dropout(
x, self.dropout, mean_field_inference=mean_field_inference
)
# Output options
# Standard: all layers, only end of pass
# - take last pass in all layers (hidden)
# - reshape and apply dropout
# - use h20 to obtain output (h2o input: hidden_size*num_layers*bi)
# Mean: last layer, mean on sequence
# - take packed output from last layer (out) that contains all time steps for the last layer
# - find where padding was done and create a mask for those values, apply this mask
# - take a mean for the whole sequence (time_steps)
# - use h2o to obtain output (beware! it is only one layer deep since it is the last one only)
if self.rnn_output_option == "standard":
# Special case for lstm where hidden = (h, c)
# In vanilla standard we have: out, hidden = rnn(X)
# hidden is built as follows:
# for each layer, for each direction, take the last hidden state
# concatenate the results. obtain hidden (num_layers * num_directions, B, D)
# here, we need to do it manually as the recurrent layers are written one by one
# We now carry out the ``concatenate`` operation.
if self.layer_type == "lstm":
hn = torch.cat([h[0] for h in list_hidden], dim=0)
else:
hn = torch.cat(list_hidden, dim=0)
hn = hn.permute(1, 2, 0).contiguous()
# hn is (num_layers * num_directions, batch, hidden_size)
batch_size = hn.shape[0]
# hn now is (batch, hidden size, num_layers * num_directions)
x = hn.view(batch_size, -1)
# x is (batch, hidden size * num_layers * num_directions)
if self.rnn_output_option == "mean":
if isinstance(x, torch.nn.utils.rnn.PackedSequence):
x, lens = torch.nn.utils.rnn.pad_packed_sequence(x)
# x is (seq_len, batch, hidden size)
# take mean over seq_len
x = x.sum(0) / lens.unsqueeze(-1).float().to(x.device)
# x is (batch, hidden_size)
else:
x = x.mean(0)
# x is (batch, hidden_size)
# apply dropout
x = self.output_dropout_layer(
x, self.dropout, mean_field_inference=mean_field_inference
)
# Final projection layer
output = self.output_layer(x)
return output
[docs]class VariationalRecurrentDropout(torch.nn.Module):
"""
This is a renamed Locked Dropout from https://github.com/salesforce/awd-lstm-lm
- We added a mean_field_inference flag to carry out mean field inference
We do this so that choosing whether to use dropout or not at inference time
is explicit
"""
def __init__(self):
super().__init__()
[docs] def forward(self, x, dropout, mean_field_inference=False):
if mean_field_inference is True:
return x
else:
m = x.data.new(1, x.size(1), x.size(2)).bernoulli_(1 - dropout)
mask = m / (1 - dropout) # rescaling to account for dropout
mask = mask.expand_as(x)
return mask * x
[docs]class VariationalDropout(torch.nn.Module):
"""Re-implementation of torch.nn.modules.Dropout
- training flag is always set to True
- We added a mean_field_inference flag to carry out mean field inference
We do this rather than using the training flag so that we can more explicitly decide
to use dropout or not at inference time
"""
[docs] def forward(self, x, dropout, mean_field_inference=False):
training = True
inplace = False
if mean_field_inference is True:
return x
else:
return torch.nn.functional.dropout(x, dropout, training, inplace)
[docs]class WeightDrop(torch.nn.Module):
"""
WeightDrop from https://github.com/salesforce/awd-lstm-lm
- Removed the variational input parameter as we will always use WeightDrop in variational mode
"""
def __init__(self, module, weights, dropout):
super(WeightDrop, self).__init__()
self.module = module
self.weights = weights
self.dropout = dropout
self._setup()
def dummy_function(*args, **kwargs):
# We need to replace flatten_parameters with a nothing function
# It must be a function rather than a lambda as otherwise pickling explodes
return
def _setup(self):
# Terrible temporary solution to an issue regarding compacting weights re: CUDNN RNN
if issubclass(type(self.module), torch.nn.RNNBase):
self.module.flatten_parameters = self.dummy_function
for name_w in self.weights:
print(f"Applying weight drop of {self.dropout} to {name_w}")
w = getattr(self.module, name_w) # take this param of the module
del self.module._parameters[name_w]
self.module.register_parameter(
name_w + "_raw", torch.nn.Parameter(w.data)
) # recreate same data with different name
def _setweights(self, mean_field_inference=False):
for name_w in self.weights:
raw_w = getattr(self.module, name_w + "_raw")
w = None
mask = torch.ones(raw_w.size(0), 1)
device = raw_w.device
mask = mask.to(device)
# forcing it to be applied training and validation
if mean_field_inference is False:
mask = torch.nn.functional.dropout(mask, p=self.dropout, training=True)
mask = mask.expand_as(raw_w)
# applying same mask to all elements of the batch
# h [batch, xdim]
w = mask * raw_w
setattr(self.module, name_w, w)
[docs] def forward(self, *args, mean_field_inference=False):
# Apply dropout to weights of module
self._setweights(mean_field_inference=mean_field_inference)
return self.module.forward(*args)
[docs]def embedded_dropout(embed, words, dropout, mean_field_inference=False):
"""
embedded_dropout from https://github.com/salesforce/awd-lstm-lm
with some modifications like the mean field inference flag
"""
if mean_field_inference:
masked_embed_weight = embed.weight
else:
mask = embed.weight.data.new().resize_((embed.weight.size(0), 1)).bernoulli_(
1 - dropout
).expand_as(embed.weight) / (1 - dropout)
masked_embed_weight = mask * embed.weight
padding_idx = embed.padding_idx
if padding_idx is None:
padding_idx = -1
X = torch.nn.functional.embedding(
words,
masked_embed_weight,
padding_idx,
embed.max_norm,
embed.norm_type,
embed.scale_grad_by_freq,
embed.sparse,
)
return X