From 93f335ef86b2a14ffc41daba612d022a1c73e045 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Tue, 11 Dec 2018 11:50:38 +0100 Subject: [PATCH] add pretrained loading from state_dict --- pytorch_pretrained_bert/modeling.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/pytorch_pretrained_bert/modeling.py b/pytorch_pretrained_bert/modeling.py index 1aeff4dd04..bfc5585ea8 100644 --- a/pytorch_pretrained_bert/modeling.py +++ b/pytorch_pretrained_bert/modeling.py @@ -448,9 +448,9 @@ class PreTrainedBertModel(nn.Module): module.bias.data.zero_() @classmethod - def from_pretrained(cls, pretrained_model_name, cache_dir=None, *inputs, **kwargs): + def from_pretrained(cls, pretrained_model_name, state_dict=None, cache_dir=None, *inputs, **kwargs): """ - Instantiate a PreTrainedBertModel from a pre-trained model file. + Instantiate a PreTrainedBertModel from a pre-trained model file or a pytorch state dict. Download and cache the pre-trained model file if needed. Params: @@ -464,6 +464,8 @@ class PreTrainedBertModel(nn.Module): - a path or url to a pretrained model archive containing: . `bert_config.json` a configuration file for the model . `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance + cache_dir: an optional path to a folder in which the pre-trained models will be cached. + state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models *inputs, **kwargs: additional input for the specific Bert class (ex: num_labels for BertForSequenceClassification) """ @@ -505,8 +507,9 @@ class PreTrainedBertModel(nn.Module): logger.info("Model config {}".format(config)) # Instantiate model. model = cls(config, *inputs, **kwargs) - weights_path = os.path.join(serialization_dir, WEIGHTS_NAME) - state_dict = torch.load(weights_path) + if state_dict is None: + weights_path = os.path.join(serialization_dir, WEIGHTS_NAME) + state_dict = torch.load(weights_path) old_keys = [] new_keys = []