From 076602bdc4b186e715538f437f2bce4b1ee5020e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Fri, 6 Dec 2019 10:11:44 +0100 Subject: [PATCH] prevent BERT weights from being downloaded twice --- examples/summarization/modeling_bertabs.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/summarization/modeling_bertabs.py b/examples/summarization/modeling_bertabs.py index 5e51526037..efca33fb56 100644 --- a/examples/summarization/modeling_bertabs.py +++ b/examples/summarization/modeling_bertabs.py @@ -158,7 +158,8 @@ class Bert(nn.Module): def __init__(self): super(Bert, self).__init__() - self.model = BertModel.from_pretrained("bert-base-uncased") + config = BertConfig.from_pretrained("bert-base-uncased") + self.model = BertModel(config) def forward(self, input_ids, attention_mask=None, token_type_ids=None, **kwargs): self.eval()