From fecaed0ed4bf338bca5b9895107b309841f8ac57 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Tue, 20 Aug 2019 10:56:12 +0200 Subject: [PATCH 1/6] add force_download option to from_pretrained methods --- pytorch_transformers/file_utils.py | 13 ++++++++----- pytorch_transformers/modeling_utils.py | 13 +++++++++++-- pytorch_transformers/tokenization_utils.py | 6 +++++- 3 files changed, 24 insertions(+), 8 deletions(-) diff --git a/pytorch_transformers/file_utils.py b/pytorch_transformers/file_utils.py index 75c075720c..074e6743ef 100644 --- a/pytorch_transformers/file_utils.py +++ b/pytorch_transformers/file_utils.py @@ -93,12 +93,15 @@ def filename_to_url(filename, cache_dir=None): return url, etag -def cached_path(url_or_filename, cache_dir=None): +def cached_path(url_or_filename, cache_dir=None, force_download=False): """ Given something that might be a URL (or might be a local path), determine which. If it's a URL, download the file and cache it, and return the path to the cached file. If it's already a local path, make sure the file exists and then return the path. + Args: + cache_dir: specify a cache directory to save the file to (overwrite the default cache dir). + force_download: if True, re-dowload the file even if it's already cached in the cache dir. """ if cache_dir is None: cache_dir = PYTORCH_TRANSFORMERS_CACHE @@ -111,7 +114,7 @@ def cached_path(url_or_filename, cache_dir=None): if parsed.scheme in ('http', 'https', 's3'): # URL, so get it from the cache (downloading if necessary) - return get_from_cache(url_or_filename, cache_dir) + return get_from_cache(url_or_filename, cache_dir=cache_dir, force_download=force_download) elif os.path.exists(url_or_filename): # File, and it exists. return url_or_filename @@ -184,7 +187,7 @@ def http_get(url, temp_file): progress.close() -def get_from_cache(url, cache_dir=None): +def get_from_cache(url, cache_dir=None, force_download=False): """ Given a URL, look for the corresponding dataset in the local cache. If it's not there, download it. Then return the path to the cached file. @@ -227,11 +230,11 @@ def get_from_cache(url, cache_dir=None): if matching_files: cache_path = os.path.join(cache_dir, matching_files[-1]) - if not os.path.exists(cache_path): + if not os.path.exists(cache_path) or force_download: # Download to temporary file, then copy to cache dir once finished. # Otherwise you get corrupt cache entries if the download gets interrupted. with tempfile.NamedTemporaryFile() as temp_file: - logger.info("%s not found in cache, downloading to %s", url, temp_file.name) + logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name) # GET file object if url.startswith("s3://"): diff --git a/pytorch_transformers/modeling_utils.py b/pytorch_transformers/modeling_utils.py index edc6b3903e..3e4fbca132 100644 --- a/pytorch_transformers/modeling_utils.py +++ b/pytorch_transformers/modeling_utils.py @@ -125,6 +125,9 @@ class PretrainedConfig(object): - The values in kwargs of any keys which are configuration attributes will be used to override the loaded values. - Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled by the `return_unused_kwargs` keyword parameter. + force_download: (`optional`) boolean, default False: + Force to (re-)download the model weights and configuration files and override the cached versions if they exists. + return_unused_kwargs: (`optional`) bool: - If False, then this function returns just the final configuration object. @@ -146,6 +149,7 @@ class PretrainedConfig(object): """ cache_dir = kwargs.pop('cache_dir', None) + force_download = kwargs.pop('force_download', False) return_unused_kwargs = kwargs.pop('return_unused_kwargs', False) if pretrained_model_name_or_path in cls.pretrained_config_archive_map: @@ -156,7 +160,7 @@ class PretrainedConfig(object): config_file = pretrained_model_name_or_path # redirect to the cache, if necessary try: - resolved_config_file = cached_path(config_file, cache_dir=cache_dir) + resolved_config_file = cached_path(config_file, cache_dir=cache_dir, force_download=force_download) except EnvironmentError: if pretrained_model_name_or_path in cls.pretrained_config_archive_map: logger.error( @@ -400,6 +404,9 @@ class PreTrainedModel(nn.Module): Path to a directory in which a downloaded pre-trained model configuration should be cached if the standard cache should not be used. + force_download: (`optional`) boolean, default False: + Force to (re-)download the model weights and configuration files and override the cached versions if they exists. + output_loading_info: (`optional`) boolean: Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages. @@ -424,6 +431,7 @@ class PreTrainedModel(nn.Module): state_dict = kwargs.pop('state_dict', None) cache_dir = kwargs.pop('cache_dir', None) from_tf = kwargs.pop('from_tf', False) + force_download = kwargs.pop('force_download', False) output_loading_info = kwargs.pop('output_loading_info', False) # Load config @@ -431,6 +439,7 @@ class PreTrainedModel(nn.Module): config, model_kwargs = cls.config_class.from_pretrained( pretrained_model_name_or_path, *model_args, cache_dir=cache_dir, return_unused_kwargs=True, + force_download=force_download, **kwargs ) else: @@ -453,7 +462,7 @@ class PreTrainedModel(nn.Module): archive_file = pretrained_model_name_or_path # redirect to the cache, if necessary try: - resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) + resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download) except EnvironmentError: if pretrained_model_name_or_path in cls.pretrained_model_archive_map: logger.error( diff --git a/pytorch_transformers/tokenization_utils.py b/pytorch_transformers/tokenization_utils.py index 74d50b385d..763c0cee04 100644 --- a/pytorch_transformers/tokenization_utils.py +++ b/pytorch_transformers/tokenization_utils.py @@ -193,6 +193,9 @@ class PreTrainedTokenizer(object): cache_dir: (`optional`) string: Path to a directory in which a downloaded predefined tokenizer vocabulary files should be cached if the standard cache should not be used. + force_download: (`optional`) boolean, default False: + Force to (re-)download the vocabulary files and override the cached versions if they exists. + inputs: (`optional`) positional arguments: will be passed to the Tokenizer ``__init__`` method. kwargs: (`optional`) keyword arguments: will be passed to the Tokenizer ``__init__`` method. Can be used to set special tokens like ``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``mask_token``, ``additional_special_tokens``. See parameters in the doc string of :class:`~pytorch_transformers.PreTrainedTokenizer` for details. @@ -223,6 +226,7 @@ class PreTrainedTokenizer(object): @classmethod def _from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): cache_dir = kwargs.pop('cache_dir', None) + force_download = kwargs.pop('force_download', False) s3_models = list(cls.max_model_input_sizes.keys()) vocab_files = {} @@ -283,7 +287,7 @@ class PreTrainedTokenizer(object): if file_path is None: resolved_vocab_files[file_id] = None else: - resolved_vocab_files[file_id] = cached_path(file_path, cache_dir=cache_dir) + resolved_vocab_files[file_id] = cached_path(file_path, cache_dir=cache_dir, force_download=force_download) except EnvironmentError: if pretrained_model_name_or_path in s3_models: logger.error("Couldn't reach server to download vocabulary.") From e239a4a20fbb901e60ffcafc06bfefcbb67eaa65 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Tue, 20 Aug 2019 11:02:00 +0200 Subject: [PATCH 2/6] close #984 --- docs/source/pretrained_models.rst | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/source/pretrained_models.rst b/docs/source/pretrained_models.rst index 987882d12e..6a14e3dcd1 100644 --- a/docs/source/pretrained_models.rst +++ b/docs/source/pretrained_models.rst @@ -72,16 +72,16 @@ Here is the full list of the currently provided pretrained models together with | | ``xlnet-large-cased`` | | 24-layer, 1024-hidden, 16-heads, 340M parameters. | | | | | XLNet Large English model | +-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| XLM | ``xlm-mlm-en-2048`` | | 12-layer, 1024-hidden, 8-heads | +| XLM | ``xlm-mlm-en-2048`` | | 12-layer, 2048-hidden, 16-heads | | | | | XLM English model | | +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| | ``xlm-mlm-ende-1024`` | | 12-layer, 1024-hidden, 8-heads | +| | ``xlm-mlm-ende-1024`` | | 6-layer, 1024-hidden, 8-heads | | | | | XLM English-German Multi-language model | | +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| | ``xlm-mlm-enfr-1024`` | | 12-layer, 1024-hidden, 8-heads | +| | ``xlm-mlm-enfr-1024`` | | 6-layer, 1024-hidden, 8-heads | | | | | XLM English-French Multi-language model | | +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| | ``xlm-mlm-enro-1024`` | | 12-layer, 1024-hidden, 8-heads | +| | ``xlm-mlm-enro-1024`` | | 6-layer, 1024-hidden, 8-heads | | | | | XLM English-Romanian Multi-language model | | +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ | | ``xlm-mlm-xnli15-1024`` | | 12-layer, 1024-hidden, 8-heads | @@ -93,7 +93,7 @@ Here is the full list of the currently provided pretrained models together with | | ``xlm-clm-enfr-1024`` | | 12-layer, 1024-hidden, 8-heads | | | | | XLM English model trained with CLM (Causal Language Modeling) | | +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -| | ``xlm-clm-ende-1024`` | | 12-layer, 1024-hidden, 8-heads | +| | ``xlm-clm-ende-1024`` | | 6-layer, 1024-hidden, 8-heads | | | | | XLM English-German Multi-language model trained with CLM (Causal Language Modeling) | +-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ | RoBERTa | ``roberta-base`` | | 12-layer, 768-hidden, 12-heads, 125M parameters | From 901dde0e4583a00dc7e486aca6cda7acb647dea9 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Tue, 20 Aug 2019 11:05:51 +0200 Subject: [PATCH 3/6] fix #1014 --- pytorch_transformers/tokenization_bert.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pytorch_transformers/tokenization_bert.py b/pytorch_transformers/tokenization_bert.py index 177d26dec1..04f35aa466 100644 --- a/pytorch_transformers/tokenization_bert.py +++ b/pytorch_transformers/tokenization_bert.py @@ -187,6 +187,8 @@ class BertTokenizer(PreTrainedTokenizer): index = 0 if os.path.isdir(vocab_path): vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['vocab_file']) + else: + vocab_file = vocab_path with open(vocab_file, "w", encoding="utf-8") as writer: for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): if index != token_index: From 53c8f700f4704a58f4684674ced1c57d6ca9240c Mon Sep 17 00:00:00 2001 From: thomwolf Date: Tue, 20 Aug 2019 11:29:26 +0200 Subject: [PATCH 4/6] fix #808 --- pytorch_transformers/modeling_bert.py | 5 ++++- pytorch_transformers/modeling_gpt2.py | 2 ++ pytorch_transformers/modeling_openai.py | 2 ++ pytorch_transformers/modeling_roberta.py | 4 ++++ pytorch_transformers/modeling_transfo_xl.py | 2 ++ pytorch_transformers/modeling_xlm.py | 4 ++++ pytorch_transformers/modeling_xlnet.py | 2 ++ 7 files changed, 20 insertions(+), 1 deletion(-) diff --git a/pytorch_transformers/modeling_bert.py b/pytorch_transformers/modeling_bert.py index 9c20eac9bf..7b34b3fd90 100644 --- a/pytorch_transformers/modeling_bert.py +++ b/pytorch_transformers/modeling_bert.py @@ -599,7 +599,10 @@ BERT_INPUTS_DOCSTRING = r""" ``tokens: [CLS] the dog is hairy . [SEP]`` ``token_type_ids: 0 0 0 0 0 0 0`` - + + Bert is a model with absolute position embeddings so it's usually advised to pad the inputs on + the right rather than the left. + Indices can be obtained using :class:`pytorch_transformers.BertTokenizer`. See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and :func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details. diff --git a/pytorch_transformers/modeling_gpt2.py b/pytorch_transformers/modeling_gpt2.py index f67d0e88d5..91d01d0584 100644 --- a/pytorch_transformers/modeling_gpt2.py +++ b/pytorch_transformers/modeling_gpt2.py @@ -390,6 +390,8 @@ GPT2_START_DOCSTRING = r""" OpenAI GPT-2 model was proposed in GPT2_INPUTS_DOCSTRING = r""" Inputs: **input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: Indices of input sequence tokens in the vocabulary. + GPT-2 is a model with absolute position embeddings so it's usually advised to pad the inputs on + the right rather than the left. Indices can be obtained using :class:`pytorch_transformers.BPT2Tokenizer`. See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and :func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details. diff --git a/pytorch_transformers/modeling_openai.py b/pytorch_transformers/modeling_openai.py index e8648487be..71ffb78e0f 100644 --- a/pytorch_transformers/modeling_openai.py +++ b/pytorch_transformers/modeling_openai.py @@ -404,6 +404,8 @@ OPENAI_GPT_START_DOCSTRING = r""" OpenAI GPT model was proposed in OPENAI_GPT_INPUTS_DOCSTRING = r""" Inputs: **input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: Indices of input sequence tokens in the vocabulary. + GPT is a model with absolute position embeddings so it's usually advised to pad the inputs on + the right rather than the left. Indices can be obtained using :class:`pytorch_transformers.BPT2Tokenizer`. See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and :func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details. diff --git a/pytorch_transformers/modeling_roberta.py b/pytorch_transformers/modeling_roberta.py index e3065cf60b..e49b2a06b1 100644 --- a/pytorch_transformers/modeling_roberta.py +++ b/pytorch_transformers/modeling_roberta.py @@ -110,6 +110,10 @@ ROBERTA_INPUTS_DOCSTRING = r""" Fully encoded sequences or sequence pairs can be obtained using the RobertaTokenizer.encode function with the ``add_special_tokens`` parameter set to ``True``. + + RoBERTa is a model with absolute position embeddings so it's usually advised to pad the inputs on + the right rather than the left. + See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and :func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details. **position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: diff --git a/pytorch_transformers/modeling_transfo_xl.py b/pytorch_transformers/modeling_transfo_xl.py index 553a71fffe..3cfdee38cb 100644 --- a/pytorch_transformers/modeling_transfo_xl.py +++ b/pytorch_transformers/modeling_transfo_xl.py @@ -936,6 +936,8 @@ TRANSFO_XL_INPUTS_DOCSTRING = r""" Inputs: **input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: Indices of input sequence tokens in the vocabulary. + Transformer-XL is a model with relative position embeddings so you can either pad the inputs on + the right or on the left. Indices can be obtained using :class:`pytorch_transformers.TransfoXLTokenizer`. See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and :func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details. diff --git a/pytorch_transformers/modeling_xlm.py b/pytorch_transformers/modeling_xlm.py index d01d245bbb..be2767ed0c 100644 --- a/pytorch_transformers/modeling_xlm.py +++ b/pytorch_transformers/modeling_xlm.py @@ -424,6 +424,10 @@ XLM_INPUTS_DOCSTRING = r""" Inputs: **input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: Indices of input sequence tokens in the vocabulary. + + XLM is a model with absolute position embeddings so it's usually advised to pad the inputs on + the right rather than the left. + Indices can be obtained using :class:`pytorch_transformers.XLMTokenizer`. See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and :func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details. diff --git a/pytorch_transformers/modeling_xlnet.py b/pytorch_transformers/modeling_xlnet.py index af33c5a6c2..d44821788e 100644 --- a/pytorch_transformers/modeling_xlnet.py +++ b/pytorch_transformers/modeling_xlnet.py @@ -655,6 +655,8 @@ XLNET_INPUTS_DOCSTRING = r""" Inputs: **input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: Indices of input sequence tokens in the vocabulary. + XLNet is a model with relative position embeddings so you can either pad the inputs on + the right or on the left. Indices can be obtained using :class:`pytorch_transformers.XLNetTokenizer`. See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and :func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details. From 6d0aa73981f15618cf8d01255b07194e946c3286 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Tue, 20 Aug 2019 12:20:21 +0200 Subject: [PATCH 5/6] fix #1034 --- pytorch_transformers/modeling_xlm.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pytorch_transformers/modeling_xlm.py b/pytorch_transformers/modeling_xlm.py index be2767ed0c..19800da2ed 100644 --- a/pytorch_transformers/modeling_xlm.py +++ b/pytorch_transformers/modeling_xlm.py @@ -440,8 +440,10 @@ XLM_INPUTS_DOCSTRING = r""" Indices are selected in the vocabulary (unlike BERT which has a specific vocabulary for segment indices). **langs**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: A parallel sequence of tokens to be used to indicate the language of each token in the input. - Indices are selected in the pre-trained language vocabulary, - i.e. in the range ``[0, config.n_langs - 1[``. + Indices are languages ids which can be obtained from the language names by using two conversion mappings + provided in the configuration of the model (only provided for multilingual models). + More precisely, the `language name -> language id` mapping is in `model.config.lang2id` (dict str -> int) and + the `language id -> language name` mapping is `model.config.id2lang` (dict int -> str). **attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``: Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: From 43489756ad421a99d0f3eb9d83116b9b4904c922 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Tue, 20 Aug 2019 16:59:11 +0200 Subject: [PATCH 6/6] adding proxies options for the from_pretrained methods --- .gitignore | 4 ++- pytorch_transformers/file_utils.py | 29 +++++++++++----------- pytorch_transformers/modeling_utils.py | 14 +++++++++-- pytorch_transformers/tokenization_utils.py | 7 +++++- 4 files changed, 36 insertions(+), 18 deletions(-) diff --git a/.gitignore b/.gitignore index 6bbe32df6c..466a167552 100644 --- a/.gitignore +++ b/.gitignore @@ -127,4 +127,6 @@ proc_data # examples runs -examples/runs \ No newline at end of file +examples/runs + +data \ No newline at end of file diff --git a/pytorch_transformers/file_utils.py b/pytorch_transformers/file_utils.py index 074e6743ef..f6f2151b12 100644 --- a/pytorch_transformers/file_utils.py +++ b/pytorch_transformers/file_utils.py @@ -17,8 +17,9 @@ from hashlib import sha256 from io import open import boto3 -import requests +from botocore.config import Config from botocore.exceptions import ClientError +import requests from tqdm import tqdm try: @@ -93,7 +94,7 @@ def filename_to_url(filename, cache_dir=None): return url, etag -def cached_path(url_or_filename, cache_dir=None, force_download=False): +def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=None): """ Given something that might be a URL (or might be a local path), determine which. If it's a URL, download the file and cache it, and @@ -114,7 +115,7 @@ def cached_path(url_or_filename, cache_dir=None, force_download=False): if parsed.scheme in ('http', 'https', 's3'): # URL, so get it from the cache (downloading if necessary) - return get_from_cache(url_or_filename, cache_dir=cache_dir, force_download=force_download) + return get_from_cache(url_or_filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies) elif os.path.exists(url_or_filename): # File, and it exists. return url_or_filename @@ -159,24 +160,24 @@ def s3_request(func): @s3_request -def s3_etag(url): +def s3_etag(url, proxies=None): """Check ETag on S3 object.""" - s3_resource = boto3.resource("s3") + s3_resource = boto3.resource("s3", config=Config(proxies=proxies)) bucket_name, s3_path = split_s3_path(url) s3_object = s3_resource.Object(bucket_name, s3_path) return s3_object.e_tag @s3_request -def s3_get(url, temp_file): +def s3_get(url, temp_file, proxies=None): """Pull a file directly from S3.""" - s3_resource = boto3.resource("s3") + s3_resource = boto3.resource("s3", config=Config(proxies=proxies)) bucket_name, s3_path = split_s3_path(url) s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) -def http_get(url, temp_file): - req = requests.get(url, stream=True) +def http_get(url, temp_file, proxies=None): + req = requests.get(url, stream=True, proxies=proxies) content_length = req.headers.get('Content-Length') total = int(content_length) if content_length is not None else None progress = tqdm(unit="B", total=total) @@ -187,7 +188,7 @@ def http_get(url, temp_file): progress.close() -def get_from_cache(url, cache_dir=None, force_download=False): +def get_from_cache(url, cache_dir=None, force_download=False, proxies=None): """ Given a URL, look for the corresponding dataset in the local cache. If it's not there, download it. Then return the path to the cached file. @@ -204,10 +205,10 @@ def get_from_cache(url, cache_dir=None, force_download=False): # Get eTag to add to filename, if it exists. if url.startswith("s3://"): - etag = s3_etag(url) + etag = s3_etag(url, proxies=proxies) else: try: - response = requests.head(url, allow_redirects=True) + response = requests.head(url, allow_redirects=True, proxies=proxies) if response.status_code != 200: etag = None else: @@ -238,9 +239,9 @@ def get_from_cache(url, cache_dir=None, force_download=False): # GET file object if url.startswith("s3://"): - s3_get(url, temp_file) + s3_get(url, temp_file, proxies=proxies) else: - http_get(url, temp_file) + http_get(url, temp_file, proxies=proxies) # we are copying the file before closing it, so flush to avoid truncation temp_file.flush() diff --git a/pytorch_transformers/modeling_utils.py b/pytorch_transformers/modeling_utils.py index 3e4fbca132..f1501aa8d5 100644 --- a/pytorch_transformers/modeling_utils.py +++ b/pytorch_transformers/modeling_utils.py @@ -128,6 +128,10 @@ class PretrainedConfig(object): force_download: (`optional`) boolean, default False: Force to (re-)download the model weights and configuration files and override the cached versions if they exists. + proxies: (`optional`) dict, default None: + A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. + The proxies are used on each request. + return_unused_kwargs: (`optional`) bool: - If False, then this function returns just the final configuration object. @@ -150,6 +154,7 @@ class PretrainedConfig(object): """ cache_dir = kwargs.pop('cache_dir', None) force_download = kwargs.pop('force_download', False) + proxies = kwargs.pop('proxies', None) return_unused_kwargs = kwargs.pop('return_unused_kwargs', False) if pretrained_model_name_or_path in cls.pretrained_config_archive_map: @@ -160,7 +165,7 @@ class PretrainedConfig(object): config_file = pretrained_model_name_or_path # redirect to the cache, if necessary try: - resolved_config_file = cached_path(config_file, cache_dir=cache_dir, force_download=force_download) + resolved_config_file = cached_path(config_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies) except EnvironmentError: if pretrained_model_name_or_path in cls.pretrained_config_archive_map: logger.error( @@ -407,6 +412,10 @@ class PreTrainedModel(nn.Module): force_download: (`optional`) boolean, default False: Force to (re-)download the model weights and configuration files and override the cached versions if they exists. + proxies: (`optional`) dict, default None: + A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. + The proxies are used on each request. + output_loading_info: (`optional`) boolean: Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages. @@ -432,6 +441,7 @@ class PreTrainedModel(nn.Module): cache_dir = kwargs.pop('cache_dir', None) from_tf = kwargs.pop('from_tf', False) force_download = kwargs.pop('force_download', False) + proxies = kwargs.pop('proxies', None) output_loading_info = kwargs.pop('output_loading_info', False) # Load config @@ -462,7 +472,7 @@ class PreTrainedModel(nn.Module): archive_file = pretrained_model_name_or_path # redirect to the cache, if necessary try: - resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download) + resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies) except EnvironmentError: if pretrained_model_name_or_path in cls.pretrained_model_archive_map: logger.error( diff --git a/pytorch_transformers/tokenization_utils.py b/pytorch_transformers/tokenization_utils.py index 763c0cee04..68af97a518 100644 --- a/pytorch_transformers/tokenization_utils.py +++ b/pytorch_transformers/tokenization_utils.py @@ -196,6 +196,10 @@ class PreTrainedTokenizer(object): force_download: (`optional`) boolean, default False: Force to (re-)download the vocabulary files and override the cached versions if they exists. + proxies: (`optional`) dict, default None: + A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. + The proxies are used on each request. + inputs: (`optional`) positional arguments: will be passed to the Tokenizer ``__init__`` method. kwargs: (`optional`) keyword arguments: will be passed to the Tokenizer ``__init__`` method. Can be used to set special tokens like ``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``mask_token``, ``additional_special_tokens``. See parameters in the doc string of :class:`~pytorch_transformers.PreTrainedTokenizer` for details. @@ -227,6 +231,7 @@ class PreTrainedTokenizer(object): def _from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): cache_dir = kwargs.pop('cache_dir', None) force_download = kwargs.pop('force_download', False) + proxies = kwargs.pop('proxies', None) s3_models = list(cls.max_model_input_sizes.keys()) vocab_files = {} @@ -287,7 +292,7 @@ class PreTrainedTokenizer(object): if file_path is None: resolved_vocab_files[file_id] = None else: - resolved_vocab_files[file_id] = cached_path(file_path, cache_dir=cache_dir, force_download=force_download) + resolved_vocab_files[file_id] = cached_path(file_path, cache_dir=cache_dir, force_download=force_download, proxies=proxies) except EnvironmentError: if pretrained_model_name_or_path in s3_models: logger.error("Couldn't reach server to download vocabulary.")