Easily train a new fast tokenizer from a given one (#12361)
* [WIP] Easily train a new fast tokenizer from a given one * Fix test * Roll out to other tokenizers and add tests * Fix bug with unk id and add emoji to test * Really use something different in test * Implement special tokens map * Map special tokens in the Transformers tokenizers * Fix test * Make test more robust * Fix test for BPE * More robust map and test Co-authored-by SaulLu * Test file * Stronger tests Co-authored-by: SaulLu <lucilesaul.com@gmail.com> * Map unk token for Wordpiece and address review comment * Fix lowercase test and address review comment * Fix all tests * Simplify test * Fix tests for realsies * Easily train a new fast tokenizer from a given one - tackle the special tokens format (str or AddedToken) (#12420) * Propose change in tests regarding lower case * add new test for special tokens types * put back the test part about decoding * add feature: the AddedToken is re-build with the different mapped content * Address review comment: simplify AddedToken building Co-authored-by: sgugger <sylvain.gugger@gmail.com> * Update src/transformers/tokenization_utils_fast.py Co-authored-by: sgugger <sylvain.gugger@gmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: SaulLu <lucilesaul.com@gmail.com> Co-authored-by: SaulLu <55560583+SaulLu@users.noreply.github.com>
This commit is contained in:
@@ -121,7 +121,7 @@ class AlbertTokenizerFast(PreTrainedTokenizerFast):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file,
|
||||
vocab_file=None,
|
||||
tokenizer_file=None,
|
||||
do_lower_case=True,
|
||||
remove_space=True,
|
||||
|
||||
@@ -109,7 +109,7 @@ class BarthezTokenizerFast(PreTrainedTokenizerFast):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file,
|
||||
vocab_file=None,
|
||||
tokenizer_file=None,
|
||||
bos_token="<s>",
|
||||
eos_token="</s>",
|
||||
|
||||
@@ -162,7 +162,7 @@ class BertTokenizerFast(PreTrainedTokenizerFast):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file,
|
||||
vocab_file=None,
|
||||
tokenizer_file=None,
|
||||
do_lower_case=True,
|
||||
unk_token="[UNK]",
|
||||
|
||||
@@ -103,7 +103,7 @@ class BigBirdTokenizerFast(PreTrainedTokenizerFast):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file,
|
||||
vocab_file=None,
|
||||
tokenizer_file=None,
|
||||
unk_token="<unk>",
|
||||
bos_token="<s>",
|
||||
|
||||
@@ -63,8 +63,8 @@ class BlenderbotSmallTokenizerFast(PreTrainedTokenizerFast):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file,
|
||||
merges_file,
|
||||
vocab_file=None,
|
||||
merges_file=None,
|
||||
unk_token="<|endoftext|>",
|
||||
bos_token="<|endoftext|>",
|
||||
eos_token="<|endoftext|>",
|
||||
|
||||
@@ -105,7 +105,7 @@ class CamembertTokenizerFast(PreTrainedTokenizerFast):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file,
|
||||
vocab_file=None,
|
||||
tokenizer_file=None,
|
||||
bos_token="<s>",
|
||||
eos_token="</s>",
|
||||
|
||||
@@ -105,8 +105,8 @@ class CLIPTokenizerFast(PreTrainedTokenizerFast):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file,
|
||||
merges_file,
|
||||
vocab_file=None,
|
||||
merges_file=None,
|
||||
tokenizer_file=None,
|
||||
unk_token="<|endoftext|>",
|
||||
bos_token="<|startoftext|>",
|
||||
|
||||
@@ -95,8 +95,8 @@ class DebertaTokenizerFast(GPT2TokenizerFast):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file,
|
||||
merges_file,
|
||||
vocab_file=None,
|
||||
merges_file=None,
|
||||
tokenizer_file=None,
|
||||
errors="replace",
|
||||
bos_token="[CLS]",
|
||||
|
||||
@@ -88,7 +88,7 @@ class FunnelTokenizerFast(BertTokenizerFast):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file,
|
||||
vocab_file=None,
|
||||
tokenizer_file=None,
|
||||
do_lower_case=True,
|
||||
unk_token="<unk>",
|
||||
|
||||
@@ -125,8 +125,8 @@ class GPT2TokenizerFast(PreTrainedTokenizerFast):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file,
|
||||
merges_file,
|
||||
vocab_file=None,
|
||||
merges_file=None,
|
||||
tokenizer_file=None,
|
||||
unk_token="<|endoftext|>",
|
||||
bos_token="<|endoftext|>",
|
||||
|
||||
@@ -67,8 +67,8 @@ class HerbertTokenizerFast(PreTrainedTokenizerFast):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file,
|
||||
merges_file,
|
||||
vocab_file=None,
|
||||
merges_file=None,
|
||||
tokenizer_file=None,
|
||||
cls_token="<s>",
|
||||
unk_token="<unk>",
|
||||
|
||||
@@ -121,7 +121,10 @@ class MBartTokenizer(XLMRobertaTokenizer):
|
||||
self._additional_special_tokens = list(self.lang_code_to_id.keys())
|
||||
|
||||
if additional_special_tokens is not None:
|
||||
self._additional_special_tokens.extend(additional_special_tokens)
|
||||
# Only add those special tokens if they are not already there.
|
||||
self._additional_special_tokens.extend(
|
||||
[t for t in additional_special_tokens if t not in self._additional_special_tokens]
|
||||
)
|
||||
|
||||
self._src_lang = src_lang if src_lang is not None else "en_XX"
|
||||
self.cur_lang_code_id = self.lang_code_to_id[self._src_lang]
|
||||
|
||||
@@ -110,7 +110,7 @@ class MBart50TokenizerFast(PreTrainedTokenizerFast):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file,
|
||||
vocab_file=None,
|
||||
src_lang=None,
|
||||
tgt_lang=None,
|
||||
tokenizer_file=None,
|
||||
|
||||
@@ -113,10 +113,16 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast):
|
||||
suffix_tokens: List[int] = []
|
||||
|
||||
def __init__(
|
||||
self, *args, tokenizer_file=None, src_lang=None, tgt_lang=None, additional_special_tokens=None, **kwargs
|
||||
self,
|
||||
vocab_file=None,
|
||||
tokenizer_file=None,
|
||||
src_lang=None,
|
||||
tgt_lang=None,
|
||||
additional_special_tokens=None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
*args,
|
||||
vocab_file=vocab_file,
|
||||
tokenizer_file=tokenizer_file,
|
||||
src_lang=src_lang,
|
||||
tgt_lang=tgt_lang,
|
||||
@@ -127,7 +133,10 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast):
|
||||
_additional_special_tokens = FAIRSEQ_LANGUAGE_CODES.copy()
|
||||
|
||||
if additional_special_tokens is not None:
|
||||
_additional_special_tokens.extend(additional_special_tokens)
|
||||
# Only add those special tokens if they are not already there.
|
||||
_additional_special_tokens.extend(
|
||||
[t for t in additional_special_tokens if t not in _additional_special_tokens]
|
||||
)
|
||||
|
||||
self.add_special_tokens({"additional_special_tokens": _additional_special_tokens})
|
||||
|
||||
|
||||
@@ -106,7 +106,7 @@ class MPNetTokenizerFast(PreTrainedTokenizerFast):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file,
|
||||
vocab_file=None,
|
||||
tokenizer_file=None,
|
||||
do_lower_case=True,
|
||||
bos_token="<s>",
|
||||
|
||||
@@ -64,7 +64,7 @@ class OpenAIGPTTokenizerFast(PreTrainedTokenizerFast):
|
||||
model_input_names = ["input_ids", "attention_mask"]
|
||||
slow_tokenizer_class = OpenAIGPTTokenizer
|
||||
|
||||
def __init__(self, vocab_file, merges_file, tokenizer_file=None, unk_token="<unk>", **kwargs):
|
||||
def __init__(self, vocab_file=None, merges_file=None, tokenizer_file=None, unk_token="<unk>", **kwargs):
|
||||
super().__init__(vocab_file, merges_file, tokenizer_file=tokenizer_file, unk_token=unk_token, **kwargs)
|
||||
|
||||
@property
|
||||
|
||||
@@ -98,7 +98,7 @@ class PegasusTokenizerFast(PreTrainedTokenizerFast):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file,
|
||||
vocab_file=None,
|
||||
tokenizer_file=None,
|
||||
pad_token="<pad>",
|
||||
eos_token="</s>",
|
||||
|
||||
@@ -87,7 +87,7 @@ class ReformerTokenizerFast(PreTrainedTokenizerFast):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file,
|
||||
vocab_file=None,
|
||||
tokenizer_file=None,
|
||||
eos_token="</s>",
|
||||
unk_token="<unk>",
|
||||
|
||||
@@ -143,8 +143,8 @@ class RobertaTokenizerFast(GPT2TokenizerFast):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file,
|
||||
merges_file,
|
||||
vocab_file=None,
|
||||
merges_file=None,
|
||||
tokenizer_file=None,
|
||||
errors="replace",
|
||||
bos_token="<s>",
|
||||
|
||||
@@ -73,7 +73,7 @@ class RoFormerTokenizerFast(PreTrainedTokenizerFast):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file,
|
||||
vocab_file=None,
|
||||
tokenizer_file=None,
|
||||
do_lower_case=True,
|
||||
unk_token="[UNK]",
|
||||
|
||||
@@ -104,7 +104,7 @@ class T5TokenizerFast(PreTrainedTokenizerFast):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file,
|
||||
vocab_file=None,
|
||||
tokenizer_file=None,
|
||||
eos_token="</s>",
|
||||
unk_token="<unk>",
|
||||
|
||||
@@ -117,7 +117,7 @@ class XLMRobertaTokenizerFast(PreTrainedTokenizerFast):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file,
|
||||
vocab_file=None,
|
||||
tokenizer_file=None,
|
||||
bos_token="<s>",
|
||||
eos_token="</s>",
|
||||
|
||||
@@ -124,7 +124,7 @@ class XLNetTokenizerFast(PreTrainedTokenizerFast):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file,
|
||||
vocab_file=None,
|
||||
tokenizer_file=None,
|
||||
do_lower_case=False,
|
||||
remove_space=True,
|
||||
|
||||
@@ -16,7 +16,6 @@
|
||||
Tokenization classes for fast tokenizers (provided by HuggingFace's tokenizers library). For slow (python) tokenizers
|
||||
see tokenization_utils.py
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from collections import defaultdict
|
||||
@@ -25,6 +24,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from tokenizers import Encoding as EncodingFast
|
||||
from tokenizers import Tokenizer as TokenizerFast
|
||||
from tokenizers.decoders import Decoder as DecoderFast
|
||||
from tokenizers.trainers import BpeTrainer, UnigramTrainer, WordLevelTrainer, WordPieceTrainer
|
||||
|
||||
from .convert_slow_tokenizer import convert_slow_tokenizer
|
||||
from .file_utils import PaddingStrategy, add_end_docstrings
|
||||
@@ -36,6 +36,7 @@ from .tokenization_utils_base import (
|
||||
PreTokenizedInput,
|
||||
PreTokenizedInputPair,
|
||||
PreTrainedTokenizerBase,
|
||||
SpecialTokensMixin,
|
||||
TextInput,
|
||||
TextInputPair,
|
||||
TruncationStrategy,
|
||||
@@ -60,6 +61,13 @@ INIT_TOKENIZER_DOCSTRING += """
|
||||
from 🤗 tokenizers <../fast_tokenizers>` for more information.
|
||||
"""
|
||||
|
||||
MODEL_TO_TRAINER_MAPPING = {
|
||||
"BPE": BpeTrainer,
|
||||
"Unigram": UnigramTrainer,
|
||||
"WordLevel": WordLevelTrainer,
|
||||
"WordPiece": WordPieceTrainer,
|
||||
}
|
||||
|
||||
|
||||
@add_end_docstrings(INIT_TOKENIZER_DOCSTRING)
|
||||
class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
|
||||
@@ -555,3 +563,162 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
|
||||
file_names = file_names + (tokenizer_file,)
|
||||
|
||||
return file_names
|
||||
|
||||
def train_new_from_iterator(
|
||||
self, text_iterator, vocab_size, new_special_tokens=None, special_tokens_map=None, **kwargs
|
||||
):
|
||||
"""
|
||||
Trains a tokenizer on a new corpus with the same defaults (in terms of special tokens or tokenization pipeline)
|
||||
as the current one.
|
||||
|
||||
Args:
|
||||
text_iterator (generator of :obj:`List[str]`):
|
||||
The training corpus. Should be a generator of batches of texts, for instance a list of lists of texts
|
||||
if you have everything in memory.
|
||||
vocab_size (obj:`int`):
|
||||
The size of the vocabulary you want for your tokenizer.
|
||||
new_special_tokens (list of :obj:`str` or :obj:`AddedToken`, `optional`):
|
||||
A list of new special tokens to add to the tokenizer you are training.
|
||||
special_tokens_map (:obj:`Dict[str, str]`, `optional`):
|
||||
If you want to rename some of the special tokens this tokenizer uses, pass along a mapping old special
|
||||
token name to new special token name in this argument.
|
||||
kwargs:
|
||||
Additional keyword arguments passed along to the trainer from the 🤗 Tokenizers library.
|
||||
|
||||
Returns:
|
||||
:class:`~transformers.PreTrainedTokenizerFast`: A new tokenizer of the same type as the original one,
|
||||
trained on :obj:`text_iterator`.
|
||||
|
||||
"""
|
||||
tokenizer_json = json.loads(self._tokenizer.to_str())
|
||||
# Remove added tokens for now (uses IDs of tokens)
|
||||
added_tokens = tokenizer_json.pop("added_tokens")
|
||||
# Remove post processor for now (uses IDs of tokens)
|
||||
post_processor = tokenizer_json.pop("post_processor")
|
||||
|
||||
unk_token = None
|
||||
# Remove vocab
|
||||
if tokenizer_json["model"]["type"] == "BPE":
|
||||
tokenizer_json["model"]["vocab"] = {}
|
||||
tokenizer_json["model"]["merges"] = []
|
||||
elif tokenizer_json["model"]["type"] == "Unigram":
|
||||
if tokenizer_json["model"]["unk_id"] is not None:
|
||||
unk_id = tokenizer_json["model"]["unk_id"]
|
||||
unk_token = tokenizer_json["model"]["vocab"][unk_id][0]
|
||||
if special_tokens_map is not None and unk_token in special_tokens_map:
|
||||
unk_token = special_tokens_map[unk_token]
|
||||
tokenizer_json["model"]["unk_id"] = 0
|
||||
tokenizer_json["model"]["vocab"] = [[unk_token, 0.0]]
|
||||
elif tokenizer_json["model"]["type"] in ["WordLevel", "WordPiece"]:
|
||||
tokenizer_json["model"]["vocab"] = {}
|
||||
else:
|
||||
raise ValueError(
|
||||
f"This method does not support this type of tokenizer (found {tokenizer_json['model']['type']}) "
|
||||
"only BPE, Unigram, WordLevel and WordPiece."
|
||||
)
|
||||
|
||||
if (
|
||||
special_tokens_map is not None
|
||||
and "unk_token" in tokenizer_json["model"]
|
||||
and tokenizer_json["model"]["unk_token"] in special_tokens_map
|
||||
):
|
||||
tokenizer_json["model"]["unk_token"] = special_tokens_map[tokenizer_json["model"]["unk_token"]]
|
||||
|
||||
tokenizer = TokenizerFast.from_str(json.dumps(tokenizer_json))
|
||||
|
||||
# Get the special tokens from the current tokenizer if none are specified.
|
||||
special_tokens = []
|
||||
for added_token in added_tokens:
|
||||
special = added_token.pop("special", None)
|
||||
_ = added_token.pop("id", None)
|
||||
if tokenizer_json["model"]["type"] != "Unigram" and not special:
|
||||
continue
|
||||
if special_tokens_map is not None and added_token["content"] in special_tokens_map:
|
||||
added_token["content"] = special_tokens_map[added_token["content"]]
|
||||
special_tokens.append(AddedToken(**added_token))
|
||||
|
||||
if new_special_tokens is not None:
|
||||
special_tokens.extend(new_special_tokens)
|
||||
|
||||
# Trainer needs to know the end of word / continuing subword thingies in BPE
|
||||
if (
|
||||
tokenizer_json["model"]["type"] == "BPE"
|
||||
and "continuing_subword_prefix" not in kwargs
|
||||
and tokenizer_json["model"]["continuing_subword_prefix"] is not None
|
||||
):
|
||||
kwargs["continuing_subword_prefix"] = tokenizer_json["model"]["continuing_subword_prefix"]
|
||||
if (
|
||||
tokenizer_json["model"]["type"] == "BPE"
|
||||
and "end_of_work_suffix" not in kwargs
|
||||
and tokenizer_json["model"]["end_of_word_suffix"] is not None
|
||||
):
|
||||
kwargs["end_of_word_suffix"] = tokenizer_json["model"]["end_of_word_suffix"]
|
||||
|
||||
trainer_class = MODEL_TO_TRAINER_MAPPING[tokenizer_json["model"]["type"]]
|
||||
trainer = trainer_class(vocab_size=vocab_size, special_tokens=special_tokens, **kwargs)
|
||||
tokenizer.train_from_iterator(text_iterator, trainer=trainer)
|
||||
|
||||
if unk_token is not None:
|
||||
# For Unigram tokenizers we need to set back the unk id of the model (bug in Tokenizers?)
|
||||
trained_tokenizer_json = json.loads(tokenizer.to_str())
|
||||
vocab = trained_tokenizer_json["model"]["vocab"]
|
||||
unk_id = 0
|
||||
while unk_id < len(vocab) and vocab[unk_id][0] != unk_token:
|
||||
unk_id += 1
|
||||
if unk_id < len(vocab):
|
||||
trained_tokenizer_json["model"]["unk_id"] = unk_id
|
||||
tokenizer = TokenizerFast.from_str(json.dumps(trained_tokenizer_json))
|
||||
|
||||
if post_processor is not None:
|
||||
trained_tokenizer_json = json.loads(tokenizer.to_str())
|
||||
# Almost done, we just have to adjust the token IDs in the post processor
|
||||
if "special_tokens" in post_processor:
|
||||
for key in post_processor["special_tokens"]:
|
||||
tokens = post_processor["special_tokens"][key]["tokens"]
|
||||
if special_tokens_map is not None:
|
||||
tokens = [special_tokens_map.get(token, token) for token in tokens]
|
||||
post_processor["special_tokens"][key]["tokens"] = tokens
|
||||
post_processor["special_tokens"][key]["ids"] = [tokenizer.token_to_id(token) for token in tokens]
|
||||
|
||||
for special_token in ["cls", "sep"]:
|
||||
if special_token in post_processor:
|
||||
token, _ = post_processor[special_token]
|
||||
if special_tokens_map is not None and token in special_tokens_map:
|
||||
token = special_tokens_map[token]
|
||||
token_id = tokenizer.token_to_id(token)
|
||||
post_processor[special_token] = [token, token_id]
|
||||
|
||||
trained_tokenizer_json["post_processor"] = post_processor
|
||||
tokenizer = TokenizerFast.from_str(json.dumps(trained_tokenizer_json))
|
||||
|
||||
kwargs = self.init_kwargs.copy()
|
||||
# Map pad/cls/mask token at the Transformers level
|
||||
special_tokens_list = SpecialTokensMixin.SPECIAL_TOKENS_ATTRIBUTES.copy()
|
||||
special_tokens_list.remove("additional_special_tokens")
|
||||
for token in special_tokens_list:
|
||||
# Get the private one to avoid unnecessary warnings.
|
||||
if getattr(self, f"_{token}") is not None:
|
||||
special_token = getattr(self, token)
|
||||
if special_tokens_map is not None and special_token in special_tokens_map:
|
||||
special_token = special_tokens_map[special_token]
|
||||
|
||||
special_token_full = getattr(self, f"_{token}")
|
||||
if isinstance(special_token_full, AddedToken):
|
||||
# Create an added token with the same paramters except the content
|
||||
kwargs[token] = AddedToken(
|
||||
special_token,
|
||||
single_word=special_token_full.single_word,
|
||||
lstrip=special_token_full.lstrip,
|
||||
rstrip=special_token_full.rstrip,
|
||||
normalized=special_token_full.normalized,
|
||||
)
|
||||
else:
|
||||
kwargs[token] = special_token
|
||||
|
||||
additional_special_tokens = self.additional_special_tokens
|
||||
if new_special_tokens is not None:
|
||||
additional_special_tokens.extend(new_special_tokens)
|
||||
if len(additional_special_tokens) > 0:
|
||||
kwargs["additional_special_tokens"] = additional_special_tokens
|
||||
|
||||
return self.__class__(tokenizer_object=tokenizer, **kwargs)
|
||||
|
||||
@@ -33,6 +33,7 @@ from transformers import (
|
||||
PreTrainedTokenizer,
|
||||
PreTrainedTokenizerBase,
|
||||
PreTrainedTokenizerFast,
|
||||
SpecialTokensMixin,
|
||||
is_tf_available,
|
||||
is_torch_available,
|
||||
)
|
||||
@@ -57,6 +58,11 @@ if TYPE_CHECKING:
|
||||
|
||||
NON_ENGLISH_TAGS = ["chinese", "dutch", "french", "finnish", "german", "multilingual"]
|
||||
|
||||
SMALL_TRAINING_CORPUS = [
|
||||
["This is the first sentence.", "This is the second one."],
|
||||
["This sentence (contains #) over symbols and numbers 12 3.", "But not this one."],
|
||||
]
|
||||
|
||||
|
||||
def filter_non_english(_, pretrained_name: str):
|
||||
"""Filter all the model for non-english language"""
|
||||
@@ -390,7 +396,11 @@ class TokenizerTesterMixin:
|
||||
tokenizer = self.get_rust_tokenizer()
|
||||
|
||||
for parameter_name, parameter in signature.parameters.items():
|
||||
if parameter.default != inspect.Parameter.empty and parameter_name != "tokenizer_file":
|
||||
if parameter.default != inspect.Parameter.empty and parameter_name not in [
|
||||
"vocab_file",
|
||||
"merges_file",
|
||||
"tokenizer_file",
|
||||
]:
|
||||
self.assertIn(parameter_name, tokenizer.init_kwargs)
|
||||
|
||||
def test_rust_and_python_full_tokenizers(self):
|
||||
@@ -3144,6 +3154,146 @@ class TokenizerTesterMixin:
|
||||
self.assertTrue(special_token_id in p_output)
|
||||
self.assertTrue(special_token_id in cr_output)
|
||||
|
||||
def test_training_new_tokenizer(self):
|
||||
# This feature only exists for fast tokenizers
|
||||
if not self.test_rust_tokenizer:
|
||||
return
|
||||
|
||||
tokenizer = self.get_rust_tokenizer()
|
||||
new_tokenizer = tokenizer.train_new_from_iterator(SMALL_TRAINING_CORPUS, 100)
|
||||
|
||||
# Test we can use the new tokenizer with something not seen during training
|
||||
inputs = new_tokenizer(["This is the first sentence", "This sentence is different 🤗."])
|
||||
self.assertEqual(len(inputs["input_ids"]), 2)
|
||||
decoded_input = new_tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True)
|
||||
expected_result = "This is the first sentence"
|
||||
|
||||
# OpenAIGPT always lowercases and has no arg.
|
||||
if new_tokenizer.init_kwargs.get("do_lower_case", False) or tokenizer.__class__.__name__.startswith(
|
||||
"OpenAIGPT"
|
||||
):
|
||||
expected_result = expected_result.lower()
|
||||
self.assertEqual(expected_result, decoded_input)
|
||||
|
||||
# We check that the parameters of the tokenizer remained the same
|
||||
# Check we have the same number of added_tokens for both pair and non-pair inputs.
|
||||
self.assertEqual(tokenizer.num_special_tokens_to_add(False), new_tokenizer.num_special_tokens_to_add(False))
|
||||
self.assertEqual(tokenizer.num_special_tokens_to_add(True), new_tokenizer.num_special_tokens_to_add(True))
|
||||
|
||||
# Check we have the correct max_length for both pair and non-pair inputs.
|
||||
self.assertEqual(tokenizer.max_len_single_sentence, new_tokenizer.max_len_single_sentence)
|
||||
self.assertEqual(tokenizer.max_len_sentences_pair, new_tokenizer.max_len_sentences_pair)
|
||||
|
||||
# Assert the set of special tokens match as we didn't ask to change them
|
||||
self.assertSequenceEqual(
|
||||
tokenizer.all_special_tokens_extended,
|
||||
new_tokenizer.all_special_tokens_extended,
|
||||
)
|
||||
|
||||
self.assertDictEqual(tokenizer.special_tokens_map, new_tokenizer.special_tokens_map)
|
||||
|
||||
def test_training_new_tokenizer_with_special_tokens_change(self):
|
||||
# This feature only exists for fast tokenizers
|
||||
if not self.test_rust_tokenizer:
|
||||
return
|
||||
|
||||
tokenizer = self.get_rust_tokenizer()
|
||||
# Test with a special tokens map
|
||||
class_signature = inspect.signature(tokenizer.__class__)
|
||||
if "cls_token" in class_signature.parameters:
|
||||
new_tokenizer = tokenizer.train_new_from_iterator(
|
||||
SMALL_TRAINING_CORPUS, 100, special_tokens_map={tokenizer.cls_token: "<cls>"}
|
||||
)
|
||||
cls_id = new_tokenizer.get_vocab()["<cls>"]
|
||||
self.assertEqual(new_tokenizer.cls_token, "<cls>")
|
||||
self.assertEqual(new_tokenizer.cls_token_id, cls_id)
|
||||
|
||||
# Create a new mapping from the special tokens defined in the original tokenizer
|
||||
special_tokens_list = SpecialTokensMixin.SPECIAL_TOKENS_ATTRIBUTES.copy()
|
||||
special_tokens_list.remove("additional_special_tokens")
|
||||
special_tokens_map = {}
|
||||
for token in special_tokens_list:
|
||||
# Get the private one to avoid unnecessary warnings.
|
||||
if getattr(tokenizer, f"_{token}") is not None:
|
||||
special_token = getattr(tokenizer, token)
|
||||
special_tokens_map[special_token] = f"{special_token}a"
|
||||
|
||||
# Train new tokenizer
|
||||
new_tokenizer = tokenizer.train_new_from_iterator(
|
||||
SMALL_TRAINING_CORPUS, 100, special_tokens_map=special_tokens_map
|
||||
)
|
||||
|
||||
# Check the changes
|
||||
for token in special_tokens_list:
|
||||
# Get the private one to avoid unnecessary warnings.
|
||||
if getattr(tokenizer, f"_{token}") is None:
|
||||
continue
|
||||
special_token = getattr(tokenizer, token)
|
||||
if special_token in special_tokens_map:
|
||||
new_special_token = getattr(new_tokenizer, token)
|
||||
self.assertEqual(special_tokens_map[special_token], new_special_token)
|
||||
|
||||
new_id = new_tokenizer.get_vocab()[new_special_token]
|
||||
self.assertEqual(getattr(new_tokenizer, f"{token}_id"), new_id)
|
||||
|
||||
# Check if the AddedToken / string format has been kept
|
||||
for special_token in tokenizer.all_special_tokens_extended:
|
||||
if isinstance(special_token, AddedToken) and special_token.content not in special_tokens_map:
|
||||
# The special token must appear identically in the list of the new tokenizer.
|
||||
self.assertTrue(
|
||||
special_token in new_tokenizer.all_special_tokens_extended,
|
||||
f"'{special_token}' should be in {new_tokenizer.all_special_tokens_extended}",
|
||||
)
|
||||
elif isinstance(special_token, AddedToken):
|
||||
# The special token must appear in the list of the new tokenizer as an object of type AddedToken with
|
||||
# the same parameters as the old AddedToken except the content that the user has requested to change.
|
||||
special_token_str = special_token.content
|
||||
new_special_token_str = special_tokens_map[special_token_str]
|
||||
|
||||
find = False
|
||||
for candidate in new_tokenizer.all_special_tokens_extended:
|
||||
if (
|
||||
isinstance(candidate, AddedToken)
|
||||
and candidate.content == new_special_token_str
|
||||
and candidate.lstrip == special_token.lstrip
|
||||
and candidate.rstrip == special_token.rstrip
|
||||
and candidate.normalized == special_token.normalized
|
||||
and candidate.single_word == special_token.single_word
|
||||
):
|
||||
find = True
|
||||
break
|
||||
self.assertTrue(
|
||||
find,
|
||||
(
|
||||
f"'{new_special_token_str}' doesn't appear in the list "
|
||||
f"'{new_tokenizer.all_special_tokens_extended}' as an AddedToken with the same parameters as "
|
||||
f"'{special_token}' in the list {tokenizer.all_special_tokens_extended}"
|
||||
),
|
||||
)
|
||||
elif special_token not in special_tokens_map:
|
||||
# The special token must appear identically in the list of the new tokenizer.
|
||||
self.assertTrue(
|
||||
special_token in new_tokenizer.all_special_tokens_extended,
|
||||
f"'{special_token}' should be in {new_tokenizer.all_special_tokens_extended}",
|
||||
)
|
||||
|
||||
else:
|
||||
# The special token must appear in the list of the new tokenizer as an object of type string.
|
||||
self.assertTrue(special_tokens_map[special_token] in new_tokenizer.all_special_tokens_extended)
|
||||
|
||||
# Test we can use the new tokenizer with something not seen during training
|
||||
inputs = new_tokenizer(["This is the first sentence", "This sentence is different 🤗."])
|
||||
self.assertEqual(len(inputs["input_ids"]), 2)
|
||||
decoded_input = new_tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True)
|
||||
expected_result = "This is the first sentence"
|
||||
|
||||
# OpenAIGPT always lowercases and has no arg.
|
||||
if new_tokenizer.init_kwargs.get("do_lower_case", False) or tokenizer.__class__.__name__.startswith(
|
||||
"OpenAIGPT"
|
||||
):
|
||||
expected_result = expected_result.lower()
|
||||
self.assertEqual(expected_result, decoded_input)
|
||||
|
||||
|
||||
@is_staging_test
|
||||
class TokenizerPushToHubTester(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user