AutoTokenizer: infer the class from the tokenizer config if possible (#12208)
* AutoTokenizer: infer the class from the tokenizer config if possible * Add tests * Update src/transformers/models/auto/tokenization_auto.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -14,12 +14,21 @@
|
||||
# limitations under the License.
|
||||
""" Auto Tokenizer class. """
|
||||
|
||||
|
||||
import json
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
from ... import GPTNeoConfig
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...file_utils import is_sentencepiece_available, is_tokenizers_available
|
||||
from ...file_utils import (
|
||||
cached_path,
|
||||
hf_bucket_url,
|
||||
is_offline_mode,
|
||||
is_sentencepiece_available,
|
||||
is_tokenizers_available,
|
||||
)
|
||||
from ...tokenization_utils_base import TOKENIZER_CONFIG_FILE
|
||||
from ...utils import logging
|
||||
from ..bart.tokenization_bart import BartTokenizer
|
||||
from ..bert.tokenization_bert import BertTokenizer
|
||||
@@ -323,6 +332,105 @@ def tokenizer_class_from_name(class_name: str):
|
||||
return c
|
||||
|
||||
|
||||
def get_tokenizer_config(
|
||||
pretrained_model_name_or_path: Union[str, os.PathLike],
|
||||
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
||||
force_download: bool = False,
|
||||
resume_download: bool = False,
|
||||
proxies: Optional[Dict[str, str]] = None,
|
||||
use_auth_token: Optional[Union[bool, str]] = None,
|
||||
revision: Optional[str] = None,
|
||||
local_files_only: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Loads the tokenizer configuration from a pretrained model tokenizer configuration.
|
||||
|
||||
Args:
|
||||
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
|
||||
This can be either:
|
||||
|
||||
- a string, the `model id` of a pretrained model configuration hosted inside a model repo on
|
||||
huggingface.co. Valid model ids can be located at the root-level, like ``bert-base-uncased``, or
|
||||
namespaced under a user or organization name, like ``dbmdz/bert-base-german-cased``.
|
||||
- a path to a `directory` containing a configuration file saved using the
|
||||
:func:`~transformers.PreTrainedTokenizer.save_pretrained` method, e.g., ``./my_model_directory/``.
|
||||
|
||||
cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`):
|
||||
Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
|
||||
cache should not be used.
|
||||
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to force to (re-)download the configuration files and override the cached versions if they
|
||||
exist.
|
||||
resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
|
||||
proxies (:obj:`Dict[str, str]`, `optional`):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
|
||||
use_auth_token (:obj:`str` or `bool`, `optional`):
|
||||
The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token
|
||||
generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`).
|
||||
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
|
||||
identifier allowed by git.
|
||||
local_files_only (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
If :obj:`True`, will only try to load the tokenizer configuration from local files.
|
||||
|
||||
.. note::
|
||||
|
||||
Passing :obj:`use_auth_token=True` is required when you want to use a private model.
|
||||
|
||||
|
||||
Returns:
|
||||
:obj:`Dict`: The configuration of the tokenizer.
|
||||
|
||||
Examples::
|
||||
|
||||
# Download configuration from huggingface.co and cache.
|
||||
tokenizer_config = get_tokenizer_config("bert-base-uncased")
|
||||
# This model does not have a tokenizer config so the result will be an empty dict.
|
||||
tokenizer_config = get_tokenizer_config("xlm-roberta-base")
|
||||
|
||||
# Save a pretrained tokenizer locally and you can reload its config
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
|
||||
tokenizer.save_pretrained("tokenizer-test")
|
||||
tokenizer_config = get_tokenizer_config("tokenizer-test")
|
||||
"""
|
||||
if is_offline_mode() and not local_files_only:
|
||||
logger.info("Offline mode: forcing local_files_only=True")
|
||||
local_files_only = True
|
||||
|
||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||
if os.path.isdir(pretrained_model_name_or_path):
|
||||
config_file = os.path.join(pretrained_model_name_or_path, TOKENIZER_CONFIG_FILE)
|
||||
else:
|
||||
config_file = hf_bucket_url(
|
||||
pretrained_model_name_or_path, filename=TOKENIZER_CONFIG_FILE, revision=revision, mirror=None
|
||||
)
|
||||
|
||||
try:
|
||||
# Load from URL or cache if already cached
|
||||
resolved_config_file = cached_path(
|
||||
config_file,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
)
|
||||
|
||||
except EnvironmentError:
|
||||
logger.info("Could not locate the tokenizer configuration file, will try to use the model config instead.")
|
||||
return {}
|
||||
|
||||
with open(resolved_config_file, encoding="utf-8") as reader:
|
||||
return json.load(reader)
|
||||
|
||||
|
||||
class AutoTokenizer:
|
||||
r"""
|
||||
This is a generic tokenizer class that will be instantiated as one of the tokenizer classes of the library when
|
||||
@@ -408,18 +516,27 @@ class AutoTokenizer:
|
||||
"""
|
||||
config = kwargs.pop("config", None)
|
||||
kwargs["_from_auto"] = True
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
use_fast = kwargs.pop("use_fast", True)
|
||||
|
||||
if config.tokenizer_class is not None:
|
||||
# First, let's try to use the tokenizer_config file to get the tokenizer class.
|
||||
tokenizer_config = get_tokenizer_config(pretrained_model_name_or_path, **kwargs)
|
||||
config_tokenizer_class = tokenizer_config.get("tokenizer_class")
|
||||
|
||||
# If that did not work, let's try to use the config.
|
||||
if config_tokenizer_class is None:
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
config_tokenizer_class = config.tokenizer_class
|
||||
|
||||
# If we have the tokenizer class from the tokenizer config or the model config we're good!
|
||||
if config_tokenizer_class is not None:
|
||||
tokenizer_class = None
|
||||
if use_fast and not config.tokenizer_class.endswith("Fast"):
|
||||
tokenizer_class_candidate = f"{config.tokenizer_class}Fast"
|
||||
if use_fast and not config_tokenizer_class.endswith("Fast"):
|
||||
tokenizer_class_candidate = f"{config_tokenizer_class}Fast"
|
||||
tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate)
|
||||
if tokenizer_class is None:
|
||||
tokenizer_class_candidate = config.tokenizer_class
|
||||
tokenizer_class_candidate = config_tokenizer_class
|
||||
tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate)
|
||||
|
||||
if tokenizer_class is None:
|
||||
@@ -428,6 +545,7 @@ class AutoTokenizer:
|
||||
)
|
||||
return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
|
||||
# Otherwise we have to be creative.
|
||||
# if model is an encoder decoder, the encoder tokenizer class is used by default
|
||||
if isinstance(config, EncoderDecoderConfig):
|
||||
if type(config.decoder) is not type(config.encoder): # noqa: E721
|
||||
|
||||
@@ -1745,6 +1745,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
if tokenizer_config_file is not None:
|
||||
with open(tokenizer_config_file, encoding="utf-8") as tokenizer_config_handle:
|
||||
init_kwargs = json.load(tokenizer_config_handle)
|
||||
init_kwargs.pop("tokenizer_class", None)
|
||||
saved_init_inputs = init_kwargs.pop("init_inputs", ())
|
||||
if not init_inputs:
|
||||
init_inputs = saved_init_inputs
|
||||
@@ -1920,6 +1921,14 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
|
||||
# add_type_field=True to allow dicts in the kwargs / differentiate from AddedToken serialization
|
||||
tokenizer_config = convert_added_tokens(tokenizer_config, add_type_field=True)
|
||||
|
||||
# Add tokenizer class to the tokenizer config to be able to reload it with from_pretrained
|
||||
tokenizer_class = self.__class__.__name__
|
||||
# Remove the Fast at the end unless we have a special `PreTrainedTokenizerFast`
|
||||
if tokenizer_class.endswith("Fast") and tokenizer_class != "PreTrainedTokenizerFast":
|
||||
tokenizer_class = tokenizer_class[:-4]
|
||||
tokenizer_config["tokenizer_class"] = tokenizer_class
|
||||
|
||||
with open(tokenizer_config_file, "w", encoding="utf-8") as f:
|
||||
f.write(json.dumps(tokenizer_config, ensure_ascii=False))
|
||||
logger.info(f"tokenizer config file saved in {tokenizer_config_file}")
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import (
|
||||
@@ -29,7 +29,7 @@ from transformers import (
|
||||
RobertaTokenizerFast,
|
||||
)
|
||||
from transformers.models.auto.configuration_auto import AutoConfig
|
||||
from transformers.models.auto.tokenization_auto import TOKENIZER_MAPPING
|
||||
from transformers.models.auto.tokenization_auto import TOKENIZER_MAPPING, get_tokenizer_config
|
||||
from transformers.models.roberta.configuration_roberta import RobertaConfig
|
||||
from transformers.testing_utils import (
|
||||
DUMMY_DIFF_TOKENIZER_IDENTIFIER,
|
||||
@@ -129,3 +129,34 @@ class AutoTokenizerTest(unittest.TestCase):
|
||||
self.assertEqual(tokenizer.vocab_size, 30000)
|
||||
self.assertEqual(tokenizer.unk_token, "[UNK]")
|
||||
self.assertEqual(tokenizer.padding_side, "right")
|
||||
|
||||
def test_auto_tokenizer_from_local_folder(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained(SMALL_MODEL_IDENTIFIER)
|
||||
self.assertIsInstance(tokenizer, (BertTokenizer, BertTokenizerFast))
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tokenizer.save_pretrained(tmp_dir)
|
||||
tokenizer2 = AutoTokenizer.from_pretrained(tmp_dir)
|
||||
|
||||
self.assertIsInstance(tokenizer2, tokenizer.__class__)
|
||||
self.assertEqual(tokenizer2.vocab_size, 12)
|
||||
|
||||
def test_get_tokenizer_config(self):
|
||||
# Check we can load the tokenizer config of an online model.
|
||||
config = get_tokenizer_config("bert-base-cased")
|
||||
# If we ever update bert-base-cased tokenizer config, this dict here will need to be updated.
|
||||
self.assertEqual(config, {"do_lower_case": False})
|
||||
|
||||
# This model does not have a tokenizer_config so we get back an empty dict.
|
||||
config = get_tokenizer_config(SMALL_MODEL_IDENTIFIER)
|
||||
self.assertDictEqual(config, {})
|
||||
|
||||
# A tokenizer saved with `save_pretrained` always creates a tokenizer config.
|
||||
tokenizer = AutoTokenizer.from_pretrained(SMALL_MODEL_IDENTIFIER)
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tokenizer.save_pretrained(tmp_dir)
|
||||
config = get_tokenizer_config(tmp_dir)
|
||||
|
||||
# Check the class of the tokenizer was properly saved (note that it always saves the slow class).
|
||||
self.assertEqual(config["tokenizer_class"], "BertTokenizer")
|
||||
# Check other keys just to make sure the config was properly saved /reloaded.
|
||||
self.assertEqual(config["name_or_path"], SMALL_MODEL_IDENTIFIER)
|
||||
|
||||
Reference in New Issue
Block a user