gpt-2 special tokens

This commit is contained in:
thomwolf
2019-04-30 11:05:54 +02:00
parent 1f5fc95b68
commit e79ceb1533

View File

@@ -547,7 +547,7 @@ class GPT2Model(GPT2PreTrainedModel):
def __init__(self, config):
super(GPT2Model, self).__init__(config)
self.wte = nn.Embedding(config.vocab_size, config.n_embd)
self.wte = nn.Embedding(config.total_tokens_embeddings, config.n_embd)
self.wpe = nn.Embedding(config.n_positions, config.n_embd)
block = Block(config.n_ctx, config, scale=True)
self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)])