diff --git a/pytorch_transformers/modeling_bert.py b/pytorch_transformers/modeling_bert.py index a044832282..27c8023c0a 100644 --- a/pytorch_transformers/modeling_bert.py +++ b/pytorch_transformers/modeling_bert.py @@ -253,7 +253,7 @@ class BertEmbeddings(nn.Module): self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) - def forward(self, input_ids, position_ids=None, token_type_ids=None): + def forward(self, input_ids, token_type_ids=None, position_ids=None): seq_length = input_ids.size(1) if position_ids is None: position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) @@ -667,7 +667,7 @@ class BertModel(BertPreTrainedModel): for layer, heads in heads_to_prune.items(): self.encoder.layer[layer].attention.prune_heads(heads) - def forward(self, input_ids, position_ids=None, token_type_ids=None, attention_mask=None, head_mask=None): + def forward(self, input_ids, token_type_ids=None, attention_mask=None, position_ids=None, head_mask=None): if attention_mask is None: attention_mask = torch.ones_like(input_ids) if token_type_ids is None: @@ -703,7 +703,7 @@ class BertModel(BertPreTrainedModel): else: head_mask = [None] * self.config.num_hidden_layers - embedding_output = self.embeddings(input_ids, position_ids, token_type_ids) + embedding_output = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids) encoder_outputs = self.encoder(embedding_output, extended_attention_mask, head_mask=head_mask) @@ -772,9 +772,10 @@ class BertForPreTraining(BertPreTrainedModel): self._tie_or_clone_weights(self.cls.predictions.decoder, self.bert.embeddings.word_embeddings) - def forward(self, input_ids, position_ids=None, token_type_ids=None, attention_mask=None, masked_lm_labels=None, - next_sentence_label=None, head_mask=None): - outputs = self.bert(input_ids, position_ids, token_type_ids, attention_mask, head_mask=head_mask) + def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, + next_sentence_label=None, position_ids=None, head_mask=None): + outputs = self.bert(input_ids, position_ids=position_ids, token_type_ids=token_type_ids, + attention_mask=attention_mask, head_mask=head_mask) sequence_output, pooled_output = outputs[:2] prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) @@ -841,8 +842,10 @@ class BertForMaskedLM(BertPreTrainedModel): self._tie_or_clone_weights(self.cls.predictions.decoder, self.bert.embeddings.word_embeddings) - def forward(self, input_ids, position_ids=None, token_type_ids=None, attention_mask=None, masked_lm_labels=None, head_mask=None): - outputs = self.bert(input_ids, position_ids, token_type_ids, attention_mask, head_mask=head_mask) + def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, + position_ids=None, head_mask=None): + outputs = self.bert(input_ids, position_ids=position_ids, token_type_ids=token_type_ids, + attention_mask=attention_mask, head_mask=head_mask) sequence_output = outputs[0] prediction_scores = self.cls(sequence_output) @@ -898,8 +901,10 @@ class BertForNextSentencePrediction(BertPreTrainedModel): self.apply(self.init_weights) - def forward(self, input_ids, position_ids=None, token_type_ids=None, attention_mask=None, next_sentence_label=None, head_mask=None): - outputs = self.bert(input_ids, position_ids, token_type_ids, attention_mask, head_mask=head_mask) + def forward(self, input_ids, token_type_ids=None, attention_mask=None, next_sentence_label=None, + position_ids=None, head_mask=None): + outputs = self.bert(input_ids, position_ids=position_ids, token_type_ids=token_type_ids, + attention_mask=attention_mask, head_mask=head_mask) pooled_output = outputs[1] seq_relationship_score = self.cls(pooled_output) @@ -959,8 +964,10 @@ class BertForSequenceClassification(BertPreTrainedModel): self.apply(self.init_weights) - def forward(self, input_ids, position_ids=None, token_type_ids=None, attention_mask=None, labels=None, head_mask=None): - outputs = self.bert(input_ids, position_ids, token_type_ids, attention_mask, head_mask=head_mask) + def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, + position_ids=None, head_mask=None): + outputs = self.bert(input_ids, position_ids=position_ids, token_type_ids=token_type_ids, + attention_mask=attention_mask, head_mask=head_mask) pooled_output = outputs[1] pooled_output = self.dropout(pooled_output) @@ -1063,14 +1070,16 @@ class BertForMultipleChoice(BertPreTrainedModel): self.apply(self.init_weights) - def forward(self, input_ids, position_ids=None, token_type_ids=None, attention_mask=None, labels=None, head_mask=None): + def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, + position_ids=None, head_mask=None): num_choices = input_ids.shape[1] flat_input_ids = input_ids.view(-1, input_ids.size(-1)) flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None - outputs = self.bert(flat_input_ids, flat_position_ids, flat_token_type_ids, flat_attention_mask, head_mask=head_mask) + outputs = self.bert(flat_input_ids, position_ids=flat_position_ids, token_type_ids=flat_token_type_ids, + attention_mask=flat_attention_mask, head_mask=head_mask) pooled_output = outputs[1] pooled_output = self.dropout(pooled_output) @@ -1131,8 +1140,10 @@ class BertForTokenClassification(BertPreTrainedModel): self.apply(self.init_weights) - def forward(self, input_ids, position_ids=None, token_type_ids=None, attention_mask=None, labels=None, head_mask=None): - outputs = self.bert(input_ids, position_ids, token_type_ids, attention_mask, head_mask=head_mask) + def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, + position_ids=None, head_mask=None): + outputs = self.bert(input_ids, position_ids=position_ids, token_type_ids=token_type_ids, + attention_mask=attention_mask, head_mask=head_mask) sequence_output = outputs[0] sequence_output = self.dropout(sequence_output) @@ -1205,9 +1216,10 @@ class BertForQuestionAnswering(BertPreTrainedModel): self.apply(self.init_weights) - def forward(self, input_ids, position_ids=None, token_type_ids=None, attention_mask=None, start_positions=None, - end_positions=None, head_mask=None): - outputs = self.bert(input_ids, position_ids, token_type_ids, attention_mask, head_mask=head_mask) + def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, + end_positions=None, position_ids=None, head_mask=None): + outputs = self.bert(input_ids, position_ids=position_ids, token_type_ids=token_type_ids, + attention_mask=attention_mask, head_mask=head_mask) sequence_output = outputs[0] logits = self.qa_outputs(sequence_output) diff --git a/pytorch_transformers/modeling_gpt2.py b/pytorch_transformers/modeling_gpt2.py index 415396496c..8edd7555db 100644 --- a/pytorch_transformers/modeling_gpt2.py +++ b/pytorch_transformers/modeling_gpt2.py @@ -591,7 +591,8 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): self.transformer.wte) def forward(self, input_ids, position_ids=None, token_type_ids=None, labels=None, past=None, head_mask=None): - transformer_outputs = self.transformer(input_ids, position_ids, token_type_ids, past, head_mask) + transformer_outputs = self.transformer(input_ids, position_ids=position_ids, token_type_ids=token_type_ids, + past=past, head_mask=head_mask) hidden_states = transformer_outputs[0] lm_logits = self.lm_head(hidden_states) @@ -709,7 +710,8 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): def forward(self, input_ids, mc_token_ids=None, lm_labels=None, mc_labels=None, token_type_ids=None, position_ids=None, past=None, head_mask=None): - transformer_outputs = self.transformer(input_ids, position_ids, token_type_ids, past, head_mask) + transformer_outputs = self.transformer(input_ids, position_ids=position_ids, token_type_ids=token_type_ids, + past=past, head_mask=head_mask) hidden_states = transformer_outputs[0] lm_logits = self.lm_head(hidden_states) diff --git a/pytorch_transformers/modeling_openai.py b/pytorch_transformers/modeling_openai.py index d51e4309b8..ebd4166b99 100644 --- a/pytorch_transformers/modeling_openai.py +++ b/pytorch_transformers/modeling_openai.py @@ -582,7 +582,8 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel): self.transformer.tokens_embed) def forward(self, input_ids, position_ids=None, token_type_ids=None, labels=None, head_mask=None): - transformer_outputs = self.transformer(input_ids, position_ids, token_type_ids, head_mask) + transformer_outputs = self.transformer(input_ids, position_ids=position_ids, token_type_ids=token_type_ids, + head_mask=head_mask) hidden_states = transformer_outputs[0] lm_logits = self.lm_head(hidden_states) @@ -693,7 +694,8 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): def forward(self, input_ids, mc_token_ids=None, lm_labels=None, mc_labels=None, token_type_ids=None, position_ids=None, head_mask=None): - transformer_outputs = self.transformer(input_ids, position_ids, token_type_ids, head_mask) + transformer_outputs = self.transformer(input_ids, position_ids=position_ids, token_type_ids=token_type_ids, + head_mask=head_mask) hidden_states = transformer_outputs[0] lm_logits = self.lm_head(hidden_states) diff --git a/pytorch_transformers/modeling_transfo_xl.py b/pytorch_transformers/modeling_transfo_xl.py index d9c8cba8db..c9ae7cd1a9 100644 --- a/pytorch_transformers/modeling_transfo_xl.py +++ b/pytorch_transformers/modeling_transfo_xl.py @@ -1344,7 +1344,7 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel): bsz = input_ids.size(0) tgt_len = input_ids.size(1) - transformer_outputs = self.transformer(input_ids, mems, head_mask) + transformer_outputs = self.transformer(input_ids, mems=mems, head_mask=head_mask) last_hidden = transformer_outputs[0] pred_hid = last_hidden[:, -tgt_len:] diff --git a/pytorch_transformers/modeling_utils.py b/pytorch_transformers/modeling_utils.py index 7a9777a0eb..3f21c98b04 100644 --- a/pytorch_transformers/modeling_utils.py +++ b/pytorch_transformers/modeling_utils.py @@ -594,7 +594,7 @@ class SQuADHead(nn.Module): """ outputs = () - start_logits = self.start_logits(hidden_states, p_mask) + start_logits = self.start_logits(hidden_states, p_mask=p_mask) if start_positions is not None and end_positions is not None: # If we are on multi-GPU, let's remove the dimension added by batch splitting diff --git a/pytorch_transformers/modeling_xlm.py b/pytorch_transformers/modeling_xlm.py index 33b5bcf7fe..7d08c462ad 100644 --- a/pytorch_transformers/modeling_xlm.py +++ b/pytorch_transformers/modeling_xlm.py @@ -768,8 +768,9 @@ class XLMWithLMHeadModel(XLMPreTrainedModel): def forward(self, input_ids, lengths=None, position_ids=None, langs=None, token_type_ids=None, attention_mask=None, cache=None, labels=None, head_mask=None): - transformer_outputs = self.transformer(input_ids, lengths=lengths, position_ids=position_ids, token_type_ids=token_type_ids, - langs=langs, attention_mask=attention_mask, cache=cache, head_mask=head_mask) + transformer_outputs = self.transformer(input_ids, lengths=lengths, position_ids=position_ids, + token_type_ids=token_type_ids, langs=langs, + attention_mask=attention_mask, cache=cache, head_mask=head_mask) output = transformer_outputs[0] outputs = self.pred_layer(output, labels) @@ -825,8 +826,9 @@ class XLMForSequenceClassification(XLMPreTrainedModel): def forward(self, input_ids, lengths=None, position_ids=None, langs=None, token_type_ids=None, attention_mask=None, cache=None, labels=None, head_mask=None): - transformer_outputs = self.transformer(input_ids, lengths=lengths, position_ids=position_ids, token_type_ids=token_type_ids, - langs=langs, attention_mask=attention_mask, cache=cache, head_mask=head_mask) + transformer_outputs = self.transformer(input_ids, lengths=lengths, position_ids=position_ids, + token_type_ids=token_type_ids, langs=langs, + attention_mask=attention_mask, cache=cache, head_mask=head_mask) output = transformer_outputs[0] logits = self.sequence_summary(output) @@ -905,8 +907,9 @@ class XLMForQuestionAnswering(XLMPreTrainedModel): def forward(self, input_ids, lengths=None, position_ids=None, langs=None, token_type_ids=None, attention_mask=None, cache=None, start_positions=None, end_positions=None, cls_index=None, is_impossible=None, p_mask=None, head_mask=None): - transformer_outputs = self.transformer(input_ids, lengths=lengths, position_ids=position_ids, token_type_ids=token_type_ids, - langs=langs, attention_mask=attention_mask, cache=cache, head_mask=head_mask) + transformer_outputs = self.transformer(input_ids, lengths=lengths, position_ids=position_ids, + token_type_ids=token_type_ids, langs=langs, + attention_mask=attention_mask, cache=cache, head_mask=head_mask) output = transformer_outputs[0] diff --git a/pytorch_transformers/modeling_xlnet.py b/pytorch_transformers/modeling_xlnet.py index a46426d82a..5e576c51c1 100644 --- a/pytorch_transformers/modeling_xlnet.py +++ b/pytorch_transformers/modeling_xlnet.py @@ -1049,8 +1049,10 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mask=None, mems=None, perm_mask=None, target_mapping=None, labels=None, head_mask=None): - transformer_outputs = self.transformer(input_ids, token_type_ids, input_mask, attention_mask, - mems, perm_mask, target_mapping, head_mask) + transformer_outputs = self.transformer(input_ids, token_type_ids=token_type_ids, + input_mask=input_mask, attention_mask=attention_mask, + mems=mems, perm_mask=perm_mask, target_mapping=target_mapping, + head_mask=head_mask) logits = self.lm_loss(transformer_outputs[0]) @@ -1119,8 +1121,10 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel): def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mask=None, mems=None, perm_mask=None, target_mapping=None, labels=None, head_mask=None): - transformer_outputs = self.transformer(input_ids, token_type_ids, input_mask, attention_mask, - mems, perm_mask, target_mapping, head_mask) + transformer_outputs = self.transformer(input_ids, token_type_ids=token_type_ids, + input_mask=input_mask, attention_mask=attention_mask, + mems=mems, perm_mask=perm_mask, target_mapping=target_mapping, + head_mask=head_mask) output = transformer_outputs[0] output = self.sequence_summary(output) @@ -1209,10 +1213,12 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel): mems=None, perm_mask=None, target_mapping=None, start_positions=None, end_positions=None, cls_index=None, is_impossible=None, p_mask=None, head_mask=None): - transformer_outputs = self.transformer(input_ids, token_type_ids, input_mask, attention_mask, - mems, perm_mask, target_mapping, head_mask) + transformer_outputs = self.transformer(input_ids, token_type_ids=token_type_ids, + input_mask=input_mask, attention_mask=attention_mask, + mems=mems, perm_mask=perm_mask, target_mapping=target_mapping, + head_mask=head_mask) hidden_states = transformer_outputs[0] - start_logits = self.start_logits(hidden_states, p_mask) + start_logits = self.start_logits(hidden_states, p_mask=p_mask) outputs = transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it