From 72fa8d03a7c081170fb2e8bb0a4592125adbb039 Mon Sep 17 00:00:00 2001 From: Haozhe Ji <1395082425@qq.com> Date: Thu, 7 Mar 2019 20:02:55 +0800 Subject: [PATCH] add 'padding_idx=0' for BertEmbeddings --- pytorch_pretrained_bert/modeling.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_pretrained_bert/modeling.py b/pytorch_pretrained_bert/modeling.py index ece1dddacc..4417c74227 100644 --- a/pytorch_pretrained_bert/modeling.py +++ b/pytorch_pretrained_bert/modeling.py @@ -238,9 +238,9 @@ class BertEmbeddings(nn.Module): """ def __init__(self, config): super(BertEmbeddings, self).__init__() - self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) - self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) - self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size, padding_idx=0) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size, padding_idx=0) # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load # any TensorFlow checkpoint file