prevent BERT weights from being downloaded twice

This commit is contained in:
Rémi Louf
2019-12-06 10:11:44 +01:00
committed by Julien Chaumond
parent 5909f71028
commit 076602bdc4

View File

@@ -158,7 +158,8 @@ class Bert(nn.Module):
def __init__(self):
super(Bert, self).__init__()
self.model = BertModel.from_pretrained("bert-base-uncased")
config = BertConfig.from_pretrained("bert-base-uncased")
self.model = BertModel(config)
def forward(self, input_ids, attention_mask=None, token_type_ids=None, **kwargs):
self.eval()