From e99b2014ccaa4a19846ccb5191e63b4bfdb1baa6 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Thu, 11 Apr 2019 11:43:13 +0200 Subject: [PATCH] fixes #471 --- pytorch_pretrained_bert/modeling_openai.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/pytorch_pretrained_bert/modeling_openai.py b/pytorch_pretrained_bert/modeling_openai.py index 296abbfc31..7bf643675e 100644 --- a/pytorch_pretrained_bert/modeling_openai.py +++ b/pytorch_pretrained_bert/modeling_openai.py @@ -371,8 +371,8 @@ class OpenAIGPTMultipleChoiceHead(nn.Module): def forward(self, hidden_states, mc_token_ids): # Classification logits # hidden_state (bsz, num_choices, seq_length, hidden_size) - # mc_token_ids (bsz, num_choices) - mc_token_ids = mc_token_ids.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, -1, hidden_states.size(-1)) + # mc_token_ids (bsz, num_choices, 1) + mc_token_ids = mc_token_ids.unsqueeze(-1).expand(-1, -1, -1, hidden_states.size(-1)) # (bsz, num_choices, 1, hidden_size) multiple_choice_h = hidden_states.gather(2, mc_token_ids).squeeze(2) # (bsz, num_choices, hidden_size) @@ -605,14 +605,12 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): return # Update config self.config.n_special = num_special_tokens - # # Build new embeddings and initialize + # Build new embeddings and initialize all new embeddings (in particular the special tokens) old_embed = self.tokens_embed self.tokens_embed = nn.Embedding(self.config.total_tokens_embeddings, self.config.n_embd) - # Initialize all new embeddings (in particular the special tokens) self.init_weights(self.tokens_embed) - # Copy word and positional embeddings from the previous weights - self.tokens_embed.weight.data[: self.config.vocab_size, :] = old_embed.weight.data[: self.config.vocab_size, :] - self.tokens_embed.weight.data[-self.config.n_positions :, :] = old_embed.weight.data[-self.config.n_positions :, :] + # Copy word embeddings from the previous weights + self.tokens_embed.weight.data[:self.config.vocab_size, :] = old_embed.weight.data[:self.config.vocab_size, :] def forward(self, input_ids, position_ids=None, token_type_ids=None): if position_ids is None: