From fecaed0ed4bf338bca5b9895107b309841f8ac57 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Tue, 20 Aug 2019 10:56:12 +0200 Subject: [PATCH] add force_download option to from_pretrained methods --- pytorch_transformers/file_utils.py | 13 ++++++++----- pytorch_transformers/modeling_utils.py | 13 +++++++++++-- pytorch_transformers/tokenization_utils.py | 6 +++++- 3 files changed, 24 insertions(+), 8 deletions(-) diff --git a/pytorch_transformers/file_utils.py b/pytorch_transformers/file_utils.py index 75c075720c..074e6743ef 100644 --- a/pytorch_transformers/file_utils.py +++ b/pytorch_transformers/file_utils.py @@ -93,12 +93,15 @@ def filename_to_url(filename, cache_dir=None): return url, etag -def cached_path(url_or_filename, cache_dir=None): +def cached_path(url_or_filename, cache_dir=None, force_download=False): """ Given something that might be a URL (or might be a local path), determine which. If it's a URL, download the file and cache it, and return the path to the cached file. If it's already a local path, make sure the file exists and then return the path. + Args: + cache_dir: specify a cache directory to save the file to (overwrite the default cache dir). + force_download: if True, re-dowload the file even if it's already cached in the cache dir. """ if cache_dir is None: cache_dir = PYTORCH_TRANSFORMERS_CACHE @@ -111,7 +114,7 @@ def cached_path(url_or_filename, cache_dir=None): if parsed.scheme in ('http', 'https', 's3'): # URL, so get it from the cache (downloading if necessary) - return get_from_cache(url_or_filename, cache_dir) + return get_from_cache(url_or_filename, cache_dir=cache_dir, force_download=force_download) elif os.path.exists(url_or_filename): # File, and it exists. return url_or_filename @@ -184,7 +187,7 @@ def http_get(url, temp_file): progress.close() -def get_from_cache(url, cache_dir=None): +def get_from_cache(url, cache_dir=None, force_download=False): """ Given a URL, look for the corresponding dataset in the local cache. If it's not there, download it. Then return the path to the cached file. @@ -227,11 +230,11 @@ def get_from_cache(url, cache_dir=None): if matching_files: cache_path = os.path.join(cache_dir, matching_files[-1]) - if not os.path.exists(cache_path): + if not os.path.exists(cache_path) or force_download: # Download to temporary file, then copy to cache dir once finished. # Otherwise you get corrupt cache entries if the download gets interrupted. with tempfile.NamedTemporaryFile() as temp_file: - logger.info("%s not found in cache, downloading to %s", url, temp_file.name) + logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name) # GET file object if url.startswith("s3://"): diff --git a/pytorch_transformers/modeling_utils.py b/pytorch_transformers/modeling_utils.py index edc6b3903e..3e4fbca132 100644 --- a/pytorch_transformers/modeling_utils.py +++ b/pytorch_transformers/modeling_utils.py @@ -125,6 +125,9 @@ class PretrainedConfig(object): - The values in kwargs of any keys which are configuration attributes will be used to override the loaded values. - Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled by the `return_unused_kwargs` keyword parameter. + force_download: (`optional`) boolean, default False: + Force to (re-)download the model weights and configuration files and override the cached versions if they exists. + return_unused_kwargs: (`optional`) bool: - If False, then this function returns just the final configuration object. @@ -146,6 +149,7 @@ class PretrainedConfig(object): """ cache_dir = kwargs.pop('cache_dir', None) + force_download = kwargs.pop('force_download', False) return_unused_kwargs = kwargs.pop('return_unused_kwargs', False) if pretrained_model_name_or_path in cls.pretrained_config_archive_map: @@ -156,7 +160,7 @@ class PretrainedConfig(object): config_file = pretrained_model_name_or_path # redirect to the cache, if necessary try: - resolved_config_file = cached_path(config_file, cache_dir=cache_dir) + resolved_config_file = cached_path(config_file, cache_dir=cache_dir, force_download=force_download) except EnvironmentError: if pretrained_model_name_or_path in cls.pretrained_config_archive_map: logger.error( @@ -400,6 +404,9 @@ class PreTrainedModel(nn.Module): Path to a directory in which a downloaded pre-trained model configuration should be cached if the standard cache should not be used. + force_download: (`optional`) boolean, default False: + Force to (re-)download the model weights and configuration files and override the cached versions if they exists. + output_loading_info: (`optional`) boolean: Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages. @@ -424,6 +431,7 @@ class PreTrainedModel(nn.Module): state_dict = kwargs.pop('state_dict', None) cache_dir = kwargs.pop('cache_dir', None) from_tf = kwargs.pop('from_tf', False) + force_download = kwargs.pop('force_download', False) output_loading_info = kwargs.pop('output_loading_info', False) # Load config @@ -431,6 +439,7 @@ class PreTrainedModel(nn.Module): config, model_kwargs = cls.config_class.from_pretrained( pretrained_model_name_or_path, *model_args, cache_dir=cache_dir, return_unused_kwargs=True, + force_download=force_download, **kwargs ) else: @@ -453,7 +462,7 @@ class PreTrainedModel(nn.Module): archive_file = pretrained_model_name_or_path # redirect to the cache, if necessary try: - resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) + resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download) except EnvironmentError: if pretrained_model_name_or_path in cls.pretrained_model_archive_map: logger.error( diff --git a/pytorch_transformers/tokenization_utils.py b/pytorch_transformers/tokenization_utils.py index 74d50b385d..763c0cee04 100644 --- a/pytorch_transformers/tokenization_utils.py +++ b/pytorch_transformers/tokenization_utils.py @@ -193,6 +193,9 @@ class PreTrainedTokenizer(object): cache_dir: (`optional`) string: Path to a directory in which a downloaded predefined tokenizer vocabulary files should be cached if the standard cache should not be used. + force_download: (`optional`) boolean, default False: + Force to (re-)download the vocabulary files and override the cached versions if they exists. + inputs: (`optional`) positional arguments: will be passed to the Tokenizer ``__init__`` method. kwargs: (`optional`) keyword arguments: will be passed to the Tokenizer ``__init__`` method. Can be used to set special tokens like ``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``mask_token``, ``additional_special_tokens``. See parameters in the doc string of :class:`~pytorch_transformers.PreTrainedTokenizer` for details. @@ -223,6 +226,7 @@ class PreTrainedTokenizer(object): @classmethod def _from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): cache_dir = kwargs.pop('cache_dir', None) + force_download = kwargs.pop('force_download', False) s3_models = list(cls.max_model_input_sizes.keys()) vocab_files = {} @@ -283,7 +287,7 @@ class PreTrainedTokenizer(object): if file_path is None: resolved_vocab_files[file_id] = None else: - resolved_vocab_files[file_id] = cached_path(file_path, cache_dir=cache_dir) + resolved_vocab_files[file_id] = cached_path(file_path, cache_dir=cache_dir, force_download=force_download) except EnvironmentError: if pretrained_model_name_or_path in s3_models: logger.error("Couldn't reach server to download vocabulary.")