updating head masking, readme and docstrings

This commit is contained in:
thomwolf
2019-06-17 15:51:28 +02:00
parent 965f172de6
commit 33d3db5c43
7 changed files with 220 additions and 43 deletions

View File

@@ -477,8 +477,8 @@ class BertEncoder(nn.Module):
def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True, head_mask=None):
all_encoder_layers = []
all_attentions = []
for layer_module in self.layer:
hidden_states = layer_module(hidden_states, attention_mask, head_mask)
for i, layer_module in enumerate(self.layer):
hidden_states = layer_module(hidden_states, attention_mask, head_mask[i])
if self.output_attentions:
attentions, hidden_states = hidden_states
all_attentions.append(attentions)
@@ -618,6 +618,9 @@ class BertPreTrainedModel(nn.Module):
. `bert-base-multilingual-uncased`
. `bert-base-multilingual-cased`
. `bert-base-chinese`
. `bert-base-german-cased`
. `bert-large-uncased-whole-word-masking`
. `bert-large-cased-whole-word-masking`
- a path or url to a pretrained model archive containing:
. `bert_config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance
@@ -744,7 +747,10 @@ class BertModel(BertPreTrainedModel):
"""BERT model ("Bidirectional Embedding Representations from a Transformer").
Params:
config: a BertConfig class instance with the configuration to build a new model
`config`: a BertConfig class instance with the configuration to build a new model
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
This can be used to compute head importance metrics. Default: False
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
@@ -758,6 +764,9 @@ class BertModel(BertPreTrainedModel):
input sequence length in the current batch. It's the mask that we typically use for attention when
a batch has varying length sentences.
`output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`.
`head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
Outputs: Tuple of (encoded_layers, pooled_output)
`encoded_layers`: controled by `output_all_encoded_layers` argument:
@@ -828,15 +837,19 @@ class BertModel(BertPreTrainedModel):
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
# Prepare head mask if needed
# 1 in head_mask indicate we need to mask the head
# 1.0 in head_mask indicate we mask the head
# attention_probs has shape bsz x n_heads x N x N
# head_mask has shape num_hidden_layers x batch x n_heads x N x N
if head_mask is not None:
if head_mask.dim() == 1:
head_mask = head_mask.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
head_mask = head_mask.expand_as(self.config.num_hidden_layers, -1, -1, -1, -1)
elif head_mask.dim() == 2:
head_mask = head_mask.unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each instance in batch
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
head_mask = (1.0 - head_mask)
else:
head_mask = [None] * self.config.num_hidden_layers
embedding_output = self.embeddings(input_ids, token_type_ids)
encoded_layers = self.encoder(embedding_output,
@@ -861,7 +874,10 @@ class BertForPreTraining(BertPreTrainedModel):
- the next sentence classification head.
Params:
config: a BertConfig class instance with the configuration to build a new model.
`config`: a BertConfig class instance with the configuration to build a new model
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
This can be used to compute head importance metrics. Default: False
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
@@ -880,6 +896,8 @@ class BertForPreTraining(BertPreTrainedModel):
`next_sentence_label`: optional next sentence classification loss: torch.LongTensor of shape [batch_size]
with indices selected in [0, 1].
0 => next sentence is the continuation, 1 => next sentence is a random sentence.
`head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
Outputs:
if `masked_lm_labels` and `next_sentence_label` are not `None`:
@@ -937,7 +955,10 @@ class BertForMaskedLM(BertPreTrainedModel):
This module comprises the BERT model followed by the masked language modeling head.
Params:
config: a BertConfig class instance with the configuration to build a new model.
`config`: a BertConfig class instance with the configuration to build a new model
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
This can be used to compute head importance metrics. Default: False
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
@@ -953,6 +974,12 @@ class BertForMaskedLM(BertPreTrainedModel):
`masked_lm_labels`: masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length]
with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss
is only computed for the labels set in [0, ..., vocab_size]
`head_mask`: an optional torch.LongTensor of shape [num_heads] with indices
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
input sequence length in the current batch. It's the mask that we typically use for attention when
a batch has varying length sentences.
`head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
Outputs:
if `masked_lm_labels` is not `None`:
@@ -1006,7 +1033,10 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
This module comprises the BERT model followed by the next sentence classification head.
Params:
config: a BertConfig class instance with the configuration to build a new model.
`config`: a BertConfig class instance with the configuration to build a new model
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
This can be used to compute head importance metrics. Default: False
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
@@ -1022,6 +1052,8 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
`next_sentence_label`: next sentence classification loss: torch.LongTensor of shape [batch_size]
with indices selected in [0, 1].
0 => next sentence is the continuation, 1 => next sentence is a random sentence.
`head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
Outputs:
if `next_sentence_label` is not `None`:
@@ -1077,7 +1109,10 @@ class BertForSequenceClassification(BertPreTrainedModel):
the pooled output.
Params:
`config`: a BertConfig class instance with the configuration to build a new model.
`config`: a BertConfig class instance with the configuration to build a new model
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
This can be used to compute head importance metrics. Default: False
`num_labels`: the number of classes for the classifier. Default = 2.
Inputs:
@@ -1093,6 +1128,8 @@ class BertForSequenceClassification(BertPreTrainedModel):
a batch has varying length sentences.
`labels`: labels for the classification output: torch.LongTensor of shape [batch_size]
with indices selected in [0, ..., num_labels].
`head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
Outputs:
if `labels` is not `None`:
@@ -1150,7 +1187,10 @@ class BertForMultipleChoice(BertPreTrainedModel):
the pooled output.
Params:
`config`: a BertConfig class instance with the configuration to build a new model.
`config`: a BertConfig class instance with the configuration to build a new model
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
This can be used to compute head importance metrics. Default: False
`num_choices`: the number of classes for the classifier. Default = 2.
Inputs:
@@ -1166,6 +1206,8 @@ class BertForMultipleChoice(BertPreTrainedModel):
a batch has varying length sentences.
`labels`: labels for the classification output: torch.LongTensor of shape [batch_size]
with indices selected in [0, ..., num_choices].
`head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
Outputs:
if `labels` is not `None`:
@@ -1226,7 +1268,10 @@ class BertForTokenClassification(BertPreTrainedModel):
the full hidden state of the last layer.
Params:
`config`: a BertConfig class instance with the configuration to build a new model.
`config`: a BertConfig class instance with the configuration to build a new model
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
This can be used to compute head importance metrics. Default: False
`num_labels`: the number of classes for the classifier. Default = 2.
Inputs:
@@ -1242,6 +1287,8 @@ class BertForTokenClassification(BertPreTrainedModel):
a batch has varying length sentences.
`labels`: labels for the classification output: torch.LongTensor of shape [batch_size, sequence_length]
with indices selected in [0, ..., num_labels].
`head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
Outputs:
if `labels` is not `None`:
@@ -1306,7 +1353,10 @@ class BertForQuestionAnswering(BertPreTrainedModel):
the sequence output that computes start_logits and end_logits
Params:
`config`: a BertConfig class instance with the configuration to build a new model.
`config`: a BertConfig class instance with the configuration to build a new model
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
This can be used to compute head importance metrics. Default: False
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
@@ -1325,6 +1375,8 @@ class BertForQuestionAnswering(BertPreTrainedModel):
`end_positions`: position of the last token for the labeled span: torch.LongTensor of shape [batch_size].
Positions are clamped to the length of the sequence and position outside of the sequence are not taken
into account for computing the loss.
`head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
Outputs:
if `start_positions` and `end_positions` are not `None`: