From 81ee29ee8d64c292c3fd5fc7e13b387acd1bfc39 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Thu, 10 Oct 2019 14:13:37 +0200 Subject: [PATCH] remove the staticmethod used to load the config --- transformers/modeling_bert.py | 33 ++++++++++++++++----------------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/transformers/modeling_bert.py b/transformers/modeling_bert.py index 5fcf41a1e1..6dae6d6ce5 100644 --- a/transformers/modeling_bert.py +++ b/transformers/modeling_bert.py @@ -715,7 +715,7 @@ class BertDecoderModel(BertPreTrainedModel): """ def __init__(self, config): - super(BertModel, self).__init__(config) + super(BertDecoderModel, self).__init__(config) self.embeddings = BertEmbeddings(config) self.decoder = BertDecoder(config) @@ -1357,28 +1357,27 @@ class Bert2Rnd(BertPreTrainedModel): pretrained weights we need to override the `from_pretrained` method of the base `PreTrainedModel` class. """ - pretrained_encoder = BertModel.from_pretrained(pretrained_model_or_path, *model_args, **model_kwargs) - config = cls._load_config(pretrained_model_or_path, *model_args, **model_kwargs) - model = cls(config) - model.encoder = pretrained_encoder - - return model - - def _load_config(self, pretrained_model_name_or_path, *args, **kwargs): - config = kwargs.pop('config', None) + # Load the configuration + config = model_kwargs.pop('config', None) if config is None: - cache_dir = kwargs.pop('cache_dir', None) - force_download = kwargs.pop('force_download', False) - config, _ = self.config_class.from_pretrained( - pretrained_model_name_or_path, - *args, + cache_dir = model_kwargs.pop('cache_dir', None) + force_download = model_kwargs.pop('force_download', False) + config, _ = cls.config_class.from_pretrained( + pretrained_model_or_path, + *model_args, cache_dir=cache_dir, return_unused_kwargs=True, force_download=force_download, - **kwargs + **model_kwargs ) - return config + model = cls(config) + + # The encoder is loaded with pretrained weights + pretrained_encoder = BertModel.from_pretrained(pretrained_model_or_path, *model_args, **model_kwargs) + model.encoder = pretrained_encoder + + return model def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None): encoder_outputs = self.encoder(input_ids,