From a424892fab8ddfe631d7498bc44072aa3a42eb3c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Wed, 16 Oct 2019 18:24:32 +0200 Subject: [PATCH] correct syntax error: dim() and not dims() --- transformers/modeling_bert.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/transformers/modeling_bert.py b/transformers/modeling_bert.py index cd9151cf62..e717031dcb 100644 --- a/transformers/modeling_bert.py +++ b/transformers/modeling_bert.py @@ -634,13 +634,13 @@ class BertModel(BertPreTrainedModel): # we may want to provide a mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just make it broadcastable to all heads. - if attention_mask.dims() == 3: + if attention_mask.dim() == 3: extended_attention_mask = attention_mask[:, None, :, :] # provided a padding mask of dimensions [batch_size, seq_length] # - if encoder, make it broadcastable to [batch_size, num_heads, seq_length, seq_length] # - if decoder, make it causal - if attention_mask.dims() == 2: + if attention_mask.dim() == 2: if self.config.is_decoder: batch_size, seq_length = input_ids.size() seq_ids = torch.arange(seq_length) @@ -816,13 +816,15 @@ class BertForMaskedLM(BertPreTrainedModel): self.bert.embeddings.word_embeddings) def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, - masked_lm_labels=None): + masked_lm_labels=None, encoder_hidden_states=None, encoder_attention_mask=None): outputs = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, - head_mask=head_mask) + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask) sequence_output = outputs[0] prediction_scores = self.cls(sequence_output) @@ -833,6 +835,15 @@ class BertForMaskedLM(BertPreTrainedModel): masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) outputs = (masked_lm_loss,) + outputs + if encoder_hidden_states is not None: + loss_fct = CrossEntropyLoss(ignore_index=-1) + + # shift predictions scores and input ids by one before computing loss + prediction_scores = prediction_scores[:, :-1, :] + input_ids = input_ids[:, 1:, :] + seq2seq_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), input_ids.view(-1)) + outputs = (seq2seq_loss,) + outputs + return outputs # (masked_lm_loss), prediction_scores, (hidden_states), (attentions)