here's one big commit

This commit is contained in:
Rémi Louf
2019-10-18 12:29:30 +02:00
parent 932543f77e
commit 4c3ac4a7d8
8 changed files with 951 additions and 408 deletions

View File

@@ -646,7 +646,7 @@ class BertModel(BertPreTrainedModel):
if attention_mask.dim() == 2:
if self.config.is_decoder:
batch_size, seq_length = input_ids.size()
seq_ids = torch.arange(seq_length)
seq_ids = torch.arange(seq_length, device=input_ids.device)
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
else:
@@ -660,6 +660,13 @@ class BertModel(BertPreTrainedModel):
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
# If a 2D encoder attention mask is provided for the cross-attention
# we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
if encoder_attention_mask is not None:
encoder_attention_mask = encoder_attention_mask[:, None, None, :]
encoder_attention_mask = encoder_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
encoder_attention_mask = (1.0 - encoder_attention_mask) * -10000.0
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
@@ -819,7 +826,7 @@ class BertForMaskedLM(BertPreTrainedModel):
self.bert.embeddings.word_embeddings)
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
masked_lm_labels=None, lm_labels=None, encoder_hidden_states=None, encoder_attention_mask=None):
masked_lm_labels=None, encoder_hidden_states=None, encoder_attention_mask=None, lm_labels=None, ):
outputs = self.bert(input_ids,
attention_mask=attention_mask,
@@ -838,11 +845,8 @@ class BertForMaskedLM(BertPreTrainedModel):
# 1. If a tensor that contains the indices of masked labels is provided,
# the cross-entropy is the MLM cross-entropy that measures the likelihood
# of predictions for masked words.
# 2. If encoder hidden states are provided we are in a causal situation where we
# 2. If `lm_label` is provided we are in a causal scenario where we
# try to predict the next word for each input in the encoder.
if masked_lm_labels is not None and lm_labels is not None:
raise AttributeError("Masked LM training with an encoder-decoder is not supported.")
if masked_lm_labels is not None:
loss_fct = CrossEntropyLoss(ignore_index=-1) # -1 index = padding token
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
@@ -851,9 +855,9 @@ class BertForMaskedLM(BertPreTrainedModel):
if lm_labels is not None:
# we are doing next-token prediction; shift prediction scores and input ids by one
prediction_scores = prediction_scores[:, :-1, :]
lm_labels = lm_labels[:, 1:, :]
lm_labels = lm_labels[:, 1:]
loss_fct = CrossEntropyLoss(ignore_index=-1)
seq2seq_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), lm_labels.view(-1))
seq2seq_loss = loss_fct(prediction_scores.reshape(-1, self.config.vocab_size), lm_labels.reshape(-1))
outputs = (seq2seq_loss,) + outputs
return outputs # (mlm_or_seq2seq_loss), prediction_scores, (hidden_states), (attentions)