From 6c4d688ffa8095f6dbaa959a51b53a91073f2aeb Mon Sep 17 00:00:00 2001 From: Vladimir Maryasin <67067775+vmaryasin@users.noreply.github.com> Date: Wed, 24 Nov 2021 12:22:03 +0100 Subject: [PATCH] add cache_dir for tokenizer verification loading (#14508) When loading a pretrained tokenizer, a verification is done to ensure that the actual tokenizer class matches the class it was called from. If the tokenizer is absent, its config file is loaded from the repo. However, the cache_dir for downloading is not provided, which leads to ignoring of the user-specified cache_dir, storing files in several places and and may result in incorrect warnings when the default cache_dir is unreachsble. This commit fixes that. --- src/transformers/tokenization_utils_base.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 0da576be74..d72ad37dcf 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -1747,6 +1747,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): init_configuration, *init_inputs, use_auth_token=use_auth_token, + cache_dir=cache_dir, **kwargs, ) @@ -1758,6 +1759,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): init_configuration, *init_inputs, use_auth_token=None, + cache_dir=None, **kwargs ): # We instantiate fast tokenizers based on a slow tokenizer if we don't have access to the tokenizer.json @@ -1797,7 +1799,11 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): # Second attempt. If we have not yet found tokenizer_class, let's try to use the config. try: - config = AutoConfig.from_pretrained(pretrained_model_name_or_path, use_auth_token=use_auth_token) + config = AutoConfig.from_pretrained( + pretrained_model_name_or_path, + use_auth_token=use_auth_token, + cache_dir=cache_dir, + ) config_tokenizer_class = config.tokenizer_class except (OSError, ValueError, KeyError): # skip if an error occurred.