diff --git a/examples/summarization/modeling_bertabs.py b/examples/summarization/modeling_bertabs.py index 5e51526037..efca33fb56 100644 --- a/examples/summarization/modeling_bertabs.py +++ b/examples/summarization/modeling_bertabs.py @@ -158,7 +158,8 @@ class Bert(nn.Module): def __init__(self): super(Bert, self).__init__() - self.model = BertModel.from_pretrained("bert-base-uncased") + config = BertConfig.from_pretrained("bert-base-uncased") + self.model = BertModel(config) def forward(self, input_ids, attention_mask=None, token_type_ids=None, **kwargs): self.eval()