adding proxies options for the from_pretrained methods
This commit is contained in:
@@ -128,6 +128,10 @@ class PretrainedConfig(object):
|
||||
force_download: (`optional`) boolean, default False:
|
||||
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
|
||||
|
||||
proxies: (`optional`) dict, default None:
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
|
||||
The proxies are used on each request.
|
||||
|
||||
return_unused_kwargs: (`optional`) bool:
|
||||
|
||||
- If False, then this function returns just the final configuration object.
|
||||
@@ -150,6 +154,7 @@ class PretrainedConfig(object):
|
||||
"""
|
||||
cache_dir = kwargs.pop('cache_dir', None)
|
||||
force_download = kwargs.pop('force_download', False)
|
||||
proxies = kwargs.pop('proxies', None)
|
||||
return_unused_kwargs = kwargs.pop('return_unused_kwargs', False)
|
||||
|
||||
if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
|
||||
@@ -160,7 +165,7 @@ class PretrainedConfig(object):
|
||||
config_file = pretrained_model_name_or_path
|
||||
# redirect to the cache, if necessary
|
||||
try:
|
||||
resolved_config_file = cached_path(config_file, cache_dir=cache_dir, force_download=force_download)
|
||||
resolved_config_file = cached_path(config_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies)
|
||||
except EnvironmentError:
|
||||
if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
|
||||
logger.error(
|
||||
@@ -407,6 +412,10 @@ class PreTrainedModel(nn.Module):
|
||||
force_download: (`optional`) boolean, default False:
|
||||
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
|
||||
|
||||
proxies: (`optional`) dict, default None:
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
|
||||
The proxies are used on each request.
|
||||
|
||||
output_loading_info: (`optional`) boolean:
|
||||
Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
|
||||
|
||||
@@ -432,6 +441,7 @@ class PreTrainedModel(nn.Module):
|
||||
cache_dir = kwargs.pop('cache_dir', None)
|
||||
from_tf = kwargs.pop('from_tf', False)
|
||||
force_download = kwargs.pop('force_download', False)
|
||||
proxies = kwargs.pop('proxies', None)
|
||||
output_loading_info = kwargs.pop('output_loading_info', False)
|
||||
|
||||
# Load config
|
||||
@@ -462,7 +472,7 @@ class PreTrainedModel(nn.Module):
|
||||
archive_file = pretrained_model_name_or_path
|
||||
# redirect to the cache, if necessary
|
||||
try:
|
||||
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download)
|
||||
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies)
|
||||
except EnvironmentError:
|
||||
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
|
||||
logger.error(
|
||||
|
||||
Reference in New Issue
Block a user