From 270fa2f20b6dd9736a08f24e6050f24b2a96b010 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Tue, 11 Dec 2018 11:50:38 +0100 Subject: [PATCH] add pretrained loading from state_dict --- pytorch_pretrained_bert/modeling.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/pytorch_pretrained_bert/modeling.py b/pytorch_pretrained_bert/modeling.py index 3af5854072..3d04f0842c 100644 --- a/pytorch_pretrained_bert/modeling.py +++ b/pytorch_pretrained_bert/modeling.py @@ -445,9 +445,9 @@ class PreTrainedBertModel(nn.Module): module.bias.data.zero_() @classmethod - def from_pretrained(cls, pretrained_model_name, cache_dir=None, *inputs, **kwargs): + def from_pretrained(cls, pretrained_model_name, state_dict=None, cache_dir=None, *inputs, **kwargs): """ - Instantiate a PreTrainedBertModel from a pre-trained model file. + Instantiate a PreTrainedBertModel from a pre-trained model file or a pytorch state dict. Download and cache the pre-trained model file if needed. Params: @@ -461,6 +461,8 @@ class PreTrainedBertModel(nn.Module): - a path or url to a pretrained model archive containing: . `bert_config.json` a configuration file for the model . `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance + cache_dir: an optional path to a folder in which the pre-trained models will be cached. + state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models *inputs, **kwargs: additional input for the specific Bert class (ex: num_labels for BertForSequenceClassification) """ @@ -502,8 +504,9 @@ class PreTrainedBertModel(nn.Module): logger.info("Model config {}".format(config)) # Instantiate model. model = cls(config, *inputs, **kwargs) - weights_path = os.path.join(serialization_dir, WEIGHTS_NAME) - state_dict = torch.load(weights_path) + if state_dict is None: + weights_path = os.path.join(serialization_dir, WEIGHTS_NAME) + state_dict = torch.load(weights_path) missing_keys = [] unexpected_keys = []