From 43489756ad421a99d0f3eb9d83116b9b4904c922 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Tue, 20 Aug 2019 16:59:11 +0200 Subject: [PATCH] adding proxies options for the from_pretrained methods --- .gitignore | 4 ++- pytorch_transformers/file_utils.py | 29 +++++++++++----------- pytorch_transformers/modeling_utils.py | 14 +++++++++-- pytorch_transformers/tokenization_utils.py | 7 +++++- 4 files changed, 36 insertions(+), 18 deletions(-) diff --git a/.gitignore b/.gitignore index 6bbe32df6c..466a167552 100644 --- a/.gitignore +++ b/.gitignore @@ -127,4 +127,6 @@ proc_data # examples runs -examples/runs \ No newline at end of file +examples/runs + +data \ No newline at end of file diff --git a/pytorch_transformers/file_utils.py b/pytorch_transformers/file_utils.py index 074e6743ef..f6f2151b12 100644 --- a/pytorch_transformers/file_utils.py +++ b/pytorch_transformers/file_utils.py @@ -17,8 +17,9 @@ from hashlib import sha256 from io import open import boto3 -import requests +from botocore.config import Config from botocore.exceptions import ClientError +import requests from tqdm import tqdm try: @@ -93,7 +94,7 @@ def filename_to_url(filename, cache_dir=None): return url, etag -def cached_path(url_or_filename, cache_dir=None, force_download=False): +def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=None): """ Given something that might be a URL (or might be a local path), determine which. If it's a URL, download the file and cache it, and @@ -114,7 +115,7 @@ def cached_path(url_or_filename, cache_dir=None, force_download=False): if parsed.scheme in ('http', 'https', 's3'): # URL, so get it from the cache (downloading if necessary) - return get_from_cache(url_or_filename, cache_dir=cache_dir, force_download=force_download) + return get_from_cache(url_or_filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies) elif os.path.exists(url_or_filename): # File, and it exists. return url_or_filename @@ -159,24 +160,24 @@ def s3_request(func): @s3_request -def s3_etag(url): +def s3_etag(url, proxies=None): """Check ETag on S3 object.""" - s3_resource = boto3.resource("s3") + s3_resource = boto3.resource("s3", config=Config(proxies=proxies)) bucket_name, s3_path = split_s3_path(url) s3_object = s3_resource.Object(bucket_name, s3_path) return s3_object.e_tag @s3_request -def s3_get(url, temp_file): +def s3_get(url, temp_file, proxies=None): """Pull a file directly from S3.""" - s3_resource = boto3.resource("s3") + s3_resource = boto3.resource("s3", config=Config(proxies=proxies)) bucket_name, s3_path = split_s3_path(url) s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) -def http_get(url, temp_file): - req = requests.get(url, stream=True) +def http_get(url, temp_file, proxies=None): + req = requests.get(url, stream=True, proxies=proxies) content_length = req.headers.get('Content-Length') total = int(content_length) if content_length is not None else None progress = tqdm(unit="B", total=total) @@ -187,7 +188,7 @@ def http_get(url, temp_file): progress.close() -def get_from_cache(url, cache_dir=None, force_download=False): +def get_from_cache(url, cache_dir=None, force_download=False, proxies=None): """ Given a URL, look for the corresponding dataset in the local cache. If it's not there, download it. Then return the path to the cached file. @@ -204,10 +205,10 @@ def get_from_cache(url, cache_dir=None, force_download=False): # Get eTag to add to filename, if it exists. if url.startswith("s3://"): - etag = s3_etag(url) + etag = s3_etag(url, proxies=proxies) else: try: - response = requests.head(url, allow_redirects=True) + response = requests.head(url, allow_redirects=True, proxies=proxies) if response.status_code != 200: etag = None else: @@ -238,9 +239,9 @@ def get_from_cache(url, cache_dir=None, force_download=False): # GET file object if url.startswith("s3://"): - s3_get(url, temp_file) + s3_get(url, temp_file, proxies=proxies) else: - http_get(url, temp_file) + http_get(url, temp_file, proxies=proxies) # we are copying the file before closing it, so flush to avoid truncation temp_file.flush() diff --git a/pytorch_transformers/modeling_utils.py b/pytorch_transformers/modeling_utils.py index 3e4fbca132..f1501aa8d5 100644 --- a/pytorch_transformers/modeling_utils.py +++ b/pytorch_transformers/modeling_utils.py @@ -128,6 +128,10 @@ class PretrainedConfig(object): force_download: (`optional`) boolean, default False: Force to (re-)download the model weights and configuration files and override the cached versions if they exists. + proxies: (`optional`) dict, default None: + A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. + The proxies are used on each request. + return_unused_kwargs: (`optional`) bool: - If False, then this function returns just the final configuration object. @@ -150,6 +154,7 @@ class PretrainedConfig(object): """ cache_dir = kwargs.pop('cache_dir', None) force_download = kwargs.pop('force_download', False) + proxies = kwargs.pop('proxies', None) return_unused_kwargs = kwargs.pop('return_unused_kwargs', False) if pretrained_model_name_or_path in cls.pretrained_config_archive_map: @@ -160,7 +165,7 @@ class PretrainedConfig(object): config_file = pretrained_model_name_or_path # redirect to the cache, if necessary try: - resolved_config_file = cached_path(config_file, cache_dir=cache_dir, force_download=force_download) + resolved_config_file = cached_path(config_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies) except EnvironmentError: if pretrained_model_name_or_path in cls.pretrained_config_archive_map: logger.error( @@ -407,6 +412,10 @@ class PreTrainedModel(nn.Module): force_download: (`optional`) boolean, default False: Force to (re-)download the model weights and configuration files and override the cached versions if they exists. + proxies: (`optional`) dict, default None: + A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. + The proxies are used on each request. + output_loading_info: (`optional`) boolean: Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages. @@ -432,6 +441,7 @@ class PreTrainedModel(nn.Module): cache_dir = kwargs.pop('cache_dir', None) from_tf = kwargs.pop('from_tf', False) force_download = kwargs.pop('force_download', False) + proxies = kwargs.pop('proxies', None) output_loading_info = kwargs.pop('output_loading_info', False) # Load config @@ -462,7 +472,7 @@ class PreTrainedModel(nn.Module): archive_file = pretrained_model_name_or_path # redirect to the cache, if necessary try: - resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download) + resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies) except EnvironmentError: if pretrained_model_name_or_path in cls.pretrained_model_archive_map: logger.error( diff --git a/pytorch_transformers/tokenization_utils.py b/pytorch_transformers/tokenization_utils.py index 763c0cee04..68af97a518 100644 --- a/pytorch_transformers/tokenization_utils.py +++ b/pytorch_transformers/tokenization_utils.py @@ -196,6 +196,10 @@ class PreTrainedTokenizer(object): force_download: (`optional`) boolean, default False: Force to (re-)download the vocabulary files and override the cached versions if they exists. + proxies: (`optional`) dict, default None: + A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. + The proxies are used on each request. + inputs: (`optional`) positional arguments: will be passed to the Tokenizer ``__init__`` method. kwargs: (`optional`) keyword arguments: will be passed to the Tokenizer ``__init__`` method. Can be used to set special tokens like ``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``mask_token``, ``additional_special_tokens``. See parameters in the doc string of :class:`~pytorch_transformers.PreTrainedTokenizer` for details. @@ -227,6 +231,7 @@ class PreTrainedTokenizer(object): def _from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): cache_dir = kwargs.pop('cache_dir', None) force_download = kwargs.pop('force_download', False) + proxies = kwargs.pop('proxies', None) s3_models = list(cls.max_model_input_sizes.keys()) vocab_files = {} @@ -287,7 +292,7 @@ class PreTrainedTokenizer(object): if file_path is None: resolved_vocab_files[file_id] = None else: - resolved_vocab_files[file_id] = cached_path(file_path, cache_dir=cache_dir, force_download=force_download) + resolved_vocab_files[file_id] = cached_path(file_path, cache_dir=cache_dir, force_download=force_download, proxies=proxies) except EnvironmentError: if pretrained_model_name_or_path in s3_models: logger.error("Couldn't reach server to download vocabulary.")