From 8e11de0e86124d61ae883f8ce06d71dfaed9b01f Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Fri, 1 Nov 2019 16:28:32 +0000 Subject: [PATCH] model forwards can take an inputs_embeds param --- transformers/modeling_gpt2.py | 30 ++++++++++++++++++++---------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/transformers/modeling_gpt2.py b/transformers/modeling_gpt2.py index 4abe21d22b..6cbed33733 100644 --- a/transformers/modeling_gpt2.py +++ b/transformers/modeling_gpt2.py @@ -370,9 +370,15 @@ class GPT2Model(GPT2PreTrainedModel): for layer, heads in heads_to_prune.items(): self.h[layer].attn.prune_heads(heads) - def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None): - input_shape = input_ids.size() - input_ids = input_ids.view(-1, input_shape[-1]) + def forward(self, input_ids=None, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None): + if input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + if token_type_ids is not None: token_type_ids = token_type_ids.view(-1, input_shape[-1]) if position_ids is not None: @@ -384,8 +390,9 @@ class GPT2Model(GPT2PreTrainedModel): else: past_length = past[0][0].size(-2) if position_ids is None: - position_ids = torch.arange(past_length, input_ids.size(-1) + past_length, dtype=torch.long, device=input_ids.device) - position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) # Attention mask. if attention_mask is not None: @@ -419,7 +426,8 @@ class GPT2Model(GPT2PreTrainedModel): else: head_mask = [None] * self.config.n_layer - inputs_embeds = self.wte(input_ids) + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) position_embeds = self.wpe(position_ids) if token_type_ids is not None: token_type_embeds = self.wte(token_type_ids) @@ -520,14 +528,15 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): def get_output_embeddings(self): return self.lm_head - def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, + def forward(self, input_ids=None, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, labels=None): transformer_outputs = self.transformer(input_ids, past=past, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, - head_mask=head_mask) + head_mask=head_mask, + inputs_embeds=inputs_embeds) hidden_states = transformer_outputs[0] lm_logits = self.lm_head(hidden_states) @@ -623,14 +632,15 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): def get_output_embeddings(self): return self.lm_head - def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, + def forward(self, input_ids=None, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, mc_token_ids=None, lm_labels=None, mc_labels=None): transformer_outputs = self.transformer(input_ids, past=past, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, - head_mask=head_mask) + head_mask=head_mask, + inputs_embeds=inputs_embeds) hidden_states = transformer_outputs[0]