add force_download option to from_pretrained methods
This commit is contained in:
@@ -125,6 +125,9 @@ class PretrainedConfig(object):
|
||||
- The values in kwargs of any keys which are configuration attributes will be used to override the loaded values.
|
||||
- Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled by the `return_unused_kwargs` keyword parameter.
|
||||
|
||||
force_download: (`optional`) boolean, default False:
|
||||
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
|
||||
|
||||
return_unused_kwargs: (`optional`) bool:
|
||||
|
||||
- If False, then this function returns just the final configuration object.
|
||||
@@ -146,6 +149,7 @@ class PretrainedConfig(object):
|
||||
|
||||
"""
|
||||
cache_dir = kwargs.pop('cache_dir', None)
|
||||
force_download = kwargs.pop('force_download', False)
|
||||
return_unused_kwargs = kwargs.pop('return_unused_kwargs', False)
|
||||
|
||||
if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
|
||||
@@ -156,7 +160,7 @@ class PretrainedConfig(object):
|
||||
config_file = pretrained_model_name_or_path
|
||||
# redirect to the cache, if necessary
|
||||
try:
|
||||
resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
|
||||
resolved_config_file = cached_path(config_file, cache_dir=cache_dir, force_download=force_download)
|
||||
except EnvironmentError:
|
||||
if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
|
||||
logger.error(
|
||||
@@ -400,6 +404,9 @@ class PreTrainedModel(nn.Module):
|
||||
Path to a directory in which a downloaded pre-trained model
|
||||
configuration should be cached if the standard cache should not be used.
|
||||
|
||||
force_download: (`optional`) boolean, default False:
|
||||
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
|
||||
|
||||
output_loading_info: (`optional`) boolean:
|
||||
Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
|
||||
|
||||
@@ -424,6 +431,7 @@ class PreTrainedModel(nn.Module):
|
||||
state_dict = kwargs.pop('state_dict', None)
|
||||
cache_dir = kwargs.pop('cache_dir', None)
|
||||
from_tf = kwargs.pop('from_tf', False)
|
||||
force_download = kwargs.pop('force_download', False)
|
||||
output_loading_info = kwargs.pop('output_loading_info', False)
|
||||
|
||||
# Load config
|
||||
@@ -431,6 +439,7 @@ class PreTrainedModel(nn.Module):
|
||||
config, model_kwargs = cls.config_class.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args,
|
||||
cache_dir=cache_dir, return_unused_kwargs=True,
|
||||
force_download=force_download,
|
||||
**kwargs
|
||||
)
|
||||
else:
|
||||
@@ -453,7 +462,7 @@ class PreTrainedModel(nn.Module):
|
||||
archive_file = pretrained_model_name_or_path
|
||||
# redirect to the cache, if necessary
|
||||
try:
|
||||
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
|
||||
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download)
|
||||
except EnvironmentError:
|
||||
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
|
||||
logger.error(
|
||||
|
||||
Reference in New Issue
Block a user