resolve PR comments
This commit is contained in:
@@ -632,6 +632,8 @@ class BertModel(BertPreTrainedModel):
|
||||
"""
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
if encoder_attention_mask is None:
|
||||
encoder_attention_mask = torch.ones_like(input_ids)
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros_like(input_ids)
|
||||
|
||||
@@ -660,12 +662,15 @@ class BertModel(BertPreTrainedModel):
|
||||
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||
|
||||
# If a 2D encoder attention mask is provided for the cross-attention
|
||||
# If a 2D ou 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
|
||||
if encoder_attention_mask is not None:
|
||||
encoder_attention_mask = encoder_attention_mask[:, None, None, :]
|
||||
encoder_attention_mask = encoder_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
||||
encoder_attention_mask = (1.0 - encoder_attention_mask) * -10000.0
|
||||
if encoder_attention_mask.dim() == 3:
|
||||
encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
|
||||
if encoder_attention_mask.dim() == 2:
|
||||
encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
|
||||
|
||||
encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
||||
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
@@ -687,7 +692,7 @@ class BertModel(BertPreTrainedModel):
|
||||
attention_mask=extended_attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask)
|
||||
encoder_attention_mask=encoder_extended_attention_mask)
|
||||
sequence_output = encoder_outputs[0]
|
||||
pooled_output = self.pooler(sequence_output)
|
||||
|
||||
@@ -788,8 +793,10 @@ class BertForMaskedLM(BertPreTrainedModel):
|
||||
in ``[0, ..., config.vocab_size]``
|
||||
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
**loss**: (`optional`, returned when ``masked_lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
||||
**masked_lm_loss**: (`optional`, returned when ``masked_lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
||||
Masked language modeling loss.
|
||||
**next_token_loss**: (`optional`, returned when ``lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
||||
Next token prediction loss.
|
||||
**prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
|
||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
||||
@@ -854,13 +861,13 @@ class BertForMaskedLM(BertPreTrainedModel):
|
||||
|
||||
if lm_labels is not None:
|
||||
# we are doing next-token prediction; shift prediction scores and input ids by one
|
||||
prediction_scores = prediction_scores[:, :-1, :]
|
||||
lm_labels = lm_labels[:, 1:]
|
||||
prediction_scores = prediction_scores[:, :-1, :].contiguous()
|
||||
lm_labels = lm_labels[:, 1:].contiguous()
|
||||
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
||||
seq2seq_loss = loss_fct(prediction_scores.reshape(-1, self.config.vocab_size), lm_labels.reshape(-1))
|
||||
outputs = (seq2seq_loss,) + outputs
|
||||
next_token_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), lm_labels.view(-1))
|
||||
outputs = (next_token_loss,) + outputs
|
||||
|
||||
return outputs # (mlm_or_seq2seq_loss), prediction_scores, (hidden_states), (attentions)
|
||||
return outputs # (masked_lm_loss), (next_token_loss), prediction_scores, (hidden_states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings("""Bert Model with a `next sentence prediction (classification)` head on top. """,
|
||||
|
||||
Reference in New Issue
Block a user