diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index a8fefa2310..f07e366c79 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -14,12 +14,21 @@ # limitations under the License. """ Auto Tokenizer class. """ - +import json +import os from collections import OrderedDict +from typing import Dict, Optional, Union from ... import GPTNeoConfig from ...configuration_utils import PretrainedConfig -from ...file_utils import is_sentencepiece_available, is_tokenizers_available +from ...file_utils import ( + cached_path, + hf_bucket_url, + is_offline_mode, + is_sentencepiece_available, + is_tokenizers_available, +) +from ...tokenization_utils_base import TOKENIZER_CONFIG_FILE from ...utils import logging from ..bart.tokenization_bart import BartTokenizer from ..bert.tokenization_bert import BertTokenizer @@ -323,6 +332,105 @@ def tokenizer_class_from_name(class_name: str): return c +def get_tokenizer_config( + pretrained_model_name_or_path: Union[str, os.PathLike], + cache_dir: Optional[Union[str, os.PathLike]] = None, + force_download: bool = False, + resume_download: bool = False, + proxies: Optional[Dict[str, str]] = None, + use_auth_token: Optional[Union[bool, str]] = None, + revision: Optional[str] = None, + local_files_only: bool = False, + **kwargs, +): + """ + Loads the tokenizer configuration from a pretrained model tokenizer configuration. + + Args: + pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`): + This can be either: + + - a string, the `model id` of a pretrained model configuration hosted inside a model repo on + huggingface.co. Valid model ids can be located at the root-level, like ``bert-base-uncased``, or + namespaced under a user or organization name, like ``dbmdz/bert-base-german-cased``. + - a path to a `directory` containing a configuration file saved using the + :func:`~transformers.PreTrainedTokenizer.save_pretrained` method, e.g., ``./my_model_directory/``. + + cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`): + Path to a directory in which a downloaded pretrained model configuration should be cached if the standard + cache should not be used. + force_download (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to force to (re-)download the configuration files and override the cached versions if they + exist. + resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists. + proxies (:obj:`Dict[str, str]`, `optional`): + A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. + use_auth_token (:obj:`str` or `bool`, `optional`): + The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token + generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`). + revision(:obj:`str`, `optional`, defaults to :obj:`"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any + identifier allowed by git. + local_files_only (:obj:`bool`, `optional`, defaults to :obj:`False`): + If :obj:`True`, will only try to load the tokenizer configuration from local files. + + .. note:: + + Passing :obj:`use_auth_token=True` is required when you want to use a private model. + + + Returns: + :obj:`Dict`: The configuration of the tokenizer. + + Examples:: + + # Download configuration from huggingface.co and cache. + tokenizer_config = get_tokenizer_config("bert-base-uncased") + # This model does not have a tokenizer config so the result will be an empty dict. + tokenizer_config = get_tokenizer_config("xlm-roberta-base") + + # Save a pretrained tokenizer locally and you can reload its config + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained("bert-base-cased") + tokenizer.save_pretrained("tokenizer-test") + tokenizer_config = get_tokenizer_config("tokenizer-test") + """ + if is_offline_mode() and not local_files_only: + logger.info("Offline mode: forcing local_files_only=True") + local_files_only = True + + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + if os.path.isdir(pretrained_model_name_or_path): + config_file = os.path.join(pretrained_model_name_or_path, TOKENIZER_CONFIG_FILE) + else: + config_file = hf_bucket_url( + pretrained_model_name_or_path, filename=TOKENIZER_CONFIG_FILE, revision=revision, mirror=None + ) + + try: + # Load from URL or cache if already cached + resolved_config_file = cached_path( + config_file, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + ) + + except EnvironmentError: + logger.info("Could not locate the tokenizer configuration file, will try to use the model config instead.") + return {} + + with open(resolved_config_file, encoding="utf-8") as reader: + return json.load(reader) + + class AutoTokenizer: r""" This is a generic tokenizer class that will be instantiated as one of the tokenizer classes of the library when @@ -408,18 +516,27 @@ class AutoTokenizer: """ config = kwargs.pop("config", None) kwargs["_from_auto"] = True - if not isinstance(config, PretrainedConfig): - config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) use_fast = kwargs.pop("use_fast", True) - if config.tokenizer_class is not None: + # First, let's try to use the tokenizer_config file to get the tokenizer class. + tokenizer_config = get_tokenizer_config(pretrained_model_name_or_path, **kwargs) + config_tokenizer_class = tokenizer_config.get("tokenizer_class") + + # If that did not work, let's try to use the config. + if config_tokenizer_class is None: + if not isinstance(config, PretrainedConfig): + config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + config_tokenizer_class = config.tokenizer_class + + # If we have the tokenizer class from the tokenizer config or the model config we're good! + if config_tokenizer_class is not None: tokenizer_class = None - if use_fast and not config.tokenizer_class.endswith("Fast"): - tokenizer_class_candidate = f"{config.tokenizer_class}Fast" + if use_fast and not config_tokenizer_class.endswith("Fast"): + tokenizer_class_candidate = f"{config_tokenizer_class}Fast" tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate) if tokenizer_class is None: - tokenizer_class_candidate = config.tokenizer_class + tokenizer_class_candidate = config_tokenizer_class tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate) if tokenizer_class is None: @@ -428,6 +545,7 @@ class AutoTokenizer: ) return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) + # Otherwise we have to be creative. # if model is an encoder decoder, the encoder tokenizer class is used by default if isinstance(config, EncoderDecoderConfig): if type(config.decoder) is not type(config.encoder): # noqa: E721 diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 4f3129e16d..31d13a3fdc 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -1745,6 +1745,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): if tokenizer_config_file is not None: with open(tokenizer_config_file, encoding="utf-8") as tokenizer_config_handle: init_kwargs = json.load(tokenizer_config_handle) + init_kwargs.pop("tokenizer_class", None) saved_init_inputs = init_kwargs.pop("init_inputs", ()) if not init_inputs: init_inputs = saved_init_inputs @@ -1920,6 +1921,14 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): # add_type_field=True to allow dicts in the kwargs / differentiate from AddedToken serialization tokenizer_config = convert_added_tokens(tokenizer_config, add_type_field=True) + + # Add tokenizer class to the tokenizer config to be able to reload it with from_pretrained + tokenizer_class = self.__class__.__name__ + # Remove the Fast at the end unless we have a special `PreTrainedTokenizerFast` + if tokenizer_class.endswith("Fast") and tokenizer_class != "PreTrainedTokenizerFast": + tokenizer_class = tokenizer_class[:-4] + tokenizer_config["tokenizer_class"] = tokenizer_class + with open(tokenizer_config_file, "w", encoding="utf-8") as f: f.write(json.dumps(tokenizer_config, ensure_ascii=False)) logger.info(f"tokenizer config file saved in {tokenizer_config_file}") diff --git a/tests/test_tokenization_auto.py b/tests/test_tokenization_auto.py index 72db79d1c5..f35d0eb5e2 100644 --- a/tests/test_tokenization_auto.py +++ b/tests/test_tokenization_auto.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. - +import tempfile import unittest from transformers import ( @@ -29,7 +29,7 @@ from transformers import ( RobertaTokenizerFast, ) from transformers.models.auto.configuration_auto import AutoConfig -from transformers.models.auto.tokenization_auto import TOKENIZER_MAPPING +from transformers.models.auto.tokenization_auto import TOKENIZER_MAPPING, get_tokenizer_config from transformers.models.roberta.configuration_roberta import RobertaConfig from transformers.testing_utils import ( DUMMY_DIFF_TOKENIZER_IDENTIFIER, @@ -129,3 +129,34 @@ class AutoTokenizerTest(unittest.TestCase): self.assertEqual(tokenizer.vocab_size, 30000) self.assertEqual(tokenizer.unk_token, "[UNK]") self.assertEqual(tokenizer.padding_side, "right") + + def test_auto_tokenizer_from_local_folder(self): + tokenizer = AutoTokenizer.from_pretrained(SMALL_MODEL_IDENTIFIER) + self.assertIsInstance(tokenizer, (BertTokenizer, BertTokenizerFast)) + with tempfile.TemporaryDirectory() as tmp_dir: + tokenizer.save_pretrained(tmp_dir) + tokenizer2 = AutoTokenizer.from_pretrained(tmp_dir) + + self.assertIsInstance(tokenizer2, tokenizer.__class__) + self.assertEqual(tokenizer2.vocab_size, 12) + + def test_get_tokenizer_config(self): + # Check we can load the tokenizer config of an online model. + config = get_tokenizer_config("bert-base-cased") + # If we ever update bert-base-cased tokenizer config, this dict here will need to be updated. + self.assertEqual(config, {"do_lower_case": False}) + + # This model does not have a tokenizer_config so we get back an empty dict. + config = get_tokenizer_config(SMALL_MODEL_IDENTIFIER) + self.assertDictEqual(config, {}) + + # A tokenizer saved with `save_pretrained` always creates a tokenizer config. + tokenizer = AutoTokenizer.from_pretrained(SMALL_MODEL_IDENTIFIER) + with tempfile.TemporaryDirectory() as tmp_dir: + tokenizer.save_pretrained(tmp_dir) + config = get_tokenizer_config(tmp_dir) + + # Check the class of the tokenizer was properly saved (note that it always saves the slow class). + self.assertEqual(config["tokenizer_class"], "BertTokenizer") + # Check other keys just to make sure the config was properly saved /reloaded. + self.assertEqual(config["name_or_path"], SMALL_MODEL_IDENTIFIER)