From 79eb391586193b86c7f11bb5cf66effe0e446395 Mon Sep 17 00:00:00 2001 From: Thomas Wolf Date: Mon, 26 Oct 2020 10:27:48 +0100 Subject: [PATCH] [tokenizers] Fixing #8001 - Adding tests on tokenizers serialization (#8006) * fixing #8001 * make T5 tokenizer serialization more robust - style --- src/transformers/tokenization_albert.py | 3 +++ src/transformers/tokenization_bert.py | 5 +++++ src/transformers/tokenization_bertweet.py | 3 ++- src/transformers/tokenization_deberta.py | 6 ++---- src/transformers/tokenization_fsmt.py | 3 +++ src/transformers/tokenization_gpt2.py | 9 ++++++++- src/transformers/tokenization_marian.py | 6 ++++-- src/transformers/tokenization_prophetnet.py | 7 ++++++- src/transformers/tokenization_t5.py | 15 +++++++++++---- src/transformers/tokenization_t5_fast.py | 19 ++++++++++++------- src/transformers/tokenization_transfo_xl.py | 14 +++++++++++++- src/transformers/tokenization_utils_base.py | 2 +- src/transformers/tokenization_utils_fast.py | 3 +-- src/transformers/tokenization_xlm.py | 3 +++ .../tokenization_xlm_prophetnet.py | 3 ++- src/transformers/tokenization_xlnet.py | 3 +++ tests/test_tokenization_common.py | 19 +++++++++++++++++++ 17 files changed, 98 insertions(+), 25 deletions(-) diff --git a/src/transformers/tokenization_albert.py b/src/transformers/tokenization_albert.py index a0e00baf25..26c5a9ffac 100644 --- a/src/transformers/tokenization_albert.py +++ b/src/transformers/tokenization_albert.py @@ -129,6 +129,9 @@ class AlbertTokenizer(PreTrainedTokenizer): **kwargs ): super().__init__( + do_lower_case=do_lower_case, + remove_space=remove_space, + keep_accents=keep_accents, bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, diff --git a/src/transformers/tokenization_bert.py b/src/transformers/tokenization_bert.py index dbcc117dd7..5bc81cb9d8 100644 --- a/src/transformers/tokenization_bert.py +++ b/src/transformers/tokenization_bert.py @@ -178,11 +178,16 @@ class BertTokenizer(PreTrainedTokenizer): **kwargs ): super().__init__( + do_lower_case=do_lower_case, + do_basic_tokenize=do_basic_tokenize, + never_split=never_split, unk_token=unk_token, sep_token=sep_token, pad_token=pad_token, cls_token=cls_token, mask_token=mask_token, + tokenize_chinese_chars=tokenize_chinese_chars, + strip_accents=strip_accents, **kwargs, ) diff --git a/src/transformers/tokenization_bertweet.py b/src/transformers/tokenization_bertweet.py index d846cb6c26..1d28289f1b 100644 --- a/src/transformers/tokenization_bertweet.py +++ b/src/transformers/tokenization_bertweet.py @@ -129,11 +129,12 @@ class BertweetTokenizer(PreTrainedTokenizer): **kwargs ): super().__init__( + normalization=normalization, bos_token=bos_token, eos_token=eos_token, - unk_token=unk_token, sep_token=sep_token, cls_token=cls_token, + unk_token=unk_token, pad_token=pad_token, mask_token=mask_token, **kwargs, diff --git a/src/transformers/tokenization_deberta.py b/src/transformers/tokenization_deberta.py index 130b7da4ba..d4d29bc0cc 100644 --- a/src/transformers/tokenization_deberta.py +++ b/src/transformers/tokenization_deberta.py @@ -308,16 +308,13 @@ class GPT2Tokenizer(object): - We remapped the token ids in our dictionary with regarding to the new special tokens, `[PAD]` => 0, `[CLS]` => 1, `[SEP]` => 2, `[UNK]` => 3, `[MASK]` => 50264 - do_lower_case (:obj:`bool`, optional): - Whether to convert inputs to lower case. **Not used in GPT2 tokenizer**. - special_tokens (:obj:`list`, optional): List of special tokens to be added to the end of the vocabulary. """ - def __init__(self, vocab_file=None, do_lower_case=True, special_tokens=None): + def __init__(self, vocab_file=None, special_tokens=None): self.pad_token = "[PAD]" self.sep_token = "[SEP]" self.unk_token = "[UNK]" @@ -523,6 +520,7 @@ class DebertaTokenizer(PreTrainedTokenizer): **kwargs ): super().__init__( + do_lower_case=do_lower_case, unk_token=unk_token, sep_token=sep_token, pad_token=pad_token, diff --git a/src/transformers/tokenization_fsmt.py b/src/transformers/tokenization_fsmt.py index 7efbb7df6a..9e3bfc289c 100644 --- a/src/transformers/tokenization_fsmt.py +++ b/src/transformers/tokenization_fsmt.py @@ -194,6 +194,9 @@ class FSMTTokenizer(PreTrainedTokenizer): ): super().__init__( langs=langs, + src_vocab_file=src_vocab_file, + tgt_vocab_file=tgt_vocab_file, + merges_file=merges_file, unk_token=unk_token, bos_token=bos_token, sep_token=sep_token, diff --git a/src/transformers/tokenization_gpt2.py b/src/transformers/tokenization_gpt2.py index 96557330a5..b65eba4bb2 100644 --- a/src/transformers/tokenization_gpt2.py +++ b/src/transformers/tokenization_gpt2.py @@ -164,7 +164,14 @@ class GPT2Tokenizer(PreTrainedTokenizer): bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token - super().__init__(bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs) + super().__init__( + errors=errors, + unk_token=unk_token, + bos_token=bos_token, + eos_token=eos_token, + add_prefix_space=add_prefix_space, + **kwargs, + ) with open(vocab_file, encoding="utf-8") as vocab_handle: self.encoder = json.load(vocab_handle) diff --git a/src/transformers/tokenization_marian.py b/src/transformers/tokenization_marian.py index 0db025898b..1819b50f0a 100644 --- a/src/transformers/tokenization_marian.py +++ b/src/transformers/tokenization_marian.py @@ -97,10 +97,12 @@ class MarianTokenizer(PreTrainedTokenizer): ): super().__init__( # bos_token=bos_token, unused. Start decoding with config.decoder_start_token_id - model_max_length=model_max_length, - eos_token=eos_token, + source_lang=source_lang, + target_lang=target_lang, unk_token=unk_token, + eos_token=eos_token, pad_token=pad_token, + model_max_length=model_max_length, **kwargs, ) assert Path(source_spm).exists(), f"cannot find spm source {source_spm}" diff --git a/src/transformers/tokenization_prophetnet.py b/src/transformers/tokenization_prophetnet.py index ca6288a2ed..30c293c19c 100644 --- a/src/transformers/tokenization_prophetnet.py +++ b/src/transformers/tokenization_prophetnet.py @@ -119,11 +119,16 @@ class ProphetNetTokenizer(PreTrainedTokenizer): **kwargs ): super().__init__( + do_lower_case=do_lower_case, + do_basic_tokenize=do_basic_tokenize, + never_split=never_split, unk_token=unk_token, sep_token=sep_token, + x_sep_token=x_sep_token, pad_token=pad_token, mask_token=mask_token, - x_sep_token=x_sep_token, + tokenize_chinese_chars=tokenize_chinese_chars, + strip_accents=strip_accents, **kwargs, ) self.unique_no_split_tokens.append(x_sep_token) diff --git a/src/transformers/tokenization_t5.py b/src/transformers/tokenization_t5.py index 1502d32898..72630dbe54 100644 --- a/src/transformers/tokenization_t5.py +++ b/src/transformers/tokenization_t5.py @@ -112,15 +112,22 @@ class T5Tokenizer(PreTrainedTokenizer): **kwargs ): # Add extra_ids to the special token list - if extra_ids > 0: - if additional_special_tokens is None: - additional_special_tokens = [] - additional_special_tokens.extend(["".format(i) for i in range(extra_ids)]) + if extra_ids > 0 and additional_special_tokens is None: + additional_special_tokens = ["".format(i) for i in range(extra_ids)] + elif extra_ids > 0 and additional_special_tokens is not None: + # Check that we have the right number of extra_id special tokens + extra_tokens = len(set(filter(lambda x: bool("extra_id" in x), additional_special_tokens))) + if extra_tokens != extra_ids: + raise ValueError( + f"Both extra_ids ({extra_ids}) and additional_special_tokens ({additional_special_tokens}) are provided to T5Tokenizer. " + "In this case the additional_special_tokens must include the extra_ids tokens" + ) super().__init__( eos_token=eos_token, unk_token=unk_token, pad_token=pad_token, + extra_ids=extra_ids, additional_special_tokens=additional_special_tokens, **kwargs, ) diff --git a/src/transformers/tokenization_t5_fast.py b/src/transformers/tokenization_t5_fast.py index b972b439ff..3ef613f75a 100644 --- a/src/transformers/tokenization_t5_fast.py +++ b/src/transformers/tokenization_t5_fast.py @@ -126,6 +126,18 @@ class T5TokenizerFast(PreTrainedTokenizerFast): additional_special_tokens=None, **kwargs ): + # Add extra_ids to the special token list + if extra_ids > 0 and additional_special_tokens is None: + additional_special_tokens = ["".format(i) for i in range(extra_ids)] + elif extra_ids > 0 and additional_special_tokens is not None: + # Check that we have the right number of extra special tokens + extra_tokens = len(set(filter(lambda x: bool("extra_id_" in x), additional_special_tokens))) + if extra_tokens != extra_ids: + raise ValueError( + f"Both extra_ids ({extra_ids}) and additional_special_tokens ({additional_special_tokens}) are provided to T5Tokenizer. " + "In this case the additional_special_tokens must include the extra_ids tokens" + ) + super().__init__( vocab_file, tokenizer_file=tokenizer_file, @@ -137,13 +149,6 @@ class T5TokenizerFast(PreTrainedTokenizerFast): **kwargs, ) - if extra_ids > 0: - all_extra_tokens = ["".format(i) for i in range(extra_ids)] - if all(tok not in self.additional_special_tokens for tok in all_extra_tokens): - self.additional_special_tokens = self.additional_special_tokens + [ - "".format(i) for i in range(extra_ids) - ] - self.vocab_file = vocab_file self._extra_ids = extra_ids diff --git a/src/transformers/tokenization_transfo_xl.py b/src/transformers/tokenization_transfo_xl.py index e1b4c99ac3..5dbd9b5473 100644 --- a/src/transformers/tokenization_transfo_xl.py +++ b/src/transformers/tokenization_transfo_xl.py @@ -164,7 +164,19 @@ class TransfoXLTokenizer(PreTrainedTokenizer): **kwargs ): super().__init__( - unk_token=unk_token, eos_token=eos_token, additional_special_tokens=additional_special_tokens, **kwargs + special=special, + min_freq=min_freq, + max_size=max_size, + lower_case=lower_case, + delimiter=delimiter, + vocab_file=vocab_file, + pretrained_vocab_file=pretrained_vocab_file, + never_split=never_split, + unk_token=unk_token, + eos_token=eos_token, + additional_special_tokens=additional_special_tokens, + language=language, + **kwargs, ) if never_split is None: diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 39a49af535..129abccf7a 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -1673,7 +1673,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin): if ( "tokenizer_file" not in resolved_vocab_files or resolved_vocab_files["tokenizer_file"] is None ) and cls.slow_tokenizer_class is not None: - slow_tokenizer = cls.slow_tokenizer_class._from_pretrained( + slow_tokenizer = (cls.slow_tokenizer_class)._from_pretrained( copy.deepcopy(resolved_vocab_files), pretrained_model_name_or_path, copy.deepcopy(init_configuration), diff --git a/src/transformers/tokenization_utils_fast.py b/src/transformers/tokenization_utils_fast.py index 419b8108eb..037976a175 100644 --- a/src/transformers/tokenization_utils_fast.py +++ b/src/transformers/tokenization_utils_fast.py @@ -16,7 +16,6 @@ For slow (python) tokenizers see tokenization_utils.py """ -import copy import json import os import warnings @@ -105,7 +104,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase): self._tokenizer = fast_tokenizer if slow_tokenizer is not None: - kwargs = copy.deepcopy(slow_tokenizer.init_kwargs) + kwargs.update(slow_tokenizer.init_kwargs) # We call this after having initialized the backend tokenizer because we update it. super().__init__(**kwargs) diff --git a/src/transformers/tokenization_xlm.py b/src/transformers/tokenization_xlm.py index 76a36f38e3..f4ab9d57f8 100644 --- a/src/transformers/tokenization_xlm.py +++ b/src/transformers/tokenization_xlm.py @@ -621,6 +621,9 @@ class XLMTokenizer(PreTrainedTokenizer): cls_token=cls_token, mask_token=mask_token, additional_special_tokens=additional_special_tokens, + lang2id=lang2id, + id2lang=id2lang, + do_lowercase_and_remove_accent=do_lowercase_and_remove_accent, **kwargs, ) diff --git a/src/transformers/tokenization_xlm_prophetnet.py b/src/transformers/tokenization_xlm_prophetnet.py index a7e5d2ffcb..b235b97a46 100644 --- a/src/transformers/tokenization_xlm_prophetnet.py +++ b/src/transformers/tokenization_xlm_prophetnet.py @@ -123,9 +123,10 @@ class XLMProphetNetTokenizer(PreTrainedTokenizer): super().__init__( bos_token=bos_token, eos_token=eos_token, - unk_token=unk_token, sep_token=sep_token, + unk_token=unk_token, pad_token=pad_token, + cls_token=cls_token, mask_token=mask_token, **kwargs, ) diff --git a/src/transformers/tokenization_xlnet.py b/src/transformers/tokenization_xlnet.py index d41f7a5bc9..ecb5b6c3c1 100644 --- a/src/transformers/tokenization_xlnet.py +++ b/src/transformers/tokenization_xlnet.py @@ -128,6 +128,9 @@ class XLNetTokenizer(PreTrainedTokenizer): **kwargs ): super().__init__( + do_lower_case=do_lower_case, + remove_space=remove_space, + keep_accents=keep_accents, bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index 7d39f3276a..0090c0f47d 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -177,6 +177,25 @@ class TokenizerTesterMixin: self.assertIn("tokenizer_file", signature.parameters) self.assertIsNone(signature.parameters["tokenizer_file"].default) + def test_tokenizer_slow_store_full_signature(self): + signature = inspect.signature(self.tokenizer_class.__init__) + tokenizer = self.get_tokenizer() + + for parameter_name, parameter in signature.parameters.items(): + if parameter.default != inspect.Parameter.empty: + self.assertIn(parameter_name, tokenizer.init_kwargs) + + def test_tokenizer_fast_store_full_signature(self): + if not self.test_rust_tokenizer: + return + + signature = inspect.signature(self.rust_tokenizer_class.__init__) + tokenizer = self.get_rust_tokenizer() + + for parameter_name, parameter in signature.parameters.items(): + if parameter.default != inspect.Parameter.empty: + self.assertIn(parameter_name, tokenizer.init_kwargs) + def test_rust_and_python_full_tokenizers(self): if not self.test_rust_tokenizer: return