fixed tests

This commit is contained in:
thomwolf
2019-07-15 12:32:19 +02:00
parent e28d8bde0d
commit f7cd7392fd
7 changed files with 63 additions and 38 deletions

View File

@@ -253,7 +253,7 @@ class BertEmbeddings(nn.Module):
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, input_ids, position_ids=None, token_type_ids=None):
def forward(self, input_ids, token_type_ids=None, position_ids=None):
seq_length = input_ids.size(1)
if position_ids is None:
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
@@ -667,7 +667,7 @@ class BertModel(BertPreTrainedModel):
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
def forward(self, input_ids, position_ids=None, token_type_ids=None, attention_mask=None, head_mask=None):
def forward(self, input_ids, token_type_ids=None, attention_mask=None, position_ids=None, head_mask=None):
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
if token_type_ids is None:
@@ -703,7 +703,7 @@ class BertModel(BertPreTrainedModel):
else:
head_mask = [None] * self.config.num_hidden_layers
embedding_output = self.embeddings(input_ids, position_ids, token_type_ids)
embedding_output = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids)
encoder_outputs = self.encoder(embedding_output,
extended_attention_mask,
head_mask=head_mask)
@@ -772,9 +772,10 @@ class BertForPreTraining(BertPreTrainedModel):
self._tie_or_clone_weights(self.cls.predictions.decoder,
self.bert.embeddings.word_embeddings)
def forward(self, input_ids, position_ids=None, token_type_ids=None, attention_mask=None, masked_lm_labels=None,
next_sentence_label=None, head_mask=None):
outputs = self.bert(input_ids, position_ids, token_type_ids, attention_mask, head_mask=head_mask)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None,
next_sentence_label=None, position_ids=None, head_mask=None):
outputs = self.bert(input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
attention_mask=attention_mask, head_mask=head_mask)
sequence_output, pooled_output = outputs[:2]
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
@@ -841,8 +842,10 @@ class BertForMaskedLM(BertPreTrainedModel):
self._tie_or_clone_weights(self.cls.predictions.decoder,
self.bert.embeddings.word_embeddings)
def forward(self, input_ids, position_ids=None, token_type_ids=None, attention_mask=None, masked_lm_labels=None, head_mask=None):
outputs = self.bert(input_ids, position_ids, token_type_ids, attention_mask, head_mask=head_mask)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None,
position_ids=None, head_mask=None):
outputs = self.bert(input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
attention_mask=attention_mask, head_mask=head_mask)
sequence_output = outputs[0]
prediction_scores = self.cls(sequence_output)
@@ -898,8 +901,10 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
self.apply(self.init_weights)
def forward(self, input_ids, position_ids=None, token_type_ids=None, attention_mask=None, next_sentence_label=None, head_mask=None):
outputs = self.bert(input_ids, position_ids, token_type_ids, attention_mask, head_mask=head_mask)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, next_sentence_label=None,
position_ids=None, head_mask=None):
outputs = self.bert(input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
attention_mask=attention_mask, head_mask=head_mask)
pooled_output = outputs[1]
seq_relationship_score = self.cls(pooled_output)
@@ -959,8 +964,10 @@ class BertForSequenceClassification(BertPreTrainedModel):
self.apply(self.init_weights)
def forward(self, input_ids, position_ids=None, token_type_ids=None, attention_mask=None, labels=None, head_mask=None):
outputs = self.bert(input_ids, position_ids, token_type_ids, attention_mask, head_mask=head_mask)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,
position_ids=None, head_mask=None):
outputs = self.bert(input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
attention_mask=attention_mask, head_mask=head_mask)
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output)
@@ -1063,14 +1070,16 @@ class BertForMultipleChoice(BertPreTrainedModel):
self.apply(self.init_weights)
def forward(self, input_ids, position_ids=None, token_type_ids=None, attention_mask=None, labels=None, head_mask=None):
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,
position_ids=None, head_mask=None):
num_choices = input_ids.shape[1]
flat_input_ids = input_ids.view(-1, input_ids.size(-1))
flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
outputs = self.bert(flat_input_ids, flat_position_ids, flat_token_type_ids, flat_attention_mask, head_mask=head_mask)
outputs = self.bert(flat_input_ids, position_ids=flat_position_ids, token_type_ids=flat_token_type_ids,
attention_mask=flat_attention_mask, head_mask=head_mask)
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output)
@@ -1131,8 +1140,10 @@ class BertForTokenClassification(BertPreTrainedModel):
self.apply(self.init_weights)
def forward(self, input_ids, position_ids=None, token_type_ids=None, attention_mask=None, labels=None, head_mask=None):
outputs = self.bert(input_ids, position_ids, token_type_ids, attention_mask, head_mask=head_mask)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,
position_ids=None, head_mask=None):
outputs = self.bert(input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
attention_mask=attention_mask, head_mask=head_mask)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output)
@@ -1205,9 +1216,10 @@ class BertForQuestionAnswering(BertPreTrainedModel):
self.apply(self.init_weights)
def forward(self, input_ids, position_ids=None, token_type_ids=None, attention_mask=None, start_positions=None,
end_positions=None, head_mask=None):
outputs = self.bert(input_ids, position_ids, token_type_ids, attention_mask, head_mask=head_mask)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None,
end_positions=None, position_ids=None, head_mask=None):
outputs = self.bert(input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
attention_mask=attention_mask, head_mask=head_mask)
sequence_output = outputs[0]
logits = self.qa_outputs(sequence_output)