[split_special_tokens] Add support for split_special_tokens argument to encode (#25081)
* draft changes * update and add tests * styling for no * move test * path to usable model * update test * small update * update bertbased tokenizers * don'tuse kwargs for _tokenize * don'tuse kwargs for _tokenize * fix copies * update * update test for special tokenizers * fixup * skip two tests * remove pdb breakpiont() * wowo * rewrite custom tests * nits * revert chang in target keys * fix markup lm * update documentation of the argument
This commit is contained in:
@@ -238,10 +238,12 @@ class BertTokenizer(PreTrainedTokenizer):
|
|||||||
def get_vocab(self):
|
def get_vocab(self):
|
||||||
return dict(self.vocab, **self.added_tokens_encoder)
|
return dict(self.vocab, **self.added_tokens_encoder)
|
||||||
|
|
||||||
def _tokenize(self, text):
|
def _tokenize(self, text, split_special_tokens=False):
|
||||||
split_tokens = []
|
split_tokens = []
|
||||||
if self.do_basic_tokenize:
|
if self.do_basic_tokenize:
|
||||||
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
|
for token in self.basic_tokenizer.tokenize(
|
||||||
|
text, never_split=self.all_special_tokens if not split_special_tokens else None
|
||||||
|
):
|
||||||
# If the token is part of the never_split set
|
# If the token is part of the never_split set
|
||||||
if token in self.basic_tokenizer.never_split:
|
if token in self.basic_tokenizer.never_split:
|
||||||
split_tokens.append(token)
|
split_tokens.append(token)
|
||||||
|
|||||||
@@ -177,10 +177,12 @@ class ConvBertTokenizer(PreTrainedTokenizer):
|
|||||||
def get_vocab(self):
|
def get_vocab(self):
|
||||||
return dict(self.vocab, **self.added_tokens_encoder)
|
return dict(self.vocab, **self.added_tokens_encoder)
|
||||||
|
|
||||||
def _tokenize(self, text):
|
def _tokenize(self, text, split_special_tokens=False):
|
||||||
split_tokens = []
|
split_tokens = []
|
||||||
if self.do_basic_tokenize:
|
if self.do_basic_tokenize:
|
||||||
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
|
for token in self.basic_tokenizer.tokenize(
|
||||||
|
text, never_split=self.all_special_tokens if not split_special_tokens else None
|
||||||
|
):
|
||||||
# If the token is part of the never_split set
|
# If the token is part of the never_split set
|
||||||
if token in self.basic_tokenizer.never_split:
|
if token in self.basic_tokenizer.never_split:
|
||||||
split_tokens.append(token)
|
split_tokens.append(token)
|
||||||
|
|||||||
@@ -178,10 +178,12 @@ class RetriBertTokenizer(PreTrainedTokenizer):
|
|||||||
return dict(self.vocab, **self.added_tokens_encoder)
|
return dict(self.vocab, **self.added_tokens_encoder)
|
||||||
|
|
||||||
# Copied from transformers.models.bert.tokenization_bert.BertTokenizer._tokenize
|
# Copied from transformers.models.bert.tokenization_bert.BertTokenizer._tokenize
|
||||||
def _tokenize(self, text):
|
def _tokenize(self, text, split_special_tokens=False):
|
||||||
split_tokens = []
|
split_tokens = []
|
||||||
if self.do_basic_tokenize:
|
if self.do_basic_tokenize:
|
||||||
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
|
for token in self.basic_tokenizer.tokenize(
|
||||||
|
text, never_split=self.all_special_tokens if not split_special_tokens else None
|
||||||
|
):
|
||||||
# If the token is part of the never_split set
|
# If the token is part of the never_split set
|
||||||
if token in self.basic_tokenizer.never_split:
|
if token in self.basic_tokenizer.never_split:
|
||||||
split_tokens.append(token)
|
split_tokens.append(token)
|
||||||
|
|||||||
@@ -195,10 +195,12 @@ class DistilBertTokenizer(PreTrainedTokenizer):
|
|||||||
return dict(self.vocab, **self.added_tokens_encoder)
|
return dict(self.vocab, **self.added_tokens_encoder)
|
||||||
|
|
||||||
# Copied from transformers.models.bert.tokenization_bert.BertTokenizer._tokenize
|
# Copied from transformers.models.bert.tokenization_bert.BertTokenizer._tokenize
|
||||||
def _tokenize(self, text):
|
def _tokenize(self, text, split_special_tokens=False):
|
||||||
split_tokens = []
|
split_tokens = []
|
||||||
if self.do_basic_tokenize:
|
if self.do_basic_tokenize:
|
||||||
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
|
for token in self.basic_tokenizer.tokenize(
|
||||||
|
text, never_split=self.all_special_tokens if not split_special_tokens else None
|
||||||
|
):
|
||||||
# If the token is part of the never_split set
|
# If the token is part of the never_split set
|
||||||
if token in self.basic_tokenizer.never_split:
|
if token in self.basic_tokenizer.never_split:
|
||||||
split_tokens.append(token)
|
split_tokens.append(token)
|
||||||
|
|||||||
@@ -194,10 +194,12 @@ class ElectraTokenizer(PreTrainedTokenizer):
|
|||||||
def get_vocab(self):
|
def get_vocab(self):
|
||||||
return dict(self.vocab, **self.added_tokens_encoder)
|
return dict(self.vocab, **self.added_tokens_encoder)
|
||||||
|
|
||||||
def _tokenize(self, text):
|
def _tokenize(self, text, split_special_tokens=False):
|
||||||
split_tokens = []
|
split_tokens = []
|
||||||
if self.do_basic_tokenize:
|
if self.do_basic_tokenize:
|
||||||
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
|
for token in self.basic_tokenizer.tokenize(
|
||||||
|
text, never_split=self.all_special_tokens if not split_special_tokens else None
|
||||||
|
):
|
||||||
# If the token is part of the never_split set
|
# If the token is part of the never_split set
|
||||||
if token in self.basic_tokenizer.never_split:
|
if token in self.basic_tokenizer.never_split:
|
||||||
split_tokens.append(token)
|
split_tokens.append(token)
|
||||||
|
|||||||
@@ -205,10 +205,12 @@ class FunnelTokenizer(PreTrainedTokenizer):
|
|||||||
return dict(self.vocab, **self.added_tokens_encoder)
|
return dict(self.vocab, **self.added_tokens_encoder)
|
||||||
|
|
||||||
# Copied from transformers.models.bert.tokenization_bert.BertTokenizer._tokenize
|
# Copied from transformers.models.bert.tokenization_bert.BertTokenizer._tokenize
|
||||||
def _tokenize(self, text):
|
def _tokenize(self, text, split_special_tokens=False):
|
||||||
split_tokens = []
|
split_tokens = []
|
||||||
if self.do_basic_tokenize:
|
if self.do_basic_tokenize:
|
||||||
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
|
for token in self.basic_tokenizer.tokenize(
|
||||||
|
text, never_split=self.all_special_tokens if not split_special_tokens else None
|
||||||
|
):
|
||||||
# If the token is part of the never_split set
|
# If the token is part of the never_split set
|
||||||
if token in self.basic_tokenizer.never_split:
|
if token in self.basic_tokenizer.never_split:
|
||||||
split_tokens.append(token)
|
split_tokens.append(token)
|
||||||
|
|||||||
@@ -176,10 +176,12 @@ class LayoutLMTokenizer(PreTrainedTokenizer):
|
|||||||
def get_vocab(self):
|
def get_vocab(self):
|
||||||
return dict(self.vocab, **self.added_tokens_encoder)
|
return dict(self.vocab, **self.added_tokens_encoder)
|
||||||
|
|
||||||
def _tokenize(self, text):
|
def _tokenize(self, text, split_special_tokens=False):
|
||||||
split_tokens = []
|
split_tokens = []
|
||||||
if self.do_basic_tokenize:
|
if self.do_basic_tokenize:
|
||||||
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
|
for token in self.basic_tokenizer.tokenize(
|
||||||
|
text, never_split=self.all_special_tokens if not split_special_tokens else None
|
||||||
|
):
|
||||||
# If the token is part of the never_split set
|
# If the token is part of the never_split set
|
||||||
if token in self.basic_tokenizer.never_split:
|
if token in self.basic_tokenizer.never_split:
|
||||||
split_tokens.append(token)
|
split_tokens.append(token)
|
||||||
|
|||||||
@@ -168,10 +168,12 @@ class LxmertTokenizer(PreTrainedTokenizer):
|
|||||||
def get_vocab(self):
|
def get_vocab(self):
|
||||||
return dict(self.vocab, **self.added_tokens_encoder)
|
return dict(self.vocab, **self.added_tokens_encoder)
|
||||||
|
|
||||||
def _tokenize(self, text):
|
def _tokenize(self, text, split_special_tokens=False):
|
||||||
split_tokens = []
|
split_tokens = []
|
||||||
if self.do_basic_tokenize:
|
if self.do_basic_tokenize:
|
||||||
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
|
for token in self.basic_tokenizer.tokenize(
|
||||||
|
text, never_split=self.all_special_tokens if not split_special_tokens else None
|
||||||
|
):
|
||||||
# If the token is part of the never_split set
|
# If the token is part of the never_split set
|
||||||
if token in self.basic_tokenizer.never_split:
|
if token in self.basic_tokenizer.never_split:
|
||||||
split_tokens.append(token)
|
split_tokens.append(token)
|
||||||
|
|||||||
@@ -166,10 +166,12 @@ class MobileBertTokenizer(PreTrainedTokenizer):
|
|||||||
def get_vocab(self):
|
def get_vocab(self):
|
||||||
return dict(self.vocab, **self.added_tokens_encoder)
|
return dict(self.vocab, **self.added_tokens_encoder)
|
||||||
|
|
||||||
def _tokenize(self, text):
|
def _tokenize(self, text, split_special_tokens=False):
|
||||||
split_tokens = []
|
split_tokens = []
|
||||||
if self.do_basic_tokenize:
|
if self.do_basic_tokenize:
|
||||||
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
|
for token in self.basic_tokenizer.tokenize(
|
||||||
|
text, never_split=self.all_special_tokens if not split_special_tokens else None
|
||||||
|
):
|
||||||
# If the token is part of the never_split set
|
# If the token is part of the never_split set
|
||||||
if token in self.basic_tokenizer.never_split:
|
if token in self.basic_tokenizer.never_split:
|
||||||
split_tokens.append(token)
|
split_tokens.append(token)
|
||||||
|
|||||||
@@ -210,10 +210,12 @@ class RoCBertTokenizer(PreTrainedTokenizer):
|
|||||||
return dict(self.vocab, **self.added_tokens_encoder)
|
return dict(self.vocab, **self.added_tokens_encoder)
|
||||||
|
|
||||||
# Copied from transformers.models.bert.tokenization_bert.BertTokenizer._tokenize
|
# Copied from transformers.models.bert.tokenization_bert.BertTokenizer._tokenize
|
||||||
def _tokenize(self, text):
|
def _tokenize(self, text, split_special_tokens=False):
|
||||||
split_tokens = []
|
split_tokens = []
|
||||||
if self.do_basic_tokenize:
|
if self.do_basic_tokenize:
|
||||||
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
|
for token in self.basic_tokenizer.tokenize(
|
||||||
|
text, never_split=self.all_special_tokens if not split_special_tokens else None
|
||||||
|
):
|
||||||
# If the token is part of the never_split set
|
# If the token is part of the never_split set
|
||||||
if token in self.basic_tokenizer.never_split:
|
if token in self.basic_tokenizer.never_split:
|
||||||
split_tokens.append(token)
|
split_tokens.append(token)
|
||||||
|
|||||||
@@ -180,10 +180,12 @@ class SqueezeBertTokenizer(PreTrainedTokenizer):
|
|||||||
def get_vocab(self):
|
def get_vocab(self):
|
||||||
return dict(self.vocab, **self.added_tokens_encoder)
|
return dict(self.vocab, **self.added_tokens_encoder)
|
||||||
|
|
||||||
def _tokenize(self, text):
|
def _tokenize(self, text, split_special_tokens=False):
|
||||||
split_tokens = []
|
split_tokens = []
|
||||||
if self.do_basic_tokenize:
|
if self.do_basic_tokenize:
|
||||||
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
|
for token in self.basic_tokenizer.tokenize(
|
||||||
|
text, never_split=self.all_special_tokens if not split_special_tokens else None
|
||||||
|
):
|
||||||
# If the token is part of the never_split set
|
# If the token is part of the never_split set
|
||||||
if token in self.basic_tokenizer.never_split:
|
if token in self.basic_tokenizer.never_split:
|
||||||
split_tokens.append(token)
|
split_tokens.append(token)
|
||||||
|
|||||||
@@ -498,6 +498,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
|
|||||||
all_special_tokens_extended = {
|
all_special_tokens_extended = {
|
||||||
str(t): t for t in self.all_special_tokens_extended if isinstance(t, AddedToken)
|
str(t): t for t in self.all_special_tokens_extended if isinstance(t, AddedToken)
|
||||||
}
|
}
|
||||||
|
split_special_tokens = kwargs.pop("split_special_tokens", self.split_special_tokens)
|
||||||
|
|
||||||
text, kwargs = self.prepare_for_tokenization(text, **kwargs)
|
text, kwargs = self.prepare_for_tokenization(text, **kwargs)
|
||||||
|
|
||||||
@@ -513,8 +514,14 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
|
|||||||
pattern = r"(" + r"|".join(escaped_special_toks) + r")|" + r"(.+?)"
|
pattern = r"(" + r"|".join(escaped_special_toks) + r")|" + r"(.+?)"
|
||||||
text = re.sub(pattern, lambda m: m.groups()[0] or m.groups()[1].lower(), text)
|
text = re.sub(pattern, lambda m: m.groups()[0] or m.groups()[1].lower(), text)
|
||||||
|
|
||||||
no_split_token = set(self.unique_no_split_tokens)
|
# split_special_tokens: empty `no_split_token`
|
||||||
tokens = self.tokens_trie.split(text)
|
if split_special_tokens:
|
||||||
|
no_split_token = []
|
||||||
|
tokens = [text]
|
||||||
|
else:
|
||||||
|
no_split_token = set(self.unique_no_split_tokens)
|
||||||
|
tokens = self.tokens_trie.split(text)
|
||||||
|
|
||||||
# ["This is something", "<special_token_1>", " else"]
|
# ["This is something", "<special_token_1>", " else"]
|
||||||
for i, token in enumerate(tokens):
|
for i, token in enumerate(tokens):
|
||||||
if token in no_split_token:
|
if token in no_split_token:
|
||||||
|
|||||||
@@ -1492,6 +1492,11 @@ INIT_TOKENIZER_DOCSTRING = r"""
|
|||||||
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`):
|
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`):
|
||||||
Whether or not the model should cleanup the spaces that were added when splitting the input text during the
|
Whether or not the model should cleanup the spaces that were added when splitting the input text during the
|
||||||
tokenization process.
|
tokenization process.
|
||||||
|
split_special_tokens (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not the special tokens should be split during the tokenization process. The default behavior is
|
||||||
|
to not split special tokens. This means that if `<s>` is the `bos_token`, then `tokenizer.tokenize("<s>") =
|
||||||
|
['<s>`]. Otherwise, if `split_special_tokens=True`, then `tokenizer.tokenize("<s>")` will be give `['<',
|
||||||
|
's', '>']`. This argument is only supported for `slow` tokenizers for the moment.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@@ -1546,6 +1551,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||||||
# By default, cleaning tokenization spaces for both fast and slow tokenizers
|
# By default, cleaning tokenization spaces for both fast and slow tokenizers
|
||||||
self.clean_up_tokenization_spaces = kwargs.pop("clean_up_tokenization_spaces", True)
|
self.clean_up_tokenization_spaces = kwargs.pop("clean_up_tokenization_spaces", True)
|
||||||
|
|
||||||
|
# By default, do not split special tokens for both fast and slow tokenizers
|
||||||
|
self.split_special_tokens = kwargs.pop("split_special_tokens", False)
|
||||||
|
|
||||||
self.deprecation_warnings = (
|
self.deprecation_warnings = (
|
||||||
{}
|
{}
|
||||||
) # Use to store when we have already noticed a deprecation warning (avoid overlogging).
|
) # Use to store when we have already noticed a deprecation warning (avoid overlogging).
|
||||||
|
|||||||
@@ -384,6 +384,10 @@ class LayoutLMv2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
def test_right_and_left_truncation(self):
|
def test_right_and_left_truncation(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("Not implemented")
|
||||||
|
def test_split_special_tokens(self):
|
||||||
|
pass
|
||||||
|
|
||||||
def test_encode_plus_with_padding(self):
|
def test_encode_plus_with_padding(self):
|
||||||
tokenizers = self.get_tokenizers(do_lower_case=False)
|
tokenizers = self.get_tokenizers(do_lower_case=False)
|
||||||
for tokenizer in tokenizers:
|
for tokenizer in tokenizers:
|
||||||
|
|||||||
@@ -264,6 +264,10 @@ class LayoutLMv3TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
def test_right_and_left_truncation(self):
|
def test_right_and_left_truncation(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("Not implemented")
|
||||||
|
def test_split_special_tokens(self):
|
||||||
|
pass
|
||||||
|
|
||||||
def test_encode_plus_with_padding(self):
|
def test_encode_plus_with_padding(self):
|
||||||
tokenizers = self.get_tokenizers(do_lower_case=False)
|
tokenizers = self.get_tokenizers(do_lower_case=False)
|
||||||
for tokenizer in tokenizers:
|
for tokenizer in tokenizers:
|
||||||
|
|||||||
@@ -144,6 +144,19 @@ class LayoutXLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
self.assertEqual(encoding_tokenizer_slow_1, encoding_tokenizer_slow_2)
|
self.assertEqual(encoding_tokenizer_slow_1, encoding_tokenizer_slow_2)
|
||||||
self.assertEqual(encoding_tokenizer_slow_1, encoding_tokenizer_slow_3)
|
self.assertEqual(encoding_tokenizer_slow_1, encoding_tokenizer_slow_3)
|
||||||
|
|
||||||
|
def test_split_special_tokens(self):
|
||||||
|
tokenizer = self.tokenizer_class.from_pretrained("microsoft/layoutxlm-base")
|
||||||
|
_, _, boxes = self.get_question_words_and_boxes()
|
||||||
|
special_token = "[SPECIAL_TOKEN]"
|
||||||
|
tokenizer.add_special_tokens({"additional_special_tokens": [special_token]})
|
||||||
|
encoded_special_token = tokenizer.tokenize(special_token, boxes=boxes, add_special_tokens=False)
|
||||||
|
self.assertEqual(len(encoded_special_token), 1)
|
||||||
|
|
||||||
|
encoded_split_special_token = tokenizer.tokenize(
|
||||||
|
special_token, add_special_tokens=False, split_special_tokens=True, boxes=boxes
|
||||||
|
)
|
||||||
|
self.assertTrue(len(encoded_split_special_token) > 1)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_sequence_builders(self):
|
def test_sequence_builders(self):
|
||||||
tokenizer = self.tokenizer_class.from_pretrained("microsoft/layoutxlm-base")
|
tokenizer = self.tokenizer_class.from_pretrained("microsoft/layoutxlm-base")
|
||||||
|
|||||||
@@ -1344,6 +1344,19 @@ class MarkupLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
self.assertTrue(special_token_id in p_output)
|
self.assertTrue(special_token_id in p_output)
|
||||||
self.assertTrue(special_token_id in cr_output)
|
self.assertTrue(special_token_id in cr_output)
|
||||||
|
|
||||||
|
def test_split_special_tokens(self):
|
||||||
|
# TODO this is only possible for slow currently
|
||||||
|
tokenizer = self.get_tokenizer()
|
||||||
|
special_token = "[SPECIAL_TOKEN]"
|
||||||
|
tokenizer.add_special_tokens({"additional_special_tokens": [special_token]})
|
||||||
|
encoded_special_token = tokenizer.tokenize(special_token, add_special_tokens=False)
|
||||||
|
self.assertEqual(len(encoded_special_token), 1)
|
||||||
|
|
||||||
|
encoded_split_special_token = tokenizer.tokenize(
|
||||||
|
special_token, add_special_tokens=False, split_special_tokens=True
|
||||||
|
)
|
||||||
|
self.assertTrue(len(encoded_split_special_token) > 1)
|
||||||
|
|
||||||
def test_training_new_tokenizer(self):
|
def test_training_new_tokenizer(self):
|
||||||
# This feature only exists for fast tokenizers
|
# This feature only exists for fast tokenizers
|
||||||
if not self.test_rust_tokenizer:
|
if not self.test_rust_tokenizer:
|
||||||
|
|||||||
@@ -3909,6 +3909,7 @@ class TokenizerTesterMixin:
|
|||||||
# Should not raise an error
|
# Should not raise an error
|
||||||
self.rust_tokenizer_class.from_pretrained(tmp_dir_2)
|
self.rust_tokenizer_class.from_pretrained(tmp_dir_2)
|
||||||
|
|
||||||
|
# TODO This is ran for all models but only tests bert...
|
||||||
def test_clean_up_tokenization_spaces(self):
|
def test_clean_up_tokenization_spaces(self):
|
||||||
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
||||||
assert tokenizer.clean_up_tokenization_spaces is True
|
assert tokenizer.clean_up_tokenization_spaces is True
|
||||||
@@ -3953,3 +3954,29 @@ class TokenizerTesterMixin:
|
|||||||
tokenizer.clean_up_tokenization_spaces = True
|
tokenizer.clean_up_tokenization_spaces = True
|
||||||
decoded = tokenizer.decode(tokens)
|
decoded = tokenizer.decode(tokens)
|
||||||
assert decoded == "[CLS] this shouldn't be! he'll go. [SEP]"
|
assert decoded == "[CLS] this shouldn't be! he'll go. [SEP]"
|
||||||
|
|
||||||
|
def test_split_special_tokens(self):
|
||||||
|
if not self.test_slow_tokenizer:
|
||||||
|
return
|
||||||
|
|
||||||
|
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
|
||||||
|
special_token = "[SPECIAL_TOKEN]"
|
||||||
|
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
|
||||||
|
tokenizer = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs)
|
||||||
|
|
||||||
|
if not tokenizer.is_fast:
|
||||||
|
# bloom, gptneox etc only have a fast
|
||||||
|
tokenizer.add_special_tokens({"additional_special_tokens": [special_token]})
|
||||||
|
encoded_special_token = tokenizer.encode(special_token, add_special_tokens=False)
|
||||||
|
self.assertEqual(len(encoded_special_token), 1)
|
||||||
|
|
||||||
|
encoded_split_special_token = tokenizer.encode(
|
||||||
|
special_token, add_special_tokens=False, split_special_tokens=True
|
||||||
|
)
|
||||||
|
if len(encoded_split_special_token) == 1:
|
||||||
|
# if we have subword tokenization or special vocab
|
||||||
|
self.assertTrue(
|
||||||
|
encoded_split_special_token[0] != tokenizer.convert_tokens_to_ids(special_token)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.assertTrue(len(encoded_split_special_token) > 1)
|
||||||
|
|||||||
Reference in New Issue
Block a user