From 21b3922e35529dfbf9213365d7d37756a59f8e0e Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Thu, 4 Feb 2021 14:18:33 -0500 Subject: [PATCH] Authorize last version of tokenizer (#9799) * Authorize last version of tokenizer * Update version table * Fix conversion of spm tokenizers and fix some hub links * Bump tokenizers version to 0.10.1rc1 * Add script to check tokenizers conversion with XNLI * Add some more mask_token lstrip support * Must modify mask_token in slow tokenizers too * Keep using the old method for Pegasus * add missing import Co-authored-by: Anthony MOI --- scripts/check_tokenizers.py | 169 ++++++++++++++++++ setup.py | 2 +- src/transformers/convert_slow_tokenizer.py | 36 ++-- src/transformers/dependency_versions_table.py | 2 +- .../models/albert/tokenization_albert.py | 5 +- .../models/albert/tokenization_albert_fast.py | 4 + .../models/barthez/tokenization_barthez.py | 5 +- .../barthez/tokenization_barthez_fast.py | 4 + .../camembert/tokenization_camembert.py | 5 +- .../camembert/tokenization_camembert_fast.py | 4 + .../models/pegasus/tokenization_pegasus.py | 2 +- .../pegasus/tokenization_pegasus_fast.py | 6 +- .../models/reformer/tokenization_reformer.py | 2 +- .../reformer/tokenization_reformer_fast.py | 4 +- .../xlm_roberta/tokenization_xlm_roberta.py | 5 +- .../tokenization_xlm_roberta_fast.py | 4 + .../models/xlnet/tokenization_xlnet.py | 5 +- .../models/xlnet/tokenization_xlnet_fast.py | 4 + 18 files changed, 245 insertions(+), 23 deletions(-) create mode 100644 scripts/check_tokenizers.py diff --git a/scripts/check_tokenizers.py b/scripts/check_tokenizers.py new file mode 100644 index 0000000000..cfd0a7f3a1 --- /dev/null +++ b/scripts/check_tokenizers.py @@ -0,0 +1,169 @@ +from collections import Counter +import datasets +import transformers +from transformers.convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS + +from transformers.utils import logging + +logging.set_verbosity_info() + +TOKENIZER_CLASSES = { + name: (getattr(transformers, name), getattr(transformers, name + "Fast")) for name in SLOW_TO_FAST_CONVERTERS +} + +dataset = datasets.load_dataset("xnli", split="test+validation") + +total = 0 +perfect = 0 +imperfect = 0 +wrong = 0 + + +def check_diff(spm_diff, tok_diff, slow, fast): + if spm_diff == list(reversed(tok_diff)): + # AAA -> AA+A vs A+AA case. + return True + elif len(spm_diff) == len(tok_diff) and fast.decode(spm_diff) == fast.decode(tok_diff): + # Second order OK + # Barrich -> Barr + ich vs Bar + rich + return True + spm_reencoded = slow.encode(slow.decode(spm_diff)) + tok_reencoded = fast.encode(fast.decode(spm_diff)) + if spm_reencoded != spm_diff and spm_reencoded == tok_reencoded: + # Type 3 error. + # Snehagatha -> + # Sne, h, aga, th, a + # Sne, ha, gat, ha + # Encoding the wrong with sp does not even recover what spm gave us + # It fits tokenizer however... + return True + return False + + +def check_LTR_mark(line, idx, fast): + enc = fast.encode_plus(line)[0] + offsets = enc.offsets + curr, prev = offsets[idx], offsets[idx - 1] + if curr is not None and line[curr[0] : curr[1]] == "\u200f": + return True + if prev is not None and line[prev[0] : prev[1]] == "\u200f": + return True + + +def check_details(line, spm_ids, tok_ids, slow, fast): + # Encoding can be the same with same result AAA -> A + AA vs AA + A + # We can check that we use at least exactly the same number of tokens. + for i, (spm_id, tok_id) in enumerate(zip(spm_ids, tok_ids)): + if spm_id != tok_id: + break + first = i + for i, (spm_id, tok_id) in enumerate(zip(reversed(spm_ids), reversed(tok_ids))): + if spm_id != tok_id: + break + last = len(spm_ids) - i + + spm_diff = spm_ids[first:last] + tok_diff = tok_ids[first:last] + + if check_diff(spm_diff, tok_diff, slow, fast): + return True + + if check_LTR_mark(line, first, fast): + return True + + if last - first > 5: + # We might have twice a single problem, attempt to subdivide the disjointed tokens into smaller problems + spms = Counter(spm_ids[first:last]) + toks = Counter(tok_ids[first:last]) + + removable_tokens = {spm_ for (spm_, si) in spms.items() if toks.get(spm_, 0) == si} + min_width = 3 + for i in range(last - first - min_width): + if all(spm_ids[first + i + j] in removable_tokens for j in range(min_width)): + possible_matches = [ + k + for k in range(last - first - min_width) + if tok_ids[first + k : first + k + min_width] == spm_ids[first + i : first + i + min_width] + ] + for j in possible_matches: + if check_diff(spm_ids[first : first + i], tok_ids[first : first + j], sp, tok) and check_details( + line, + spm_ids[first + i : last], + tok_ids[first + j : last], + slow, + fast, + ): + return True + + print(f"Spm: {[fast.decode([spm_ids[i]]) for i in range(first, last)]}") + try: + print(f"Tok: {[fast.decode([tok_ids[i]]) for i in range(first, last)]}") + except Exception: + pass + + ok_start = fast.decode(spm_ids[:first]) + ok_end = fast.decode(spm_ids[last:]) + wrong = fast.decode(spm_ids[first:last]) + print() + print(wrong) + return False + + +def test_string(slow, fast, text): + global perfect + global imperfect + global wrong + global total + + slow_ids = slow.encode(text) + fast_ids = fast.encode(text) + + skip_assert = False + total += 1 + + if slow_ids != fast_ids: + if check_details(text, slow_ids, fast_ids, slow, fast): + skip_assert = True + imperfect += 1 + else: + wrong += 1 + else: + perfect += 1 + + if total % 10000 == 0: + print(f"({perfect} / {imperfect} / {wrong} ----- {perfect + imperfect + wrong})") + + if skip_assert: + return + + assert ( + slow_ids == fast_ids + ), f"line {text} : \n\n{slow_ids}\n{fast_ids}\n\n{slow.tokenize(text)}\n{fast.tokenize(text)}" + + +def test_tokenizer(slow, fast): + global batch_total + for i in range(len(dataset)): + # premise, all languages + for text in dataset[i]["premise"].values(): + test_string(slow, fast, text) + + # hypothesis, all languages + for text in dataset[i]["hypothesis"]["translation"]: + test_string(slow, fast, text) + + +if __name__ == "__main__": + for name, (slow_class, fast_class) in TOKENIZER_CLASSES.items(): + checkpoint_names = list(slow_class.max_model_input_sizes.keys()) + for checkpoint in checkpoint_names: + imperfect = 0 + perfect = 0 + wrong = 0 + total = 0 + + print(f"========================== Checking {name}: {checkpoint} ==========================") + slow = slow_class.from_pretrained(checkpoint, force_download=True) + fast = fast_class.from_pretrained(checkpoint, force_download=True) + test_tokenizer(slow, fast) + print(f"Accuracy {perfect * 100 / total:.2f}") diff --git a/setup.py b/setup.py index 0d49171f8a..567fff7a2a 100644 --- a/setup.py +++ b/setup.py @@ -132,7 +132,7 @@ _deps = [ "tensorflow-cpu>=2.3", "tensorflow>=2.3", "timeout-decorator", - "tokenizers==0.9.4", + "tokenizers==0.10.1rc1", "torch>=1.0", "tqdm>=4.27", "unidic>=1.0.2", diff --git a/src/transformers/convert_slow_tokenizer.py b/src/transformers/convert_slow_tokenizer.py index ad301574f7..9a11a5b0e6 100644 --- a/src/transformers/convert_slow_tokenizer.py +++ b/src/transformers/convert_slow_tokenizer.py @@ -21,7 +21,7 @@ from typing import Dict, List, Tuple -from tokenizers import Tokenizer, decoders, normalizers, pre_tokenizers, processors +from tokenizers import Regex, Tokenizer, decoders, normalizers, pre_tokenizers, processors from tokenizers.models import BPE, Unigram, WordPiece from .file_utils import requires_protobuf, requires_sentencepiece @@ -340,7 +340,12 @@ class SpmConverter(Converter): def normalizer(self, proto): precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap - return normalizers.Precompiled(precompiled_charsmap) + return normalizers.Sequence( + [normalizers.Precompiled(precompiled_charsmap), normalizers.Replace(Regex(" {2,}"), " ")] + ) + + def pre_tokenizer(self, replacement, add_prefix_space): + return pre_tokenizers.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space) def post_processor(self): return None @@ -353,12 +358,7 @@ class SpmConverter(Converter): replacement = "▁" add_prefix_space = True - tokenizer.pre_tokenizer = pre_tokenizers.Sequence( - [ - pre_tokenizers.WhitespaceSplit(), - pre_tokenizers.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space), - ] - ) + tokenizer.pre_tokenizer = self.pre_tokenizer(replacement, add_prefix_space) tokenizer.decoder = decoders.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space) post_processor = self.post_processor() if post_processor: @@ -375,7 +375,11 @@ class AlbertConverter(SpmConverter): ] def normalizer(self, proto): - list_normalizers = [normalizers.Replace("``", '"'), normalizers.Replace("''", '"')] + list_normalizers = [ + normalizers.Replace("``", '"'), + normalizers.Replace("''", '"'), + normalizers.Replace(Regex(" {2,}"), " "), + ] if not self.original_tokenizer.keep_accents: list_normalizers.append(normalizers.NFKD()) list_normalizers.append(normalizers.StripAccents()) @@ -529,7 +533,11 @@ class XLNetConverter(SpmConverter): ] def normalizer(self, proto): - list_normalizers = [normalizers.Replace("``", '"'), normalizers.Replace("''", '"')] + list_normalizers = [ + normalizers.Replace("``", '"'), + normalizers.Replace("''", '"'), + normalizers.Replace(Regex(" {2,}"), " "), + ] if not self.original_tokenizer.keep_accents: list_normalizers.append(normalizers.NFKD()) list_normalizers.append(normalizers.StripAccents()) @@ -574,6 +582,14 @@ class PegasusConverter(SpmConverter): def unk_id(self, proto): return proto.trainer_spec.unk_id + self.original_tokenizer.offset + def pre_tokenizer(self, replacement, add_prefix_space): + return pre_tokenizers.Sequence( + [ + pre_tokenizers.WhitespaceSplit(), + pre_tokenizers.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space), + ] + ) + def post_processor(self): eos = self.original_tokenizer.eos_token special_tokens = [ diff --git a/src/transformers/dependency_versions_table.py b/src/transformers/dependency_versions_table.py index 8dca85c2f0..6bc57add82 100644 --- a/src/transformers/dependency_versions_table.py +++ b/src/transformers/dependency_versions_table.py @@ -45,7 +45,7 @@ deps = { "tensorflow-cpu": "tensorflow-cpu>=2.3", "tensorflow": "tensorflow>=2.3", "timeout-decorator": "timeout-decorator", - "tokenizers": "tokenizers==0.9.4", + "tokenizers": "tokenizers==0.10.1rc1", "torch": "torch>=1.0", "tqdm": "tqdm>=4.27", "unidic": "unidic>=1.0.2", diff --git a/src/transformers/models/albert/tokenization_albert.py b/src/transformers/models/albert/tokenization_albert.py index 890c7f8707..c51e30bb99 100644 --- a/src/transformers/models/albert/tokenization_albert.py +++ b/src/transformers/models/albert/tokenization_albert.py @@ -22,7 +22,7 @@ from typing import List, Optional, Tuple import sentencepiece as spm -from ...tokenization_utils import PreTrainedTokenizer +from ...tokenization_utils import AddedToken, PreTrainedTokenizer from ...utils import logging @@ -127,6 +127,9 @@ class AlbertTokenizer(PreTrainedTokenizer): mask_token="[MASK]", **kwargs ): + # Mask token behave like a normal word, i.e. include the space before it + mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token + super().__init__( do_lower_case=do_lower_case, remove_space=remove_space, diff --git a/src/transformers/models/albert/tokenization_albert_fast.py b/src/transformers/models/albert/tokenization_albert_fast.py index 5cfa584386..40b80f0142 100644 --- a/src/transformers/models/albert/tokenization_albert_fast.py +++ b/src/transformers/models/albert/tokenization_albert_fast.py @@ -20,6 +20,7 @@ from shutil import copyfile from typing import List, Optional, Tuple from ...file_utils import is_sentencepiece_available +from ...tokenization_utils import AddedToken from ...tokenization_utils_fast import PreTrainedTokenizerFast from ...utils import logging @@ -134,6 +135,9 @@ class AlbertTokenizerFast(PreTrainedTokenizerFast): mask_token="[MASK]", **kwargs ): + # Mask token behave like a normal word, i.e. include the space before it + mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token + super().__init__( vocab_file, tokenizer_file=tokenizer_file, diff --git a/src/transformers/models/barthez/tokenization_barthez.py b/src/transformers/models/barthez/tokenization_barthez.py index d751de0e0c..f8061b323b 100644 --- a/src/transformers/models/barthez/tokenization_barthez.py +++ b/src/transformers/models/barthez/tokenization_barthez.py @@ -21,7 +21,7 @@ from typing import List, Optional, Tuple import sentencepiece as spm -from ...tokenization_utils import PreTrainedTokenizer +from ...tokenization_utils import AddedToken, PreTrainedTokenizer from ...utils import logging @@ -112,6 +112,9 @@ class BarthezTokenizer(PreTrainedTokenizer): mask_token="", **kwargs ): + # Mask token behave like a normal word, i.e. include the space before it + mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token + super().__init__( bos_token=bos_token, eos_token=eos_token, diff --git a/src/transformers/models/barthez/tokenization_barthez_fast.py b/src/transformers/models/barthez/tokenization_barthez_fast.py index 070d6e6c7e..d61ac07446 100644 --- a/src/transformers/models/barthez/tokenization_barthez_fast.py +++ b/src/transformers/models/barthez/tokenization_barthez_fast.py @@ -20,6 +20,7 @@ from shutil import copyfile from typing import List, Optional, Tuple from ...file_utils import is_sentencepiece_available +from ...tokenization_utils import AddedToken from ...tokenization_utils_fast import PreTrainedTokenizerFast from ...utils import logging @@ -119,6 +120,9 @@ class BarthezTokenizerFast(PreTrainedTokenizerFast): mask_token="", **kwargs ): + # Mask token behave like a normal word, i.e. include the space before it + mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token + super().__init__( vocab_file, tokenizer_file=tokenizer_file, diff --git a/src/transformers/models/camembert/tokenization_camembert.py b/src/transformers/models/camembert/tokenization_camembert.py index 0be12e76be..6e866ba638 100644 --- a/src/transformers/models/camembert/tokenization_camembert.py +++ b/src/transformers/models/camembert/tokenization_camembert.py @@ -21,7 +21,7 @@ from typing import List, Optional, Tuple import sentencepiece as spm -from ...tokenization_utils import PreTrainedTokenizer +from ...tokenization_utils import AddedToken, PreTrainedTokenizer from ...utils import logging @@ -116,6 +116,9 @@ class CamembertTokenizer(PreTrainedTokenizer): additional_special_tokens=["NOTUSED", "NOTUSED"], **kwargs ): + # Mask token behave like a normal word, i.e. include the space before it + mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token + super().__init__( bos_token=bos_token, eos_token=eos_token, diff --git a/src/transformers/models/camembert/tokenization_camembert_fast.py b/src/transformers/models/camembert/tokenization_camembert_fast.py index 437fa77173..87019e7253 100644 --- a/src/transformers/models/camembert/tokenization_camembert_fast.py +++ b/src/transformers/models/camembert/tokenization_camembert_fast.py @@ -20,6 +20,7 @@ from shutil import copyfile from typing import List, Optional, Tuple from ...file_utils import is_sentencepiece_available +from ...tokenization_utils import AddedToken from ...tokenization_utils_fast import PreTrainedTokenizerFast from ...utils import logging @@ -123,6 +124,9 @@ class CamembertTokenizerFast(PreTrainedTokenizerFast): additional_special_tokens=["NOTUSED", "NOTUSED"], **kwargs ): + # Mask token behave like a normal word, i.e. include the space before it + mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token + super().__init__( vocab_file, tokenizer_file=tokenizer_file, diff --git a/src/transformers/models/pegasus/tokenization_pegasus.py b/src/transformers/models/pegasus/tokenization_pegasus.py index 08e4f7d194..68ad5b83ad 100644 --- a/src/transformers/models/pegasus/tokenization_pegasus.py +++ b/src/transformers/models/pegasus/tokenization_pegasus.py @@ -27,7 +27,7 @@ SPIECE_UNDERLINE = "▁" VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"} PRETRAINED_VOCAB_FILES_MAP = { - "vocab_file": {"google/pegasus-xsum": "https://cdn.huggingface.co/google/pegasus-xsum/spiece.model"} + "vocab_file": {"google/pegasus-xsum": "https://huggingface.co/google/pegasus-xsum/resolve/main/spiece.model"} } PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { diff --git a/src/transformers/models/pegasus/tokenization_pegasus_fast.py b/src/transformers/models/pegasus/tokenization_pegasus_fast.py index d097692a98..626d930398 100644 --- a/src/transformers/models/pegasus/tokenization_pegasus_fast.py +++ b/src/transformers/models/pegasus/tokenization_pegasus_fast.py @@ -38,8 +38,10 @@ SPIECE_UNDERLINE = "▁" VOCAB_FILES_NAMES = {"vocab_file": "spiece.model", "tokenizer_file": "tokenizer.json"} PRETRAINED_VOCAB_FILES_MAP = { - "vocab_file": {"google/pegasus-xsum": "https://cdn.huggingface.co/google/pegasus-xsum/spiece.model"}, - "tokenizer_file": {"google/pegasus-xsum": "https://cdn.huggingface.co/google/pegasus-xsum/tokenizer.json"}, + "vocab_file": {"google/pegasus-xsum": "https://huggingface.co/google/pegasus-xsum/resolve/main/spiece.model"}, + "tokenizer_file": { + "google/pegasus-xsum": "https://huggingface.co/google/pegasus-xsum/resolve/main/tokenizer.json" + }, } PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { diff --git a/src/transformers/models/reformer/tokenization_reformer.py b/src/transformers/models/reformer/tokenization_reformer.py index 99c2c4c9b7..3c6ad94703 100644 --- a/src/transformers/models/reformer/tokenization_reformer.py +++ b/src/transformers/models/reformer/tokenization_reformer.py @@ -42,7 +42,7 @@ VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"} #################################################### PRETRAINED_VOCAB_FILES_MAP = { "vocab_file": { - "google/reformer-crime-and-punishment": "https://cdn.huggingface.co/google/reformer-crime-and-punishment/spiece.model" + "google/reformer-crime-and-punishment": "https://huggingface.co/google/reformer-crime-and-punishment/resolve/main/spiece.model" } } diff --git a/src/transformers/models/reformer/tokenization_reformer_fast.py b/src/transformers/models/reformer/tokenization_reformer_fast.py index ebe5ccab80..f8ab110a2f 100644 --- a/src/transformers/models/reformer/tokenization_reformer_fast.py +++ b/src/transformers/models/reformer/tokenization_reformer_fast.py @@ -47,10 +47,10 @@ VOCAB_FILES_NAMES = {"vocab_file": "spiece.model", "tokenizer_file": "tokenizer. #################################################### PRETRAINED_VOCAB_FILES_MAP = { "vocab_file": { - "google/reformer-crime-and-punishment": "https://cdn.huggingface.co/google/reformer-crime-and-punishment/spiece.model" + "google/reformer-crime-and-punishment": "https://huggingface.co/google/reformer-crime-and-punishment/resolve/main/spiece.model" }, "tokenizer_file": { - "google/reformer-crime-and-punishment": "https://cdn.huggingface.co/google/reformer-crime-and-punishment/tokenizer.json" + "google/reformer-crime-and-punishment": "https://huggingface.co/google/reformer-crime-and-punishment/resolve/main/tokenizer.json" }, } diff --git a/src/transformers/models/xlm_roberta/tokenization_xlm_roberta.py b/src/transformers/models/xlm_roberta/tokenization_xlm_roberta.py index 2d85e30dd0..5d642ef431 100644 --- a/src/transformers/models/xlm_roberta/tokenization_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/tokenization_xlm_roberta.py @@ -21,7 +21,7 @@ from typing import List, Optional, Tuple import sentencepiece as spm -from ...tokenization_utils import PreTrainedTokenizer +from ...tokenization_utils import AddedToken, PreTrainedTokenizer from ...utils import logging @@ -117,6 +117,9 @@ class XLMRobertaTokenizer(PreTrainedTokenizer): mask_token="", **kwargs ): + # Mask token behave like a normal word, i.e. include the space before it + mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token + super().__init__( bos_token=bos_token, eos_token=eos_token, diff --git a/src/transformers/models/xlm_roberta/tokenization_xlm_roberta_fast.py b/src/transformers/models/xlm_roberta/tokenization_xlm_roberta_fast.py index befd84be94..b3f97e3eaf 100644 --- a/src/transformers/models/xlm_roberta/tokenization_xlm_roberta_fast.py +++ b/src/transformers/models/xlm_roberta/tokenization_xlm_roberta_fast.py @@ -20,6 +20,7 @@ from shutil import copyfile from typing import List, Optional, Tuple from ...file_utils import is_sentencepiece_available +from ...tokenization_utils import AddedToken from ...tokenization_utils_fast import PreTrainedTokenizerFast from ...utils import logging @@ -127,6 +128,9 @@ class XLMRobertaTokenizerFast(PreTrainedTokenizerFast): mask_token="", **kwargs ): + # Mask token behave like a normal word, i.e. include the space before it + mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token + super().__init__( vocab_file, tokenizer_file=tokenizer_file, diff --git a/src/transformers/models/xlnet/tokenization_xlnet.py b/src/transformers/models/xlnet/tokenization_xlnet.py index 82d7122b6f..054fbf7c4f 100644 --- a/src/transformers/models/xlnet/tokenization_xlnet.py +++ b/src/transformers/models/xlnet/tokenization_xlnet.py @@ -23,7 +23,7 @@ from typing import List, Optional, Tuple import sentencepiece as spm from ...file_utils import SPIECE_UNDERLINE -from ...tokenization_utils import PreTrainedTokenizer +from ...tokenization_utils import AddedToken, PreTrainedTokenizer from ...utils import logging @@ -126,6 +126,9 @@ class XLNetTokenizer(PreTrainedTokenizer): additional_special_tokens=["", ""], **kwargs ): + # Mask token behave like a normal word, i.e. include the space before it + mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token + super().__init__( do_lower_case=do_lower_case, remove_space=remove_space, diff --git a/src/transformers/models/xlnet/tokenization_xlnet_fast.py b/src/transformers/models/xlnet/tokenization_xlnet_fast.py index 84af74070d..e2ebd0cfbb 100644 --- a/src/transformers/models/xlnet/tokenization_xlnet_fast.py +++ b/src/transformers/models/xlnet/tokenization_xlnet_fast.py @@ -20,6 +20,7 @@ from shutil import copyfile from typing import List, Optional, Tuple from ...file_utils import is_sentencepiece_available +from ...tokenization_utils import AddedToken from ...tokenization_utils_fast import PreTrainedTokenizerFast from ...utils import logging @@ -138,6 +139,9 @@ class XLNetTokenizerFast(PreTrainedTokenizerFast): additional_special_tokens=["", ""], **kwargs ): + # Mask token behave like a normal word, i.e. include the space before it + mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token + super().__init__( vocab_file=vocab_file, tokenizer_file=tokenizer_file,