fixes #471
This commit is contained in:
@@ -371,8 +371,8 @@ class OpenAIGPTMultipleChoiceHead(nn.Module):
|
|||||||
def forward(self, hidden_states, mc_token_ids):
|
def forward(self, hidden_states, mc_token_ids):
|
||||||
# Classification logits
|
# Classification logits
|
||||||
# hidden_state (bsz, num_choices, seq_length, hidden_size)
|
# hidden_state (bsz, num_choices, seq_length, hidden_size)
|
||||||
# mc_token_ids (bsz, num_choices)
|
# mc_token_ids (bsz, num_choices, 1)
|
||||||
mc_token_ids = mc_token_ids.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, -1, hidden_states.size(-1))
|
mc_token_ids = mc_token_ids.unsqueeze(-1).expand(-1, -1, -1, hidden_states.size(-1))
|
||||||
# (bsz, num_choices, 1, hidden_size)
|
# (bsz, num_choices, 1, hidden_size)
|
||||||
multiple_choice_h = hidden_states.gather(2, mc_token_ids).squeeze(2)
|
multiple_choice_h = hidden_states.gather(2, mc_token_ids).squeeze(2)
|
||||||
# (bsz, num_choices, hidden_size)
|
# (bsz, num_choices, hidden_size)
|
||||||
@@ -605,14 +605,12 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
|
|||||||
return
|
return
|
||||||
# Update config
|
# Update config
|
||||||
self.config.n_special = num_special_tokens
|
self.config.n_special = num_special_tokens
|
||||||
# # Build new embeddings and initialize
|
# Build new embeddings and initialize all new embeddings (in particular the special tokens)
|
||||||
old_embed = self.tokens_embed
|
old_embed = self.tokens_embed
|
||||||
self.tokens_embed = nn.Embedding(self.config.total_tokens_embeddings, self.config.n_embd)
|
self.tokens_embed = nn.Embedding(self.config.total_tokens_embeddings, self.config.n_embd)
|
||||||
# Initialize all new embeddings (in particular the special tokens)
|
|
||||||
self.init_weights(self.tokens_embed)
|
self.init_weights(self.tokens_embed)
|
||||||
# Copy word and positional embeddings from the previous weights
|
# Copy word embeddings from the previous weights
|
||||||
self.tokens_embed.weight.data[: self.config.vocab_size, :] = old_embed.weight.data[: self.config.vocab_size, :]
|
self.tokens_embed.weight.data[:self.config.vocab_size, :] = old_embed.weight.data[:self.config.vocab_size, :]
|
||||||
self.tokens_embed.weight.data[-self.config.n_positions :, :] = old_embed.weight.data[-self.config.n_positions :, :]
|
|
||||||
|
|
||||||
def forward(self, input_ids, position_ids=None, token_type_ids=None):
|
def forward(self, input_ids, position_ids=None, token_type_ids=None):
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
|
|||||||
Reference in New Issue
Block a user