update special token addition
This commit is contained in:
@@ -608,6 +608,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
|
|||||||
# Build new embeddings and initialize all new embeddings (in particular the special tokens)
|
# Build new embeddings and initialize all new embeddings (in particular the special tokens)
|
||||||
old_embed = self.tokens_embed
|
old_embed = self.tokens_embed
|
||||||
self.tokens_embed = nn.Embedding(self.config.total_tokens_embeddings, self.config.n_embd)
|
self.tokens_embed = nn.Embedding(self.config.total_tokens_embeddings, self.config.n_embd)
|
||||||
|
self.tokens_embed.to(old_embed.device.weight.device)
|
||||||
self.init_weights(self.tokens_embed)
|
self.init_weights(self.tokens_embed)
|
||||||
# Copy word embeddings from the previous weights
|
# Copy word embeddings from the previous weights
|
||||||
self.tokens_embed.weight.data[:self.config.vocab_size, :] = old_embed.weight.data[:self.config.vocab_size, :]
|
self.tokens_embed.weight.data[:self.config.vocab_size, :] = old_embed.weight.data[:self.config.vocab_size, :]
|
||||||
|
|||||||
Reference in New Issue
Block a user