prevent BERT weights from being downloaded twice
This commit is contained in:
committed by
Julien Chaumond
parent
5909f71028
commit
076602bdc4
@@ -158,7 +158,8 @@ class Bert(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(Bert, self).__init__()
|
||||
self.model = BertModel.from_pretrained("bert-base-uncased")
|
||||
config = BertConfig.from_pretrained("bert-base-uncased")
|
||||
self.model = BertModel(config)
|
||||
|
||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None, **kwargs):
|
||||
self.eval()
|
||||
|
||||
Reference in New Issue
Block a user