prevent BERT weights from being downloaded twice
This commit is contained in:
committed by
Julien Chaumond
parent
5909f71028
commit
076602bdc4
@@ -158,7 +158,8 @@ class Bert(nn.Module):
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(Bert, self).__init__()
|
super(Bert, self).__init__()
|
||||||
self.model = BertModel.from_pretrained("bert-base-uncased")
|
config = BertConfig.from_pretrained("bert-base-uncased")
|
||||||
|
self.model = BertModel(config)
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None, **kwargs):
|
def forward(self, input_ids, attention_mask=None, token_type_ids=None, **kwargs):
|
||||||
self.eval()
|
self.eval()
|
||||||
|
|||||||
Reference in New Issue
Block a user