From 0c5a4fe9c9641b8ca47b6a1ff7b994ff18ca98c6 Mon Sep 17 00:00:00 2001 From: VictorSanh Date: Fri, 31 May 2019 00:27:18 -0400 Subject: [PATCH] modify from_pretrained for OpenAIGPT --- pytorch_pretrained_bert/modeling_openai.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/pytorch_pretrained_bert/modeling_openai.py b/pytorch_pretrained_bert/modeling_openai.py index f956462ddb..8cf4117134 100644 --- a/pytorch_pretrained_bert/modeling_openai.py +++ b/pytorch_pretrained_bert/modeling_openai.py @@ -419,9 +419,7 @@ class OpenAIGPTPreTrainedModel(nn.Module): pass @classmethod - def from_pretrained( - cls, pretrained_model_name_or_path, num_special_tokens=None, state_dict=None, cache_dir=None, from_tf=False, *inputs, **kwargs - ): + def from_pretrained(cls, pretrained_model_name_or_path, num_special_tokens=None, *inputs, **kwargs): """ Instantiate a OpenAIGPTPreTrainedModel from a pre-trained model file or a pytorch state dict. Download and cache the pre-trained model file if needed. @@ -434,14 +432,20 @@ class OpenAIGPTPreTrainedModel(nn.Module): . `openai_gpt_config.json` a configuration file for the model . `pytorch_model.bin` a PyTorch dump of a OpenAIGPTModel instance - a path or url to a pretrained model archive containing: - . `bert_config.json` a configuration file for the model + . `openai-gpt-config.json` a configuration file for the model . a series of NumPy files containing OpenAI TensorFlow trained weights from_tf: should we load the weights from a locally saved TensorFlow checkpoint cache_dir: an optional path to a folder in which the pre-trained models will be cached. state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of pre-trained models - *inputs, **kwargs: additional input for the specific Bert class - (ex: num_labels for BertForSequenceClassification) + *inputs, **kwargs: additional input for the specific OpenAI-GPT class """ + state_dict = kwargs.get('state_dict', None) + kwargs.pop('state_dict', None) + cache_dir = kwargs.get('cache_dir', None) + kwargs.pop('cache_dir', None) + from_tf = kwargs.get('from_tf', False) + kwargs.pop('from_tf', None) + if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP: archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path] config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path]