From e79ceb15331ddadbe0f0ccb857218b1ba2cca368 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Tue, 30 Apr 2019 11:05:54 +0200 Subject: [PATCH] gpt-2 special tokens --- pytorch_pretrained_bert/modeling_gpt2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_pretrained_bert/modeling_gpt2.py b/pytorch_pretrained_bert/modeling_gpt2.py index 05a748d43c..5537f93f66 100644 --- a/pytorch_pretrained_bert/modeling_gpt2.py +++ b/pytorch_pretrained_bert/modeling_gpt2.py @@ -547,7 +547,7 @@ class GPT2Model(GPT2PreTrainedModel): def __init__(self, config): super(GPT2Model, self).__init__(config) - self.wte = nn.Embedding(config.vocab_size, config.n_embd) + self.wte = nn.Embedding(config.total_tokens_embeddings, config.n_embd) self.wpe = nn.Embedding(config.n_positions, config.n_embd) block = Block(config.n_ctx, config, scale=True) self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)])