updating examples
This commit is contained in:
@@ -57,16 +57,18 @@ class PretrainedConfig(object):
|
||||
pretrained_model_name_or_path: either:
|
||||
- a str with the name of a pre-trained model to load selected in the list of:
|
||||
. `xlnet-large-cased`
|
||||
- a path or url to a pretrained model archive containing:
|
||||
. `config.json` a configuration file for the model
|
||||
- a path or url to a directory containing a configuration file `config.json` for the model,
|
||||
- a path or url to a configuration file for the model.
|
||||
cache_dir: an optional path to a folder in which the pre-trained model configuration will be cached.
|
||||
"""
|
||||
cache_dir = kwargs.pop('cache_dir', None)
|
||||
|
||||
if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
|
||||
config_file = cls.pretrained_config_archive_map[pretrained_model_name_or_path]
|
||||
else:
|
||||
elif os.path.isdir(pretrained_model_name_or_path):
|
||||
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
|
||||
else:
|
||||
config_file = pretrained_model_name_or_path
|
||||
# redirect to the cache, if necessary
|
||||
try:
|
||||
resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
|
||||
@@ -200,6 +202,7 @@ class PreTrainedModel(nn.Module):
|
||||
- a path or url to a tensorflow pretrained model checkpoint containing:
|
||||
. `config.json` a configuration file for the model
|
||||
. `model.chkpt` a TensorFlow checkpoint
|
||||
config: an optional configuration for the model
|
||||
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
|
||||
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
|
||||
state_dict: an optional state dictionnary (collections.OrderedDict object) to use
|
||||
@@ -207,23 +210,31 @@ class PreTrainedModel(nn.Module):
|
||||
*inputs, **kwargs: additional input for the specific XLNet class
|
||||
(ex: num_labels for XLNetForSequenceClassification)
|
||||
"""
|
||||
config = kwargs.pop('config', None)
|
||||
state_dict = kwargs.pop('state_dict', None)
|
||||
cache_dir = kwargs.pop('cache_dir', None)
|
||||
from_tf = kwargs.pop('from_tf', False)
|
||||
output_loading_info = kwargs.pop('output_loading_info', False)
|
||||
|
||||
# Load config
|
||||
config = cls.config_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
if config is None:
|
||||
config = cls.config_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
|
||||
# Load model
|
||||
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
|
||||
archive_file = cls.pretrained_model_archive_map[pretrained_model_name_or_path]
|
||||
else:
|
||||
elif os.path.isdir(pretrained_model_name_or_path):
|
||||
if from_tf:
|
||||
# Directly load from a TensorFlow checkpoint
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")
|
||||
else:
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
|
||||
else:
|
||||
if from_tf:
|
||||
# Directly load from a TensorFlow checkpoint
|
||||
archive_file = pretrained_model_name_or_path + ".index"
|
||||
else:
|
||||
archive_file = pretrained_model_name_or_path
|
||||
# redirect to the cache, if necessary
|
||||
try:
|
||||
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
|
||||
|
||||
Reference in New Issue
Block a user