add force_download option to from_pretrained methods

This commit is contained in:
thomwolf
2019-08-20 10:56:12 +02:00
parent c589862b78
commit fecaed0ed4
3 changed files with 24 additions and 8 deletions

View File

@@ -193,6 +193,9 @@ class PreTrainedTokenizer(object):
cache_dir: (`optional`) string:
Path to a directory in which a downloaded predefined tokenizer vocabulary files should be cached if the standard cache should not be used.
force_download: (`optional`) boolean, default False:
Force to (re-)download the vocabulary files and override the cached versions if they exists.
inputs: (`optional`) positional arguments: will be passed to the Tokenizer ``__init__`` method.
kwargs: (`optional`) keyword arguments: will be passed to the Tokenizer ``__init__`` method. Can be used to set special tokens like ``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``mask_token``, ``additional_special_tokens``. See parameters in the doc string of :class:`~pytorch_transformers.PreTrainedTokenizer` for details.
@@ -223,6 +226,7 @@ class PreTrainedTokenizer(object):
@classmethod
def _from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
cache_dir = kwargs.pop('cache_dir', None)
force_download = kwargs.pop('force_download', False)
s3_models = list(cls.max_model_input_sizes.keys())
vocab_files = {}
@@ -283,7 +287,7 @@ class PreTrainedTokenizer(object):
if file_path is None:
resolved_vocab_files[file_id] = None
else:
resolved_vocab_files[file_id] = cached_path(file_path, cache_dir=cache_dir)
resolved_vocab_files[file_id] = cached_path(file_path, cache_dir=cache_dir, force_download=force_download)
except EnvironmentError:
if pretrained_model_name_or_path in s3_models:
logger.error("Couldn't reach server to download vocabulary.")