Merge pull request #490 from huggingface/better_finetuning_GPT_GPT-2
Clean up GPT and GPT-2 losses computation
This commit is contained in:
@@ -608,6 +608,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
|
|||||||
# Build new embeddings and initialize all new embeddings (in particular the special tokens)
|
# Build new embeddings and initialize all new embeddings (in particular the special tokens)
|
||||||
old_embed = self.tokens_embed
|
old_embed = self.tokens_embed
|
||||||
self.tokens_embed = nn.Embedding(self.config.total_tokens_embeddings, self.config.n_embd)
|
self.tokens_embed = nn.Embedding(self.config.total_tokens_embeddings, self.config.n_embd)
|
||||||
|
self.tokens_embed.to(old_embed.weight.device)
|
||||||
self.init_weights(self.tokens_embed)
|
self.init_weights(self.tokens_embed)
|
||||||
# Copy word embeddings from the previous weights
|
# Copy word embeddings from the previous weights
|
||||||
self.tokens_embed.weight.data[:self.config.vocab_size, :] = old_embed.weight.data[:self.config.vocab_size, :]
|
self.tokens_embed.weight.data[:self.config.vocab_size, :] = old_embed.weight.data[:self.config.vocab_size, :]
|
||||||
@@ -715,9 +716,8 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
|
|||||||
lm_logits = self.lm_head(hidden_states)
|
lm_logits = self.lm_head(hidden_states)
|
||||||
if lm_labels is not None:
|
if lm_labels is not None:
|
||||||
# Shift so that tokens < n predict n
|
# Shift so that tokens < n predict n
|
||||||
shift_logits = lm_logits[:, :-1].contiguous()
|
shift_logits = lm_logits[..., :-1, :].contiguous()
|
||||||
shift_labels = lm_labels[:, 1:].contiguous()
|
shift_labels = lm_labels[..., 1:].contiguous()
|
||||||
|
|
||||||
# Flatten the tokens
|
# Flatten the tokens
|
||||||
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
||||||
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
|
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
|
||||||
@@ -807,11 +807,10 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
|
|||||||
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids)
|
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids)
|
||||||
losses = []
|
losses = []
|
||||||
if lm_labels is not None:
|
if lm_labels is not None:
|
||||||
shift_logits = lm_logits[:, :-1].contiguous()
|
shift_logits = lm_logits[..., :-1, :].contiguous()
|
||||||
shift_labels = lm_labels[:, 1:].contiguous()
|
shift_labels = lm_labels[..., 1:].contiguous()
|
||||||
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
||||||
losses.append(loss_fct(shift_logits.view(-1,
|
losses.append(loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)))
|
||||||
shift_logits.size(-1)), shift_labels.view(-1)))
|
|
||||||
if mc_labels is not None:
|
if mc_labels is not None:
|
||||||
loss_fct = CrossEntropyLoss()
|
loss_fct = CrossEntropyLoss()
|
||||||
losses.append(loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)))
|
losses.append(loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)))
|
||||||
|
|||||||
Reference in New Issue
Block a user