From e695470794f236392f249aeb815b62490126f595 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Tue, 25 Jan 2022 09:41:21 -0500 Subject: [PATCH] Avoid using get_list_of_files (#15287) * Avoid using get_list_of_files in config * Wip, change tokenizer file getter * Remove call in tokenizer files * Remove last call to get_list_model_files * Better tests * Unit tests for new function * Document bad API --- src/transformers/configuration_utils.py | 60 ++++------ src/transformers/file_utils.py | 112 ++++++++++++++++++ .../models/auto/processing_auto.py | 30 +++-- .../models/auto/tokenization_auto.py | 62 ++-------- src/transformers/tokenization_utils_base.py | 53 ++++----- tests/test_configuration_common.py | 14 ++- tests/test_file_utils.py | 29 +++++ tests/test_tokenization_fast.py | 6 +- 8 files changed, 232 insertions(+), 134 deletions(-) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 1d97de7d51..670fb78560 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -21,7 +21,7 @@ import json import os import re import warnings -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, List, Tuple, Union from packaging import version @@ -36,7 +36,6 @@ from .file_utils import ( RevisionNotFoundError, cached_path, copy_func, - get_list_of_files, hf_bucket_url, is_offline_mode, is_remote_url, @@ -46,7 +45,7 @@ from .utils import logging logger = logging.get_logger(__name__) -FULL_CONFIGURATION_FILE = "config.json" + _re_configuration_file = re.compile(r"config\.(.*)\.json") @@ -533,6 +532,23 @@ class PretrainedConfig(PushToHubMixin): `Tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the configuration object. """ + original_kwargs = copy.deepcopy(kwargs) + # Get config dict associated with the base config file + config_dict, kwargs = cls._get_config_dict(pretrained_model_name_or_path, **kwargs) + + # That config file may point us toward another config file to use. + if "configuration_files" in config_dict: + configuration_file = get_configuration_file(config_dict["configuration_files"]) + config_dict, kwargs = cls._get_config_dict( + pretrained_model_name_or_path, _configuration_file=configuration_file, **original_kwargs + ) + + return config_dict, kwargs + + @classmethod + def _get_config_dict( + cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs + ) -> Tuple[Dict[str, Any], Dict[str, Any]]: cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) resume_download = kwargs.pop("resume_download", False) @@ -555,12 +571,7 @@ class PretrainedConfig(PushToHubMixin): if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): config_file = pretrained_model_name_or_path else: - configuration_file = get_configuration_file( - pretrained_model_name_or_path, - revision=revision, - use_auth_token=use_auth_token, - local_files_only=local_files_only, - ) + configuration_file = kwargs.get("_configuration_file", CONFIG_NAME) if os.path.isdir(pretrained_model_name_or_path): config_file = os.path.join(pretrained_model_name_or_path, configuration_file) @@ -840,41 +851,18 @@ class PretrainedConfig(PushToHubMixin): d["torch_dtype"] = str(d["torch_dtype"]).split(".")[1] -def get_configuration_file( - path_or_repo: Union[str, os.PathLike], - revision: Optional[str] = None, - use_auth_token: Optional[Union[bool, str]] = None, - local_files_only: bool = False, -) -> str: +def get_configuration_file(configuration_files: List[str]) -> str: """ Get the configuration file to use for this version of transformers. Args: - path_or_repo (`str` or `os.PathLike`): - Can be either the id of a repo on huggingface.co or a path to a *directory*. - revision(`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a - git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any - identifier allowed by git. - use_auth_token (`str` or *bool*, *optional*): - The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated - when running `transformers-cli login` (stored in `~/.huggingface`). - local_files_only (`bool`, *optional*, defaults to `False`): - Whether or not to only rely on local files and not to attempt to download any files. + configuration_files (`List[str]`): The list of available configuration files. Returns: `str`: The configuration file to use. """ - # Inspect all files from the repo/folder. - try: - all_files = get_list_of_files( - path_or_repo, revision=revision, use_auth_token=use_auth_token, local_files_only=local_files_only - ) - except Exception: - return FULL_CONFIGURATION_FILE - configuration_files_map = {} - for file_name in all_files: + for file_name in configuration_files: search = _re_configuration_file.search(file_name) if search is not None: v = search.groups()[0] @@ -882,7 +870,7 @@ def get_configuration_file( available_versions = sorted(configuration_files_map.keys()) # Defaults to FULL_CONFIGURATION_FILE and then try to look at some newer versions. - configuration_file = FULL_CONFIGURATION_FILE + configuration_file = CONFIG_NAME transformers_version = version.parse(__version__) for v in available_versions: if version.parse(v) <= transformers_version: diff --git a/src/transformers/file_utils.py b/src/transformers/file_utils.py index d71aeb0f7c..43b87f2ca3 100644 --- a/src/transformers/file_utils.py +++ b/src/transformers/file_utils.py @@ -2112,6 +2112,112 @@ def get_from_cache( return cache_path +def get_file_from_repo( + path_or_repo: Union[str, os.PathLike], + filename: str, + cache_dir: Optional[Union[str, os.PathLike]] = None, + force_download: bool = False, + resume_download: bool = False, + proxies: Optional[Dict[str, str]] = None, + use_auth_token: Optional[Union[bool, str]] = None, + revision: Optional[str] = None, + local_files_only: bool = False, +): + """ + Tries to locate a file in a local folder and repo, downloads and cache it if necessary. + + Args: + path_or_repo (`str` or `os.PathLike`): + This can be either: + + - a string, the *model id* of a model repo on huggingface.co. + - a path to a *directory* potentially containing the file. + filename (`str`): + The name of the file to locate in `path_or_repo`. + cache_dir (`str` or `os.PathLike`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the standard + cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force to (re-)download the configuration files and override the cached versions if they + exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `transformers-cli login` (stored in `~/.huggingface`). + revision(`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + local_files_only (`bool`, *optional*, defaults to `False`): + If `True`, will only try to load the tokenizer configuration from local files. + + + + Passing `use_auth_token=True` is required when you want to use a private model. + + + + Returns: + `Optional[str]`: Returns the resolved file (to the cache folder if downloaded from a repo) or `None` if the + file does not exist. + + Examples: + + ```python + # Download a tokenizer configuration from huggingface.co and cache. + tokenizer_config = get_file_from_repo("bert-base-uncased", "tokenizer_config.json") + # This model does not have a tokenizer config so the result will be None. + tokenizer_config = get_file_from_repo("xlm-roberta-base", "tokenizer_config.json") + ```""" + if is_offline_mode() and not local_files_only: + logger.info("Offline mode: forcing local_files_only=True") + local_files_only = True + + path_or_repo = str(path_or_repo) + if os.path.isdir(path_or_repo): + resolved_file = os.path.join(path_or_repo, filename) + return resolved_file if os.path.isfile(resolved_file) else None + else: + resolved_file = hf_bucket_url(path_or_repo, filename=filename, revision=revision, mirror=None) + + try: + # Load from URL or cache if already cached + resolved_file = cached_path( + resolved_file, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + ) + + except RepositoryNotFoundError as err: + logger.error(err) + raise EnvironmentError( + f"{path_or_repo} is not a local folder and is not a valid model identifier " + "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to " + "pass a token having permission to this repo with `use_auth_token` or log in with " + "`huggingface-cli login` and pass `use_auth_token=True`." + ) + except RevisionNotFoundError as err: + logger.error(err) + raise EnvironmentError( + f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists " + "for this model name. Check the model page at " + f"'https://huggingface.co/{path_or_repo}' for available revisions." + ) + except EnvironmentError: + # The repo and revision exist, but the file does not or there was a connection error fetching it. + return None + + return resolved_file + + def has_file( path_or_repo: Union[str, os.PathLike], filename: str, @@ -2184,6 +2290,12 @@ def get_list_of_files( local_files_only (`bool`, *optional*, defaults to `False`): Whether or not to only rely on local files and not to attempt to download any files. + + + This API is not optimized, so calling it a lot may result in connection errors. + + + Returns: `List[str]`: The list of files available in `path_or_repo`. """ diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index f6cebe349c..5a788e16b8 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -14,12 +14,14 @@ # limitations under the License. """ AutoProcessor class.""" import importlib +import inspect +import json from collections import OrderedDict # Build the list of all feature extractors from ...configuration_utils import PretrainedConfig from ...feature_extraction_utils import FeatureExtractionMixin -from ...file_utils import CONFIG_NAME, FEATURE_EXTRACTOR_NAME, get_list_of_files +from ...file_utils import CONFIG_NAME, FEATURE_EXTRACTOR_NAME, get_file_from_repo from ...tokenization_utils import TOKENIZER_CONFIG_FILE from .auto_factory import _LazyAutoMapping from .configuration_auto import ( @@ -29,7 +31,6 @@ from .configuration_auto import ( model_type_to_module_name, replace_list_option_in_docstrings, ) -from .tokenization_auto import get_tokenizer_config PROCESSOR_MAPPING_NAMES = OrderedDict( @@ -145,24 +146,29 @@ class AutoProcessor: kwargs["_from_auto"] = True # First, let's see if we have a preprocessor config. - # get_list_of_files only takes three of the kwargs we have, so we filter them. - get_list_of_files_kwargs = { - key: kwargs[key] for key in ["revision", "use_auth_token", "local_files_only"] if key in kwargs + # Filter the kwargs for `get_file_from_repo``. + get_file_from_repo_kwargs = { + key: kwargs[key] for key in inspect.signature(get_file_from_repo).parameters.keys() if key in kwargs } - model_files = get_list_of_files(pretrained_model_name_or_path, **get_list_of_files_kwargs) - # strip to file name - model_files = [f.split("/")[-1] for f in model_files] - # Let's start by checking whether the processor class is saved in a feature extractor - if FEATURE_EXTRACTOR_NAME in model_files: + preprocessor_config_file = get_file_from_repo( + pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME, **get_file_from_repo_kwargs + ) + if preprocessor_config_file is not None: config_dict, _ = FeatureExtractionMixin.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs) if "processor_class" in config_dict: processor_class = processor_class_from_name(config_dict["processor_class"]) return processor_class.from_pretrained(pretrained_model_name_or_path, **kwargs) # Next, let's check whether the processor class is saved in a tokenizer - if TOKENIZER_CONFIG_FILE in model_files: - config_dict = get_tokenizer_config(pretrained_model_name_or_path, **kwargs) + # Let's start by checking whether the processor class is saved in a feature extractor + tokenizer_config_file = get_file_from_repo( + pretrained_model_name_or_path, TOKENIZER_CONFIG_FILE, **get_file_from_repo_kwargs + ) + if tokenizer_config_file is not None: + with open(tokenizer_config_file, encoding="utf-8") as reader: + config_dict = json.load(reader) + if "processor_class" in config_dict: processor_class = processor_class_from_name(config_dict["processor_class"]) return processor_class.from_pretrained(pretrained_model_name_or_path, **kwargs) diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 459e4f81b5..2e706427d3 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -21,15 +21,7 @@ from collections import OrderedDict from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union from ...configuration_utils import PretrainedConfig -from ...file_utils import ( - RepositoryNotFoundError, - RevisionNotFoundError, - cached_path, - hf_bucket_url, - is_offline_mode, - is_sentencepiece_available, - is_tokenizers_available, -) +from ...file_utils import get_file_from_repo, is_sentencepiece_available, is_tokenizers_available from ...tokenization_utils import PreTrainedTokenizer from ...tokenization_utils_base import TOKENIZER_CONFIG_FILE from ...tokenization_utils_fast import PreTrainedTokenizerFast @@ -329,46 +321,18 @@ def get_tokenizer_config( tokenizer.save_pretrained("tokenizer-test") tokenizer_config = get_tokenizer_config("tokenizer-test") ```""" - if is_offline_mode() and not local_files_only: - logger.info("Offline mode: forcing local_files_only=True") - local_files_only = True - - pretrained_model_name_or_path = str(pretrained_model_name_or_path) - if os.path.isdir(pretrained_model_name_or_path): - config_file = os.path.join(pretrained_model_name_or_path, TOKENIZER_CONFIG_FILE) - else: - config_file = hf_bucket_url( - pretrained_model_name_or_path, filename=TOKENIZER_CONFIG_FILE, revision=revision, mirror=None - ) - - try: - # Load from URL or cache if already cached - resolved_config_file = cached_path( - config_file, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - resume_download=resume_download, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - ) - - except RepositoryNotFoundError as err: - logger.error(err) - raise EnvironmentError( - f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier " - "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to " - "pass a token having permission to this repo with `use_auth_token` or log in with " - "`huggingface-cli login` and pass `use_auth_token=True`." - ) - except RevisionNotFoundError as err: - logger.error(err) - raise EnvironmentError( - f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists " - "for this model name. Check the model page at " - f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions." - ) - except EnvironmentError: + resolved_config_file = get_file_from_repo( + pretrained_model_name_or_path, + TOKENIZER_CONFIG_FILE, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + use_auth_token=use_auth_token, + revision=revision, + local_files_only=local_files_only, + ) + if resolved_config_file is None: logger.info("Could not locate the tokenizer configuration file, will try to use the model config instead.") return {} diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 290f627b04..8389e7a6cf 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -50,7 +50,7 @@ from .file_utils import ( add_end_docstrings, cached_path, copy_func, - get_list_of_files, + get_file_from_repo, hf_bucket_url, is_flax_available, is_offline_mode, @@ -1649,12 +1649,26 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): vocab_files[file_id] = pretrained_model_name_or_path else: # At this point pretrained_model_name_or_path is either a directory or a model identifier name - fast_tokenizer_file = get_fast_tokenizer_file( + + # Try to get the tokenizer config to see if there are versioned tokenizer files. + fast_tokenizer_file = FULL_TOKENIZER_FILE + resolved_config_file = get_file_from_repo( pretrained_model_name_or_path, - revision=revision, + TOKENIZER_CONFIG_FILE, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, use_auth_token=use_auth_token, + revision=revision, local_files_only=local_files_only, ) + if resolved_config_file is not None: + with open(resolved_config_file, encoding="utf-8") as reader: + tokenizer_config = json.load(reader) + if "fast_tokenizer_files" in tokenizer_config: + fast_tokenizer_file = get_fast_tokenizer_file(tokenizer_config["fast_tokenizer_files"]) + additional_files_names = { "added_tokens_file": ADDED_TOKENS_FILE, "special_tokens_map_file": SPECIAL_TOKENS_MAP_FILE, @@ -3495,41 +3509,18 @@ For a more complete example, see the implementation of `prepare_seq2seq_batch`. return model_inputs -def get_fast_tokenizer_file( - path_or_repo: Union[str, os.PathLike], - revision: Optional[str] = None, - use_auth_token: Optional[Union[bool, str]] = None, - local_files_only: bool = False, -) -> str: +def get_fast_tokenizer_file(tokenization_files: List[str]) -> str: """ - Get the tokenizer file to use for this version of transformers. + Get the tokenization file to use for this version of transformers. Args: - path_or_repo (`str` or `os.PathLike`): - Can be either the id of a repo on huggingface.co or a path to a *directory*. - revision(`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a - git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any - identifier allowed by git. - use_auth_token (`str` or *bool*, *optional*): - The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated - when running `transformers-cli login` (stored in `~/.huggingface`). - local_files_only (`bool`, *optional*, defaults to `False`): - Whether or not to only rely on local files and not to attempt to download any files. + tokenization_files (`List[str]`): The list of available configuration files. Returns: - `str`: The tokenizer file to use. + `str`: The tokenization file to use. """ - # Inspect all files from the repo/folder. - try: - all_files = get_list_of_files( - path_or_repo, revision=revision, use_auth_token=use_auth_token, local_files_only=local_files_only - ) - except Exception: - return FULL_TOKENIZER_FILE - tokenizer_files_map = {} - for file_name in all_files: + for file_name in tokenization_files: search = _re_tokenizer_file.search(file_name) if search is not None: v = search.groups()[0] diff --git a/tests/test_configuration_common.py b/tests/test_configuration_common.py index 79e0e62417..7a84d84c79 100644 --- a/tests/test_configuration_common.py +++ b/tests/test_configuration_common.py @@ -313,6 +313,7 @@ class ConfigTestUtils(unittest.TestCase): class ConfigurationVersioningTest(unittest.TestCase): def test_local_versioning(self): configuration = AutoConfig.from_pretrained("bert-base-cased") + configuration.configuration_files = ["config.4.0.0.json"] with tempfile.TemporaryDirectory() as tmp_dir: configuration.save_pretrained(tmp_dir) @@ -325,23 +326,26 @@ class ConfigurationVersioningTest(unittest.TestCase): # Will need to be adjusted if we reach v42 and this test is still here. # Should pick the old configuration file as the version of Transformers is < 4.42.0 + configuration.configuration_files = ["config.42.0.0.json"] + configuration.hidden_size = 768 + configuration.save_pretrained(tmp_dir) shutil.move(os.path.join(tmp_dir, "config.4.0.0.json"), os.path.join(tmp_dir, "config.42.0.0.json")) new_configuration = AutoConfig.from_pretrained(tmp_dir) self.assertEqual(new_configuration.hidden_size, 768) def test_repo_versioning_before(self): - # This repo has two configuration files, one for v5.0.0 and above with an added token, one for versions lower. - repo = "microsoft/layoutxlm-base" + # This repo has two configuration files, one for v4.0.0 and above with a different hidden size. + repo = "hf-internal-testing/test-two-configs" import transformers as new_transformers - new_transformers.configuration_utils.__version__ = "v5.0.0" + new_transformers.configuration_utils.__version__ = "v4.0.0" new_configuration = new_transformers.models.auto.AutoConfig.from_pretrained(repo) - self.assertEqual(new_configuration.tokenizer_class, None) + self.assertEqual(new_configuration.hidden_size, 2) # Testing an older version by monkey-patching the version in the module it's used. import transformers as old_transformers old_transformers.configuration_utils.__version__ = "v3.0.0" old_configuration = old_transformers.models.auto.AutoConfig.from_pretrained(repo) - self.assertEqual(old_configuration.tokenizer_class, "XLMRobertaTokenizer") + self.assertEqual(old_configuration.hidden_size, 768) diff --git a/tests/test_file_utils.py b/tests/test_file_utils.py index 682f1214c7..768dda263d 100644 --- a/tests/test_file_utils.py +++ b/tests/test_file_utils.py @@ -15,7 +15,10 @@ import contextlib import importlib import io +import json +import tempfile import unittest +from pathlib import Path import transformers @@ -31,6 +34,7 @@ from transformers.file_utils import ( RepositoryNotFoundError, RevisionNotFoundError, filename_to_url, + get_file_from_repo, get_from_cache, has_file, hf_bucket_url, @@ -128,6 +132,31 @@ class GetFromCacheTests(unittest.TestCase): self.assertFalse(has_file("hf-internal-testing/tiny-bert-pt-only", TF2_WEIGHTS_NAME)) self.assertFalse(has_file("hf-internal-testing/tiny-bert-pt-only", FLAX_WEIGHTS_NAME)) + def test_get_file_from_repo_distant(self): + # `get_file_from_repo` returns None if the file does not exist + self.assertIsNone(get_file_from_repo("bert-base-cased", "ahah.txt")) + + # The function raises if the repository does not exist. + with self.assertRaisesRegex(EnvironmentError, "is not a valid model identifier"): + get_file_from_repo("bert-base-case", "config.json") + + # The function raises if the revision does not exist. + with self.assertRaisesRegex(EnvironmentError, "is not a valid git identifier"): + get_file_from_repo("bert-base-cased", "config.json", revision="ahaha") + + resolved_file = get_file_from_repo("bert-base-cased", "config.json") + # The name is the cached name which is not very easy to test, so instead we load the content. + config = json.loads(open(resolved_file, "r").read()) + self.assertEqual(config["hidden_size"], 768) + + def test_get_file_from_repo_local(self): + with tempfile.TemporaryDirectory() as tmp_dir: + filename = Path(tmp_dir) / "a.txt" + filename.touch() + self.assertEqual(get_file_from_repo(tmp_dir, "a.txt"), str(filename)) + + self.assertIsNone(get_file_from_repo(tmp_dir, "b.txt")) + class ContextManagerTests(unittest.TestCase): @unittest.mock.patch("sys.stdout", new_callable=io.StringIO) diff --git a/tests/test_tokenization_fast.py b/tests/test_tokenization_fast.py index 4fb710319f..b3fd682484 100644 --- a/tests/test_tokenization_fast.py +++ b/tests/test_tokenization_fast.py @@ -108,6 +108,8 @@ class TokenizerVersioningTest(unittest.TestCase): json_tokenizer["model"]["vocab"]["huggingface"] = len(tokenizer) with tempfile.TemporaryDirectory() as tmp_dir: + # Hack to save this in the tokenizer_config.json + tokenizer.init_kwargs["fast_tokenizer_files"] = ["tokenizer.4.0.0.json"] tokenizer.save_pretrained(tmp_dir) json.dump(json_tokenizer, open(os.path.join(tmp_dir, "tokenizer.4.0.0.json"), "w")) @@ -120,6 +122,8 @@ class TokenizerVersioningTest(unittest.TestCase): # Will need to be adjusted if we reach v42 and this test is still here. # Should pick the old tokenizer file as the version of Transformers is < 4.0.0 shutil.move(os.path.join(tmp_dir, "tokenizer.4.0.0.json"), os.path.join(tmp_dir, "tokenizer.42.0.0.json")) + tokenizer.init_kwargs["fast_tokenizer_files"] = ["tokenizer.42.0.0.json"] + tokenizer.save_pretrained(tmp_dir) new_tokenizer = AutoTokenizer.from_pretrained(tmp_dir) self.assertEqual(len(new_tokenizer), len(tokenizer)) json_tokenizer = json.loads(new_tokenizer._tokenizer.to_str()) @@ -127,7 +131,7 @@ class TokenizerVersioningTest(unittest.TestCase): def test_repo_versioning(self): # This repo has two tokenizer files, one for v4.0.0 and above with an added token, one for versions lower. - repo = "sgugger/finetuned-bert-mrpc" + repo = "hf-internal-testing/test-two-tokenizers" # This should pick the new tokenizer file as the version of Transformers is > 4.0.0 tokenizer = AutoTokenizer.from_pretrained(repo)