modify from_pretrained for OpenAIGPT
This commit is contained in:
@@ -419,9 +419,7 @@ class OpenAIGPTPreTrainedModel(nn.Module):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(
|
def from_pretrained(cls, pretrained_model_name_or_path, num_special_tokens=None, *inputs, **kwargs):
|
||||||
cls, pretrained_model_name_or_path, num_special_tokens=None, state_dict=None, cache_dir=None, from_tf=False, *inputs, **kwargs
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Instantiate a OpenAIGPTPreTrainedModel from a pre-trained model file or a pytorch state dict.
|
Instantiate a OpenAIGPTPreTrainedModel from a pre-trained model file or a pytorch state dict.
|
||||||
Download and cache the pre-trained model file if needed.
|
Download and cache the pre-trained model file if needed.
|
||||||
@@ -434,14 +432,20 @@ class OpenAIGPTPreTrainedModel(nn.Module):
|
|||||||
. `openai_gpt_config.json` a configuration file for the model
|
. `openai_gpt_config.json` a configuration file for the model
|
||||||
. `pytorch_model.bin` a PyTorch dump of a OpenAIGPTModel instance
|
. `pytorch_model.bin` a PyTorch dump of a OpenAIGPTModel instance
|
||||||
- a path or url to a pretrained model archive containing:
|
- a path or url to a pretrained model archive containing:
|
||||||
. `bert_config.json` a configuration file for the model
|
. `openai-gpt-config.json` a configuration file for the model
|
||||||
. a series of NumPy files containing OpenAI TensorFlow trained weights
|
. a series of NumPy files containing OpenAI TensorFlow trained weights
|
||||||
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
|
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
|
||||||
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
|
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
|
||||||
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of pre-trained models
|
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of pre-trained models
|
||||||
*inputs, **kwargs: additional input for the specific Bert class
|
*inputs, **kwargs: additional input for the specific OpenAI-GPT class
|
||||||
(ex: num_labels for BertForSequenceClassification)
|
|
||||||
"""
|
"""
|
||||||
|
state_dict = kwargs.get('state_dict', None)
|
||||||
|
kwargs.pop('state_dict', None)
|
||||||
|
cache_dir = kwargs.get('cache_dir', None)
|
||||||
|
kwargs.pop('cache_dir', None)
|
||||||
|
from_tf = kwargs.get('from_tf', False)
|
||||||
|
kwargs.pop('from_tf', None)
|
||||||
|
|
||||||
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
|
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
|
||||||
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
|
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
|
||||||
config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path]
|
config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path]
|
||||||
|
|||||||
Reference in New Issue
Block a user