add cache_dir for tokenizer verification loading (#14508)

When loading a pretrained tokenizer, a verification is done to ensure
that the actual tokenizer class matches the class it was called from.
If the tokenizer is absent, its config file is loaded from the repo.

However, the cache_dir for downloading is not provided, which leads to
ignoring of the user-specified cache_dir, storing files in several
places and and may result in incorrect warnings when the default
cache_dir is unreachsble.

This commit fixes that.
This commit is contained in:
Vladimir Maryasin
2021-11-24 12:22:03 +01:00
committed by GitHub
parent 956a483173
commit 6c4d688ffa

View File

@@ -1747,6 +1747,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
init_configuration,
*init_inputs,
use_auth_token=use_auth_token,
cache_dir=cache_dir,
**kwargs,
)
@@ -1758,6 +1759,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
init_configuration,
*init_inputs,
use_auth_token=None,
cache_dir=None,
**kwargs
):
# We instantiate fast tokenizers based on a slow tokenizer if we don't have access to the tokenizer.json
@@ -1797,7 +1799,11 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
# Second attempt. If we have not yet found tokenizer_class, let's try to use the config.
try:
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, use_auth_token=use_auth_token)
config = AutoConfig.from_pretrained(
pretrained_model_name_or_path,
use_auth_token=use_auth_token,
cache_dir=cache_dir,
)
config_tokenizer_class = config.tokenizer_class
except (OSError, ValueError, KeyError):
# skip if an error occurred.