correct syntax error: dim() and not dims()
This commit is contained in:
@@ -634,13 +634,13 @@ class BertModel(BertPreTrainedModel):
|
|||||||
|
|
||||||
# we may want to provide a mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
# we may want to provide a mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||||
# ourselves in which case we just make it broadcastable to all heads.
|
# ourselves in which case we just make it broadcastable to all heads.
|
||||||
if attention_mask.dims() == 3:
|
if attention_mask.dim() == 3:
|
||||||
extended_attention_mask = attention_mask[:, None, :, :]
|
extended_attention_mask = attention_mask[:, None, :, :]
|
||||||
|
|
||||||
# provided a padding mask of dimensions [batch_size, seq_length]
|
# provided a padding mask of dimensions [batch_size, seq_length]
|
||||||
# - if encoder, make it broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
# - if encoder, make it broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||||
# - if decoder, make it causal
|
# - if decoder, make it causal
|
||||||
if attention_mask.dims() == 2:
|
if attention_mask.dim() == 2:
|
||||||
if self.config.is_decoder:
|
if self.config.is_decoder:
|
||||||
batch_size, seq_length = input_ids.size()
|
batch_size, seq_length = input_ids.size()
|
||||||
seq_ids = torch.arange(seq_length)
|
seq_ids = torch.arange(seq_length)
|
||||||
@@ -816,13 +816,15 @@ class BertForMaskedLM(BertPreTrainedModel):
|
|||||||
self.bert.embeddings.word_embeddings)
|
self.bert.embeddings.word_embeddings)
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
||||||
masked_lm_labels=None):
|
masked_lm_labels=None, encoder_hidden_states=None, encoder_attention_mask=None):
|
||||||
|
|
||||||
outputs = self.bert(input_ids,
|
outputs = self.bert(input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
head_mask=head_mask)
|
head_mask=head_mask,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask)
|
||||||
|
|
||||||
sequence_output = outputs[0]
|
sequence_output = outputs[0]
|
||||||
prediction_scores = self.cls(sequence_output)
|
prediction_scores = self.cls(sequence_output)
|
||||||
@@ -833,6 +835,15 @@ class BertForMaskedLM(BertPreTrainedModel):
|
|||||||
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
|
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
|
||||||
outputs = (masked_lm_loss,) + outputs
|
outputs = (masked_lm_loss,) + outputs
|
||||||
|
|
||||||
|
if encoder_hidden_states is not None:
|
||||||
|
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
||||||
|
|
||||||
|
# shift predictions scores and input ids by one before computing loss
|
||||||
|
prediction_scores = prediction_scores[:, :-1, :]
|
||||||
|
input_ids = input_ids[:, 1:, :]
|
||||||
|
seq2seq_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), input_ids.view(-1))
|
||||||
|
outputs = (seq2seq_loss,) + outputs
|
||||||
|
|
||||||
return outputs # (masked_lm_loss), prediction_scores, (hidden_states), (attentions)
|
return outputs # (masked_lm_loss), prediction_scores, (hidden_states), (attentions)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user