diff --git a/docs/source/model_doc/transformerxl.rst b/docs/source/model_doc/transformerxl.rst
index f5f6b22c3e..ecdf2dd3b9 100644
--- a/docs/source/model_doc/transformerxl.rst
+++ b/docs/source/model_doc/transformerxl.rst
@@ -46,13 +46,6 @@ TransfoXLTokenizer
:members: save_vocabulary
-TransfoXLTokenizerFast
-~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-
-.. autoclass:: transformers.TransfoXLTokenizerFast
- :members:
-
-
TransfoXL specific outputs
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
diff --git a/setup.py b/setup.py
index 1619038ee5..38cb92861a 100644
--- a/setup.py
+++ b/setup.py
@@ -111,7 +111,7 @@ setup(
packages=find_packages("src"),
install_requires=[
"numpy",
- "tokenizers == 0.8.1.rc2",
+ "tokenizers == 0.9.0.rc2",
# dataclasses for Python versions that don't have it
"dataclasses;python_version<'3.7'",
# utilities from PyPA to e.g. compare versions
@@ -124,8 +124,9 @@ setup(
"tqdm >= 4.27",
# for OpenAI GPT
"regex != 2019.12.17",
- # for XLNet
+ # for SentencePiece models
"sentencepiece != 0.1.92",
+ "protobuf",
# for XLM
"sacremoses",
],
diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py
index dfe8073c0a..143ff4187d 100755
--- a/src/transformers/__init__.py
+++ b/src/transformers/__init__.py
@@ -152,7 +152,7 @@ from .pipelines import (
from .retrieval_rag import RagRetriever
# Tokenizers
-from .tokenization_albert import AlbertTokenizer
+from .tokenization_albert import AlbertTokenizer, AlbertTokenizerFast
from .tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer
from .tokenization_bart import BartTokenizer, BartTokenizerFast
from .tokenization_bert import BasicTokenizer, BertTokenizer, BertTokenizerFast, WordpieceTokenizer
@@ -160,7 +160,7 @@ from .tokenization_bert_generation import BertGenerationTokenizer
from .tokenization_bert_japanese import BertJapaneseTokenizer, CharacterTokenizer, MecabTokenizer
from .tokenization_bertweet import BertweetTokenizer
from .tokenization_blenderbot import BlenderbotSmallTokenizer, BlenderbotTokenizer
-from .tokenization_camembert import CamembertTokenizer
+from .tokenization_camembert import CamembertTokenizer, CamembertTokenizerFast
from .tokenization_ctrl import CTRLTokenizer
from .tokenization_deberta import DebertaTokenizer
from .tokenization_distilbert import DistilBertTokenizer, DistilBertTokenizerFast
@@ -180,18 +180,18 @@ from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast
from .tokenization_layoutlm import LayoutLMTokenizer, LayoutLMTokenizerFast
from .tokenization_longformer import LongformerTokenizer, LongformerTokenizerFast
from .tokenization_lxmert import LxmertTokenizer, LxmertTokenizerFast
-from .tokenization_mbart import MBartTokenizer
+from .tokenization_mbart import MBartTokenizer, MBartTokenizerFast
from .tokenization_mobilebert import MobileBertTokenizer, MobileBertTokenizerFast
from .tokenization_openai import OpenAIGPTTokenizer, OpenAIGPTTokenizerFast
-from .tokenization_pegasus import PegasusTokenizer
+from .tokenization_pegasus import PegasusTokenizer, PegasusTokenizerFast
from .tokenization_phobert import PhobertTokenizer
from .tokenization_rag import RagTokenizer
-from .tokenization_reformer import ReformerTokenizer
+from .tokenization_reformer import ReformerTokenizer, ReformerTokenizerFast
from .tokenization_retribert import RetriBertTokenizer, RetriBertTokenizerFast
from .tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast
from .tokenization_squeezebert import SqueezeBertTokenizer, SqueezeBertTokenizerFast
-from .tokenization_t5 import T5Tokenizer
-from .tokenization_transfo_xl import TransfoXLCorpus, TransfoXLTokenizer, TransfoXLTokenizerFast
+from .tokenization_t5 import T5Tokenizer, T5TokenizerFast
+from .tokenization_transfo_xl import TransfoXLCorpus, TransfoXLTokenizer
from .tokenization_utils import PreTrainedTokenizer
from .tokenization_utils_base import (
BatchEncoding,
@@ -203,8 +203,8 @@ from .tokenization_utils_base import (
)
from .tokenization_utils_fast import PreTrainedTokenizerFast
from .tokenization_xlm import XLMTokenizer
-from .tokenization_xlm_roberta import XLMRobertaTokenizer
-from .tokenization_xlnet import SPIECE_UNDERLINE, XLNetTokenizer
+from .tokenization_xlm_roberta import XLMRobertaTokenizer, XLMRobertaTokenizerFast
+from .tokenization_xlnet import SPIECE_UNDERLINE, XLNetTokenizer, XLNetTokenizerFast
# Trainer
from .trainer_callback import (
diff --git a/src/transformers/convert_slow_tokenizer.py b/src/transformers/convert_slow_tokenizer.py
new file mode 100644
index 0000000000..687df4b546
--- /dev/null
+++ b/src/transformers/convert_slow_tokenizer.py
@@ -0,0 +1,566 @@
+# coding=utf-8
+# Copyright 2018 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Utilities to convert slow tokenizers in their fast tokenizers counterparts.
+
+ All the conversions are grouped here to gather SentencePiece dependencies outside of
+ the fast tokenizers files and allow to make our dependency on SentencePiece optional.
+"""
+
+from typing import Dict, List, Tuple
+
+from sentencepiece import SentencePieceProcessor
+from tokenizers import Tokenizer, decoders, normalizers, pre_tokenizers, processors
+from tokenizers.models import BPE, Unigram, WordPiece
+
+# from transformers.tokenization_openai import OpenAIGPTTokenizer
+from transformers.utils import sentencepiece_model_pb2 as model
+
+
+class SentencePieceExtractor:
+ """
+ Extractor implementation for SentencePiece trained models.
+ https://github.com/google/sentencepiece
+ """
+
+ def __init__(self, model: str):
+ # Get SentencePiece
+ self.sp = SentencePieceProcessor()
+ self.sp.Load(model)
+
+ def extract(self) -> Tuple[Dict[str, int], List[Tuple]]:
+ sp = self.sp
+ vocab = {sp.id_to_piece(index): index for index in range(sp.GetPieceSize())}
+
+ # Merges
+ merges = []
+ for piece_l in vocab.keys():
+ for piece_r in vocab.keys():
+ merge = f"{piece_l}{piece_r}"
+ piece_id = vocab.get(merge, None)
+ if piece_id:
+ merges += [(piece_l, piece_r, piece_id)]
+ merges = sorted(merges, key=lambda val: val[2])
+ merges = [(val[0], val[1]) for val in merges]
+
+ return vocab, merges
+
+
+def check_number_comma(piece: str) -> bool:
+ return len(piece) < 2 or piece[-1] != "," or not piece[-2].isdigit()
+
+
+def get_proto(filename: str):
+ m = model.ModelProto()
+ m.ParseFromString(open(filename, "rb").read())
+ return m
+
+
+class Converter:
+ def __init__(self, original_tokenizer):
+ self.original_tokenizer = original_tokenizer
+
+ def converted(self) -> Tokenizer:
+ raise NotImplementedError()
+
+
+class BertConverter(Converter):
+ def converted(self) -> Tokenizer:
+ vocab = self.original_tokenizer.vocab
+ tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
+
+ # # Let the tokenizer know about special tokens if they are part of the vocab
+ # if tokenizer.token_to_id(str(self.original_tokenizer.unk_token)) is not None:
+ # tokenizer.add_special_tokens([str(self.original_tokenizer.unk_token)])
+ # if tokenizer.token_to_id(str(self.original_tokenizer.sep_token)) is not None:
+ # tokenizer.add_special_tokens([str(self.original_tokenizer.sep_token)])
+ # if tokenizer.token_to_id(str(self.original_tokenizer.cls_token)) is not None:
+ # tokenizer.add_special_tokens([str(self.original_tokenizer.cls_token)])
+ # if tokenizer.token_to_id(str(self.original_tokenizer.pad_token)) is not None:
+ # tokenizer.add_special_tokens([str(self.original_tokenizer.pad_token)])
+ # if tokenizer.token_to_id(str(self.original_tokenizer.mask_token)) is not None:
+ # tokenizer.add_special_tokens([str(self.original_tokenizer.mask_token)])
+
+ tokenize_chinese_chars = False
+ strip_accents = False
+ do_lower_case = False
+ if hasattr(self.original_tokenizer, "basic_tokenizer"):
+ tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
+ strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
+ do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
+
+ tokenizer.normalizer = normalizers.BertNormalizer(
+ clean_text=True,
+ handle_chinese_chars=tokenize_chinese_chars,
+ strip_accents=strip_accents,
+ lowercase=do_lower_case,
+ )
+ tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
+
+ cls = str(self.original_tokenizer.cls_token)
+ sep = str(self.original_tokenizer.sep_token)
+ cls_token_id = self.original_tokenizer.cls_token_id
+ sep_token_id = self.original_tokenizer.sep_token_id
+
+ tokenizer.post_processor = processors.TemplateProcessing(
+ single=f"{cls}:0 $A:0 {sep}:0",
+ pair=f"{cls}:0 $A:0 {sep}:0 $B:1 {sep}:1",
+ special_tokens=[
+ (cls, cls_token_id),
+ (sep, sep_token_id),
+ ],
+ )
+ tokenizer.decoder = decoders.WordPiece(prefix="##")
+
+ return tokenizer
+
+
+class FunnelConverter(Converter):
+ def converted(self) -> Tokenizer:
+ vocab = self.original_tokenizer.vocab
+ tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
+
+ # # Let the tokenizer know about special tokens if they are part of the vocab
+ # if tokenizer.token_to_id(str(self.original_tokenizer.unk_token)) is not None:
+ # tokenizer.add_special_tokens([str(self.original_tokenizer.unk_token)])
+ # if tokenizer.token_to_id(str(self.original_tokenizer.sep_token)) is not None:
+ # tokenizer.add_special_tokens([str(self.original_tokenizer.sep_token)])
+ # if tokenizer.token_to_id(str(self.original_tokenizer.cls_token)) is not None:
+ # tokenizer.add_special_tokens([str(self.original_tokenizer.cls_token)])
+ # if tokenizer.token_to_id(str(self.original_tokenizer.pad_token)) is not None:
+ # tokenizer.add_special_tokens([str(self.original_tokenizer.pad_token)])
+ # if tokenizer.token_to_id(str(self.original_tokenizer.mask_token)) is not None:
+ # tokenizer.add_special_tokens([str(self.original_tokenizer.mask_token)])
+
+ tokenize_chinese_chars = False
+ strip_accents = False
+ do_lower_case = False
+ if hasattr(self.original_tokenizer, "basic_tokenizer"):
+ tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
+ strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
+ do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
+
+ tokenizer.normalizer = normalizers.BertNormalizer(
+ clean_text=True,
+ handle_chinese_chars=tokenize_chinese_chars,
+ strip_accents=strip_accents,
+ lowercase=do_lower_case,
+ )
+ tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
+
+ cls = str(self.original_tokenizer.cls_token)
+ sep = str(self.original_tokenizer.sep_token)
+ cls_token_id = self.original_tokenizer.cls_token_id
+ sep_token_id = self.original_tokenizer.sep_token_id
+
+ tokenizer.post_processor = processors.TemplateProcessing(
+ single=f"{cls}:2 $A:0 {sep}:0", # token_type_id is 2 for Funnel transformer
+ pair=f"{cls}:2 $A:0 {sep}:0 $B:1 {sep}:1",
+ special_tokens=[
+ (cls, cls_token_id),
+ (sep, sep_token_id),
+ ],
+ )
+ tokenizer.decoder = decoders.WordPiece(prefix="##")
+
+ return tokenizer
+
+
+class OpenAIGPTConverter(Converter):
+ def converted(self) -> Tokenizer:
+ vocab = self.original_tokenizer.encoder
+ merges = list(self.original_tokenizer.bpe_ranks.keys())
+ unk_token = self.original_tokenizer.unk_token
+
+ tokenizer = Tokenizer(
+ BPE(
+ vocab=vocab,
+ merges=merges,
+ dropout=None,
+ unk_token=str(unk_token),
+ end_of_word_suffix="",
+ fuse_unk=False,
+ )
+ )
+
+ if tokenizer.token_to_id(str(unk_token)) is not None:
+ tokenizer.add_special_tokens([str(unk_token)])
+
+ tokenizer.normalizer = normalizers.BertNormalizer(lowercase=True)
+ tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
+ tokenizer.decoder = decoders.BPEDecoder(suffix="")
+
+ return tokenizer
+
+
+class GPT2Converter(Converter):
+ def converted(self) -> Tokenizer:
+ vocab = self.original_tokenizer.encoder
+ merges = list(self.original_tokenizer.bpe_ranks.keys())
+
+ tokenizer = Tokenizer(
+ BPE(
+ vocab=vocab,
+ merges=merges,
+ dropout=None,
+ continuing_subword_prefix="",
+ end_of_word_suffix="",
+ fuse_unk=False,
+ )
+ )
+
+ tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=self.original_tokenizer.add_prefix_space)
+ tokenizer.decoder = decoders.ByteLevel()
+ tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
+
+ return tokenizer
+
+
+class RobertaConverter(Converter):
+ def converted(self) -> Tokenizer:
+ ot = self.original_tokenizer
+ vocab = ot.encoder
+ merges = list(ot.bpe_ranks.keys())
+
+ tokenizer = Tokenizer(
+ BPE(
+ vocab=vocab,
+ merges=merges,
+ dropout=None,
+ continuing_subword_prefix="",
+ end_of_word_suffix="",
+ fuse_unk=False,
+ )
+ )
+
+ tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=ot.add_prefix_space)
+ tokenizer.decoder = decoders.ByteLevel()
+ tokenizer.post_processor = processors.RobertaProcessing(
+ sep=(ot.sep_token, ot.sep_token_id),
+ cls=(ot.cls_token, ot.cls_token_id),
+ add_prefix_space=ot.add_prefix_space,
+ trim_offsets=True, # True by default on Roberta (historical)
+ )
+
+ return tokenizer
+
+
+class SpmConverter(Converter):
+ def __init__(self, *args):
+ super().__init__(*args)
+ self.proto = get_proto(self.original_tokenizer.vocab_file)
+
+ def vocab(self, proto):
+ return [(piece.piece, piece.score) for piece in proto.pieces]
+
+ def unk_id(self, proto):
+ return proto.trainer_spec.unk_id
+
+ def tokenizer(self, proto):
+ model_type = proto.trainer_spec.model_type
+ vocab = self.vocab(proto)
+ unk_id = self.unk_id(proto)
+
+ if model_type == 1:
+ tokenizer = Tokenizer(Unigram(vocab, unk_id))
+ elif model_type == 2:
+ vocab, merges = SentencePieceExtractor(self.original_tokenizer.vocab_file).extract()
+ tokenizer = Tokenizer(
+ BPE(
+ vocab,
+ merges,
+ unk_token=proto.trainer_spec.unk_piece,
+ fuse_unk=True,
+ )
+ )
+ else:
+ raise Exception(
+ "You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
+ )
+
+ return tokenizer
+
+ def normalizer(self, proto):
+ precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
+ return normalizers.Precompiled(precompiled_charsmap)
+
+ def post_processor(self):
+ return None
+
+ def converted(self) -> Tokenizer:
+ tokenizer = self.tokenizer(self.proto)
+
+ # Tokenizer assemble
+ tokenizer.normalizer = self.normalizer(self.proto)
+
+ replacement = "▁"
+ add_prefix_space = True
+ tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
+ [
+ pre_tokenizers.WhitespaceSplit(),
+ pre_tokenizers.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space),
+ ]
+ )
+ tokenizer.decoder = decoders.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space)
+ post_processor = self.post_processor()
+ if post_processor:
+ tokenizer.post_processor = post_processor
+
+ return tokenizer
+
+
+class AlbertConverter(SpmConverter):
+ def vocab(self, proto):
+ return [
+ (piece.piece, piece.score) if check_number_comma(piece.piece) else (piece.piece, piece.score - 100)
+ for piece in proto.pieces
+ ]
+
+ def normalizer(self, proto):
+ list_normalizers = [normalizers.Replace("``", '"'), normalizers.Replace("''", '"')]
+ if not self.original_tokenizer.keep_accents:
+ list_normalizers.append(normalizers.NFKD())
+ list_normalizers.append(normalizers.StripAccents())
+ if self.original_tokenizer.do_lower_case:
+ list_normalizers.append(normalizers.Lowercase())
+
+ precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
+ list_normalizers.append(normalizers.Precompiled(precompiled_charsmap))
+ return normalizers.Sequence(list_normalizers)
+
+ def post_processor(self):
+ return processors.TemplateProcessing(
+ single="[CLS]:0 $A:0 [SEP]:0",
+ pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
+ special_tokens=[
+ ("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
+ ("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
+ ],
+ )
+
+
+class CamembertConverter(SpmConverter):
+ def vocab(self, proto):
+ vocab = [
+ ("NOTUSED", 0.0),
+ ("", 0.0),
+ ("NOTUSED", 0.0),
+ ("", 0.0),
+ ]
+ # We down-grade the original SentencePiece by -100 to avoid using it and use our added token instead
+ vocab += [(piece.piece, piece.score if i != 0 else piece.score - 100) for i, piece in enumerate(proto.pieces)]
+ vocab += [("", 0.0)]
+ return vocab
+
+ def unk_id(self, proto):
+ # See vocab unk position
+ return 3
+
+ def post_processor(self):
+ return processors.TemplateProcessing(
+ single=" $A ",
+ pair=" $A $B ",
+ special_tokens=[
+ ("", self.original_tokenizer.convert_tokens_to_ids("")),
+ ("", self.original_tokenizer.convert_tokens_to_ids("")),
+ ],
+ )
+
+
+class MBartConverter(SpmConverter):
+ def vocab(self, proto):
+ vocab = [
+ ("", 0.0),
+ ("", 0.0),
+ ("", 0.0),
+ ("", 0.0),
+ ]
+ vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
+ vocab += [
+ ("ar_AR", 0.0),
+ ("cs_CZ", 0.0),
+ ("de_DE", 0.0),
+ ("en_XX", 0.0),
+ ("es_XX", 0.0),
+ ("et_EE", 0.0),
+ ("fi_FI", 0.0),
+ ("fr_XX", 0.0),
+ ("gu_IN", 0.0),
+ ("hi_IN", 0.0),
+ ("it_IT", 0.0),
+ ("ja_XX", 0.0),
+ ("kk_KZ", 0.0),
+ ("ko_KR", 0.0),
+ ("lt_LT", 0.0),
+ ("lv_LV", 0.0),
+ ("my_MM", 0.0),
+ ("ne_NP", 0.0),
+ ("nl_XX", 0.0),
+ ("ro_RO", 0.0),
+ ("ru_RU", 0.0),
+ ("si_LK", 0.0),
+ ("tr_TR", 0.0),
+ ("vi_VN", 0.0),
+ ("zh_CN", 0.0),
+ ]
+ vocab += [("", 0.0)]
+ return vocab
+
+ def unk_id(self, proto):
+ return 3
+
+ def post_processor(self):
+ return processors.TemplateProcessing(
+ single="$A en_XX",
+ pair="$A $B en_XX",
+ special_tokens=[
+ ("en_XX", self.original_tokenizer.convert_tokens_to_ids("en_XX")),
+ ("", self.original_tokenizer.convert_tokens_to_ids("")),
+ ],
+ )
+
+
+class XLMRobertaConverter(SpmConverter):
+ def vocab(self, proto):
+ vocab = [
+ ("", 0.0),
+ ("", 0.0),
+ ("", 0.0),
+ ("", 0.0),
+ ]
+ vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
+ vocab += [("", 0.0)]
+ return vocab
+
+ def unk_id(self, proto):
+ unk_id = 3
+ return unk_id
+
+ def post_processor(self):
+ return processors.TemplateProcessing(
+ single=" $A ",
+ pair=" $A $B ",
+ special_tokens=[
+ ("", self.original_tokenizer.convert_tokens_to_ids("")),
+ ("", self.original_tokenizer.convert_tokens_to_ids("")),
+ ],
+ )
+
+
+class XLNetConverter(SpmConverter):
+ def vocab(self, proto):
+ return [
+ (piece.piece, piece.score) if check_number_comma(piece.piece) else (piece.piece, piece.score - 100)
+ for piece in proto.pieces
+ ]
+
+ def normalizer(self, proto):
+ list_normalizers = [normalizers.Replace("``", '"'), normalizers.Replace("''", '"')]
+ if not self.original_tokenizer.keep_accents:
+ list_normalizers.append(normalizers.NFKD())
+ list_normalizers.append(normalizers.StripAccents())
+ if self.original_tokenizer.do_lower_case:
+ list_normalizers.append(normalizers.Lowercase())
+
+ precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
+ list_normalizers.append(normalizers.Precompiled(precompiled_charsmap))
+ return normalizers.Sequence(list_normalizers)
+
+ def post_processor(self):
+ return processors.TemplateProcessing(
+ single="$A:0 :0 :2",
+ pair="$A:0 :0 $B:1 :1 :2",
+ special_tokens=[
+ ("", self.original_tokenizer.convert_tokens_to_ids("")),
+ ("", self.original_tokenizer.convert_tokens_to_ids("")),
+ ],
+ )
+
+
+class ReformerConverter(SpmConverter):
+ pass
+
+
+class BertGenerationConverter(SpmConverter):
+ pass
+
+
+class PegasusConverter(SpmConverter):
+ def vocab(self, proto):
+ vocab = [
+ (self.original_tokenizer.pad_token, 0),
+ (self.original_tokenizer.eos_token, 0),
+ ]
+ vocab += [(f"unk_{i}", -100) for i in range(2, 2 + self.original_tokenizer.offset)]
+ vocab += [(piece.piece, piece.score) for piece in proto.pieces[2:]]
+ return vocab
+
+ def unk_id(self, proto):
+ return proto.trainer_spec.unk_id + self.original_tokenizer.offset
+
+ def post_processor(self):
+ eos = self.original_tokenizer.eos_token
+ return processors.TemplateProcessing(
+ single=["$A", eos],
+ pair=["$A", "$B", eos],
+ special_tokens=[
+ (eos, self.original_tokenizer.eos_token_id),
+ ],
+ )
+
+
+class T5Converter(SpmConverter):
+ def vocab(self, proto):
+ num_extra_ids = self.original_tokenizer._extra_ids
+ vocab = [(piece.piece, piece.score) for piece in proto.pieces]
+ vocab += [("".format(i), 0.0) for i in range(num_extra_ids - 1, -1, -1)]
+ return vocab
+
+ def post_processor(self):
+ return processors.TemplateProcessing(
+ single=["$A", ""],
+ pair=["$A", "", "$B", ""],
+ special_tokens=[
+ ("", self.original_tokenizer.convert_tokens_to_ids("")),
+ ],
+ )
+
+
+CONVERTERS = {
+ "AlbertTokenizer": AlbertConverter,
+ "BertTokenizer": BertConverter,
+ "BertGenerationTokenizer": BertGenerationConverter,
+ "BartTokenizer": RobertaConverter,
+ "CamembertTokenizer": CamembertConverter,
+ "DistilBertTokenizer": BertConverter,
+ "DPRReaderTokenizer": BertConverter,
+ "DPRQuestionEncoderTokenizer": BertConverter,
+ "DPRContextEncoderTokenizer": BertConverter,
+ "FunnelTokenizer": FunnelConverter,
+ "GPT2Tokenizer": GPT2Converter,
+ "LxmertTokenizer": BertConverter,
+ "MBartTokenizer": MBartConverter,
+ "OpenAIGPTTokenizer": OpenAIGPTConverter,
+ "PegasusTokenizer": PegasusConverter,
+ "ReformerTokenizer": ReformerConverter,
+ "RobertaTokenizer": RobertaConverter,
+ "T5Tokenizer": T5Converter,
+ "XLMRobertaTokenizer": XLMRobertaConverter,
+ "XLNetTokenizer": XLNetConverter,
+}
+
+
+def convert_slow_tokenizer(transformer_tokenizer) -> Tokenizer:
+ converter_class = CONVERTERS[transformer_tokenizer.__class__.__name__]
+ return converter_class(transformer_tokenizer).converted()
diff --git a/src/transformers/tokenization_albert.py b/src/transformers/tokenization_albert.py
index b5d9296dc5..424630b21f 100644
--- a/src/transformers/tokenization_albert.py
+++ b/src/transformers/tokenization_albert.py
@@ -21,6 +21,7 @@ from shutil import copyfile
from typing import List, Optional
from .tokenization_utils import PreTrainedTokenizer
+from .tokenization_utils_fast import PreTrainedTokenizerFast
from .utils import logging
@@ -340,3 +341,206 @@ class AlbertTokenizer(PreTrainedTokenizer):
copyfile(self.vocab_file, out_vocab_file)
return (out_vocab_file,)
+
+
+class AlbertTokenizerFast(PreTrainedTokenizerFast):
+ """
+ Construct a "fast" ALBERT tokenizer (backed by HuggingFace's `tokenizers` library). Based on
+ `SentencePiece `__.
+
+ This tokenizer inherits from :class:`~transformers.PreTrainedTokenizerFast` which contains most of the main
+ methods. Users should refer to this superclass for more information regarding those methods.
+
+ Args:
+ vocab_file (:obj:`str`):
+ `SentencePiece `__ file (generally has a `.spm` extension) that
+ contains the vocabulary necessary to instantiate a tokenizer.
+ do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`):
+ Whether or not to lowercase the input when tokenizing.
+ remove_space (:obj:`bool`, `optional`, defaults to :obj:`True`):
+ Whether or not to strip the text when tokenizing (removing excess spaces before and after the string).
+ keep_accents (:obj:`bool`, `optional`, defaults to :obj:`False`):
+ Whether or not to keep accents when tokenizing.
+ bos_token (:obj:`str`, `optional`, defaults to :obj:`"[CLS]"`):
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
+
+ .. note::
+
+ When building a sequence using special tokens, this is not the token that is used for the beginning
+ of sequence. The token used is the :obj:`cls_token`.
+ eos_token (:obj:`str`, `optional`, defaults to :obj:`"[SEP]"`):
+ The end of sequence token.
+
+ .. note::
+
+ When building a sequence using special tokens, this is not the token that is used for the end
+ of sequence. The token used is the :obj:`sep_token`.
+ unk_token (:obj:`str`, `optional`, defaults to :obj:`""`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ sep_token (:obj:`str`, `optional`, defaults to :obj:`"[SEP]"`):
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences
+ for sequence classification or for a text and a question for question answering.
+ It is also used as the last token of a sequence built with special tokens.
+ pad_token (:obj:`str`, `optional`, defaults to :obj:`""`):
+ The token used for padding, for example when batching sequences of different lengths.
+ cls_token (:obj:`str`, `optional`, defaults to :obj:`"[CLS]"`):
+ The classifier token which is used when doing sequence classification (classification of the whole
+ sequence instead of per-token classification). It is the first token of the sequence when built with
+ special tokens.
+ mask_token (:obj:`str`, `optional`, defaults to :obj:`"[MASK]"`):
+ The token used for masking values. This is the token used when training this model with masked language
+ modeling. This is the token which the model will try to predict.
+
+ Attributes:
+ sp_model (:obj:`SentencePieceProcessor`):
+ The `SentencePiece` processor that is used for every conversion (string, tokens and IDs).
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
+ max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
+ slow_tokenizer_class = AlbertTokenizer
+
+ def __init__(
+ self,
+ vocab_file,
+ do_lower_case=True,
+ remove_space=True,
+ keep_accents=False,
+ bos_token="[CLS]",
+ eos_token="[SEP]",
+ unk_token="",
+ sep_token="[SEP]",
+ pad_token="",
+ cls_token="[CLS]",
+ mask_token="[MASK]",
+ **kwargs
+ ):
+ super().__init__(
+ vocab_file,
+ do_lower_case=do_lower_case,
+ remove_space=remove_space,
+ keep_accents=keep_accents,
+ bos_token=bos_token,
+ eos_token=eos_token,
+ unk_token=unk_token,
+ sep_token=sep_token,
+ pad_token=pad_token,
+ cls_token=cls_token,
+ mask_token=mask_token,
+ **kwargs,
+ )
+
+ self.do_lower_case = do_lower_case
+ self.remove_space = remove_space
+ self.keep_accents = keep_accents
+ self.vocab_file = vocab_file
+
+ def build_inputs_with_special_tokens(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+ ) -> List[int]:
+ """
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks
+ by concatenating and adding special tokens.
+ An ALBERT sequence has the following format:
+
+ - single sequence: ``[CLS] X [SEP]``
+ - pair of sequences: ``[CLS] A [SEP] B [SEP]``
+
+ Args:
+ token_ids_0 (:obj:`List[int]`):
+ List of IDs to which the special tokens will be added.
+ token_ids_1 (:obj:`List[int]`, `optional`):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ :obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
+ """
+ sep = [self.sep_token_id]
+ cls = [self.cls_token_id]
+ if token_ids_1 is None:
+ return cls + token_ids_0 + sep
+ return cls + token_ids_0 + sep + token_ids_1 + sep
+
+ def get_special_tokens_mask(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
+ ) -> List[int]:
+ """
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
+ special tokens using the tokenizer ``prepare_for_model`` method.
+
+ Args:
+ token_ids_0 (:obj:`List[int]`):
+ List of IDs.
+ token_ids_1 (:obj:`List[int]`, `optional`):
+ Optional second list of IDs for sequence pairs.
+ already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
+ Whether or not the token list is already formatted with special tokens for the model.
+
+ Returns:
+ :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
+ """
+ if already_has_special_tokens:
+ if token_ids_1 is not None:
+ raise ValueError(
+ "You should not supply a second sequence if the provided sequence of "
+ "ids is already formatted with special tokens for the model."
+ )
+ return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
+
+ if token_ids_1 is not None:
+ return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
+ return [1] + ([0] * len(token_ids_0)) + [1]
+
+ def create_token_type_ids_from_sequences(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+ ) -> List[int]:
+ """
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task.
+ An ALBERT sequence pair mask has the following format:
+
+ ::
+
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
+ | first sequence | second sequence |
+
+ If :obj:`token_ids_1` is :obj:`None`, this method only returns the first portion of the mask (0s).
+
+ Args:
+ token_ids_0 (:obj:`List[int]`):
+ List of IDs.
+ token_ids_1 (:obj:`List[int]`, `optional`):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ :obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
+ sequence(s).
+ """
+ sep = [self.sep_token_id]
+ cls = [self.cls_token_id]
+
+ if token_ids_1 is None:
+ return len(cls + token_ids_0 + sep) * [0]
+ return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
+
+ def save_vocabulary(self, save_directory):
+ """
+ Save the sentencepiece vocabulary (copy original file) and special tokens file to a directory.
+
+ Args:
+ save_directory (:obj:`str`):
+ The directory in which to save the vocabulary.
+
+ Returns:
+ :obj:`Tuple(str)`: Paths to the files saved.
+ """
+ if not os.path.isdir(save_directory):
+ logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
+ return
+ out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES["vocab_file"])
+
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
+ copyfile(self.vocab_file, out_vocab_file)
+
+ return (out_vocab_file,)
diff --git a/src/transformers/tokenization_auto.py b/src/transformers/tokenization_auto.py
index fcf42d26f4..f833791dce 100644
--- a/src/transformers/tokenization_auto.py
+++ b/src/transformers/tokenization_auto.py
@@ -56,14 +56,14 @@ from .configuration_auto import (
replace_list_option_in_docstrings,
)
from .configuration_utils import PretrainedConfig
-from .tokenization_albert import AlbertTokenizer
+from .tokenization_albert import AlbertTokenizer, AlbertTokenizerFast
from .tokenization_bart import BartTokenizer, BartTokenizerFast
from .tokenization_bert import BertTokenizer, BertTokenizerFast
from .tokenization_bert_generation import BertGenerationTokenizer
from .tokenization_bert_japanese import BertJapaneseTokenizer
from .tokenization_bertweet import BertweetTokenizer
from .tokenization_blenderbot import BlenderbotSmallTokenizer
-from .tokenization_camembert import CamembertTokenizer
+from .tokenization_camembert import CamembertTokenizer, CamembertTokenizerFast
from .tokenization_ctrl import CTRLTokenizer
from .tokenization_deberta import DebertaTokenizer
from .tokenization_distilbert import DistilBertTokenizer, DistilBertTokenizerFast
@@ -77,21 +77,21 @@ from .tokenization_layoutlm import LayoutLMTokenizer, LayoutLMTokenizerFast
from .tokenization_longformer import LongformerTokenizer, LongformerTokenizerFast
from .tokenization_lxmert import LxmertTokenizer, LxmertTokenizerFast
from .tokenization_marian import MarianTokenizer
-from .tokenization_mbart import MBartTokenizer
+from .tokenization_mbart import MBartTokenizer, MBartTokenizerFast
from .tokenization_mobilebert import MobileBertTokenizer, MobileBertTokenizerFast
from .tokenization_openai import OpenAIGPTTokenizer, OpenAIGPTTokenizerFast
-from .tokenization_pegasus import PegasusTokenizer
+from .tokenization_pegasus import PegasusTokenizer, PegasusTokenizerFast
from .tokenization_phobert import PhobertTokenizer
from .tokenization_rag import RagTokenizer
-from .tokenization_reformer import ReformerTokenizer
+from .tokenization_reformer import ReformerTokenizer, ReformerTokenizerFast
from .tokenization_retribert import RetriBertTokenizer, RetriBertTokenizerFast
from .tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast
from .tokenization_squeezebert import SqueezeBertTokenizer, SqueezeBertTokenizerFast
-from .tokenization_t5 import T5Tokenizer
-from .tokenization_transfo_xl import TransfoXLTokenizer, TransfoXLTokenizerFast
+from .tokenization_t5 import T5Tokenizer, T5TokenizerFast
+from .tokenization_transfo_xl import TransfoXLTokenizer
from .tokenization_xlm import XLMTokenizer
-from .tokenization_xlm_roberta import XLMRobertaTokenizer
-from .tokenization_xlnet import XLNetTokenizer
+from .tokenization_xlm_roberta import XLMRobertaTokenizer, XLMRobertaTokenizerFast
+from .tokenization_xlnet import XLNetTokenizer, XLNetTokenizerFast
from .utils import logging
@@ -101,14 +101,14 @@ logger = logging.get_logger(__name__)
TOKENIZER_MAPPING = OrderedDict(
[
(RetriBertConfig, (RetriBertTokenizer, RetriBertTokenizerFast)),
- (T5Config, (T5Tokenizer, None)),
+ (T5Config, (T5Tokenizer, T5TokenizerFast)),
(MobileBertConfig, (MobileBertTokenizer, MobileBertTokenizerFast)),
(DistilBertConfig, (DistilBertTokenizer, DistilBertTokenizerFast)),
- (AlbertConfig, (AlbertTokenizer, None)),
- (CamembertConfig, (CamembertTokenizer, None)),
- (PegasusConfig, (PegasusTokenizer, None)),
- (MBartConfig, (MBartTokenizer, None)),
- (XLMRobertaConfig, (XLMRobertaTokenizer, None)),
+ (AlbertConfig, (AlbertTokenizer, AlbertTokenizerFast)),
+ (CamembertConfig, (CamembertTokenizer, CamembertTokenizerFast)),
+ (PegasusConfig, (PegasusTokenizer, PegasusTokenizerFast)),
+ (MBartConfig, (MBartTokenizer, MBartTokenizerFast)),
+ (XLMRobertaConfig, (XLMRobertaTokenizer, XLMRobertaTokenizerFast)),
(MarianConfig, (MarianTokenizer, None)),
(BlenderbotConfig, (BlenderbotSmallTokenizer, None)),
(LongformerConfig, (LongformerTokenizer, None)),
@@ -117,7 +117,7 @@ TOKENIZER_MAPPING = OrderedDict(
(RobertaConfig, (BertweetTokenizer, None)),
(RobertaConfig, (PhobertTokenizer, None)),
(RobertaConfig, (RobertaTokenizer, RobertaTokenizerFast)),
- (ReformerConfig, (ReformerTokenizer, None)),
+ (ReformerConfig, (ReformerTokenizer, ReformerTokenizerFast)),
(ElectraConfig, (ElectraTokenizer, ElectraTokenizerFast)),
(FunnelConfig, (FunnelTokenizer, FunnelTokenizerFast)),
(LxmertConfig, (LxmertTokenizer, LxmertTokenizerFast)),
@@ -127,15 +127,14 @@ TOKENIZER_MAPPING = OrderedDict(
(BertConfig, (BertTokenizer, BertTokenizerFast)),
(OpenAIGPTConfig, (OpenAIGPTTokenizer, OpenAIGPTTokenizerFast)),
(GPT2Config, (GPT2Tokenizer, GPT2TokenizerFast)),
- (TransfoXLConfig, (TransfoXLTokenizer, TransfoXLTokenizerFast)),
- (XLNetConfig, (XLNetTokenizer, None)),
+ (TransfoXLConfig, (TransfoXLTokenizer, None)),
+ (XLNetConfig, (XLNetTokenizer, XLNetTokenizerFast)),
(FlaubertConfig, (FlaubertTokenizer, None)),
(XLMConfig, (XLMTokenizer, None)),
(CTRLConfig, (CTRLTokenizer, None)),
(FSMTConfig, (FSMTTokenizer, None)),
(BertGenerationConfig, (BertGenerationTokenizer, None)),
(DebertaConfig, (DebertaTokenizer, None)),
- (LayoutLMConfig, (LayoutLMTokenizer, None)),
(RagConfig, (RagTokenizer, None)),
]
)
diff --git a/src/transformers/tokenization_bart.py b/src/transformers/tokenization_bart.py
index 6b676b25fa..40fe7c0e9e 100644
--- a/src/transformers/tokenization_bart.py
+++ b/src/transformers/tokenization_bart.py
@@ -163,6 +163,7 @@ class BartTokenizerFast(RobertaTokenizerFast):
"vocab_file": {m: vocab_url for m in _all_bart_models},
"merges_file": {m: merges_url for m in _all_bart_models},
}
+ slow_tokenizer_class = BartTokenizer
def prepare_seq2seq_batch(
self,
diff --git a/src/transformers/tokenization_bert.py b/src/transformers/tokenization_bert.py
index 2b2aceca6f..3b620865dc 100644
--- a/src/transformers/tokenization_bert.py
+++ b/src/transformers/tokenization_bert.py
@@ -20,8 +20,6 @@ import os
import unicodedata
from typing import List, Optional
-from tokenizers import BertWordPieceTokenizer
-
from .tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace
from .tokenization_utils_fast import PreTrainedTokenizerFast
from .utils import logging
@@ -206,6 +204,10 @@ class BertTokenizer(PreTrainedTokenizer):
)
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token)
+ @property
+ def do_lower_case(self):
+ return self.basic_tokenizer.do_lower_case
+
@property
def vocab_size(self):
return len(self.vocab)
@@ -329,7 +331,7 @@ class BertTokenizer(PreTrainedTokenizer):
def save_vocabulary(self, vocab_path):
"""
- Save the vocabulary (copy original file) and special tokens file to a directory.
+ Save the vocabulary and special tokens file to a directory.
Args:
vocab_path (:obj:`str`):
@@ -610,6 +612,7 @@ class BertTokenizerFast(PreTrainedTokenizerFast):
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
+ slow_tokenizer_class = BertTokenizer
def __init__(
self,
@@ -620,31 +623,20 @@ class BertTokenizerFast(PreTrainedTokenizerFast):
pad_token="[PAD]",
cls_token="[CLS]",
mask_token="[MASK]",
- clean_text=True,
tokenize_chinese_chars=True,
strip_accents=None,
- wordpieces_prefix="##",
**kwargs
):
super().__init__(
- BertWordPieceTokenizer(
- vocab_file=vocab_file,
- unk_token=unk_token,
- sep_token=sep_token,
- cls_token=cls_token,
- pad_token=pad_token,
- mask_token=mask_token,
- clean_text=clean_text,
- handle_chinese_chars=tokenize_chinese_chars,
- strip_accents=strip_accents,
- lowercase=do_lower_case,
- wordpieces_prefix=wordpieces_prefix,
- ),
+ vocab_file,
+ do_lower_case=do_lower_case,
unk_token=unk_token,
sep_token=sep_token,
pad_token=pad_token,
cls_token=cls_token,
mask_token=mask_token,
+ tokenize_chinese_chars=tokenize_chinese_chars,
+ strip_accents=strip_accents,
**kwargs,
)
diff --git a/src/transformers/tokenization_bert_japanese.py b/src/transformers/tokenization_bert_japanese.py
index 48b5c87c31..0248e33d2e 100644
--- a/src/transformers/tokenization_bert_japanese.py
+++ b/src/transformers/tokenization_bert_japanese.py
@@ -16,6 +16,7 @@
import collections
+import copy
import os
import unicodedata
from typing import Optional
@@ -116,6 +117,13 @@ class BertJapaneseTokenizer(BertTokenizer):
pad_token=pad_token,
cls_token=cls_token,
mask_token=mask_token,
+ do_lower_case=do_lower_case,
+ do_word_tokenize=do_word_tokenize,
+ do_subword_tokenize=do_subword_tokenize,
+ word_tokenizer_type=word_tokenizer_type,
+ subword_tokenizer_type=subword_tokenizer_type,
+ never_split=never_split,
+ mecab_kwargs=mecab_kwargs,
**kwargs,
)
# ^^ We call the grandparent's init, not the parent's.
@@ -129,6 +137,10 @@ class BertJapaneseTokenizer(BertTokenizer):
self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
self.do_word_tokenize = do_word_tokenize
+ self.word_tokenizer_type = word_tokenizer_type
+ self.lower_case = do_lower_case
+ self.never_split = never_split
+ self.mecab_kwargs = copy.deepcopy(mecab_kwargs)
if do_word_tokenize:
if word_tokenizer_type == "basic":
self.word_tokenizer = BasicTokenizer(
@@ -142,6 +154,7 @@ class BertJapaneseTokenizer(BertTokenizer):
raise ValueError("Invalid word_tokenizer_type '{}' is specified.".format(word_tokenizer_type))
self.do_subword_tokenize = do_subword_tokenize
+ self.subword_tokenizer_type = subword_tokenizer_type
if do_subword_tokenize:
if subword_tokenizer_type == "wordpiece":
self.subword_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token)
@@ -150,6 +163,23 @@ class BertJapaneseTokenizer(BertTokenizer):
else:
raise ValueError("Invalid subword_tokenizer_type '{}' is specified.".format(subword_tokenizer_type))
+ @property
+ def do_lower_case(self):
+ return self.lower_case
+
+ def __getstate__(self):
+ state = dict(self.__dict__)
+ if self.word_tokenizer_type == "mecab":
+ del state["word_tokenizer"]
+ return state
+
+ def __setstate__(self, state):
+ self.__dict__ = state
+ if self.word_tokenizer_type == "mecab":
+ self.word_tokenizer = MecabTokenizer(
+ do_lower_case=self.do_lower_case, never_split=self.never_split, **(self.mecab_kwargs or {})
+ )
+
def _tokenize(self, text):
if self.do_word_tokenize:
tokens = self.word_tokenizer.tokenize(text, never_split=self.all_special_tokens)
diff --git a/src/transformers/tokenization_bertweet.py b/src/transformers/tokenization_bertweet.py
index 3c30c0d40a..b5cd4faaf2 100644
--- a/src/transformers/tokenization_bertweet.py
+++ b/src/transformers/tokenization_bertweet.py
@@ -129,7 +129,6 @@ class BertweetTokenizer(PreTrainedTokenizer):
**kwargs
):
super().__init__(
- max_len=128,
bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,
diff --git a/src/transformers/tokenization_camembert.py b/src/transformers/tokenization_camembert.py
index c23758f74b..2726ce1e16 100644
--- a/src/transformers/tokenization_camembert.py
+++ b/src/transformers/tokenization_camembert.py
@@ -22,6 +22,7 @@ from typing import List, Optional
import sentencepiece as spm
from .tokenization_utils import PreTrainedTokenizer
+from .tokenization_utils_fast import PreTrainedTokenizerFast
from .utils import logging
@@ -36,7 +37,7 @@ PRETRAINED_VOCAB_FILES_MAP = {
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
- "camembert-base": None,
+ "camembert-base": 512,
}
SHARED_MODEL_IDENTIFIERS = [
@@ -118,7 +119,6 @@ class CamembertTokenizer(PreTrainedTokenizer):
**kwargs
):
super().__init__(
- max_len=512,
bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,
@@ -223,6 +223,11 @@ class CamembertTokenizer(PreTrainedTokenizer):
def vocab_size(self):
return len(self.fairseq_tokens_to_ids) + len(self.sp_model)
+ def get_vocab(self):
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
+ vocab.update(self.added_tokens_encoder)
+ return vocab
+
def _tokenize(self, text):
return self.sp_model.EncodeAsPieces(text)
@@ -284,3 +289,189 @@ class CamembertTokenizer(PreTrainedTokenizer):
copyfile(self.vocab_file, out_vocab_file)
return (out_vocab_file,)
+
+
+class CamembertTokenizerFast(PreTrainedTokenizerFast):
+ """
+ Construct a "fast" CamemBERT tokenizer (backed by HuggingFace's `tokenizers` library). Adapted from
+ :class:`~transformers.RobertaTokenizer` and :class:`~transformers.XLNetTokenizer`. Based on `SentencePiece
+ `__.
+
+ This tokenizer inherits from :class:`~transformers.PreTrainedTokenizerFast` which contains most of the main
+ methods. Users should refer to this superclass for more information regarding those methods.
+
+ vocab_file (:obj:`str`):
+ `SentencePiece `__ file (generally has a `.spm` extension) that
+ contains the vocabulary necessary to instantiate a tokenizer.
+ bos_token (:obj:`str`, `optional`, defaults to :obj:`""`):
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
+
+ .. note::
+
+ When building a sequence using special tokens, this is not the token that is used for the beginning
+ of sequence. The token used is the :obj:`cls_token`.
+ eos_token (:obj:`str`, `optional`, defaults to :obj:`""`):
+ The end of sequence token.
+
+ .. note::
+
+ When building a sequence using special tokens, this is not the token that is used for the end
+ of sequence. The token used is the :obj:`sep_token`.
+ sep_token (:obj:`str`, `optional`, defaults to :obj:`""`):
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences
+ for sequence classification or for a text and a question for question answering.
+ It is also used as the last token of a sequence built with special tokens.
+ cls_token (:obj:`str`, `optional`, defaults to :obj:`""`):
+ The classifier token which is used when doing sequence classification (classification of the whole
+ sequence instead of per-token classification). It is the first token of the sequence when built with
+ special tokens.
+ unk_token (:obj:`str`, `optional`, defaults to :obj:`""`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ pad_token (:obj:`str`, `optional`, defaults to :obj:`""`):
+ The token used for padding, for example when batching sequences of different lengths.
+ mask_token (:obj:`str`, `optional`, defaults to :obj:`""`):
+ The token used for masking values. This is the token used when training this model with masked language
+ modeling. This is the token which the model will try to predict.
+ additional_special_tokens (:obj:`List[str]`, `optional`, defaults to :obj:`["NOTUSED", "NOTUSED"]`):
+ Additional special tokens used by the tokenizer.
+
+ Attributes:
+ sp_model (:obj:`SentencePieceProcessor`):
+ The `SentencePiece` processor that is used for every conversion (string, tokens and IDs).
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
+ max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
+ model_input_names = ["attention_mask"]
+ slow_tokenizer_class = CamembertTokenizer
+
+ def __init__(
+ self,
+ vocab_file,
+ bos_token="",
+ eos_token="",
+ sep_token="",
+ cls_token="",
+ unk_token="",
+ pad_token="",
+ mask_token="",
+ additional_special_tokens=["NOTUSED", "NOTUSED"],
+ **kwargs
+ ):
+ super().__init__(
+ vocab_file,
+ bos_token=bos_token,
+ eos_token=eos_token,
+ sep_token=sep_token,
+ cls_token=cls_token,
+ unk_token=unk_token,
+ pad_token=pad_token,
+ mask_token=mask_token,
+ additional_special_tokens=additional_special_tokens,
+ **kwargs,
+ )
+
+ self.vocab_file = vocab_file
+
+ def build_inputs_with_special_tokens(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+ ) -> List[int]:
+ """
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks
+ by concatenating and adding special tokens.
+ An CamemBERT sequence has the following format:
+
+ - single sequence: `` X ``
+ - pair of sequences: `` A B ``
+
+ Args:
+ token_ids_0 (:obj:`List[int]`):
+ List of IDs to which the special tokens will be added.
+ token_ids_1 (:obj:`List[int]`, `optional`):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ :obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
+ """
+
+ if token_ids_1 is None:
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
+ cls = [self.cls_token_id]
+ sep = [self.sep_token_id]
+ return cls + token_ids_0 + sep + sep + token_ids_1 + sep
+
+ def get_special_tokens_mask(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
+ ) -> List[int]:
+ """
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
+ special tokens using the tokenizer ``prepare_for_model`` method.
+
+ Args:
+ token_ids_0 (:obj:`List[int]`):
+ List of IDs.
+ token_ids_1 (:obj:`List[int]`, `optional`):
+ Optional second list of IDs for sequence pairs.
+ already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
+ Whether or not the token list is already formatted with special tokens for the model.
+
+ Returns:
+ :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
+ """
+ if already_has_special_tokens:
+ if token_ids_1 is not None:
+ raise ValueError(
+ "You should not supply a second sequence if the provided sequence of "
+ "ids is already formated with special tokens for the model."
+ )
+ return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
+
+ if token_ids_1 is None:
+ return [1] + ([0] * len(token_ids_0)) + [1]
+ return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]
+
+ def create_token_type_ids_from_sequences(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+ ) -> List[int]:
+ """
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task.
+ CamemBERT, like RoBERTa, does not make use of token type ids, therefore a list of zeros is returned.
+
+ Args:
+ token_ids_0 (:obj:`List[int]`):
+ List of IDs.
+ token_ids_1 (:obj:`List[int]`, `optional`):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ :obj:`List[int]`: List of zeros.
+ """
+ sep = [self.sep_token_id]
+ cls = [self.cls_token_id]
+
+ if token_ids_1 is None:
+ return len(cls + token_ids_0 + sep) * [0]
+ return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
+
+ def save_vocabulary(self, save_directory):
+ """
+ Save the sentencepiece vocabulary (copy original file) and special tokens file to a directory.
+
+ Args:
+ save_directory (:obj:`str`):
+ The directory in which to save the vocabulary.
+
+ Returns:
+ :obj:`Tuple(str)`: Paths to the files saved.
+ """
+ if not os.path.isdir(save_directory):
+ logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
+ return
+ out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES["vocab_file"])
+
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
+ copyfile(self.vocab_file, out_vocab_file)
+
+ return (out_vocab_file,)
diff --git a/src/transformers/tokenization_distilbert.py b/src/transformers/tokenization_distilbert.py
index 9de887328a..1ab8cf3009 100644
--- a/src/transformers/tokenization_distilbert.py
+++ b/src/transformers/tokenization_distilbert.py
@@ -87,3 +87,4 @@ class DistilBertTokenizerFast(BertTokenizerFast):
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
model_input_names = ["attention_mask"]
+ slow_tokenizer_class = DistilBertTokenizer
diff --git a/src/transformers/tokenization_dpr.py b/src/transformers/tokenization_dpr.py
index e645ccd1c0..bf40bc53c8 100644
--- a/src/transformers/tokenization_dpr.py
+++ b/src/transformers/tokenization_dpr.py
@@ -98,6 +98,7 @@ class DPRContextEncoderTokenizerFast(BertTokenizerFast):
pretrained_vocab_files_map = CONTEXT_ENCODER_PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = CONTEXT_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
pretrained_init_configuration = CONTEXT_ENCODER_PRETRAINED_INIT_CONFIGURATION
+ slow_tokenizer_class = DPRContextEncoderTokenizer
class DPRQuestionEncoderTokenizer(BertTokenizer):
@@ -132,6 +133,7 @@ class DPRQuestionEncoderTokenizerFast(BertTokenizerFast):
pretrained_vocab_files_map = QUESTION_ENCODER_PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = QUESTION_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
pretrained_init_configuration = QUESTION_ENCODER_PRETRAINED_INIT_CONFIGURATION
+ slow_tokenizer_class = DPRQuestionEncoderTokenizer
DPRSpanPrediction = collections.namedtuple(
@@ -417,3 +419,4 @@ class DPRReaderTokenizerFast(CustomDPRReaderTokenizerMixin, BertTokenizerFast):
max_model_input_sizes = READER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
pretrained_init_configuration = READER_PRETRAINED_INIT_CONFIGURATION
model_input_names = ["attention_mask"]
+ slow_tokenizer_class = DPRReaderTokenizer
diff --git a/src/transformers/tokenization_electra.py b/src/transformers/tokenization_electra.py
index 1184a0914a..30608ae04c 100644
--- a/src/transformers/tokenization_electra.py
+++ b/src/transformers/tokenization_electra.py
@@ -80,3 +80,4 @@ class ElectraTokenizerFast(BertTokenizerFast):
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
+ slow_tokenizer_class = ElectraTokenizer
diff --git a/src/transformers/tokenization_fsmt.py b/src/transformers/tokenization_fsmt.py
index e5e095ee8c..05ce582dff 100644
--- a/src/transformers/tokenization_fsmt.py
+++ b/src/transformers/tokenization_fsmt.py
@@ -181,6 +181,7 @@ class FSMTTokenizer(PreTrainedTokenizer):
**kwargs
):
super().__init__(
+ langs=langs,
unk_token=unk_token,
bos_token=bos_token,
sep_token=sep_token,
diff --git a/src/transformers/tokenization_funnel.py b/src/transformers/tokenization_funnel.py
index 96084933e3..48c768f59b 100644
--- a/src/transformers/tokenization_funnel.py
+++ b/src/transformers/tokenization_funnel.py
@@ -152,6 +152,7 @@ class FunnelTokenizerFast(BertTokenizerFast):
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
+ slow_tokenizer_class = FunnelTokenizer
cls_token_type_id: int = 2
def __init__(
@@ -217,16 +218,3 @@ class FunnelTokenizerFast(BertTokenizerFast):
if token_ids_1 is None:
return len(cls) * [self.cls_token_type_id] + len(token_ids_0 + sep) * [0]
return len(cls) * [self.cls_token_type_id] + len(token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
-
- def _convert_encoding(self, encoding, **kwargs):
- # The fast tokenizer doesn't use the function above so we fix the cls token type id when decoding the fast
- # tokenzier output.
- encoding_dict = super()._convert_encoding(encoding, **kwargs)
- if "token_type_ids" in encoding_dict:
- # Note: we can't assume the token is in first position because left padding is a thing, hence the
- # double list comprehension.
- encoding_dict["token_type_ids"] = [
- [self.cls_token_type_id if i == self.cls_token_id else t for i, t in zip(input_ids, type_ids)]
- for input_ids, type_ids in zip(encoding_dict["input_ids"], encoding_dict["token_type_ids"])
- ]
- return encoding_dict
diff --git a/src/transformers/tokenization_gpt2.py b/src/transformers/tokenization_gpt2.py
index ee586050a9..deae5ea66f 100644
--- a/src/transformers/tokenization_gpt2.py
+++ b/src/transformers/tokenization_gpt2.py
@@ -21,7 +21,6 @@ import warnings
from functools import lru_cache
import regex as re
-from tokenizers import ByteLevelBPETokenizer
from .tokenization_utils import AddedToken, PreTrainedTokenizer
from .tokenization_utils_base import BatchEncoding
@@ -360,6 +359,7 @@ class GPT2TokenizerFast(PreTrainedTokenizerFast):
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
model_input_names = ["attention_mask"]
+ slow_tokenizer_class = GPT2Tokenizer
def __init__(
self,
@@ -369,19 +369,15 @@ class GPT2TokenizerFast(PreTrainedTokenizerFast):
bos_token="<|endoftext|>",
eos_token="<|endoftext|>",
add_prefix_space=False,
- trim_offsets=True,
**kwargs
):
super().__init__(
- ByteLevelBPETokenizer(
- vocab_file=vocab_file,
- merges_file=merges_file,
- add_prefix_space=add_prefix_space,
- trim_offsets=trim_offsets,
- ),
+ vocab_file,
+ merges_file,
+ unk_token=unk_token,
bos_token=bos_token,
eos_token=eos_token,
- unk_token=unk_token,
+ add_prefix_space=add_prefix_space,
**kwargs,
)
self.add_prefix_space = add_prefix_space
@@ -409,8 +405,9 @@ class GPT2TokenizerFast(PreTrainedTokenizerFast):
FutureWarning,
)
is_split_into_words = kwargs.pop("is_pretokenized")
+ else:
+ is_split_into_words = kwargs.get("is_split_into_words", False)
- is_split_into_words = kwargs.get("is_split_into_words", False)
assert self.add_prefix_space or not is_split_into_words, (
f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True "
"to use it with pretokenized inputs."
diff --git a/src/transformers/tokenization_longformer.py b/src/transformers/tokenization_longformer.py
index 4d2ee634f6..5c2718e0da 100644
--- a/src/transformers/tokenization_longformer.py
+++ b/src/transformers/tokenization_longformer.py
@@ -69,3 +69,4 @@ class LongformerTokenizerFast(RobertaTokenizerFast):
"vocab_file": {m: vocab_url for m in _all_longformer_models},
"merges_file": {m: merges_url for m in _all_longformer_models},
}
+ slow_tokenizer_class = LongformerTokenizer
diff --git a/src/transformers/tokenization_lxmert.py b/src/transformers/tokenization_lxmert.py
index f85c9d124e..163684d9e9 100644
--- a/src/transformers/tokenization_lxmert.py
+++ b/src/transformers/tokenization_lxmert.py
@@ -79,3 +79,4 @@ class LxmertTokenizerFast(BertTokenizerFast):
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
+ slow_tokenizer_class = LxmertTokenizer
diff --git a/src/transformers/tokenization_mbart.py b/src/transformers/tokenization_mbart.py
index 0d65e64c1d..2f13279fe2 100644
--- a/src/transformers/tokenization_mbart.py
+++ b/src/transformers/tokenization_mbart.py
@@ -15,10 +15,12 @@
from typing import List, Optional
+from tokenizers import processors
+
from .file_utils import add_start_docstrings
from .tokenization_utils import BatchEncoding
from .tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING
-from .tokenization_xlm_roberta import XLMRobertaTokenizer
+from .tokenization_xlm_roberta import XLMRobertaTokenizer, XLMRobertaTokenizerFast
from .utils import logging
@@ -109,6 +111,10 @@ class MBartTokenizer(XLMRobertaTokenizer):
self._additional_special_tokens = list(self.lang_code_to_id.keys())
self.set_src_lang_special_tokens(kwargs.get("src_lang", "en_XX"))
+ @property
+ def vocab_size(self):
+ return len(self.sp_model) + len(self.lang_code_to_id) + self.fairseq_offset + 1 # Plus 1 for the mask token
+
def get_special_tokens_mask(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
) -> List[int]:
@@ -227,3 +233,185 @@ class MBartTokenizer(XLMRobertaTokenizer):
self.cur_lang_code = self.lang_code_to_id[lang]
self.prefix_tokens = []
self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
+
+
+class MBartTokenizerFast(XLMRobertaTokenizerFast):
+ """
+ Construct a "fast" MBART tokenizer (backed by HuggingFace's `tokenizers` library).
+
+ :class:`~transformers.MBartTokenizerFast` is a subclass of :class:`~transformers.XLMRobertaTokenizerFast` and adds
+ a new :meth:`~transformers.MBartTokenizerFast.prepare_seq2seq_batch`.
+
+ Refer to superclass :class:`~transformers.XLMRobertaTokenizerFast` for usage examples and documentation concerning
+ the initialization parameters and other methods.
+
+ .. warning::
+ ``prepare_seq2seq_batch`` should be used to encode inputs. Other tokenizer methods like ``encode`` do not work
+ properly.
+
+ The tokenization method is `` `` for source language documents, and
+ `` ``` for target language documents.
+
+ Examples::
+
+ >>> from transformers import MBartTokenizerFast
+ >>> tokenizer = MBartTokenizerFast.from_pretrained('facebook/mbart-large-en-ro')
+ >>> example_english_phrase = " UN Chief Says There Is No Military Solution in Syria"
+ >>> expected_translation_romanian = "Şeful ONU declară că nu există o soluţie militară în Siria"
+ >>> batch: dict = tokenizer.prepare_seq2seq_batch(
+ ... example_english_phrase, src_lang="en_XX", tgt_lang="ro_RO", tgt_texts=expected_translation_romanian
+ ... )
+ """
+
+ vocab_files_names = {"vocab_file": "sentencepiece.bpe.model"}
+ max_model_input_sizes = {m: 1024 for m in _all_mbart_models}
+ pretrained_vocab_files_map = {"vocab_file": {m: SPM_URL for m in _all_mbart_models}}
+ slow_tokenizer_class = MBartTokenizer
+
+ prefix_tokens: List[int] = []
+ suffix_tokens: List[int] = []
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ self.cur_lang_code = self.convert_tokens_to_ids("en_XX")
+ self.set_src_lang_special_tokens(kwargs.get("src_lang", "en_XX"))
+
+ self.add_special_tokens({"additional_special_tokens": FAIRSEQ_LANGUAGE_CODES})
+
+ def get_special_tokens_mask(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
+ ) -> List[int]:
+ """
+ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
+ special tokens using the tokenizer ``prepare_for_model`` method.
+
+ Args:
+ token_ids_0 (:obj:`List[int]`):
+ List of ids.
+ token_ids_1 (:obj:`List[int]`, `optional`):
+ Optional second list of IDs for sequence pairs.
+ already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
+ Whether or not the token list is already formatted with special tokens for the model.
+
+ Returns:
+ :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
+ """
+
+ if already_has_special_tokens:
+ if token_ids_1 is not None:
+ raise ValueError(
+ "You should not supply a second sequence if the provided sequence of "
+ "ids is already formated with special tokens for the model."
+ )
+ return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
+ prefix_ones = [1] * len(self.prefix_tokens)
+ suffix_ones = [1] * len(self.suffix_tokens)
+ if token_ids_1 is None:
+ return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones
+ return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones
+
+ def build_inputs_with_special_tokens(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+ ) -> List[int]:
+ """
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks
+ by concatenating and adding special tokens. The special tokens depend on calling set_lang.
+
+ An MBART sequence has the following format, where ``X`` represents the sequence:
+
+ - ``input_ids`` (for encoder) ``X [eos, src_lang_code]``
+ - ``decoder_input_ids``: (for decoder) ``[tgt_lang_code] X [eos]``
+
+ BOS is never used.
+ Pairs of sequences are not the expected use case, but they will be handled without a separator.
+
+ Args:
+ token_ids_0 (:obj:`List[int]`):
+ List of IDs to which the special tokens will be added.
+ token_ids_1 (:obj:`List[int]`, `optional`):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ :obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
+ """
+ if token_ids_1 is None:
+ return self.prefix_tokens + token_ids_0 + self.suffix_tokens
+ # We don't expect to process pairs, but leave the pair logic for API consistency
+ return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
+
+ @add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
+ def prepare_seq2seq_batch(
+ self,
+ src_texts: List[str],
+ src_lang: str = "en_XX",
+ tgt_texts: Optional[List[str]] = None,
+ tgt_lang: str = "ro_RO",
+ max_length: Optional[int] = None,
+ max_target_length: Optional[int] = None,
+ truncation: bool = True,
+ padding: str = "longest",
+ return_tensors: str = "pt",
+ **kwargs,
+ ) -> BatchEncoding:
+ if max_length is None:
+ max_length = self.max_len
+ self.set_src_lang_special_tokens(src_lang)
+ model_inputs: BatchEncoding = self(
+ src_texts,
+ add_special_tokens=True,
+ return_tensors=return_tensors,
+ max_length=max_length,
+ padding=padding,
+ truncation=truncation,
+ **kwargs,
+ )
+ if tgt_texts is None:
+ return model_inputs
+ # Process tgt_texts
+ if max_target_length is None:
+ max_target_length = max_length
+ self.set_tgt_lang_special_tokens(tgt_lang)
+
+ labels = self(
+ tgt_texts,
+ add_special_tokens=True,
+ return_tensors=return_tensors,
+ padding=padding,
+ max_length=max_target_length,
+ truncation=True,
+ **kwargs,
+ )["input_ids"]
+ model_inputs["labels"] = labels
+ self.set_src_lang_special_tokens(src_lang) # sets to src_lang
+ return model_inputs
+
+ def set_src_lang_special_tokens(self, src_lang) -> None:
+ """Reset the special tokens to the source lang setting. No prefix and suffix=[eos, cur_lang_code]."""
+ self.cur_lang_code = self.convert_tokens_to_ids(src_lang)
+ self.prefix_tokens = []
+ self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
+
+ prefix_tokens_str = self.convert_ids_to_tokens(self.prefix_tokens)
+ suffix_tokens_str = self.convert_ids_to_tokens(self.suffix_tokens)
+
+ self._tokenizer.post_processor = processors.TemplateProcessing(
+ single=prefix_tokens_str + ["$A"] + suffix_tokens_str,
+ pair=prefix_tokens_str + ["$A", "$B"] + suffix_tokens_str,
+ special_tokens=list(zip(prefix_tokens_str + suffix_tokens_str, self.prefix_tokens + self.suffix_tokens)),
+ )
+
+ def set_tgt_lang_special_tokens(self, lang: str) -> None:
+ """Reset the special tokens to the target language setting. Prefix [tgt_lang_code], suffix =[eos]."""
+ self.cur_lang_code = self.convert_tokens_to_ids(lang)
+ self.prefix_tokens = []
+ self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
+
+ prefix_tokens_str = self.convert_ids_to_tokens(self.prefix_tokens)
+ suffix_tokens_str = self.convert_ids_to_tokens(self.suffix_tokens)
+
+ self._tokenizer.post_processor = processors.TemplateProcessing(
+ single=prefix_tokens_str + ["$A"] + suffix_tokens_str,
+ pair=prefix_tokens_str + ["$A", "$B"] + suffix_tokens_str,
+ special_tokens=list(zip(prefix_tokens_str + suffix_tokens_str, self.prefix_tokens + self.suffix_tokens)),
+ )
diff --git a/src/transformers/tokenization_mobilebert.py b/src/transformers/tokenization_mobilebert.py
index 72c0c1ec7f..44874b8c23 100644
--- a/src/transformers/tokenization_mobilebert.py
+++ b/src/transformers/tokenization_mobilebert.py
@@ -65,3 +65,4 @@ class MobileBertTokenizerFast(BertTokenizerFast):
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
+ slow_tokenizer_class = MobileBertTokenizer
diff --git a/src/transformers/tokenization_openai.py b/src/transformers/tokenization_openai.py
index 7106030d62..d03ecfb3d0 100644
--- a/src/transformers/tokenization_openai.py
+++ b/src/transformers/tokenization_openai.py
@@ -19,8 +19,6 @@ import json
import os
import re
-from tokenizers import CharBPETokenizer
-
from .tokenization_bert import BasicTokenizer
from .tokenization_utils import PreTrainedTokenizer
from .tokenization_utils_fast import PreTrainedTokenizerFast
@@ -123,6 +121,10 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer):
self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {}
+ @property
+ def do_lower_case(self):
+ return True
+
@property
def vocab_size(self):
return len(self.encoder)
@@ -243,9 +245,8 @@ class OpenAIGPTTokenizerFast(PreTrainedTokenizerFast):
Construct a "fast" GPT Tokenizer (backed by HuggingFace's `tokenizers` library). Based on Byte-Pair-Encoding with
the following peculiarities:
- - lowercases all inputs,
- - uses :obj:`SpaCy` tokenizer and :obj:`ftfy` for pre-BPE tokenization if they are installed, fallback to BERT's
- :obj:`BasicTokenizer` if not.
+ - lower case all inputs
+ - uses BERT's BasicTokenizer for pre-BPE tokenization
This tokenizer inherits from :class:`~transformers.PreTrainedTokenizerFast` which contains most of the main
methods. Users should refer to this superclass for more information regarding those methods.
@@ -264,10 +265,11 @@ class OpenAIGPTTokenizerFast(PreTrainedTokenizerFast):
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
model_input_names = ["attention_mask"]
+ slow_tokenizer_class = OpenAIGPTTokenizer
def __init__(self, vocab_file, merges_file, unk_token="", **kwargs):
- kwargs.setdefault("unk_token", unk_token)
- super().__init__(
- CharBPETokenizer(vocab_file=vocab_file, merges_file=merges_file, unk_token=unk_token, lowercase=True),
- **kwargs,
- )
+ super().__init__(vocab_file, merges_file, unk_token=unk_token, **kwargs)
+
+ @property
+ def do_lower_case(self):
+ return True
diff --git a/src/transformers/tokenization_pegasus.py b/src/transformers/tokenization_pegasus.py
index e3aeb17461..346bcdb58c 100644
--- a/src/transformers/tokenization_pegasus.py
+++ b/src/transformers/tokenization_pegasus.py
@@ -15,10 +15,23 @@
from typing import Dict, List, Optional
from .file_utils import add_start_docstrings
-from .tokenization_reformer import ReformerTokenizer
+from .tokenization_reformer import ReformerTokenizer, ReformerTokenizerFast
from .tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING, BatchEncoding
+SPIECE_UNDERLINE = "▁"
+
+VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"}
+
+PRETRAINED_VOCAB_FILES_MAP = {
+ "vocab_file": {"google/pegasus-xsum": "https://cdn.huggingface.co/google/pegasus-xsum/spiece.model"}
+}
+
+PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
+ "google/pegasus-xsum": 512,
+}
+
+
class PegasusTokenizer(ReformerTokenizer):
r"""
Construct a Pegasus tokenizer.
@@ -31,6 +44,8 @@ class PegasusTokenizer(ReformerTokenizer):
"""
offset = 103 # entries 2-104 are only used for pretraining
vocab_files_names = {"vocab_file": "spiece.model"}
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
+ max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -150,3 +165,85 @@ class PegasusTokenizer(ReformerTokenizer):
# for k, v in decoder_inputs.items():
# model_inputs[f"decoder_{k}"] = v
return model_inputs
+
+
+class PegasusTokenizerFast(ReformerTokenizerFast):
+ offset = 103 # entries 2-104 are only used for pretraining
+ vocab_files_names = {"vocab_file": "spiece.model"}
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
+ max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
+ slow_tokenizer_class = PegasusTokenizer
+
+ # def num_special_tokens_to_add(self, pair=False):
+ # """Just EOS"""
+ # return 1
+
+ def _special_token_mask(self, seq):
+ all_special_ids = set(self.all_special_ids) # call it once instead of inside list comp
+ all_special_ids.remove(self.unk_token_id) # is only sometimes special
+ assert all_special_ids == set([0, 1])
+ return [1 if x in all_special_ids else 0 for x in seq]
+
+ def get_special_tokens_mask(
+ self, token_ids_0: List, token_ids_1: Optional[List] = None, already_has_special_tokens: bool = False
+ ) -> List[int]:
+ """Get list where entries are [1] if a token is [eos] or [pad] else 0."""
+ if already_has_special_tokens:
+ return self._special_token_mask(token_ids_0)
+ elif token_ids_1 is None:
+ return self._special_token_mask(token_ids_0) + [1]
+ else:
+ return self._special_token_mask(token_ids_0 + token_ids_1) + [1]
+
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]:
+ """
+ Build model inputs from a sequence by adding eos to the end. no bos token is added to the front.
+ - single sequence: ``X ``
+ - pair of sequences: ``A B `` (not intended use)
+
+ Args:
+ token_ids_0 (:obj:`List[int]`):
+ List of IDs to which the special tokens will be added
+ token_ids_1 (:obj:`List[int]`, `optional`):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ :obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
+ """
+ if token_ids_1 is None:
+ return token_ids_0 + [self.eos_token_id]
+ # We don't expect to process pairs, but leave the pair logic for API consistency
+ return token_ids_0 + token_ids_1 + [self.eos_token_id]
+
+ @add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
+ def prepare_seq2seq_batch(
+ self,
+ src_texts: List[str],
+ tgt_texts: Optional[List[str]] = None,
+ max_length: Optional[int] = None,
+ max_target_length: Optional[int] = None,
+ return_tensors: str = "pt",
+ truncation=True,
+ padding="longest",
+ **unused,
+ ) -> BatchEncoding:
+ if "" in src_texts:
+ raise ValueError(f"found empty string in src_texts: {src_texts}")
+ tokenizer_kwargs = dict(
+ add_special_tokens=True,
+ return_tensors=return_tensors,
+ max_length=max_length,
+ truncation=truncation,
+ padding=padding,
+ )
+ model_inputs: BatchEncoding = self(src_texts, **tokenizer_kwargs)
+ if tgt_texts is None:
+ return model_inputs
+ if max_target_length is not None:
+ tokenizer_kwargs["max_length"] = max_target_length
+ # TODO(@sshleifer): maybe tgt_texts = [self.pad_token + t for t in tgt_texts] # add decoder_start_token_id
+ labels: BatchEncoding = self(tgt_texts, **tokenizer_kwargs)["input_ids"]
+ model_inputs["labels"] = labels
+ # for k, v in decoder_inputs.items():
+ # model_inputs[f"decoder_{k}"] = v
+ return model_inputs
diff --git a/src/transformers/tokenization_phobert.py b/src/transformers/tokenization_phobert.py
index cb7326ca32..b09fbd1ba3 100644
--- a/src/transformers/tokenization_phobert.py
+++ b/src/transformers/tokenization_phobert.py
@@ -126,7 +126,6 @@ class PhobertTokenizer(PreTrainedTokenizer):
**kwargs
):
super().__init__(
- max_len=256,
bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,
diff --git a/src/transformers/tokenization_reformer.py b/src/transformers/tokenization_reformer.py
index e416d30921..017e4a3465 100644
--- a/src/transformers/tokenization_reformer.py
+++ b/src/transformers/tokenization_reformer.py
@@ -19,6 +19,7 @@ import os
from shutil import copyfile
from .tokenization_utils import PreTrainedTokenizer
+from .tokenization_utils_fast import PreTrainedTokenizerFast
from .utils import logging
@@ -184,3 +185,72 @@ class ReformerTokenizer(PreTrainedTokenizer):
copyfile(self.vocab_file, out_vocab_file)
return (out_vocab_file,)
+
+
+class ReformerTokenizerFast(PreTrainedTokenizerFast):
+ """
+ Construct a "fast" Reformer tokenizer (backed by HuggingFace's `tokenizers` library). Based on `SentencePiece
+ `__ .
+
+ This tokenizer inherits from :class:`~transformers.PreTrainedTokenizerFast` which contains most of the main
+ methods. Users should refer to this superclass for more information regarding those methods.
+
+ Args:
+ vocab_file (:obj:`str`):
+ `SentencePiece `__ file (generally has a `.spm` extension) that
+ contains the vocabulary necessary to instantiate a tokenizer.
+ eos_token (:obj:`str`, `optional`, defaults to :obj:`""`):
+ The end of sequence token.
+
+ .. note::
+
+ When building a sequence using special tokens, this is not the token that is used for the end
+ of sequence. The token used is the :obj:`sep_token`.
+ unk_token (:obj:`str`, `optional`, defaults to :obj:`""`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ pad_token (:obj:`str`, `optional`, defaults to :obj:`""`):
+ The token used for padding, for example when batching sequences of different lengths.
+ additional_special_tokens (:obj:`List[str]`, `optional`):
+ Additional special tokens used by the tokenizer.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
+ max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
+ model_input_names = ["attention_mask"]
+ slow_tokenizer_class = ReformerTokenizer
+
+ def __init__(
+ self,
+ vocab_file,
+ eos_token="",
+ unk_token="",
+ pad_token="",
+ additional_special_tokens=[],
+ **kwargs
+ ):
+ super().__init__(
+ vocab_file,
+ eos_token=eos_token,
+ unk_token=unk_token,
+ pad_token=pad_token,
+ additional_special_tokens=additional_special_tokens,
+ **kwargs,
+ )
+
+ self.vocab_file = vocab_file
+
+ def save_vocabulary(self, save_directory):
+ """Save the sentencepiece vocabulary (copy original file) and special tokens file
+ to a directory.
+ """
+ if not os.path.isdir(save_directory):
+ logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
+ return
+ out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES["vocab_file"])
+
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
+ copyfile(self.vocab_file, out_vocab_file)
+
+ return (out_vocab_file,)
diff --git a/src/transformers/tokenization_retribert.py b/src/transformers/tokenization_retribert.py
index 15bdad3a25..58c3722d76 100644
--- a/src/transformers/tokenization_retribert.py
+++ b/src/transformers/tokenization_retribert.py
@@ -71,4 +71,5 @@ class RetriBertTokenizerFast(BertTokenizerFast):
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
+ slow_tokenizer_class = RetriBertTokenizer
model_input_names = ["attention_mask"]
diff --git a/src/transformers/tokenization_roberta.py b/src/transformers/tokenization_roberta.py
index 3aa312fd9b..4b00996414 100644
--- a/src/transformers/tokenization_roberta.py
+++ b/src/transformers/tokenization_roberta.py
@@ -17,8 +17,6 @@
import warnings
from typing import List, Optional
-from tokenizers.processors import RobertaProcessing
-
from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast
from .tokenization_utils import AddedToken
from .utils import logging
@@ -344,6 +342,7 @@ class RobertaTokenizerFast(GPT2TokenizerFast):
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
model_input_names = ["attention_mask"]
+ slow_tokenizer_class = RobertaTokenizer
def __init__(
self,
@@ -358,38 +357,23 @@ class RobertaTokenizerFast(GPT2TokenizerFast):
pad_token="",
mask_token="",
add_prefix_space=False,
- trim_offsets=True,
**kwargs
):
- # Mask token behave like a normal word, i.e. include the space before it
- mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
-
- kwargs.setdefault("pad_token", pad_token)
- kwargs.setdefault("sep_token", sep_token)
- kwargs.setdefault("cls_token", cls_token)
- kwargs.setdefault("mask_token", mask_token)
-
super().__init__(
- vocab_file=vocab_file,
- merges_file=merges_file,
- unk_token=unk_token,
+ vocab_file,
+ merges_file,
+ errors=errors,
bos_token=bos_token,
eos_token=eos_token,
+ sep_token=sep_token,
+ cls_token=cls_token,
+ unk_token=unk_token,
+ pad_token=pad_token,
+ mask_token=mask_token,
add_prefix_space=add_prefix_space,
- trim_offsets=trim_offsets,
**kwargs,
)
- # This will add the necessary special tokens to the vocabulary if needed
- self.sanitize_special_tokens()
-
- self.backend_tokenizer._tokenizer.post_processor = RobertaProcessing(
- sep=(sep_token, self.sep_token_id),
- cls=(cls_token, self.cls_token_id),
- add_prefix_space=add_prefix_space,
- trim_offsets=trim_offsets,
- )
-
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
output = [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
if token_ids_1 is None:
diff --git a/src/transformers/tokenization_t5.py b/src/transformers/tokenization_t5.py
index 08438aa0c4..5d885ab114 100644
--- a/src/transformers/tokenization_t5.py
+++ b/src/transformers/tokenization_t5.py
@@ -24,6 +24,7 @@ from typing import List, Optional
from .file_utils import add_start_docstrings
from .tokenization_utils import BatchEncoding, PreTrainedTokenizer
from .tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING
+from .tokenization_utils_fast import PreTrainedTokenizerFast
from .utils import logging
@@ -322,3 +323,161 @@ class T5Tokenizer(PreTrainedTokenizer):
)
model_inputs["labels"] = labels_and_decoder_mask["input_ids"]
return model_inputs
+
+
+class T5TokenizerFast(PreTrainedTokenizerFast):
+ """
+ Construct a "fast" T5 tokenizer (backed by HuggingFace's `tokenizers` library). Based on `SentencePiece
+ `__ .
+
+ This tokenizer inherits from :class:`~transformers.PreTrainedTokenizerFast` which contains most of the main
+ methods. Users should refer to this superclass for more information regarding those methods.
+
+ Args:
+ vocab_file (:obj:`str`):
+ `SentencePiece `__ file (generally has a `.spm` extension) that
+ contains the vocabulary necessary to instantiate a tokenizer.
+ eos_token (:obj:`str`, `optional`, defaults to :obj:`""`):
+ The end of sequence token.
+
+ .. note::
+
+ When building a sequence using special tokens, this is not the token that is used for the end
+ of sequence. The token used is the :obj:`sep_token`.
+ unk_token (:obj:`str`, `optional`, defaults to :obj:`""`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ pad_token (:obj:`str`, `optional`, defaults to :obj:`""`):
+ The token used for padding, for example when batching sequences of different lengths.
+ extra_ids (:obj:`int`, `optional`, defaults to 100):
+ Add a number of extra ids added to the end of the vocabulary for use as sentinels.
+ These tokens are accessible as "" where "{%d}" is a number between 0 and extra_ids-1.
+ Extra tokens are indexed from the end of the vocabulary up to beginnning ("" is the last token
+ in the vocabulary like in T5 preprocessing see `here
+ `__).
+ additional_special_tokens (:obj:`List[str]`, `optional`):
+ Additional special tokens used by the tokenizer.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
+ max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
+ model_input_names = ["attention_mask"]
+ slow_tokenizer_class = T5Tokenizer
+
+ prefix_tokens: List[int] = []
+
+ def __init__(
+ self,
+ vocab_file,
+ eos_token="",
+ unk_token="",
+ pad_token="",
+ extra_ids=100,
+ additional_special_tokens=None,
+ **kwargs
+ ):
+ super().__init__(
+ vocab_file,
+ eos_token=eos_token,
+ unk_token=unk_token,
+ pad_token=pad_token,
+ extra_ids=extra_ids,
+ additional_special_tokens=additional_special_tokens,
+ **kwargs,
+ )
+
+ self.vocab_file = vocab_file
+ self._extra_ids = extra_ids
+
+ def save_vocabulary(self, save_directory):
+ """
+ Save the sentencepiece vocabulary (copy original file) and special tokens file to a directory.
+
+ Args:
+ save_directory (:obj:`str`):
+ The directory in which to save the vocabulary.
+
+ Returns:
+ :obj:`Tuple(str)`: Paths to the files saved.
+ """
+ if not os.path.isdir(save_directory):
+ logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
+ return
+ out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES["vocab_file"])
+
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
+ copyfile(self.vocab_file, out_vocab_file)
+
+ return (out_vocab_file,)
+
+ def build_inputs_with_special_tokens(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+ ) -> List[int]:
+ """
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks
+ by concatenating and adding special tokens.
+ A sequence has the following format:
+
+ - single sequence: ``X ``
+ - pair of sequences: ``A B ``
+
+ Args:
+ token_ids_0 (:obj:`List[int]`):
+ List of IDs to which the special tokens will be added.
+ token_ids_1 (:obj:`List[int]`, `optional`):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ :obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
+ """
+ token_ids_0 = token_ids_0 + [self.eos_token_id]
+ if token_ids_1 is None:
+ return self.prefix_tokens + token_ids_0
+ else:
+ token_ids_1 = token_ids_1 + [self.eos_token_id]
+ return self.prefix_tokens + token_ids_0 + token_ids_1
+
+ @add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
+ def prepare_seq2seq_batch(
+ self,
+ src_texts: List[str],
+ tgt_texts: Optional[List[str]] = None,
+ max_length: Optional[int] = None,
+ max_target_length: Optional[int] = None,
+ padding: str = "longest",
+ return_tensors: str = None,
+ truncation: bool = True,
+ **kwargs,
+ ) -> BatchEncoding:
+ if max_length is None:
+ max_length = self.max_len
+ self.prefix_tokens = []
+ model_inputs = self(
+ src_texts,
+ add_special_tokens=True,
+ return_tensors=return_tensors,
+ max_length=max_length,
+ padding=padding,
+ truncation=truncation,
+ **kwargs,
+ )
+ if tgt_texts is None:
+ return model_inputs
+ # Process tgt_texts
+ if max_target_length is None:
+ max_target_length = max_length
+ # set prefix_tokens for target text
+ self.prefix_tokens = [self.pad_token_id]
+ labels_and_decoder_mask = self(
+ tgt_texts,
+ add_special_tokens=True,
+ return_tensors=return_tensors,
+ padding=padding,
+ max_length=max_target_length,
+ truncation=truncation,
+ **kwargs,
+ )
+ model_inputs["labels"] = labels_and_decoder_mask["input_ids"]
+ self.prefix_tokens = []
+ return model_inputs
diff --git a/src/transformers/tokenization_transfo_xl.py b/src/transformers/tokenization_transfo_xl.py
index b6f34d43da..6e018150d6 100644
--- a/src/transformers/tokenization_transfo_xl.py
+++ b/src/transformers/tokenization_transfo_xl.py
@@ -22,23 +22,15 @@ import glob
import os
import pickle
import re
-import warnings
from collections import Counter, OrderedDict
-from typing import List, Optional
+from typing import List
import numpy as np
import sacremoses as sm
-from tokenizers import Tokenizer
-from tokenizers.implementations import BaseTokenizer
-from tokenizers.models import WordLevel
-from tokenizers.normalizers import Lowercase, Sequence, Strip, unicode_normalizer_from_str
-from tokenizers.pre_tokenizers import CharDelimiterSplit, WhitespaceSplit
-from tokenizers.processors import BertProcessing
from .file_utils import cached_path, is_torch_available, torch_only_method
from .tokenization_utils import PreTrainedTokenizer
-from .tokenization_utils_fast import PreTrainedTokenizerFast
from .utils import logging
@@ -53,7 +45,6 @@ VOCAB_FILES_NAMES = {
"pretrained_vocab_file_torch": "vocab.bin",
"vocab_file": "vocab.txt",
}
-VOCAB_FILES_NAMES_FAST = {"pretrained_vocab_file": "vocab.json", "vocab_file": "vocab.json"}
PRETRAINED_VOCAB_FILES_MAP = {
"pretrained_vocab_file": {
@@ -61,12 +52,6 @@ PRETRAINED_VOCAB_FILES_MAP = {
}
}
-PRETRAINED_VOCAB_FILES_MAP_FAST = {
- "pretrained_vocab_file": {
- "transfo-xl-wt103": "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-vocab.json",
- }
-}
-
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"transfo-xl-wt103": None,
}
@@ -240,6 +225,10 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
if vocab_file is not None:
self.build_vocab()
+ @property
+ def do_lower_case(self):
+ return self.lower_case
+
def _compile_space_around_punctuation_pattern(self):
look_ahead_for_special_token = "(?=[{}])".format(self.punctuation_symbols)
look_ahead_to_match_all_except_space = r"(?=[^\s])"
@@ -299,11 +288,6 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
:obj:`Tuple(str)`: Paths to the files saved.
"""
- logger.warning(
- "Please note you will not be able to load the save vocabulary in"
- " Rust-based TransfoXLTokenizerFast as they don't share the same structure."
- )
-
if os.path.isdir(vocab_path):
vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES["pretrained_vocab_file"])
else:
@@ -492,165 +476,6 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
return symbols
-class _TransfoXLDelimiterLookupTokenizer(BaseTokenizer):
- def __init__(
- self,
- vocab_file,
- delimiter,
- lowercase,
- unk_token,
- eos_token,
- add_eos=False,
- add_double_eos=False,
- normalization: Optional[str] = None,
- ):
-
- try:
- tokenizer = WordLevel(vocab_file, unk_token=unk_token)
- tokenizer = Tokenizer(tokenizer)
- except Exception:
- raise ValueError(
- "Unable to parse file {}. Unknown format. "
- "If you tried to load a model saved through TransfoXLTokenizer,"
- "please note they are not compatible.".format(vocab_file)
- )
-
- # Create the correct normalization path
- normalizer = []
-
- # Include unicode normalization
- if normalization:
- normalizer += [unicode_normalizer_from_str(normalization)]
-
- # Include case normalization
- if lowercase:
- normalizer += [Lowercase()]
-
- # Strip normalizer at the end
- normalizer += [Strip(left=True, right=True)]
-
- if len(normalizer) > 0:
- tokenizer.normalizer = Sequence(normalizer) if len(normalizer) > 1 else normalizer[0]
-
- # Setup the splitter
- tokenizer.pre_tokenizer = CharDelimiterSplit(delimiter) if delimiter else WhitespaceSplit()
-
- if add_double_eos:
- tokenizer.post_processor = BertProcessing(
- (eos_token, tokenizer.token_to_id(eos_token)), (eos_token, tokenizer.token_to_id(eos_token))
- )
-
- parameters = {
- "model": "TransfoXLModel",
- "add_eos": add_eos,
- "add_double_eos": add_double_eos,
- "unk_token": unk_token,
- "eos_token": eos_token,
- "delimiter": delimiter,
- "lowercase": lowercase,
- }
-
- super().__init__(tokenizer, parameters)
-
-
-class TransfoXLTokenizerFast(PreTrainedTokenizerFast):
- """
- Construct a "fast" Transformer-XL tokenizer (backed by HuggingFace's `tokenizers` library) adapted from Vocab class
- in `the original code `__. The Transformer-XL tokenizer is a
- word-level tokenizer (no sub-word tokenization).
-
- This tokenizer inherits from :class:`~transformers.PreTrainedTokenizerFast` which contains most of the main
- methods. Users should refer to this superclass for more information regarding those methods.
-
- Args:
- special (:obj:`List[str]`, `optional`):
- A list of special tokens (to be treated by the original implementation of this tokenizer).
- min_freq (:obj:`int`, `optional`, defaults to 0):
- The minimum number of times a token has to be present in order to be kept in the vocabulary (otherwise it
- will be mapped to :obj:`unk_token`).
- max_size (:obj:`int`, `optional`):
- The maximum size of the vocabulary. If left unset, it will default to the size of the vocabulary found
- after excluding the tokens according to the :obj:`min_freq` rule.
- lower_case (:obj:`bool`, `optional`, defaults to :obj:`False`):
- Whether or not to lowercase the input when tokenizing.
- delimiter (:obj:`str`, `optional`):
- The delimiter used btween tokens.
- vocab_file (:obj:`str`, `optional`):
- File containing the vocabulary (from the original implementation).
- pretrained_vocab_file (:obj:`str`, `optional`):
- File containing the vocabulary as saved with the :obj:`save_pretrained()` method.
- never_split (xxx, `optional`):
- Fill me with intesting stuff.
- unk_token (:obj:`str`, `optional`, defaults to :obj:`""`):
- The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
- token instead.
- eos_token (:obj:`str`, `optional`, defaults to :obj:`""`):
- The end of sequence token.
- additional_special_tokens (:obj:`List[str]`, `optional`, defaults to :obj:`[""]`):
- A list of additional special tokens (for the HuggingFace functionality).
- add_eos (:obj:`bool`, `optional`, defaults to :obj:`False`):
- Whether or not to add the end-of-sentence token.
- add_double_eos (:obj:`bool`, `optional`, defaults to :obj:`False`):
- Whether or not to add the end-of-sentence token.
- normalization (xxx, `optional`):
- Fill me with intesting stuff.
- """
-
- vocab_files_names = VOCAB_FILES_NAMES_FAST
- pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP_FAST
- max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
- model_input_names = []
-
- def __init__(
- self,
- special=None,
- min_freq=0,
- max_size=None,
- lower_case=False,
- delimiter=None,
- vocab_file=None,
- pretrained_vocab_file=None,
- never_split=None,
- unk_token="",
- eos_token="",
- additional_special_tokens=[""],
- add_eos=False,
- add_double_eos=False,
- normalization=None,
- **kwargs
- ):
-
- super().__init__(
- _TransfoXLDelimiterLookupTokenizer(
- vocab_file=vocab_file or pretrained_vocab_file,
- delimiter=delimiter,
- lowercase=lower_case,
- unk_token=unk_token,
- eos_token=eos_token,
- add_eos=add_eos,
- add_double_eos=add_double_eos,
- normalization=normalization,
- ),
- unk_token=unk_token,
- eos_token=eos_token,
- additional_special_tokens=additional_special_tokens,
- **kwargs,
- )
-
- warnings.warn(
- "The class `TransfoXLTokenizerFast` is deprecated and will be removed in a future version. Please use `TransfoXLTokenizer` with it's enhanced tokenization instead.",
- FutureWarning,
- )
-
- def save_pretrained(self, save_directory):
- logger.warning(
- "Please note you will not be able to load the vocabulary in"
- " Python-based TransfoXLTokenizer as they don't share the same structure."
- )
-
- return super().save_pretrained(save_directory)
-
-
class LMOrderedIterator(object):
def __init__(self, data, bsz, bptt, device="cpu", ext_len=None):
"""
diff --git a/src/transformers/tokenization_utils.py b/src/transformers/tokenization_utils.py
index 17291ff641..188b8ac76d 100644
--- a/src/transformers/tokenization_utils.py
+++ b/src/transformers/tokenization_utils.py
@@ -15,7 +15,6 @@
""" Tokenization classes for python tokenizers.
For fast tokenizers (provided by HuggingFace's tokenizers library) see tokenization_utils_fast.py
"""
-
import itertools
import re
import unicodedata
@@ -45,6 +44,11 @@ from .utils import logging
logger = logging.get_logger(__name__)
+# Slow tokenizers are saved in a vocabulary plus three separated files
+SPECIAL_TOKENS_MAP_FILE = "special_tokens_map.json"
+ADDED_TOKENS_FILE = "added_tokens.json"
+TOKENIZER_CONFIG_FILE = "tokenizer_config.json"
+
def _is_whitespace(char):
"""Checks whether `char` is a whitespace character."""
@@ -190,7 +194,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
tokens_to_add = []
for token in new_tokens:
assert isinstance(token, str)
- if not special_tokens and self.init_kwargs.get("do_lower_case", False):
+ if not special_tokens and hasattr(self, "do_lower_case") and self.do_lower_case:
token = token.lower()
if (
token != self.unk_token
@@ -239,6 +243,9 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
"""
Converts a string in a sequence of tokens, using the tokenizer.
+ Note that, unlike Fast tokenizers (instances of PreTrainedTokenizerFast), this method
+ won't replace the unknown tokens with the `unk_token` yet (this is done in the `encode()` method)
+
Split in words for word-based vocabulary or sub-words for sub-word-based vocabularies (BPE/SentencePieces/WordPieces).
Takes care of added tokens.
@@ -268,7 +275,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
logger.warning(f"Keyword arguments {kwargs} not recognized.")
# TODO: should this be in the base class?
- if self.init_kwargs.get("do_lower_case", False):
+ if hasattr(self, "do_lower_case") and self.do_lower_case:
# convert non-special tokens to lowercase
escaped_special_toks = [re.escape(s_tok) for s_tok in self.all_special_tokens]
pattern = r"(" + r"|".join(escaped_special_toks) + r")|" + r"(.+?)"
@@ -740,7 +747,11 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
return " ".join(tokens)
def decode(
- self, token_ids: List[int], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = True
+ self,
+ token_ids: List[int],
+ skip_special_tokens: bool = False,
+ clean_up_tokenization_spaces: bool = True,
+ spaces_between_special_tokens: bool = True,
) -> str:
"""
Converts a sequence of ids in a string, using the tokenizer and vocabulary
@@ -755,6 +766,10 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
Whether or not to remove special tokens in the decoding.
clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to clean up the tokenization spaces.
+ spaces_between_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`True`):
+ Whether or not to add spaces around special tokens.
+ The behavior of Fast tokenizers is to have this to :obj:`False`.
+ This is setup to :obj:`True` in slow tokenizers for backward compatibility.
Returns:
:obj:`str`: The decoded sentence.
@@ -778,7 +793,11 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
current_sub_text.append(token)
if current_sub_text:
sub_texts.append(self.convert_tokens_to_string(current_sub_text))
- text = " ".join(sub_texts)
+
+ if spaces_between_special_tokens:
+ text = " ".join(sub_texts)
+ else:
+ text = "".join(sub_texts)
if clean_up_tokenization_spaces:
clean_text = self.clean_up_tokenization(text)
diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py
index 4330ae0a36..fcbdd32412 100644
--- a/src/transformers/tokenization_utils_base.py
+++ b/src/transformers/tokenization_utils_base.py
@@ -646,6 +646,8 @@ class SpecialTokensMixin:
# which are not yet in the vocabulary. Necesssary for serialization/de-serialization
# TODO clean this up at some point (probably by sitching to fast tokenizers)
for key, value in kwargs.items():
+ if value is None:
+ continue
if key in self.SPECIAL_TOKENS_ATTRIBUTES:
if key == "additional_special_tokens":
assert isinstance(value, (list, tuple)), f"Value {value} is not a list or tuple"
@@ -778,6 +780,9 @@ class SpecialTokensMixin:
return self._add_tokens(new_tokens, special_tokens=special_tokens)
+ def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens: bool = False) -> int:
+ raise NotImplementedError
+
@property
def bos_token(self) -> str:
"""
@@ -1293,11 +1298,12 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
max_model_input_sizes: Dict[str, Optional[int]] = {}
model_input_names: List[str] = ["token_type_ids", "attention_mask"]
padding_side: str = "right"
+ slow_tokenizer_class = None
def __init__(self, **kwargs):
# inputs and kwargs for saving and re-loading (see ``from_pretrained`` and ``save_pretrained``)
self.init_inputs = ()
- self.init_kwargs = kwargs
+ self.init_kwargs = copy.deepcopy(kwargs)
# For backward compatibility we fallback to set model_max_length from max_len if provided
model_max_length = kwargs.pop("model_max_length", kwargs.pop("max_len", None))
@@ -1311,6 +1317,10 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
], f"Padding side should be selected between 'right' and 'left', current value: {self.padding_side}"
self.model_input_names = kwargs.pop("model_input_names", self.model_input_names)
+ self.deprecation_warnings = (
+ {}
+ ) # Use to store when we have already noticed a deprecation warning (avoid overlogging).
+
super().__init__(**kwargs)
@property
@@ -1343,9 +1353,11 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
def max_len_single_sentence(self, value) -> int:
# For backward compatibility, allow to try to setup 'max_len_single_sentence'.
if value == self.model_max_length - self.num_special_tokens_to_add(pair=False) and self.verbose:
- logger.warning(
- "Setting 'max_len_single_sentence' is now deprecated. " "This value is automatically set up."
- )
+ if not self.deprecation_warnings.get("max_len_single_sentence", False):
+ logger.warning(
+ "Setting 'max_len_single_sentence' is now deprecated. " "This value is automatically set up."
+ )
+ self.deprecation_warnings["max_len_single_sentence"] = True
else:
raise ValueError(
"Setting 'max_len_single_sentence' is now deprecated. " "This value is automatically set up."
@@ -1355,16 +1367,18 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
def max_len_sentences_pair(self, value) -> int:
# For backward compatibility, allow to try to setup 'max_len_sentences_pair'.
if value == self.model_max_length - self.num_special_tokens_to_add(pair=True) and self.verbose:
- logger.warning(
- "Setting 'max_len_sentences_pair' is now deprecated. " "This value is automatically set up."
- )
+ if not self.deprecation_warnings.get("max_len_sentences_pair", False):
+ logger.warning(
+ "Setting 'max_len_sentences_pair' is now deprecated. " "This value is automatically set up."
+ )
+ self.deprecation_warnings["max_len_sentences_pair"] = True
else:
raise ValueError(
"Setting 'max_len_sentences_pair' is now deprecated. " "This value is automatically set up."
)
@classmethod
- def from_pretrained(cls, *inputs, **kwargs):
+ def from_pretrained(cls, pretrained_model_name_or_path, *init_inputs, **kwargs):
r"""
Instantiate a :class:`~transformers.tokenization_utils_base.PreTrainedTokenizerBase` (or a derived class) from
a predefined tokenizer.
@@ -1425,10 +1439,6 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
assert tokenizer.unk_token == ''
"""
- return cls._from_pretrained(*inputs, **kwargs)
-
- @classmethod
- def _from_pretrained(cls, pretrained_model_name_or_path, *init_inputs, **kwargs):
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
@@ -1475,7 +1485,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
"added_tokens_file": ADDED_TOKENS_FILE,
"special_tokens_map_file": SPECIAL_TOKENS_MAP_FILE,
"tokenizer_config_file": TOKENIZER_CONFIG_FILE,
- "full_tokenizer_file": FULL_TOKENIZER_FILE,
+ "tokenizer_file": FULL_TOKENIZER_FILE,
}
# Look for the tokenizer files
for file_id, file_name in {**cls.vocab_files_names, **additional_files_names}.items():
@@ -1541,6 +1551,28 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
else:
logger.info("loading file {} from cache at {}".format(file_path, resolved_vocab_files[file_id]))
+ return cls._from_pretrained(
+ resolved_vocab_files, pretrained_model_name_or_path, init_configuration, *init_inputs, **kwargs
+ )
+
+ @classmethod
+ def _from_pretrained(
+ cls, resolved_vocab_files, pretrained_model_name_or_path, init_configuration, *init_inputs, **kwargs
+ ):
+ # We instantiate fast tokenizers based on a slow tokenizer for now
+ # In the future we can also use a direct way based on saving/instantiating
+ # tokenizer's Tokenizer directly from it's serialization JSON
+ if cls.slow_tokenizer_class is not None:
+ slow_tokenizer = cls.slow_tokenizer_class._from_pretrained(
+ copy.deepcopy(resolved_vocab_files),
+ pretrained_model_name_or_path,
+ copy.deepcopy(init_configuration),
+ *init_inputs,
+ **(copy.deepcopy(kwargs)),
+ )
+ else:
+ slow_tokenizer = None
+
# Prepare tokenizer initialization kwargs
# Did we saved some inputs and kwargs to reload ?
tokenizer_config_file = resolved_vocab_files.pop("tokenizer_config_file", None)
@@ -1556,6 +1588,19 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
# Update with newly provided kwargs
init_kwargs.update(kwargs)
+ # Convert AddedTokens serialized as dict to class instances
+ def convert_added_tokens(obj: Union[AddedToken, Any]):
+ if isinstance(obj, dict) and "__type" in obj and obj["__type"] == "AddedToken":
+ obj.pop("__type")
+ return AddedToken(**obj)
+ elif isinstance(obj, (list, tuple)):
+ return list(convert_added_tokens(o) for o in obj)
+ elif isinstance(obj, dict):
+ return {k: convert_added_tokens(v) for k, v in obj.items()}
+ return obj
+
+ init_kwargs = convert_added_tokens(init_kwargs)
+
# Set max length if needed
if pretrained_model_name_or_path in cls.max_model_input_sizes:
# if we're using a pretrained model, ensure the tokenizer
@@ -1570,6 +1615,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
if args_name not in init_kwargs:
init_kwargs[args_name] = file_path
+ if slow_tokenizer is not None:
+ init_kwargs["__slow_tokenizer"] = slow_tokenizer
+
# Instantiate tokenizer.
try:
tokenizer = cls(*init_inputs, **init_kwargs)
@@ -1580,8 +1628,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
)
# Save inputs and kwargs for saving and re-loading with ``save_pretrained``
- tokenizer.init_inputs = init_inputs
- tokenizer.init_kwargs = init_kwargs
+ # Removed: Now done at the base class level
+ # tokenizer.init_inputs = init_inputs
+ # tokenizer.init_kwargs = init_kwargs
# If there is a complementary special token map, load it
special_tokens_map_file = resolved_vocab_files.pop("special_tokens_map_file", None)
@@ -1589,11 +1638,8 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
with open(special_tokens_map_file, encoding="utf-8") as special_tokens_map_handle:
special_tokens_map = json.load(special_tokens_map_handle)
+ special_tokens_map = convert_added_tokens(special_tokens_map)
for key, value in special_tokens_map.items():
- if isinstance(value, dict):
- value = AddedToken(**value)
- elif isinstance(value, list):
- value = [AddedToken(**token) if isinstance(token, dict) else token for token in value]
setattr(tokenizer, key, value)
# Add supplementary tokens.
@@ -1623,14 +1669,17 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
def save_pretrained(self, save_directory: str) -> Tuple[str]:
"""
- Save the tokenizer vocabulary files together with:
+ Save the full tokenizer state.
- - added tokens,
- - special tokens to class attributes mapping,
- - tokenizer instantiation positional and keywords inputs (e.g. do_lower_case for Bert).
This method make sure the full tokenizer can then be re-loaded using the
- :meth:`~transformers.tokenization_utils_base.PreTrainedTokenizerBase.from_pretrained` class method.
+ :meth:`~transformers.tokenization_utils_base.PreTrainedTokenizer.from_pretrained` class method.
+
+ .. Note::
+ A "fast" tokenizer (instance of :class:`transformers.PreTrainedTokenizerFast`) saved with
+ this method will not be possible to load back
+ in a "slow" tokenizer, i.e. in a :class:`transformers.PreTrainedTokenizer` instance. It can only be loaded
+ in a "fast" tokenizer, i.e. in a :class:`transformers.PreTrainedTokenizerFast` instance.
.. Warning::
This won't save modifications you may have applied to the tokenizer after the instantiation (for instance,
@@ -1648,7 +1697,6 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
os.makedirs(save_directory, exist_ok=True)
special_tokens_map_file = os.path.join(save_directory, SPECIAL_TOKENS_MAP_FILE)
- added_tokens_file = os.path.join(save_directory, ADDED_TOKENS_FILE)
tokenizer_config_file = os.path.join(save_directory, TOKENIZER_CONFIG_FILE)
tokenizer_config = copy.deepcopy(self.init_kwargs)
@@ -1657,22 +1705,33 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
for file_id in self.vocab_files_names.keys():
tokenizer_config.pop(file_id, None)
+ # Sanitize AddedTokens
+ def convert_added_tokens(obj: Union[AddedToken, Any]):
+ if isinstance(obj, AddedToken):
+ out = obj.__getstate__()
+ out["__type"] = "AddedToken"
+ return out
+ elif isinstance(obj, (list, tuple)):
+ return list(convert_added_tokens(o) for o in obj)
+ elif isinstance(obj, dict):
+ return {k: convert_added_tokens(v) for k, v in obj.items()}
+ return obj
+
+ tokenizer_config = convert_added_tokens(tokenizer_config)
with open(tokenizer_config_file, "w", encoding="utf-8") as f:
f.write(json.dumps(tokenizer_config, ensure_ascii=False))
+ # Sanitize AddedTokens in special_tokens_map
+ write_dict = convert_added_tokens(self.special_tokens_map_extended)
with open(special_tokens_map_file, "w", encoding="utf-8") as f:
- write_dict = {}
- for key, value in self.special_tokens_map_extended.items():
- if isinstance(value, AddedToken):
- write_dict[key] = value.__getstate__()
- elif isinstance(value, list):
- write_dict[key] = [
- token.__getstate__() if isinstance(token, AddedToken) else token for token in value
- ]
- else:
- write_dict[key] = value
f.write(json.dumps(write_dict, ensure_ascii=False))
+ file_names = (tokenizer_config_file, special_tokens_map_file)
+
+ return self._save_pretrained(save_directory, file_names)
+
+ def _save_pretrained(self, save_directory: str, file_names: Tuple[str]) -> Tuple[str]:
+ added_tokens_file = os.path.join(save_directory, ADDED_TOKENS_FILE)
added_vocab = self.get_added_vocab()
if added_vocab:
with open(added_tokens_file, "w", encoding="utf-8") as f:
@@ -1681,7 +1740,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
vocab_files = self.save_vocabulary(save_directory)
- return vocab_files + (special_tokens_map_file, added_tokens_file)
+ return file_names + (vocab_files, added_tokens_file)
@add_end_docstrings(
ENCODE_KWARGS_DOCSTRING,
@@ -1752,13 +1811,15 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
# If you only set max_length, it activates truncation for max_length
if max_length is not None and padding is False and truncation is False:
if verbose:
- logger.warning(
- "Truncation was not explicitely activated but `max_length` is provided a specific value, "
- "please use `truncation=True` to explicitely truncate examples to max length. "
- "Defaulting to 'longest_first' truncation strategy. "
- "If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy "
- "more precisely by providing a specific strategy to `truncation`."
- )
+ if not self.deprecation_warnings.get("Truncation-not-explicitely-activated", False):
+ logger.warning(
+ "Truncation was not explicitely activated but `max_length` is provided a specific value, "
+ "please use `truncation=True` to explicitely truncate examples to max length. "
+ "Defaulting to 'longest_first' truncation strategy. "
+ "If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy "
+ "more precisely by providing a specific strategy to `truncation`."
+ )
+ self.deprecation_warnings["Truncation-not-explicitely-activated"] = True
truncation = "longest_first"
# Get padding strategy
@@ -1818,10 +1879,12 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
if padding_strategy == PaddingStrategy.MAX_LENGTH:
if self.model_max_length > LARGE_INTEGER:
if verbose:
- logger.warning(
- "Asking to pad to max_length but no maximum length is provided and the model has no predefined maximum length. "
- "Default to no padding."
- )
+ if not self.deprecation_warnings.get("Asking-to-pad-to-max_length", False):
+ logger.warning(
+ "Asking to pad to max_length but no maximum length is provided and the model has no predefined maximum length. "
+ "Default to no padding."
+ )
+ self.deprecation_warnings["Asking-to-pad-to-max_length"] = True
padding_strategy = PaddingStrategy.DO_NOT_PAD
else:
max_length = self.model_max_length
@@ -1829,10 +1892,12 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE:
if self.model_max_length > LARGE_INTEGER:
if verbose:
- logger.warning(
- "Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. "
- "Default to no truncation."
- )
+ if not self.deprecation_warnings.get("Asking-to-truncate-to-max_length", False):
+ logger.warning(
+ "Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. "
+ "Default to no truncation."
+ )
+ self.deprecation_warnings["Asking-to-truncate-to-max_length"] = True
truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE
else:
max_length = self.model_max_length
@@ -2437,6 +2502,13 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
len_ids = len(ids)
len_pair_ids = len(pair_ids) if pair else 0
+ if return_token_type_ids is not None and not add_special_tokens:
+ raise ValueError(
+ "Asking to return token_type_ids while setting add_special_tokens to False "
+ "results in an undefined behavior. Please set add_special_tokens to True or "
+ "set return_token_type_ids to None."
+ )
+
# Load from model defaults
if return_token_type_ids is None:
return_token_type_ids = "token_type_ids" in self.model_input_names
@@ -2469,7 +2541,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
token_type_ids = self.create_token_type_ids_from_sequences(ids, pair_ids)
else:
sequence = ids + pair_ids if pair else ids
- token_type_ids = [0] * len(ids) + ([1] * len(pair_ids) if pair else [])
+ token_type_ids = [0] * len(ids) + ([0] * len(pair_ids) if pair else [])
# Build output dictionnary
encoded_inputs["input_ids"] = sequence
@@ -2483,11 +2555,13 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
# Check lengths
if max_length is None and len(encoded_inputs["input_ids"]) > self.model_max_length and verbose:
- logger.warning(
- "Token indices sequence length is longer than the specified maximum sequence length "
- "for this model ({} > {}). Running this sequence through the model will result in "
- "indexing errors".format(len(encoded_inputs["input_ids"]), self.model_max_length)
- )
+ if not self.deprecation_warnings.get("sequence-length-is-longer-than-the-specified-maximum", False):
+ logger.warning(
+ "Token indices sequence length is longer than the specified maximum sequence length "
+ "for this model ({} > {}). Running this sequence through the model will result in "
+ "indexing errors".format(len(encoded_inputs["input_ids"]), self.model_max_length)
+ )
+ self.deprecation_warnings["sequence-length-is-longer-than-the-specified-maximum"] = True
# Padding
if padding_strategy != PaddingStrategy.DO_NOT_PAD or return_attention_mask:
@@ -2703,7 +2777,11 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
]
def decode(
- self, token_ids: List[int], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = True
+ self,
+ token_ids: List[int],
+ skip_special_tokens: bool = False,
+ clean_up_tokenization_spaces: bool = True,
+ **kwargs
) -> str:
"""
Converts a sequence of ids in a string, using the tokenizer and vocabulary
diff --git a/src/transformers/tokenization_utils_fast.py b/src/transformers/tokenization_utils_fast.py
index c357e00c74..8c00c7ddb3 100644
--- a/src/transformers/tokenization_utils_fast.py
+++ b/src/transformers/tokenization_utils_fast.py
@@ -16,16 +16,19 @@
For slow (python) tokenizers see tokenization_utils.py
"""
+import copy
import os
import warnings
from collections import defaultdict
from typing import Any, Dict, List, Optional, Tuple, Union
from tokenizers import Encoding as EncodingFast
+from tokenizers import Tokenizer as TokenizerFast
from tokenizers.decoders import Decoder as DecoderFast
-from tokenizers.implementations import BaseTokenizer as BaseTokenizerFast
+from .convert_slow_tokenizer import convert_slow_tokenizer
from .file_utils import add_end_docstrings
+from .tokenization_utils import PreTrainedTokenizer
from .tokenization_utils_base import (
INIT_TOKENIZER_DOCSTRING,
AddedToken,
@@ -44,6 +47,15 @@ from .utils import logging
logger = logging.get_logger(__name__)
+# Fast tokenizers (provided by HuggingFace tokenizer's library) can be saved in a single file
+TOKENIZER_FILE = "tokenizer.json"
+SPECIAL_TOKENS_MAP_FILE = "special_tokens_map.json"
+TOKENIZER_CONFIG_FILE = "tokenizer_config.json"
+
+# Slow tokenizers have an additional addedd tokens files
+ADDED_TOKENS_FILE = "added_tokens.json"
+
+
@add_end_docstrings(
INIT_TOKENIZER_DOCSTRING,
"""
@@ -64,12 +76,19 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
dictionary structures (BPE, sentencepiece...).
"""
- def __init__(self, tokenizer: BaseTokenizerFast, **kwargs):
- if not isinstance(tokenizer, BaseTokenizerFast):
- raise ValueError(
- "Tokenizer should be an instance of a BaseTokenizer " "provided by HuggingFace tokenizers library."
- )
- self._tokenizer: BaseTokenizerFast = tokenizer
+ slow_tokenizer_class: PreTrainedTokenizer = None
+
+ def __init__(self, *args, **kwargs):
+ # We instantiate fast tokenizers based on a slow tokenizer for now
+ # In the future we'll also use a direct way based on saving/instantiating
+ # tokenizer's Tokenizer directly from it's serialization JSON
+ if "__slow_tokenizer" in kwargs and kwargs["__slow_tokenizer"]:
+ slow_tokenizer = kwargs.pop("__slow_tokenizer")
+ else:
+ slow_tokenizer = self.slow_tokenizer_class(*args, **kwargs)
+ self._tokenizer = convert_slow_tokenizer(slow_tokenizer)
+
+ kwargs = copy.deepcopy(slow_tokenizer.init_kwargs)
# We call this after having initialized the backend tokenizer because we update it.
super().__init__(**kwargs)
@@ -116,7 +135,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
return self._tokenizer.get_vocab_size(with_added_tokens=True)
@property
- def backend_tokenizer(self) -> BaseTokenizerFast:
+ def backend_tokenizer(self) -> TokenizerFast:
"""
:obj:`tokenizers.implementations.BaseTokenizer`: The Rust tokenizer used as a backend.
"""
@@ -259,6 +278,9 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
"""
Converts a string in a sequence of tokens, using the backend Rust tokenizer.
+ Note that, unlike slow tokenizers (instances of :class:`~transformers.PreTrainedTokenizer`), this method
+ will replace the unknown tokens with the :obj:`unk_token`.
+
Args:
text (:obj:`str`):
The sequence to be encoded.
@@ -343,7 +365,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
) -> BatchEncoding:
if not isinstance(batch_text_or_text_pairs, list):
- raise ValueError(
+ raise TypeError(
"batch_text_or_text_pairs has to be a list (got {})".format(type(batch_text_or_text_pairs))
)
@@ -487,7 +509,11 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
return batched_output
def decode(
- self, token_ids: List[int], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = True
+ self,
+ token_ids: Union[int, List[int]],
+ skip_special_tokens: bool = False,
+ clean_up_tokenization_spaces: bool = True,
+ **kwargs
) -> str:
"""
Converts a sequence of ids in a string, using the tokenizer and vocabulary
@@ -496,7 +522,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
Similar to doing ``self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))``.
Args:
- token_ids (:obj:`List[int]`):
+ token_ids (:obj:`Union[int, List[int]]`):
List of tokenized input ids. Can be obtained using the ``__call__`` method.
skip_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to remove special tokens in the decoding.
@@ -506,6 +532,8 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
Returns:
:obj:`str`: The decoded sentence.
"""
+ if isinstance(token_ids, int):
+ token_ids = [token_ids]
text = self._tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
if clean_up_tokenization_spaces:
@@ -520,8 +548,8 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
and special token mappings.
.. warning::
- Please use :meth:`~transformers.PreTrainedTokenizer.save_pretrained` to save the full tokenizer state if
- you want to reload it using the :meth:`~transformers.PreTrainedTokenizer.from_pretrained` class method.
+ Please use :meth:`~transformers.PreTrainedTokenizerFast.save_pretrained` to save the full tokenizer state if
+ you want to reload it using the :meth:`~transformers.PreTrainedTokenizerFast.from_pretrained` class method.
Args:
save_directory (:obj:`str`): The path to adirectory where the tokenizer will be saved.
@@ -530,7 +558,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
A tuple of :obj:`str`: The files saved.
"""
if os.path.isdir(save_directory):
- files = self._tokenizer.save_model(save_directory)
+ files = self._tokenizer.model.save(save_directory)
else:
folder, file = os.path.split(os.path.abspath(save_directory))
files = self._tokenizer.save_model(folder, name=file)
diff --git a/src/transformers/tokenization_xlm.py b/src/transformers/tokenization_xlm.py
index 3fc4b80d9d..9a1286e442 100644
--- a/src/transformers/tokenization_xlm.py
+++ b/src/transformers/tokenization_xlm.py
@@ -648,6 +648,10 @@ class XLMTokenizer(PreTrainedTokenizer):
self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {}
+ @property
+ def do_lower_case(self):
+ return self.do_lowercase_and_remove_accent
+
def moses_punct_norm(self, text, lang):
if lang not in self.cache_moses_punct_normalizer:
punct_normalizer = sm.MosesPunctNormalizer(lang=lang)
diff --git a/src/transformers/tokenization_xlm_roberta.py b/src/transformers/tokenization_xlm_roberta.py
index c19a6b0f8f..24139b8811 100644
--- a/src/transformers/tokenization_xlm_roberta.py
+++ b/src/transformers/tokenization_xlm_roberta.py
@@ -20,6 +20,7 @@ from shutil import copyfile
from typing import List, Optional
from .tokenization_utils import PreTrainedTokenizer
+from .tokenization_utils_fast import PreTrainedTokenizerFast
from .tokenization_xlnet import SPIECE_UNDERLINE
from .utils import logging
@@ -307,3 +308,190 @@ class XLMRobertaTokenizer(PreTrainedTokenizer):
copyfile(self.vocab_file, out_vocab_file)
return (out_vocab_file,)
+
+
+class XLMRobertaTokenizerFast(PreTrainedTokenizerFast):
+ """
+ Construct a "fast" XLM-RoBERTa tokenizer (backed by HuggingFace's `tokenizers` library). Adapted from
+ :class:`~transfomers.RobertaTokenizer` and class:`~transfomers.XLNetTokenizer`. Based on `SentencePiece
+ `__.
+
+ This tokenizer inherits from :class:`~transformers.PreTrainedTokenizerFast` which contains most of the main
+ methods. Users should refer to this superclass for more information regarding those methods.
+
+ Args:
+ vocab_file (:obj:`str`):
+ Path to the vocabulary file.
+ bos_token (:obj:`str`, `optional`, defaults to :obj:`""`):
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
+
+ .. note::
+
+ When building a sequence using special tokens, this is not the token that is used for the beginning
+ of sequence. The token used is the :obj:`cls_token`.
+ eos_token (:obj:`str`, `optional`, defaults to :obj:`""`):
+ The end of sequence token.
+
+ .. note::
+
+ When building a sequence using special tokens, this is not the token that is used for the end
+ of sequence. The token used is the :obj:`sep_token`.
+ sep_token (:obj:`str`, `optional`, defaults to :obj:`""`):
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences
+ for sequence classification or for a text and a question for question answering.
+ It is also used as the last token of a sequence built with special tokens.
+ cls_token (:obj:`str`, `optional`, defaults to :obj:`""`):
+ The classifier token which is used when doing sequence classification (classification of the whole
+ sequence instead of per-token classification). It is the first token of the sequence when built with
+ special tokens.
+ unk_token (:obj:`str`, `optional`, defaults to :obj:`""`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ pad_token (:obj:`str`, `optional`, defaults to :obj:`""`):
+ The token used for padding, for example when batching sequences of different lengths.
+ mask_token (:obj:`str`, `optional`, defaults to :obj:`""`):
+ The token used for masking values. This is the token used when training this model with masked language
+ modeling. This is the token which the model will try to predict.
+ additional_special_tokens (:obj:`List[str]`, `optional`, defaults to :obj:`["NOTUSED", "NOTUSED"]`):
+ Additional special tokens used by the tokenizer.
+
+ Attributes:
+ sp_model (:obj:`SentencePieceProcessor`):
+ The `SentencePiece` processor that is used for every conversion (string, tokens and IDs).
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
+ max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
+ model_input_names = ["attention_mask"]
+ slow_tokenizer_class = XLMRobertaTokenizer
+
+ def __init__(
+ self,
+ vocab_file,
+ bos_token="",
+ eos_token="",
+ sep_token="",
+ cls_token="",
+ unk_token="",
+ pad_token="",
+ mask_token="",
+ **kwargs
+ ):
+ super().__init__(
+ vocab_file,
+ bos_token=bos_token,
+ eos_token=eos_token,
+ sep_token=sep_token,
+ cls_token=cls_token,
+ unk_token=unk_token,
+ pad_token=pad_token,
+ mask_token=mask_token,
+ **kwargs,
+ )
+
+ self.vocab_file = vocab_file
+
+ def build_inputs_with_special_tokens(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+ ) -> List[int]:
+ """
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks
+ by concatenating and adding special tokens.
+ An XLM-RoBERTa sequence has the following format:
+
+ - single sequence: `` X ``
+ - pair of sequences: `` A B ``
+
+ Args:
+ token_ids_0 (:obj:`List[int]`):
+ List of IDs to which the special tokens will be added.
+ token_ids_1 (:obj:`List[int]`, `optional`):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ :obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
+ """
+
+ if token_ids_1 is None:
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
+ cls = [self.cls_token_id]
+ sep = [self.sep_token_id]
+ return cls + token_ids_0 + sep + sep + token_ids_1 + sep
+
+ def get_special_tokens_mask(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
+ ) -> List[int]:
+ """
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
+ special tokens using the tokenizer ``prepare_for_model`` method.
+
+ Args:
+ token_ids_0 (:obj:`List[int]`):
+ List of IDs.
+ token_ids_1 (:obj:`List[int]`, `optional`):
+ Optional second list of IDs for sequence pairs.
+ already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
+ Whether or not the token list is already formatted with special tokens for the model.
+
+ Returns:
+ :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
+ """
+
+ if already_has_special_tokens:
+ if token_ids_1 is not None:
+ raise ValueError(
+ "You should not supply a second sequence if the provided sequence of "
+ "ids is already formated with special tokens for the model."
+ )
+ return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
+
+ if token_ids_1 is None:
+ return [1] + ([0] * len(token_ids_0)) + [1]
+ return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]
+
+ def create_token_type_ids_from_sequences(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+ ) -> List[int]:
+ """
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task.
+ XLM-RoBERTa does not make use of token type ids, therefore a list of zeros is returned.
+
+ Args:
+ token_ids_0 (:obj:`List[int]`):
+ List of IDs.
+ token_ids_1 (:obj:`List[int]`, `optional`):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ :obj:`List[int]`: List of zeros.
+
+ """
+
+ sep = [self.sep_token_id]
+ cls = [self.cls_token_id]
+
+ if token_ids_1 is None:
+ return len(cls + token_ids_0 + sep) * [0]
+ return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
+
+ def save_vocabulary(self, save_directory):
+ """
+ Save the sentencepiece vocabulary (copy original file) and special tokens file to a directory.
+
+ Args:
+ save_directory (:obj:`str`):
+ The directory in which to save the vocabulary.
+
+ Returns:
+ :obj:`Tuple(str)`: Paths to the files saved.
+ """
+ if not os.path.isdir(save_directory):
+ logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
+ return
+ out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES["vocab_file"])
+
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
+ copyfile(self.vocab_file, out_vocab_file)
+
+ return (out_vocab_file,)
diff --git a/src/transformers/tokenization_xlnet.py b/src/transformers/tokenization_xlnet.py
index 59c61e5e7d..f2484b2af0 100644
--- a/src/transformers/tokenization_xlnet.py
+++ b/src/transformers/tokenization_xlnet.py
@@ -21,6 +21,7 @@ from shutil import copyfile
from typing import List, Optional
from .tokenization_utils import PreTrainedTokenizer
+from .tokenization_utils_fast import PreTrainedTokenizerFast
from .utils import logging
@@ -344,3 +345,213 @@ class XLNetTokenizer(PreTrainedTokenizer):
copyfile(self.vocab_file, out_vocab_file)
return (out_vocab_file,)
+
+
+class XLNetTokenizerFast(PreTrainedTokenizerFast):
+ """
+ Construct a "fast" XLNet tokenizer (backed by HuggingFace's `tokenizers` library). Based on
+ `SentencePiece `__.
+
+ This tokenizer inherits from :class:`~transformers.PreTrainedTokenizerFast` which contains most of the main
+ methods. Users should refer to this superclass for more information regarding those methods.
+
+ Args:
+ vocab_file (:obj:`str`):
+ `SentencePiece `__ file (generally has a .spm extension) that
+ contains the vocabulary necessary to instantiate a tokenizer.
+ do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`):
+ Whether to lowercase the input when tokenizing.
+ remove_space (:obj:`bool`, `optional`, defaults to :obj:`True`):
+ Whether to strip the text when tokenizing (removing excess spaces before and after the string).
+ keep_accents (:obj:`bool`, `optional`, defaults to :obj:`False`):
+ Whether to keep accents when tokenizing.
+ bos_token (:obj:`str`, `optional`, defaults to :obj:`""`):
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
+
+ .. note::
+
+ When building a sequence using special tokens, this is not the token that is used for the beginning
+ of sequence. The token used is the :obj:`cls_token`.
+ eos_token (:obj:`str`, `optional`, defaults to :obj:`""`):
+ The end of sequence token.
+
+ .. note::
+
+ When building a sequence using special tokens, this is not the token that is used for the end
+ of sequence. The token used is the :obj:`sep_token`.
+ unk_token (:obj:`str`, `optional`, defaults to :obj:`""`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ sep_token (:obj:`str`, `optional`, defaults to :obj:`""`):
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences
+ for sequence classification or for a text and a question for question answering.
+ It is also used as the last token of a sequence built with special tokens.
+ pad_token (:obj:`str`, `optional`, defaults to :obj:`""`):
+ The token used for padding, for example when batching sequences of different lengths.
+ cls_token (:obj:`str`, `optional`, defaults to :obj:`""`):
+ The classifier token which is used when doing sequence classification (classification of the whole
+ sequence instead of per-token classification). It is the first token of the sequence when built with
+ special tokens.
+ mask_token (:obj:`str`, `optional`, defaults to :obj:`""`):
+ The token used for masking values. This is the token used when training this model with masked language
+ modeling. This is the token which the model will try to predict.
+ additional_special_tokens (:obj:`List[str]`, `optional`, defaults to :obj:`["", ""]`):
+ Additional special tokens used by the tokenizer.
+
+ Attributes:
+ sp_model (:obj:`SentencePieceProcessor`):
+ The `SentencePiece` processor that is used for every conversion (string, tokens and IDs).
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
+ max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
+ padding_side = "left"
+ slow_tokenizer_class = XLNetTokenizer
+
+ def __init__(
+ self,
+ vocab_file,
+ do_lower_case=False,
+ remove_space=True,
+ keep_accents=False,
+ bos_token="",
+ eos_token="",
+ unk_token="",
+ sep_token="",
+ pad_token="",
+ cls_token="",
+ mask_token="",
+ additional_special_tokens=["", ""],
+ **kwargs
+ ):
+ super().__init__(
+ vocab_file=vocab_file,
+ do_lower_case=do_lower_case,
+ remove_space=remove_space,
+ keep_accents=keep_accents,
+ bos_token=bos_token,
+ eos_token=eos_token,
+ unk_token=unk_token,
+ sep_token=sep_token,
+ pad_token=pad_token,
+ cls_token=cls_token,
+ mask_token=mask_token,
+ additional_special_tokens=additional_special_tokens,
+ **kwargs,
+ )
+
+ self._pad_token_type_id = 3
+ self.do_lower_case = do_lower_case
+ self.remove_space = remove_space
+ self.keep_accents = keep_accents
+ self.vocab_file = vocab_file
+
+ def build_inputs_with_special_tokens(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+ ) -> List[int]:
+ """
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks
+ by concatenating and adding special tokens.
+ An XLNet sequence has the following format:
+
+ - single sequence: ``X ``
+ - pair of sequences: ``A B ``
+
+ Args:
+ token_ids_0 (:obj:`List[int]`):
+ List of IDs to which the special tokens will be added.
+ token_ids_1 (:obj:`List[int]`, `optional`):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ :obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
+ """
+ sep = [self.sep_token_id]
+ cls = [self.cls_token_id]
+ if token_ids_1 is None:
+ return token_ids_0 + sep + cls
+ return token_ids_0 + sep + token_ids_1 + sep + cls
+
+ def get_special_tokens_mask(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
+ ) -> List[int]:
+ """
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
+ special tokens using the tokenizer ``prepare_for_model`` method.
+
+ Args:
+ token_ids_0 (:obj:`List[int]`):
+ List of IDs.
+ token_ids_1 (:obj:`List[int]`, `optional`):
+ Optional second list of IDs for sequence pairs.
+ already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
+ Whether or not the token list is already formatted with special tokens for the model.
+
+ Returns:
+ :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
+ """
+
+ if already_has_special_tokens:
+ if token_ids_1 is not None:
+ raise ValueError(
+ "You should not supply a second sequence if the provided sequence of "
+ "ids is already formated with special tokens for the model."
+ )
+ return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
+
+ if token_ids_1 is not None:
+ return ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1, 1]
+ return ([0] * len(token_ids_0)) + [1, 1]
+
+ def create_token_type_ids_from_sequences(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+ ) -> List[int]:
+ """
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task.
+ An XLNet sequence pair mask has the following format:
+
+ ::
+
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
+ | first sequence | second sequence |
+
+ If :obj:`token_ids_1` is :obj:`None`, this method only returns the first portion of the mask (0s).
+
+ Args:
+ token_ids_0 (:obj:`List[int]`):
+ List of IDs.
+ token_ids_1 (:obj:`List[int]`, `optional`):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ :obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
+ sequence(s).
+ """
+ sep = [self.sep_token_id]
+ cls_segment_id = [2]
+
+ if token_ids_1 is None:
+ return len(token_ids_0 + sep) * [0] + cls_segment_id
+ return len(token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] + cls_segment_id
+
+ def save_vocabulary(self, save_directory):
+ """
+ Save the sentencepiece vocabulary (copy original file) and special tokens file to a directory.
+
+ Args:
+ save_directory (:obj:`str`):
+ The directory in which to save the vocabulary.
+
+ Returns:
+ :obj:`Tuple(str)`: Paths to the files saved.
+ """
+ if not os.path.isdir(save_directory):
+ logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
+ return
+ out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES["vocab_file"])
+
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
+ copyfile(self.vocab_file, out_vocab_file)
+
+ return (out_vocab_file,)
diff --git a/src/transformers/utils/sentencepiece_model_pb2.py b/src/transformers/utils/sentencepiece_model_pb2.py
new file mode 100644
index 0000000000..20cd7f8bca
--- /dev/null
+++ b/src/transformers/utils/sentencepiece_model_pb2.py
@@ -0,0 +1,1169 @@
+# flake8: noqa
+# Generated by the protocol buffer compiler. DO NOT EDIT!
+# source: sentencepiece_model.proto
+
+import sys
+
+
+_b = sys.version_info[0] < 3 and (lambda x: x) or (lambda x: x.encode("latin1"))
+from google.protobuf import descriptor as _descriptor
+from google.protobuf import descriptor_pb2
+from google.protobuf import message as _message
+from google.protobuf import reflection as _reflection
+from google.protobuf import symbol_database as _symbol_database
+
+
+# @@protoc_insertion_point(imports)
+
+_sym_db = _symbol_database.Default()
+
+
+DESCRIPTOR = _descriptor.FileDescriptor(
+ name="sentencepiece_model.proto",
+ package="sentencepiece",
+ syntax="proto2",
+ serialized_pb=_b(
+ '\n\x19sentencepiece_model.proto\x12\rsentencepiece"\xf4\x08\n\x0bTrainerSpec\x12\r\n\x05input\x18\x01 \x03(\t\x12\x14\n\x0cinput_format\x18\x07 \x01(\t\x12\x14\n\x0cmodel_prefix\x18\x02 \x01(\t\x12\x41\n\nmodel_type\x18\x03 \x01(\x0e\x32$.sentencepiece.TrainerSpec.ModelType:\x07UNIGRAM\x12\x18\n\nvocab_size\x18\x04 \x01(\x05:\x04\x38\x30\x30\x30\x12\x17\n\x0f\x61\x63\x63\x65pt_language\x18\x05 \x03(\t\x12 \n\x15self_test_sample_size\x18\x06 \x01(\x05:\x01\x30\x12"\n\x12\x63haracter_coverage\x18\n \x01(\x02:\x06\x30.9995\x12\x1e\n\x13input_sentence_size\x18\x0b \x01(\x05:\x01\x30\x12$\n\x16shuffle_input_sentence\x18\x13 \x01(\x08:\x04true\x12 \n\x14mining_sentence_size\x18\x0c \x01(\x05\x42\x02\x18\x01\x12"\n\x16training_sentence_size\x18\r \x01(\x05\x42\x02\x18\x01\x12(\n\x17seed_sentencepiece_size\x18\x0e \x01(\x05:\x07\x31\x30\x30\x30\x30\x30\x30\x12\x1e\n\x10shrinking_factor\x18\x0f \x01(\x02:\x04\x30.75\x12!\n\x13max_sentence_length\x18\x12 \x01(\x05:\x04\x34\x31\x39\x32\x12\x17\n\x0bnum_threads\x18\x10 \x01(\x05:\x02\x31\x36\x12\x1d\n\x12num_sub_iterations\x18\x11 \x01(\x05:\x01\x32\x12$\n\x18max_sentencepiece_length\x18\x14 \x01(\x05:\x02\x31\x36\x12%\n\x17split_by_unicode_script\x18\x15 \x01(\x08:\x04true\x12\x1d\n\x0fsplit_by_number\x18\x17 \x01(\x08:\x04true\x12!\n\x13split_by_whitespace\x18\x16 \x01(\x08:\x04true\x12)\n\x1atreat_whitespace_as_suffix\x18\x18 \x01(\x08:\x05\x66\x61lse\x12\x17\n\x0f\x63ontrol_symbols\x18\x1e \x03(\t\x12\x1c\n\x14user_defined_symbols\x18\x1f \x03(\t\x12\x1e\n\x10hard_vocab_limit\x18! \x01(\x08:\x04true\x12\x1c\n\ruse_all_vocab\x18" \x01(\x08:\x05\x66\x61lse\x12\x11\n\x06unk_id\x18( \x01(\x05:\x01\x30\x12\x11\n\x06\x62os_id\x18) \x01(\x05:\x01\x31\x12\x11\n\x06\x65os_id\x18* \x01(\x05:\x01\x32\x12\x12\n\x06pad_id\x18+ \x01(\x05:\x02-1\x12\x18\n\tunk_piece\x18- \x01(\t:\x05\x12\x16\n\tbos_piece\x18. \x01(\t:\x03\x12\x17\n\teos_piece\x18/ \x01(\t:\x04\x12\x18\n\tpad_piece\x18\x30 \x01(\t:\x05\x12\x1a\n\x0bunk_surface\x18, \x01(\t:\x05 \xe2\x81\x87 "5\n\tModelType\x12\x0b\n\x07UNIGRAM\x10\x01\x12\x07\n\x03\x42PE\x10\x02\x12\x08\n\x04WORD\x10\x03\x12\x08\n\x04\x43HAR\x10\x04*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02"\xd1\x01\n\x0eNormalizerSpec\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x1c\n\x14precompiled_charsmap\x18\x02 \x01(\x0c\x12\x1e\n\x10\x61\x64\x64_dummy_prefix\x18\x03 \x01(\x08:\x04true\x12&\n\x18remove_extra_whitespaces\x18\x04 \x01(\x08:\x04true\x12 \n\x12\x65scape_whitespaces\x18\x05 \x01(\x08:\x04true\x12\x1e\n\x16normalization_rule_tsv\x18\x06 \x01(\t*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02"y\n\x0cSelfTestData\x12\x33\n\x07samples\x18\x01 \x03(\x0b\x32".sentencepiece.SelfTestData.Sample\x1a)\n\x06Sample\x12\r\n\x05input\x18\x01 \x01(\t\x12\x10\n\x08\x65xpected\x18\x02 \x01(\t*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02"\xba\x03\n\nModelProto\x12\x37\n\x06pieces\x18\x01 \x03(\x0b\x32\'.sentencepiece.ModelProto.SentencePiece\x12\x30\n\x0ctrainer_spec\x18\x02 \x01(\x0b\x32\x1a.sentencepiece.TrainerSpec\x12\x36\n\x0fnormalizer_spec\x18\x03 \x01(\x0b\x32\x1d.sentencepiece.NormalizerSpec\x12\x33\n\x0eself_test_data\x18\x04 \x01(\x0b\x32\x1b.sentencepiece.SelfTestData\x1a\xc8\x01\n\rSentencePiece\x12\r\n\x05piece\x18\x01 \x01(\t\x12\r\n\x05score\x18\x02 \x01(\x02\x12\x42\n\x04type\x18\x03 \x01(\x0e\x32,.sentencepiece.ModelProto.SentencePiece.Type:\x06NORMAL"J\n\x04Type\x12\n\n\x06NORMAL\x10\x01\x12\x0b\n\x07UNKNOWN\x10\x02\x12\x0b\n\x07\x43ONTROL\x10\x03\x12\x10\n\x0cUSER_DEFINED\x10\x04\x12\n\n\x06UNUSED\x10\x05*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\x42\x02H\x03'
+ ),
+)
+_sym_db.RegisterFileDescriptor(DESCRIPTOR)
+
+
+_TRAINERSPEC_MODELTYPE = _descriptor.EnumDescriptor(
+ name="ModelType",
+ full_name="sentencepiece.TrainerSpec.ModelType",
+ filename=None,
+ file=DESCRIPTOR,
+ values=[
+ _descriptor.EnumValueDescriptor(name="UNIGRAM", index=0, number=1, options=None, type=None),
+ _descriptor.EnumValueDescriptor(name="BPE", index=1, number=2, options=None, type=None),
+ _descriptor.EnumValueDescriptor(name="WORD", index=2, number=3, options=None, type=None),
+ _descriptor.EnumValueDescriptor(name="CHAR", index=3, number=4, options=None, type=None),
+ ],
+ containing_type=None,
+ options=None,
+ serialized_start=1121,
+ serialized_end=1174,
+)
+_sym_db.RegisterEnumDescriptor(_TRAINERSPEC_MODELTYPE)
+
+_MODELPROTO_SENTENCEPIECE_TYPE = _descriptor.EnumDescriptor(
+ name="Type",
+ full_name="sentencepiece.ModelProto.SentencePiece.Type",
+ filename=None,
+ file=DESCRIPTOR,
+ values=[
+ _descriptor.EnumValueDescriptor(name="NORMAL", index=0, number=1, options=None, type=None),
+ _descriptor.EnumValueDescriptor(name="UNKNOWN", index=1, number=2, options=None, type=None),
+ _descriptor.EnumValueDescriptor(name="CONTROL", index=2, number=3, options=None, type=None),
+ _descriptor.EnumValueDescriptor(name="USER_DEFINED", index=3, number=4, options=None, type=None),
+ _descriptor.EnumValueDescriptor(name="UNUSED", index=4, number=5, options=None, type=None),
+ ],
+ containing_type=None,
+ options=None,
+ serialized_start=1869,
+ serialized_end=1943,
+)
+_sym_db.RegisterEnumDescriptor(_MODELPROTO_SENTENCEPIECE_TYPE)
+
+
+_TRAINERSPEC = _descriptor.Descriptor(
+ name="TrainerSpec",
+ full_name="sentencepiece.TrainerSpec",
+ filename=None,
+ file=DESCRIPTOR,
+ containing_type=None,
+ fields=[
+ _descriptor.FieldDescriptor(
+ name="input",
+ full_name="sentencepiece.TrainerSpec.input",
+ index=0,
+ number=1,
+ type=9,
+ cpp_type=9,
+ label=3,
+ has_default_value=False,
+ default_value=[],
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ options=None,
+ ),
+ _descriptor.FieldDescriptor(
+ name="input_format",
+ full_name="sentencepiece.TrainerSpec.input_format",
+ index=1,
+ number=7,
+ type=9,
+ cpp_type=9,
+ label=1,
+ has_default_value=False,
+ default_value=_b("").decode("utf-8"),
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ options=None,
+ ),
+ _descriptor.FieldDescriptor(
+ name="model_prefix",
+ full_name="sentencepiece.TrainerSpec.model_prefix",
+ index=2,
+ number=2,
+ type=9,
+ cpp_type=9,
+ label=1,
+ has_default_value=False,
+ default_value=_b("").decode("utf-8"),
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ options=None,
+ ),
+ _descriptor.FieldDescriptor(
+ name="model_type",
+ full_name="sentencepiece.TrainerSpec.model_type",
+ index=3,
+ number=3,
+ type=14,
+ cpp_type=8,
+ label=1,
+ has_default_value=True,
+ default_value=1,
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ options=None,
+ ),
+ _descriptor.FieldDescriptor(
+ name="vocab_size",
+ full_name="sentencepiece.TrainerSpec.vocab_size",
+ index=4,
+ number=4,
+ type=5,
+ cpp_type=1,
+ label=1,
+ has_default_value=True,
+ default_value=8000,
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ options=None,
+ ),
+ _descriptor.FieldDescriptor(
+ name="accept_language",
+ full_name="sentencepiece.TrainerSpec.accept_language",
+ index=5,
+ number=5,
+ type=9,
+ cpp_type=9,
+ label=3,
+ has_default_value=False,
+ default_value=[],
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ options=None,
+ ),
+ _descriptor.FieldDescriptor(
+ name="self_test_sample_size",
+ full_name="sentencepiece.TrainerSpec.self_test_sample_size",
+ index=6,
+ number=6,
+ type=5,
+ cpp_type=1,
+ label=1,
+ has_default_value=True,
+ default_value=0,
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ options=None,
+ ),
+ _descriptor.FieldDescriptor(
+ name="character_coverage",
+ full_name="sentencepiece.TrainerSpec.character_coverage",
+ index=7,
+ number=10,
+ type=2,
+ cpp_type=6,
+ label=1,
+ has_default_value=True,
+ default_value=float(0.9995),
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ options=None,
+ ),
+ _descriptor.FieldDescriptor(
+ name="input_sentence_size",
+ full_name="sentencepiece.TrainerSpec.input_sentence_size",
+ index=8,
+ number=11,
+ type=5,
+ cpp_type=1,
+ label=1,
+ has_default_value=True,
+ default_value=0,
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ options=None,
+ ),
+ _descriptor.FieldDescriptor(
+ name="shuffle_input_sentence",
+ full_name="sentencepiece.TrainerSpec.shuffle_input_sentence",
+ index=9,
+ number=19,
+ type=8,
+ cpp_type=7,
+ label=1,
+ has_default_value=True,
+ default_value=True,
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ options=None,
+ ),
+ _descriptor.FieldDescriptor(
+ name="mining_sentence_size",
+ full_name="sentencepiece.TrainerSpec.mining_sentence_size",
+ index=10,
+ number=12,
+ type=5,
+ cpp_type=1,
+ label=1,
+ has_default_value=False,
+ default_value=0,
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ options=_descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b("\030\001")),
+ ),
+ _descriptor.FieldDescriptor(
+ name="training_sentence_size",
+ full_name="sentencepiece.TrainerSpec.training_sentence_size",
+ index=11,
+ number=13,
+ type=5,
+ cpp_type=1,
+ label=1,
+ has_default_value=False,
+ default_value=0,
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ options=_descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b("\030\001")),
+ ),
+ _descriptor.FieldDescriptor(
+ name="seed_sentencepiece_size",
+ full_name="sentencepiece.TrainerSpec.seed_sentencepiece_size",
+ index=12,
+ number=14,
+ type=5,
+ cpp_type=1,
+ label=1,
+ has_default_value=True,
+ default_value=1000000,
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ options=None,
+ ),
+ _descriptor.FieldDescriptor(
+ name="shrinking_factor",
+ full_name="sentencepiece.TrainerSpec.shrinking_factor",
+ index=13,
+ number=15,
+ type=2,
+ cpp_type=6,
+ label=1,
+ has_default_value=True,
+ default_value=float(0.75),
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ options=None,
+ ),
+ _descriptor.FieldDescriptor(
+ name="max_sentence_length",
+ full_name="sentencepiece.TrainerSpec.max_sentence_length",
+ index=14,
+ number=18,
+ type=5,
+ cpp_type=1,
+ label=1,
+ has_default_value=True,
+ default_value=4192,
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ options=None,
+ ),
+ _descriptor.FieldDescriptor(
+ name="num_threads",
+ full_name="sentencepiece.TrainerSpec.num_threads",
+ index=15,
+ number=16,
+ type=5,
+ cpp_type=1,
+ label=1,
+ has_default_value=True,
+ default_value=16,
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ options=None,
+ ),
+ _descriptor.FieldDescriptor(
+ name="num_sub_iterations",
+ full_name="sentencepiece.TrainerSpec.num_sub_iterations",
+ index=16,
+ number=17,
+ type=5,
+ cpp_type=1,
+ label=1,
+ has_default_value=True,
+ default_value=2,
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ options=None,
+ ),
+ _descriptor.FieldDescriptor(
+ name="max_sentencepiece_length",
+ full_name="sentencepiece.TrainerSpec.max_sentencepiece_length",
+ index=17,
+ number=20,
+ type=5,
+ cpp_type=1,
+ label=1,
+ has_default_value=True,
+ default_value=16,
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ options=None,
+ ),
+ _descriptor.FieldDescriptor(
+ name="split_by_unicode_script",
+ full_name="sentencepiece.TrainerSpec.split_by_unicode_script",
+ index=18,
+ number=21,
+ type=8,
+ cpp_type=7,
+ label=1,
+ has_default_value=True,
+ default_value=True,
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ options=None,
+ ),
+ _descriptor.FieldDescriptor(
+ name="split_by_number",
+ full_name="sentencepiece.TrainerSpec.split_by_number",
+ index=19,
+ number=23,
+ type=8,
+ cpp_type=7,
+ label=1,
+ has_default_value=True,
+ default_value=True,
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ options=None,
+ ),
+ _descriptor.FieldDescriptor(
+ name="split_by_whitespace",
+ full_name="sentencepiece.TrainerSpec.split_by_whitespace",
+ index=20,
+ number=22,
+ type=8,
+ cpp_type=7,
+ label=1,
+ has_default_value=True,
+ default_value=True,
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ options=None,
+ ),
+ _descriptor.FieldDescriptor(
+ name="treat_whitespace_as_suffix",
+ full_name="sentencepiece.TrainerSpec.treat_whitespace_as_suffix",
+ index=21,
+ number=24,
+ type=8,
+ cpp_type=7,
+ label=1,
+ has_default_value=True,
+ default_value=False,
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ options=None,
+ ),
+ _descriptor.FieldDescriptor(
+ name="control_symbols",
+ full_name="sentencepiece.TrainerSpec.control_symbols",
+ index=22,
+ number=30,
+ type=9,
+ cpp_type=9,
+ label=3,
+ has_default_value=False,
+ default_value=[],
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ options=None,
+ ),
+ _descriptor.FieldDescriptor(
+ name="user_defined_symbols",
+ full_name="sentencepiece.TrainerSpec.user_defined_symbols",
+ index=23,
+ number=31,
+ type=9,
+ cpp_type=9,
+ label=3,
+ has_default_value=False,
+ default_value=[],
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ options=None,
+ ),
+ _descriptor.FieldDescriptor(
+ name="hard_vocab_limit",
+ full_name="sentencepiece.TrainerSpec.hard_vocab_limit",
+ index=24,
+ number=33,
+ type=8,
+ cpp_type=7,
+ label=1,
+ has_default_value=True,
+ default_value=True,
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ options=None,
+ ),
+ _descriptor.FieldDescriptor(
+ name="use_all_vocab",
+ full_name="sentencepiece.TrainerSpec.use_all_vocab",
+ index=25,
+ number=34,
+ type=8,
+ cpp_type=7,
+ label=1,
+ has_default_value=True,
+ default_value=False,
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ options=None,
+ ),
+ _descriptor.FieldDescriptor(
+ name="unk_id",
+ full_name="sentencepiece.TrainerSpec.unk_id",
+ index=26,
+ number=40,
+ type=5,
+ cpp_type=1,
+ label=1,
+ has_default_value=True,
+ default_value=0,
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ options=None,
+ ),
+ _descriptor.FieldDescriptor(
+ name="bos_id",
+ full_name="sentencepiece.TrainerSpec.bos_id",
+ index=27,
+ number=41,
+ type=5,
+ cpp_type=1,
+ label=1,
+ has_default_value=True,
+ default_value=1,
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ options=None,
+ ),
+ _descriptor.FieldDescriptor(
+ name="eos_id",
+ full_name="sentencepiece.TrainerSpec.eos_id",
+ index=28,
+ number=42,
+ type=5,
+ cpp_type=1,
+ label=1,
+ has_default_value=True,
+ default_value=2,
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ options=None,
+ ),
+ _descriptor.FieldDescriptor(
+ name="pad_id",
+ full_name="sentencepiece.TrainerSpec.pad_id",
+ index=29,
+ number=43,
+ type=5,
+ cpp_type=1,
+ label=1,
+ has_default_value=True,
+ default_value=-1,
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ options=None,
+ ),
+ _descriptor.FieldDescriptor(
+ name="unk_piece",
+ full_name="sentencepiece.TrainerSpec.unk_piece",
+ index=30,
+ number=45,
+ type=9,
+ cpp_type=9,
+ label=1,
+ has_default_value=True,
+ default_value=_b("").decode("utf-8"),
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ options=None,
+ ),
+ _descriptor.FieldDescriptor(
+ name="bos_piece",
+ full_name="sentencepiece.TrainerSpec.bos_piece",
+ index=31,
+ number=46,
+ type=9,
+ cpp_type=9,
+ label=1,
+ has_default_value=True,
+ default_value=_b("").decode("utf-8"),
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ options=None,
+ ),
+ _descriptor.FieldDescriptor(
+ name="eos_piece",
+ full_name="sentencepiece.TrainerSpec.eos_piece",
+ index=32,
+ number=47,
+ type=9,
+ cpp_type=9,
+ label=1,
+ has_default_value=True,
+ default_value=_b("").decode("utf-8"),
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ options=None,
+ ),
+ _descriptor.FieldDescriptor(
+ name="pad_piece",
+ full_name="sentencepiece.TrainerSpec.pad_piece",
+ index=33,
+ number=48,
+ type=9,
+ cpp_type=9,
+ label=1,
+ has_default_value=True,
+ default_value=_b("").decode("utf-8"),
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ options=None,
+ ),
+ _descriptor.FieldDescriptor(
+ name="unk_surface",
+ full_name="sentencepiece.TrainerSpec.unk_surface",
+ index=34,
+ number=44,
+ type=9,
+ cpp_type=9,
+ label=1,
+ has_default_value=True,
+ default_value=_b(" \342\201\207 ").decode("utf-8"),
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ options=None,
+ ),
+ ],
+ extensions=[],
+ nested_types=[],
+ enum_types=[
+ _TRAINERSPEC_MODELTYPE,
+ ],
+ options=None,
+ is_extendable=True,
+ syntax="proto2",
+ extension_ranges=[
+ (200, 536870912),
+ ],
+ oneofs=[],
+ serialized_start=45,
+ serialized_end=1185,
+)
+
+
+_NORMALIZERSPEC = _descriptor.Descriptor(
+ name="NormalizerSpec",
+ full_name="sentencepiece.NormalizerSpec",
+ filename=None,
+ file=DESCRIPTOR,
+ containing_type=None,
+ fields=[
+ _descriptor.FieldDescriptor(
+ name="name",
+ full_name="sentencepiece.NormalizerSpec.name",
+ index=0,
+ number=1,
+ type=9,
+ cpp_type=9,
+ label=1,
+ has_default_value=False,
+ default_value=_b("").decode("utf-8"),
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ options=None,
+ ),
+ _descriptor.FieldDescriptor(
+ name="precompiled_charsmap",
+ full_name="sentencepiece.NormalizerSpec.precompiled_charsmap",
+ index=1,
+ number=2,
+ type=12,
+ cpp_type=9,
+ label=1,
+ has_default_value=False,
+ default_value=_b(""),
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ options=None,
+ ),
+ _descriptor.FieldDescriptor(
+ name="add_dummy_prefix",
+ full_name="sentencepiece.NormalizerSpec.add_dummy_prefix",
+ index=2,
+ number=3,
+ type=8,
+ cpp_type=7,
+ label=1,
+ has_default_value=True,
+ default_value=True,
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ options=None,
+ ),
+ _descriptor.FieldDescriptor(
+ name="remove_extra_whitespaces",
+ full_name="sentencepiece.NormalizerSpec.remove_extra_whitespaces",
+ index=3,
+ number=4,
+ type=8,
+ cpp_type=7,
+ label=1,
+ has_default_value=True,
+ default_value=True,
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ options=None,
+ ),
+ _descriptor.FieldDescriptor(
+ name="escape_whitespaces",
+ full_name="sentencepiece.NormalizerSpec.escape_whitespaces",
+ index=4,
+ number=5,
+ type=8,
+ cpp_type=7,
+ label=1,
+ has_default_value=True,
+ default_value=True,
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ options=None,
+ ),
+ _descriptor.FieldDescriptor(
+ name="normalization_rule_tsv",
+ full_name="sentencepiece.NormalizerSpec.normalization_rule_tsv",
+ index=5,
+ number=6,
+ type=9,
+ cpp_type=9,
+ label=1,
+ has_default_value=False,
+ default_value=_b("").decode("utf-8"),
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ options=None,
+ ),
+ ],
+ extensions=[],
+ nested_types=[],
+ enum_types=[],
+ options=None,
+ is_extendable=True,
+ syntax="proto2",
+ extension_ranges=[
+ (200, 536870912),
+ ],
+ oneofs=[],
+ serialized_start=1188,
+ serialized_end=1397,
+)
+
+
+_SELFTESTDATA_SAMPLE = _descriptor.Descriptor(
+ name="Sample",
+ full_name="sentencepiece.SelfTestData.Sample",
+ filename=None,
+ file=DESCRIPTOR,
+ containing_type=None,
+ fields=[
+ _descriptor.FieldDescriptor(
+ name="input",
+ full_name="sentencepiece.SelfTestData.Sample.input",
+ index=0,
+ number=1,
+ type=9,
+ cpp_type=9,
+ label=1,
+ has_default_value=False,
+ default_value=_b("").decode("utf-8"),
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ options=None,
+ ),
+ _descriptor.FieldDescriptor(
+ name="expected",
+ full_name="sentencepiece.SelfTestData.Sample.expected",
+ index=1,
+ number=2,
+ type=9,
+ cpp_type=9,
+ label=1,
+ has_default_value=False,
+ default_value=_b("").decode("utf-8"),
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ options=None,
+ ),
+ ],
+ extensions=[],
+ nested_types=[],
+ enum_types=[],
+ options=None,
+ is_extendable=False,
+ syntax="proto2",
+ extension_ranges=[],
+ oneofs=[],
+ serialized_start=1468,
+ serialized_end=1509,
+)
+
+_SELFTESTDATA = _descriptor.Descriptor(
+ name="SelfTestData",
+ full_name="sentencepiece.SelfTestData",
+ filename=None,
+ file=DESCRIPTOR,
+ containing_type=None,
+ fields=[
+ _descriptor.FieldDescriptor(
+ name="samples",
+ full_name="sentencepiece.SelfTestData.samples",
+ index=0,
+ number=1,
+ type=11,
+ cpp_type=10,
+ label=3,
+ has_default_value=False,
+ default_value=[],
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ options=None,
+ ),
+ ],
+ extensions=[],
+ nested_types=[
+ _SELFTESTDATA_SAMPLE,
+ ],
+ enum_types=[],
+ options=None,
+ is_extendable=True,
+ syntax="proto2",
+ extension_ranges=[
+ (200, 536870912),
+ ],
+ oneofs=[],
+ serialized_start=1399,
+ serialized_end=1520,
+)
+
+
+_MODELPROTO_SENTENCEPIECE = _descriptor.Descriptor(
+ name="SentencePiece",
+ full_name="sentencepiece.ModelProto.SentencePiece",
+ filename=None,
+ file=DESCRIPTOR,
+ containing_type=None,
+ fields=[
+ _descriptor.FieldDescriptor(
+ name="piece",
+ full_name="sentencepiece.ModelProto.SentencePiece.piece",
+ index=0,
+ number=1,
+ type=9,
+ cpp_type=9,
+ label=1,
+ has_default_value=False,
+ default_value=_b("").decode("utf-8"),
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ options=None,
+ ),
+ _descriptor.FieldDescriptor(
+ name="score",
+ full_name="sentencepiece.ModelProto.SentencePiece.score",
+ index=1,
+ number=2,
+ type=2,
+ cpp_type=6,
+ label=1,
+ has_default_value=False,
+ default_value=float(0),
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ options=None,
+ ),
+ _descriptor.FieldDescriptor(
+ name="type",
+ full_name="sentencepiece.ModelProto.SentencePiece.type",
+ index=2,
+ number=3,
+ type=14,
+ cpp_type=8,
+ label=1,
+ has_default_value=True,
+ default_value=1,
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ options=None,
+ ),
+ ],
+ extensions=[],
+ nested_types=[],
+ enum_types=[
+ _MODELPROTO_SENTENCEPIECE_TYPE,
+ ],
+ options=None,
+ is_extendable=True,
+ syntax="proto2",
+ extension_ranges=[
+ (200, 536870912),
+ ],
+ oneofs=[],
+ serialized_start=1754,
+ serialized_end=1954,
+)
+
+_MODELPROTO = _descriptor.Descriptor(
+ name="ModelProto",
+ full_name="sentencepiece.ModelProto",
+ filename=None,
+ file=DESCRIPTOR,
+ containing_type=None,
+ fields=[
+ _descriptor.FieldDescriptor(
+ name="pieces",
+ full_name="sentencepiece.ModelProto.pieces",
+ index=0,
+ number=1,
+ type=11,
+ cpp_type=10,
+ label=3,
+ has_default_value=False,
+ default_value=[],
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ options=None,
+ ),
+ _descriptor.FieldDescriptor(
+ name="trainer_spec",
+ full_name="sentencepiece.ModelProto.trainer_spec",
+ index=1,
+ number=2,
+ type=11,
+ cpp_type=10,
+ label=1,
+ has_default_value=False,
+ default_value=None,
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ options=None,
+ ),
+ _descriptor.FieldDescriptor(
+ name="normalizer_spec",
+ full_name="sentencepiece.ModelProto.normalizer_spec",
+ index=2,
+ number=3,
+ type=11,
+ cpp_type=10,
+ label=1,
+ has_default_value=False,
+ default_value=None,
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ options=None,
+ ),
+ _descriptor.FieldDescriptor(
+ name="self_test_data",
+ full_name="sentencepiece.ModelProto.self_test_data",
+ index=3,
+ number=4,
+ type=11,
+ cpp_type=10,
+ label=1,
+ has_default_value=False,
+ default_value=None,
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ options=None,
+ ),
+ ],
+ extensions=[],
+ nested_types=[
+ _MODELPROTO_SENTENCEPIECE,
+ ],
+ enum_types=[],
+ options=None,
+ is_extendable=True,
+ syntax="proto2",
+ extension_ranges=[
+ (200, 536870912),
+ ],
+ oneofs=[],
+ serialized_start=1523,
+ serialized_end=1965,
+)
+
+_TRAINERSPEC.fields_by_name["model_type"].enum_type = _TRAINERSPEC_MODELTYPE
+_TRAINERSPEC_MODELTYPE.containing_type = _TRAINERSPEC
+_SELFTESTDATA_SAMPLE.containing_type = _SELFTESTDATA
+_SELFTESTDATA.fields_by_name["samples"].message_type = _SELFTESTDATA_SAMPLE
+_MODELPROTO_SENTENCEPIECE.fields_by_name["type"].enum_type = _MODELPROTO_SENTENCEPIECE_TYPE
+_MODELPROTO_SENTENCEPIECE.containing_type = _MODELPROTO
+_MODELPROTO_SENTENCEPIECE_TYPE.containing_type = _MODELPROTO_SENTENCEPIECE
+_MODELPROTO.fields_by_name["pieces"].message_type = _MODELPROTO_SENTENCEPIECE
+_MODELPROTO.fields_by_name["trainer_spec"].message_type = _TRAINERSPEC
+_MODELPROTO.fields_by_name["normalizer_spec"].message_type = _NORMALIZERSPEC
+_MODELPROTO.fields_by_name["self_test_data"].message_type = _SELFTESTDATA
+DESCRIPTOR.message_types_by_name["TrainerSpec"] = _TRAINERSPEC
+DESCRIPTOR.message_types_by_name["NormalizerSpec"] = _NORMALIZERSPEC
+DESCRIPTOR.message_types_by_name["SelfTestData"] = _SELFTESTDATA
+DESCRIPTOR.message_types_by_name["ModelProto"] = _MODELPROTO
+
+TrainerSpec = _reflection.GeneratedProtocolMessageType(
+ "TrainerSpec",
+ (_message.Message,),
+ dict(
+ DESCRIPTOR=_TRAINERSPEC,
+ __module__="sentencepiece_model_pb2"
+ # @@protoc_insertion_point(class_scope:sentencepiece.TrainerSpec)
+ ),
+)
+_sym_db.RegisterMessage(TrainerSpec)
+
+NormalizerSpec = _reflection.GeneratedProtocolMessageType(
+ "NormalizerSpec",
+ (_message.Message,),
+ dict(
+ DESCRIPTOR=_NORMALIZERSPEC,
+ __module__="sentencepiece_model_pb2"
+ # @@protoc_insertion_point(class_scope:sentencepiece.NormalizerSpec)
+ ),
+)
+_sym_db.RegisterMessage(NormalizerSpec)
+
+SelfTestData = _reflection.GeneratedProtocolMessageType(
+ "SelfTestData",
+ (_message.Message,),
+ dict(
+ Sample=_reflection.GeneratedProtocolMessageType(
+ "Sample",
+ (_message.Message,),
+ dict(
+ DESCRIPTOR=_SELFTESTDATA_SAMPLE,
+ __module__="sentencepiece_model_pb2"
+ # @@protoc_insertion_point(class_scope:sentencepiece.SelfTestData.Sample)
+ ),
+ ),
+ DESCRIPTOR=_SELFTESTDATA,
+ __module__="sentencepiece_model_pb2"
+ # @@protoc_insertion_point(class_scope:sentencepiece.SelfTestData)
+ ),
+)
+_sym_db.RegisterMessage(SelfTestData)
+_sym_db.RegisterMessage(SelfTestData.Sample)
+
+ModelProto = _reflection.GeneratedProtocolMessageType(
+ "ModelProto",
+ (_message.Message,),
+ dict(
+ SentencePiece=_reflection.GeneratedProtocolMessageType(
+ "SentencePiece",
+ (_message.Message,),
+ dict(
+ DESCRIPTOR=_MODELPROTO_SENTENCEPIECE,
+ __module__="sentencepiece_model_pb2"
+ # @@protoc_insertion_point(class_scope:sentencepiece.ModelProto.SentencePiece)
+ ),
+ ),
+ DESCRIPTOR=_MODELPROTO,
+ __module__="sentencepiece_model_pb2"
+ # @@protoc_insertion_point(class_scope:sentencepiece.ModelProto)
+ ),
+)
+_sym_db.RegisterMessage(ModelProto)
+_sym_db.RegisterMessage(ModelProto.SentencePiece)
+
+
+DESCRIPTOR.has_options = True
+DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b("H\003"))
+_TRAINERSPEC.fields_by_name["mining_sentence_size"].has_options = True
+_TRAINERSPEC.fields_by_name["mining_sentence_size"]._options = _descriptor._ParseOptions(
+ descriptor_pb2.FieldOptions(), _b("\030\001")
+)
+_TRAINERSPEC.fields_by_name["training_sentence_size"].has_options = True
+_TRAINERSPEC.fields_by_name["training_sentence_size"]._options = _descriptor._ParseOptions(
+ descriptor_pb2.FieldOptions(), _b("\030\001")
+)
+# @@protoc_insertion_point(module_scope)
diff --git a/tests/test_tokenization_albert.py b/tests/test_tokenization_albert.py
index d1a7c65e22..724b98327e 100644
--- a/tests/test_tokenization_albert.py
+++ b/tests/test_tokenization_albert.py
@@ -17,7 +17,7 @@
import os
import unittest
-from transformers.tokenization_albert import AlbertTokenizer
+from transformers.tokenization_albert import AlbertTokenizer, AlbertTokenizerFast
from .test_tokenization_common import TokenizerTesterMixin
@@ -28,6 +28,8 @@ SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixture
class AlbertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = AlbertTokenizer
+ rust_tokenizer_class = AlbertTokenizerFast
+ test_rust_tokenizer = True
def setUp(self):
super().setUp()
@@ -41,6 +43,28 @@ class AlbertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
output_text = "this is a test"
return input_text, output_text
+ def test_rust_and_python_full_tokenizers(self):
+ if not self.test_rust_tokenizer:
+ return
+
+ tokenizer = self.get_tokenizer()
+ rust_tokenizer = self.get_rust_tokenizer()
+
+ sequence = "I was born in 92000, and this is falsé."
+
+ tokens = tokenizer.tokenize(sequence)
+ rust_tokens = rust_tokenizer.tokenize(sequence)
+ self.assertListEqual(tokens, rust_tokens)
+
+ ids = tokenizer.encode(sequence, add_special_tokens=False)
+ rust_ids = rust_tokenizer.encode(sequence, add_special_tokens=False)
+ self.assertListEqual(ids, rust_ids)
+
+ rust_tokenizer = self.get_rust_tokenizer()
+ ids = tokenizer.encode(sequence)
+ rust_ids = rust_tokenizer.encode(sequence)
+ self.assertListEqual(ids, rust_ids)
+
def test_full_tokenizer(self):
tokenizer = AlbertTokenizer(SAMPLE_VOCAB, keep_accents=True)
diff --git a/tests/test_tokenization_bart.py b/tests/test_tokenization_bart.py
index bbd448b24a..0aa1c74684 100644
--- a/tests/test_tokenization_bart.py
+++ b/tests/test_tokenization_bart.py
@@ -12,6 +12,8 @@ from .test_tokenization_common import TokenizerTesterMixin
class TestTokenizationBart(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = BartTokenizer
+ rust_tokenizer_class = BartTokenizerFast
+ test_rust_tokenizer = True
def setUp(self):
super().setUp()
diff --git a/tests/test_tokenization_bert.py b/tests/test_tokenization_bert.py
index 4421d30de4..015e534678 100644
--- a/tests/test_tokenization_bert.py
+++ b/tests/test_tokenization_bert.py
@@ -35,7 +35,9 @@ from .test_tokenization_common import TokenizerTesterMixin
class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = BertTokenizer
+ rust_tokenizer_class = BertTokenizerFast
test_rust_tokenizer = True
+ space_between_special_tokens = True
def setUp(self):
super().setUp()
@@ -61,9 +63,6 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
- def get_rust_tokenizer(self, **kwargs):
- return BertTokenizerFast.from_pretrained(self.tmpdirname, **kwargs)
-
def get_input_output_texts(self, tokenizer):
input_text = "UNwant\u00E9d,running"
output_text = "unwanted, running"
diff --git a/tests/test_tokenization_bert_japanese.py b/tests/test_tokenization_bert_japanese.py
index b14f19f9ad..9953dc72d4 100644
--- a/tests/test_tokenization_bert_japanese.py
+++ b/tests/test_tokenization_bert_japanese.py
@@ -15,6 +15,7 @@
import os
+import pickle
import unittest
from transformers.testing_utils import custom_tokenizers
@@ -33,6 +34,7 @@ from .test_tokenization_common import TokenizerTesterMixin
class BertJapaneseTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = BertJapaneseTokenizer
+ space_between_special_tokens = True
def setUp(self):
super().setUp()
@@ -87,6 +89,26 @@ class BertJapaneseTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
self.assertListEqual(tokens, ["こんにちは", "、", "世界", "。", "こん", "##ばんは", "、", "世界", "。"])
self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [3, 12, 10, 14, 4, 9, 12, 10, 14])
+ def test_pickle_mecab_tokenizer(self):
+ tokenizer = self.tokenizer_class(self.vocab_file, word_tokenizer_type="mecab")
+ self.assertIsNotNone(tokenizer)
+
+ text = "こんにちは、世界。\nこんばんは、世界。"
+ tokens = tokenizer.tokenize(text)
+ self.assertListEqual(tokens, ["こんにちは", "、", "世界", "。", "こん", "##ばんは", "、", "世界", "。"])
+ self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [3, 12, 10, 14, 4, 9, 12, 10, 14])
+
+ filename = os.path.join(self.tmpdirname, "tokenizer.bin")
+ with open(filename, "wb") as handle:
+ pickle.dump(tokenizer, handle)
+
+ with open(filename, "rb") as handle:
+ tokenizer_new = pickle.load(handle)
+
+ tokens_loaded = tokenizer_new.tokenize(text)
+
+ self.assertListEqual(tokens, tokens_loaded)
+
def test_mecab_tokenizer_ipadic(self):
tokenizer = MecabTokenizer(mecab_dic="ipadic")
diff --git a/tests/test_tokenization_camembert.py b/tests/test_tokenization_camembert.py
new file mode 100644
index 0000000000..c8eae66d48
--- /dev/null
+++ b/tests/test_tokenization_camembert.py
@@ -0,0 +1,64 @@
+# coding=utf-8
+# Copyright 2018 Google T5 Authors and HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+import unittest
+
+from transformers.testing_utils import _torch_available
+from transformers.tokenization_camembert import CamembertTokenizer, CamembertTokenizerFast
+
+from .test_tokenization_common import TokenizerTesterMixin
+
+
+SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model")
+
+FRAMEWORK = "pt" if _torch_available else "tf"
+
+
+class CamembertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
+
+ tokenizer_class = CamembertTokenizer
+ rust_tokenizer_class = CamembertTokenizerFast
+ test_rust_tokenizer = True
+
+ def setUp(self):
+ super().setUp()
+
+ # We have a SentencePiece fixture for testing
+ tokenizer = CamembertTokenizer(SAMPLE_VOCAB)
+ tokenizer.save_pretrained(self.tmpdirname)
+
+ def test_rust_and_python_full_tokenizers(self):
+ if not self.test_rust_tokenizer:
+ return
+
+ tokenizer = self.get_tokenizer()
+ rust_tokenizer = self.get_rust_tokenizer()
+
+ sequence = "I was born in 92000, and this is falsé."
+
+ tokens = tokenizer.tokenize(sequence)
+ rust_tokens = rust_tokenizer.tokenize(sequence)
+ self.assertListEqual(tokens, rust_tokens)
+
+ ids = tokenizer.encode(sequence, add_special_tokens=False)
+ rust_ids = rust_tokenizer.encode(sequence, add_special_tokens=False)
+ self.assertListEqual(ids, rust_ids)
+
+ rust_tokenizer = self.get_rust_tokenizer()
+ ids = tokenizer.encode(sequence)
+ rust_ids = rust_tokenizer.encode(sequence)
+ self.assertListEqual(ids, rust_ids)
diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py
index d5bfe9a7c5..74c57e5e3c 100644
--- a/tests/test_tokenization_common.py
+++ b/tests/test_tokenization_common.py
@@ -56,7 +56,9 @@ def merge_model_tokenizer_mappings(
class TokenizerTesterMixin:
tokenizer_class = None
+ rust_tokenizer_class = None
test_rust_tokenizer = False
+ space_between_special_tokens = False
def setUp(self):
self.tmpdirname = tempfile.mkdtemp()
@@ -68,12 +70,15 @@ class TokenizerTesterMixin:
input_txt = self.get_clean_sequence(tokenizer)[0]
return input_txt, input_txt
- def get_clean_sequence(self, tokenizer, with_prefix_space=False, max_length=20) -> Tuple[str, list]:
+ def get_clean_sequence(self, tokenizer, with_prefix_space=False, max_length=20, min_length=5) -> Tuple[str, list]:
toks = [(i, tokenizer.decode([i], clean_up_tokenization_spaces=False)) for i in range(len(tokenizer))]
toks = list(filter(lambda t: re.match(r"^[ a-zA-Z]+$", t[1]), toks))
toks = list(filter(lambda t: [t[0]] == tokenizer.encode(t[1], add_special_tokens=False), toks))
if max_length is not None and len(toks) > max_length:
toks = toks[:max_length]
+ if min_length is not None and len(toks) < min_length and len(toks) > 0:
+ while len(toks) < min_length:
+ toks = toks + toks
# toks_str = [t[1] for t in toks]
toks_ids = [t[0] for t in toks]
@@ -99,7 +104,7 @@ class TokenizerTesterMixin:
return self.tokenizer_class.from_pretrained(self.tmpdirname, **kwargs)
def get_rust_tokenizer(self, **kwargs) -> PreTrainedTokenizerFast:
- raise NotImplementedError
+ return self.rust_tokenizer_class.from_pretrained(self.tmpdirname, **kwargs)
# def get_input_output_texts(self) -> Tuple[str, str]:
# """Feel free to overwrite"""
@@ -118,6 +123,29 @@ class TokenizerTesterMixin:
for i in range(len(batch_encode_plus_sequences["input_ids"]))
]
+ def test_rust_and_python_full_tokenizers(self):
+ if not self.test_rust_tokenizer:
+ return
+
+ tokenizer = self.get_tokenizer()
+ rust_tokenizer = self.get_rust_tokenizer()
+
+ sequence, _ = self.get_input_output_texts(tokenizer)
+
+ # We don't have an exact equivalence on `tokenize()` between Rust and Slow
+ # Slow tokenizer only split tokens, Rust tokenizers will replace with
+ # tokens = tokenizer.tokenize(sequence)
+ # rust_tokens = rust_tokenizer.tokenize(sequence)
+ # self.assertListEqual(tokens, rust_tokens)
+
+ ids = tokenizer.encode(sequence, add_special_tokens=False)
+ rust_ids = rust_tokenizer.encode(sequence, add_special_tokens=False)
+ self.assertListEqual(ids, rust_ids)
+
+ ids = tokenizer.encode(sequence, add_special_tokens=True)
+ rust_ids = rust_tokenizer.encode(sequence, add_special_tokens=True)
+ self.assertListEqual(ids, rust_ids)
+
def test_tokenizers_common_properties(self):
tokenizers = self.get_tokenizers()
for tokenizer in tokenizers:
@@ -241,6 +269,9 @@ class TokenizerTesterMixin:
tokenizers = self.get_tokenizers(fast=False, do_lower_case=True)
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
+ if not hasattr(tokenizer, "do_lower_case") or not tokenizer.do_lower_case:
+ continue
+
special_token = tokenizer.all_special_tokens[0]
text = special_token + " aaaaa bbbbbb low cccccccccdddddddd l " + special_token
@@ -272,6 +303,9 @@ class TokenizerTesterMixin:
tokenizers = self.get_tokenizers(fast=False, do_lower_case=False)
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
+ if hasattr(tokenizer, "do_lower_case") and tokenizer.do_lower_case:
+ continue
+
special_token = tokenizer.all_special_tokens[0]
text = special_token + " aaaaa bbbbbb low cccccccccdddddddd l " + special_token
@@ -282,7 +316,7 @@ class TokenizerTesterMixin:
toks0 = tokenizer.tokenize(text) # toks before adding new_toks
added = tokenizer.add_tokens(new_toks)
- self.assertEqual(added, 4)
+ self.assertIn(added, [2, 4])
toks = tokenizer.tokenize(text)
toks2 = tokenizer.tokenize(text2)
@@ -390,12 +424,17 @@ class TokenizerTesterMixin:
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
- new_toks = ["[ABC]", "[DEF]"] # TODO(thom) add this one back when Rust toks are ready: , "GHI IHG"]
+ # new_toks = ["[ABC]", "[DEF]"] # TODO(thom) add this one back when Rust toks are ready: , "GHI IHG"]
+ new_toks = [AddedToken("[ABC]", normalized=False), AddedToken("[DEF]", normalized=False)]
tokenizer.add_tokens(new_toks)
- input = "[ABC] [DEF] [ABC] [DEF]" # TODO(thom) add back cf above: "[ABC] [DEF] [ABC] GHI IHG [DEF]"
+ input = "[ABC][DEF][ABC][DEF]" # TODO(thom) add back cf above: "[ABC] [DEF] [ABC] GHI IHG [DEF]"
+ if self.space_between_special_tokens:
+ output = "[ABC] [DEF] [ABC] [DEF]"
+ else:
+ output = input
encoded = tokenizer.encode(input, add_special_tokens=False)
- decoded = tokenizer.decode(encoded)
- self.assertEqual(decoded, input)
+ decoded = tokenizer.decode(encoded, spaces_between_special_tokens=self.space_between_special_tokens)
+ self.assertIn(decoded, [output, output.lower()])
def test_pretrained_model_lists(self):
weights_list = list(self.tokenizer_class.max_model_input_sizes.keys())
@@ -447,7 +486,7 @@ class TokenizerTesterMixin:
sequence = tokenizer.encode(seq_0, add_special_tokens=False)
total_length = len(sequence)
- assert total_length > 1, "Issue with the testing sequence, please update it it's too short"
+ assert total_length > 4, "Issue with the testing sequence, please update it it's too short"
# Test with max model input length
model_max_length = tokenizer.model_max_length
@@ -546,6 +585,7 @@ class TokenizerTesterMixin:
model_max_length = tokenizer.model_max_length
self.assertEqual(model_max_length, 100)
seq_2 = seq_0 * model_max_length
+ assert len(seq_2) > model_max_length
sequence1 = tokenizer(seq_1, add_special_tokens=False)
total_length1 = len(sequence1["input_ids"])
@@ -559,9 +599,9 @@ class TokenizerTesterMixin:
[False, True, "longest"] if tokenizer.pad_token and tokenizer.pad_token_id >= 0 else [False]
)
for padding_state in padding_strategies:
- with self.subTest(f"Padding: {padding_state}"):
+ with self.subTest(f"{tokenizer.__class__.__name__} Padding: {padding_state}"):
for truncation_state in [True, "longest_first", "only_first"]:
- with self.subTest(f"Truncation: {truncation_state}"):
+ with self.subTest(f"{tokenizer.__class__.__name__} Truncation: {truncation_state}"):
output = tokenizer(seq_2, seq_1, padding=padding_state, truncation=truncation_state)
self.assertEqual(len(output["input_ids"]), model_max_length)
@@ -748,34 +788,47 @@ class TokenizerTesterMixin:
# # This is not supported with the Rust tokenizers
# # self.assertEqual(tokenizer.encode(input_ids, add_special_tokens=True), formatted_input)
- def test_swap_special_token(self):
- tokenizers = self.get_tokenizers(do_lower_case=False)
- for tokenizer in tokenizers:
- with self.subTest(f"{tokenizer.__class__.__name__}"):
- mask = ""
- sequence = "Encode this sequence"
- sequence_masked_0 = "Encode sequence"
- sequence_masked_1 = " this sequence"
+ # def test_swap_special_token(self):
+ # tokenizers = self.get_tokenizers(do_lower_case=False)
+ # for tokenizer in tokenizers:
+ # with self.subTest(f"{tokenizer.__class__.__name__}"):
+ # # Our mask token
+ # mask = ""
+ # # We take a single word in the middle of the vocabulary
+ # all_tokens = sorted(tokenizer.get_vocab().keys())
+ # word = tokenizer.decode(tokenizer.encode(all_tokens[len(all_tokens)//2], add_special_tokens=False)[:1])
- # Add tokens so that masked token isn't split
- tokenizer.add_tokens(sequence.split())
- tokenizer.add_special_tokens({"mask_token": mask})
- mask_ind = tokenizer.convert_tokens_to_ids(mask)
- encoded = tokenizer.encode(sequence, add_special_tokens=False)
+ # sequence_0 = "Encode " + word + " sequence"
+ # sequence_masked_0 = "Encode " + mask + " sequence"
- # Test first masked sequence
- encoded_masked = tokenizer.encode(sequence_masked_0, add_special_tokens=False)
- mask_loc = encoded_masked.index(mask_ind)
- encoded_masked[mask_loc] = encoded[mask_loc]
+ # sequence_1 = word + " this sequence"
+ # sequence_masked_1 = mask + " this sequence"
- self.assertEqual(encoded_masked, encoded)
+ # # Add tokens so that masked token isn't split
+ # # tokens = [AddedToken(t, lstrip=True, normalized=False) for t in sequence.split()]
+ # # tokenizer.add_tokens(tokens)
+ # tokenizer.add_special_tokens(
+ # {"mask_token": AddedToken(mask, normalized=False)}
+ # ) # Eat left space on Byte-level BPE tokenizers
+ # mask_ind = tokenizer.convert_tokens_to_ids(mask)
- # Test second masked sequence
- encoded_masked = tokenizer.encode(sequence_masked_1, add_special_tokens=False)
- mask_loc = encoded_masked.index(mask_ind)
- encoded_masked[mask_loc] = encoded[mask_loc]
+ # # Test first masked sequence
+ # encoded_0 = tokenizer.encode(sequence_0, add_special_tokens=False)
+ # encoded_masked = tokenizer.encode(sequence_masked_0, add_special_tokens=False)
+ # assert len(encoded_masked) == len(encoded_0)
+ # mask_loc = encoded_masked.index(mask_ind)
+ # encoded_masked[mask_loc] = encoded_0[mask_loc]
- self.assertEqual(encoded_masked, encoded)
+ # self.assertEqual(encoded_masked, encoded_0)
+
+ # # Test second masked sequence
+ # encoded_1 = tokenizer.encode(sequence_1, add_special_tokens=False)
+ # encoded_masked = tokenizer.encode(sequence_masked_1, add_special_tokens=False)
+ # assert len(encoded_masked) == len(encoded_1)
+ # mask_loc = encoded_masked.index(mask_ind)
+ # encoded_masked[mask_loc] = encoded_1[mask_loc]
+
+ # self.assertEqual(encoded_masked, encoded_1)
def test_special_tokens_mask(self):
tokenizers = self.get_tokenizers(do_lower_case=False)
@@ -919,10 +972,10 @@ class TokenizerTesterMixin:
def test_padding_to_multiple_of(self):
tokenizers = self.get_tokenizers()
for tokenizer in tokenizers:
- if tokenizer.pad_token is None:
- self.skipTest("No padding token.")
- else:
- with self.subTest(f"{tokenizer.__class__.__name__}"):
+ with self.subTest(f"{tokenizer.__class__.__name__}"):
+ if tokenizer.pad_token is None:
+ self.skipTest("No padding token.")
+ else:
empty_tokens = tokenizer("", padding=True, pad_to_multiple_of=8)
normal_tokens = tokenizer("This is a sample input", padding=True, pad_to_multiple_of=8)
for key, value in empty_tokens.items():
@@ -1063,14 +1116,15 @@ class TokenizerTesterMixin:
tokenizers = self.get_tokenizers(do_lower_case=False)
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
- vocab = tokenizer.get_vocab()
+ vocab_dict = tokenizer.get_vocab()
+ self.assertIsInstance(vocab_dict, dict)
+ self.assertGreaterEqual(len(tokenizer), len(vocab_dict))
- self.assertIsInstance(vocab, dict)
+ vocab = [tokenizer.convert_ids_to_tokens(i) for i in range(len(tokenizer))]
self.assertEqual(len(vocab), len(tokenizer))
tokenizer.add_tokens(["asdfasdfasdfasdf"])
- vocab = tokenizer.get_vocab()
- self.assertIsInstance(vocab, dict)
+ vocab = [tokenizer.convert_ids_to_tokens(i) for i in range(len(tokenizer))]
self.assertEqual(len(vocab), len(tokenizer))
def test_conversion_reversible(self):
@@ -1079,6 +1133,8 @@ class TokenizerTesterMixin:
with self.subTest(f"{tokenizer.__class__.__name__}"):
vocab = tokenizer.get_vocab()
for word, ind in vocab.items():
+ if word == tokenizer.unk_token:
+ continue
self.assertEqual(tokenizer.convert_tokens_to_ids(word), ind)
self.assertEqual(tokenizer.convert_ids_to_tokens(ind), word)
@@ -1173,12 +1229,13 @@ class TokenizerTesterMixin:
def test_added_token_serializable(self):
tokenizers = self.get_tokenizers(do_lower_case=False)
for tokenizer in tokenizers:
- new_token = AddedToken("new_token", lstrip=True)
- tokenizer.add_special_tokens({"additional_special_tokens": [new_token]})
+ with self.subTest(f"{tokenizer.__class__.__name__}"):
+ new_token = AddedToken("new_token", lstrip=True)
+ tokenizer.add_special_tokens({"additional_special_tokens": [new_token]})
- with tempfile.TemporaryDirectory() as tmp_dir_name:
- tokenizer.save_pretrained(tmp_dir_name)
- tokenizer.from_pretrained(tmp_dir_name)
+ with tempfile.TemporaryDirectory() as tmp_dir_name:
+ tokenizer.save_pretrained(tmp_dir_name)
+ tokenizer.from_pretrained(tmp_dir_name)
def test_batch_encode_plus_padding(self):
# Test that padded sequences are equivalent between batch_encode_plus and encode_plus
@@ -1243,6 +1300,9 @@ class TokenizerTesterMixin:
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
+ if hasattr(tokenizer, "add_prefix_space") and not tokenizer.add_prefix_space:
+ continue
+
# Prepare a sequence from our tokenizer vocabulary
sequence, ids = self.get_clean_sequence(tokenizer, with_prefix_space=True, max_length=20)
# sequence = " " + sequence # To be sure the byte-level tokenizers are feeling good
@@ -1345,12 +1405,14 @@ class TokenizerTesterMixin:
def test_prepare_for_model(self):
tokenizers = self.get_tokenizers(do_lower_case=False)
for tokenizer in tokenizers:
- string_sequence = "Testing the prepare_for_model method."
- ids = tokenizer.encode(string_sequence, add_special_tokens=False)
- input_dict = tokenizer.encode_plus(string_sequence)
- prepared_input_dict = tokenizer.prepare_for_model(ids)
+ with self.subTest(f"{tokenizer.__class__.__name__}"):
+ string_sequence = "Testing the prepare_for_model method."
+ ids = tokenizer.encode(string_sequence, add_special_tokens=False)
+ prepared_input_dict = tokenizer.prepare_for_model(ids, add_special_tokens=True)
- self.assertEqual(input_dict, prepared_input_dict)
+ input_dict = tokenizer.encode_plus(string_sequence, add_special_tokens=True)
+
+ self.assertEqual(input_dict, prepared_input_dict)
def test_batch_encode_plus_overflowing_tokens(self):
tokenizers = self.get_tokenizers(do_lower_case=False)
diff --git a/tests/test_tokenization_ctrl.py b/tests/test_tokenization_ctrl.py
index 59d543e1f6..34b2ec9789 100644
--- a/tests/test_tokenization_ctrl.py
+++ b/tests/test_tokenization_ctrl.py
@@ -25,6 +25,7 @@ from .test_tokenization_common import TokenizerTesterMixin
class CTRLTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = CTRLTokenizer
+ test_rust_tokenizer = False
def setUp(self):
super().setUp()
diff --git a/tests/test_tokenization_distilbert.py b/tests/test_tokenization_distilbert.py
index bee28425c7..b076e2c779 100644
--- a/tests/test_tokenization_distilbert.py
+++ b/tests/test_tokenization_distilbert.py
@@ -23,9 +23,8 @@ from .test_tokenization_bert import BertTokenizationTest
class DistilBertTokenizationTest(BertTokenizationTest):
tokenizer_class = DistilBertTokenizer
-
- def get_rust_tokenizer(self, **kwargs):
- return DistilBertTokenizerFast.from_pretrained(self.tmpdirname, **kwargs)
+ rust_tokenizer_class = DistilBertTokenizerFast
+ test_rust_tokenizer = True
@slow
def test_sequence_builders(self):
diff --git a/tests/test_tokenization_dpr.py b/tests/test_tokenization_dpr.py
index 2043d4e9f9..d9ec74ec5d 100644
--- a/tests/test_tokenization_dpr.py
+++ b/tests/test_tokenization_dpr.py
@@ -32,25 +32,22 @@ from .test_tokenization_bert import BertTokenizationTest
class DPRContextEncoderTokenizationTest(BertTokenizationTest):
tokenizer_class = DPRContextEncoderTokenizer
-
- def get_rust_tokenizer(self, **kwargs):
- return DPRContextEncoderTokenizerFast.from_pretrained(self.tmpdirname, **kwargs)
+ rust_tokenizer_class = DPRContextEncoderTokenizerFast
+ test_rust_tokenizer = True
class DPRQuestionEncoderTokenizationTest(BertTokenizationTest):
tokenizer_class = DPRQuestionEncoderTokenizer
-
- def get_rust_tokenizer(self, **kwargs):
- return DPRQuestionEncoderTokenizerFast.from_pretrained(self.tmpdirname, **kwargs)
+ rust_tokenizer_class = DPRQuestionEncoderTokenizerFast
+ test_rust_tokenizer = True
class DPRReaderTokenizationTest(BertTokenizationTest):
tokenizer_class = DPRReaderTokenizer
-
- def get_rust_tokenizer(self, **kwargs):
- return DPRReaderTokenizerFast.from_pretrained(self.tmpdirname, **kwargs)
+ rust_tokenizer_class = DPRReaderTokenizerFast
+ test_rust_tokenizer = True
@slow
def test_decode_best_spans(self):
diff --git a/tests/test_tokenization_fast.py b/tests/test_tokenization_fast.py
index 1184a6c32f..818b357b01 100644
--- a/tests/test_tokenization_fast.py
+++ b/tests/test_tokenization_fast.py
@@ -1,23 +1,52 @@
import logging
+import shutil
+import tempfile
import unittest
from collections import namedtuple
from itertools import takewhile
from transformers import (
+ AlbertTokenizer,
+ AlbertTokenizerFast,
+ BartTokenizer,
+ BartTokenizerFast,
BertTokenizer,
BertTokenizerFast,
+ CamembertTokenizer,
+ CamembertTokenizerFast,
DistilBertTokenizer,
+ DistilBertTokenizerFast,
+ DPRContextEncoderTokenizer,
+ DPRContextEncoderTokenizerFast,
+ DPRQuestionEncoderTokenizer,
+ DPRQuestionEncoderTokenizerFast,
+ DPRReaderTokenizer,
+ DPRReaderTokenizerFast,
+ FunnelTokenizer,
+ FunnelTokenizerFast,
GPT2Tokenizer,
GPT2TokenizerFast,
+ LxmertTokenizer,
+ LxmertTokenizerFast,
+ MBartTokenizer,
+ MBartTokenizerFast,
OpenAIGPTTokenizer,
- PreTrainedTokenizer,
+ OpenAIGPTTokenizerFast,
+ PegasusTokenizer,
+ PegasusTokenizerFast,
+ ReformerTokenizer,
+ ReformerTokenizerFast,
RobertaTokenizer,
+ RobertaTokenizerFast,
+ T5Tokenizer,
+ T5TokenizerFast,
+ XLMRobertaTokenizer,
+ XLMRobertaTokenizerFast,
+ XLNetTokenizer,
+ XLNetTokenizerFast,
is_torch_available,
)
from transformers.testing_utils import get_tests_dir
-from transformers.tokenization_distilbert import DistilBertTokenizerFast
-from transformers.tokenization_openai import OpenAIGPTTokenizerFast
-from transformers.tokenization_roberta import RobertaTokenizerFast
logger = logging.getLogger(__name__)
@@ -40,245 +69,261 @@ class CommonFastTokenizerTest(unittest.TestCase):
TOKENIZERS_CLASSES = frozenset([])
def setUp(self) -> None:
+ # Tokenizer.filter makes it possible to filter which Tokenizer to case based on all the
+ # information available in Tokenizer (name, rust class, python class, vocab key name)
+ self.tokenizers_list = [
+ (tok_case, pretrained_name, dict(t for t in tok_case.kwargs) if tok_case.kwargs else {})
+ for tok_case in self.TOKENIZERS_CLASSES
+ for pretrained_name in tok_case.python_cls.pretrained_vocab_files_map[tok_case.vocab_key].keys()
+ if tok_case.filter is None or (tok_case.filter is not None and tok_case.filter(tok_case, pretrained_name))
+ ]
with open(f"{get_tests_dir()}/fixtures/sample_text.txt", encoding="utf-8") as f_data:
self._data = f_data.read().replace("\n\n", "\n").strip()
- def test_all_tokenizers(self):
- for tok_case in self.TOKENIZERS_CLASSES:
- for pretrained_name in tok_case.python_cls.pretrained_vocab_files_map[tok_case.vocab_key].keys():
+ self.tmpdirname = tempfile.mkdtemp()
- # Tokenizer.filter makes it possible to filter which Tokenizer to case based on all the
- # information available in Tokenizer (name, rust class, python class, vocab key name)
- if tok_case.filter is None or (
- tok_case.filter is not None and tok_case.filter(tok_case, pretrained_name)
- ):
- kwargs = dict(t for t in tok_case.kwargs) if tok_case.kwargs else {}
- with self.subTest("{} ({})".format(tok_case.name, pretrained_name)):
- tokenizer_r = tok_case.rust_cls.from_pretrained(pretrained_name, **kwargs)
- tokenizer_p = tok_case.python_cls.from_pretrained(pretrained_name, **kwargs)
+ def tearDown(self):
+ shutil.rmtree(self.tmpdirname)
- self.fast_align_python(tokenizer_r, tokenizer_p, tok_case, pretrained_name)
- self.fast_only(tokenizer_r)
+ def test_is_fast(self):
+ for tok_case, pretrained_name, kwargs in self.tokenizers_list:
+ with self.subTest("{} ({})".format(tok_case.name, pretrained_name)):
+ tokenizer_r = tok_case.rust_cls.from_pretrained(pretrained_name, **kwargs)
+ tokenizer_p = tok_case.python_cls.from_pretrained(pretrained_name, **kwargs)
- def test_pretokenized_tokenizers(self):
- for tok_case in self.TOKENIZERS_CLASSES:
- for pretrained_name in tok_case.python_cls.pretrained_vocab_files_map[tok_case.vocab_key].keys():
+ # Check is_fast is set correctly
+ self.assertFalse(tokenizer_p.is_fast)
+ self.assertTrue(tokenizer_r.is_fast)
- # Tokenizer.filter makes it possible to filter which Tokenizer to case based on all the
- # information available in Tokenizer (name, rust class, python class, vocab key name)
- if tok_case.filter is None or (
- tok_case.filter is not None and tok_case.filter(tok_case, pretrained_name)
- ):
- with self.subTest("{} ({})".format(tok_case.name, pretrained_name)):
- tokenizer_r = tok_case.rust_cls.from_pretrained(pretrained_name, add_prefix_space=True)
- tokenizer_p = tok_case.python_cls.from_pretrained(pretrained_name, add_prefix_space=True)
+ def test_fast_only_inputs(self):
+ for tok_case, pretrained_name, kwargs in self.tokenizers_list:
+ with self.subTest("{} ({})".format(tok_case.name, pretrained_name)):
+ tokenizer_r = tok_case.rust_cls.from_pretrained(pretrained_name, **kwargs)
- self.assert_pretokenized_inputs(tokenizer_r, tokenizer_p)
+ # Ensure None raise an error
+ self.assertRaises(TypeError, tokenizer_r.tokenize, None)
+ self.assertRaises(TypeError, tokenizer_r.encode, None)
+ self.assertRaises(TypeError, tokenizer_r.encode_plus, None)
+ self.assertRaises(TypeError, tokenizer_r.batch_encode_plus, None)
- def fast_align_python(self, tokenizer_r, tokenizer_p, tok_case, pretrained_name):
- # Check is_fast is set correctly
- self.assertFalse(tokenizer_p.is_fast)
- self.assertTrue(tokenizer_r.is_fast)
+ def test_alignement_methods(self):
+ for tok_case, pretrained_name, kwargs in self.tokenizers_list:
+ with self.subTest("{} ({})".format(tok_case.name, pretrained_name)):
+ tokenizer_r = tok_case.rust_cls.from_pretrained(pretrained_name, **kwargs)
- # Check that Rust and Python align
- self.assert_tokenization_python_rust_equals(tokenizer_r, tokenizer_p)
- self.assert_num_special_tokens_to_add_equal(tokenizer_r, tokenizer_p)
- self.assert_max_length_equal(tokenizer_r, tokenizer_p)
- self.assert_special_tokens_map_equal(tokenizer_r, tokenizer_p)
- self.assert_embeded_special_tokens(tokenizer_r, tokenizer_p)
- self.assert_padding(tokenizer_r, tokenizer_p)
- self.assert_create_token_type_ids(tokenizer_r, tokenizer_p)
- self.assert_prepare_for_model(tokenizer_r, tokenizer_p)
+ words = ["Wonderful", "no", "inspiration", "example", "with", "subtoken"]
+ text = " ".join(words)
+ batch_size = 3
- def fast_only(self, tokenizer_r):
- # Ensure None raise an error
- self.assertRaises(ValueError, tokenizer_r.tokenize, None)
- self.assertRaises(ValueError, tokenizer_r.encode, None)
- self.assertRaises(ValueError, tokenizer_r.encode_plus, None)
- self.assertRaises(ValueError, tokenizer_r.batch_encode_plus, None)
+ encoding = tokenizer_r.encode_plus(text, add_special_tokens=False)
- self.assert_add_tokens(tokenizer_r)
- self.assert_offsets_mapping(tokenizer_r)
- self.assert_add_special_tokens(tokenizer_r)
- self.assert_alignement_methods(tokenizer_r)
- self.assert_batch_encode_dynamic_overflowing(tokenizer_r)
+ batch_encoding = tokenizer_r.batch_encode_plus([text] * batch_size, add_special_tokens=False)
+ num_tokens = len(encoding["input_ids"])
- def assert_alignement_methods(self, tokenizer_r):
- words = ["Wonderful", "no", "inspiration", "example", "with", "subtoken"]
- text = " ".join(words)
- batch_size = 3
+ last_word_index = len(words) - 1
+ last_token_index = num_tokens - 1
+ last_batch_index = batch_size - 1
+ last_char_index = len(text) - 1
- encoding = tokenizer_r.encode_plus(text, add_special_tokens=False)
+ # words, tokens
+ self.assertEqual(len(encoding.words(0)), num_tokens)
+ self.assertEqual(max(encoding.words(0)), last_word_index)
+ self.assertEqual(min(encoding.words(0)), 0)
+ self.assertEqual(len(batch_encoding.words(last_batch_index)), num_tokens)
+ self.assertEqual(max(batch_encoding.words(last_batch_index)), last_word_index)
+ self.assertEqual(min(batch_encoding.words(last_batch_index)), 0)
+ self.assertEqual(len(encoding.tokens(0)), num_tokens)
- batch_encoding = tokenizer_r.batch_encode_plus([text] * batch_size, add_special_tokens=False)
- num_tokens = len(encoding["input_ids"])
+ # Assert token_to_word
+ self.assertEqual(encoding.token_to_word(0), 0)
+ self.assertEqual(encoding.token_to_word(0, 0), 0)
+ self.assertEqual(encoding.token_to_word(last_token_index), last_word_index)
+ self.assertEqual(encoding.token_to_word(0, last_token_index), last_word_index)
+ self.assertEqual(batch_encoding.token_to_word(1, 0), 0)
+ self.assertEqual(batch_encoding.token_to_word(0, last_token_index), last_word_index)
+ self.assertEqual(batch_encoding.token_to_word(last_batch_index, last_token_index), last_word_index)
- last_word_index = len(words) - 1
- last_token_index = num_tokens - 1
- last_batch_index = batch_size - 1
- last_char_index = len(text) - 1
+ # Assert word_to_tokens
+ self.assertEqual(encoding.word_to_tokens(0).start, 0)
+ self.assertEqual(encoding.word_to_tokens(0, 0).start, 0)
+ self.assertEqual(encoding.word_to_tokens(last_word_index).end, last_token_index + 1)
+ self.assertEqual(encoding.word_to_tokens(0, last_word_index).end, last_token_index + 1)
+ self.assertEqual(batch_encoding.word_to_tokens(1, 0).start, 0)
+ self.assertEqual(batch_encoding.word_to_tokens(0, last_word_index).end, last_token_index + 1)
+ self.assertEqual(
+ batch_encoding.word_to_tokens(last_batch_index, last_word_index).end, last_token_index + 1
+ )
- # words, tokens
- self.assertEqual(len(encoding.words(0)), num_tokens)
- self.assertEqual(max(encoding.words(0)), last_word_index)
- self.assertEqual(min(encoding.words(0)), 0)
- self.assertEqual(len(batch_encoding.words(last_batch_index)), num_tokens)
- self.assertEqual(max(batch_encoding.words(last_batch_index)), last_word_index)
- self.assertEqual(min(batch_encoding.words(last_batch_index)), 0)
- self.assertEqual(len(encoding.tokens(0)), num_tokens)
+ # Assert token_to_chars
+ self.assertEqual(encoding.token_to_chars(0).start, 0)
+ self.assertEqual(encoding.token_to_chars(0, 0).start, 0)
+ self.assertEqual(encoding.token_to_chars(last_token_index).end, last_char_index + 1)
+ self.assertEqual(encoding.token_to_chars(0, last_token_index).end, last_char_index + 1)
+ self.assertEqual(batch_encoding.token_to_chars(1, 0).start, 0)
+ self.assertEqual(batch_encoding.token_to_chars(0, last_token_index).end, last_char_index + 1)
+ self.assertEqual(
+ batch_encoding.token_to_chars(last_batch_index, last_token_index).end, last_char_index + 1
+ )
- # Assert token_to_word
- self.assertEqual(encoding.token_to_word(0), 0)
- self.assertEqual(encoding.token_to_word(0, 0), 0)
- self.assertEqual(encoding.token_to_word(last_token_index), last_word_index)
- self.assertEqual(encoding.token_to_word(0, last_token_index), last_word_index)
- self.assertEqual(batch_encoding.token_to_word(1, 0), 0)
- self.assertEqual(batch_encoding.token_to_word(0, last_token_index), last_word_index)
- self.assertEqual(batch_encoding.token_to_word(last_batch_index, last_token_index), last_word_index)
+ # Assert char_to_token
+ self.assertEqual(encoding.char_to_token(0), 0)
+ self.assertEqual(encoding.char_to_token(0, 0), 0)
+ self.assertEqual(encoding.char_to_token(last_char_index), last_token_index)
+ self.assertEqual(encoding.char_to_token(0, last_char_index), last_token_index)
+ self.assertEqual(batch_encoding.char_to_token(1, 0), 0)
+ self.assertEqual(batch_encoding.char_to_token(0, last_char_index), last_token_index)
+ self.assertEqual(batch_encoding.char_to_token(last_batch_index, last_char_index), last_token_index)
- # Assert word_to_tokens
- self.assertEqual(encoding.word_to_tokens(0).start, 0)
- self.assertEqual(encoding.word_to_tokens(0, 0).start, 0)
- self.assertEqual(encoding.word_to_tokens(last_word_index).end, last_token_index + 1)
- self.assertEqual(encoding.word_to_tokens(0, last_word_index).end, last_token_index + 1)
- self.assertEqual(batch_encoding.word_to_tokens(1, 0).start, 0)
- self.assertEqual(batch_encoding.word_to_tokens(0, last_word_index).end, last_token_index + 1)
- self.assertEqual(batch_encoding.word_to_tokens(last_batch_index, last_word_index).end, last_token_index + 1)
+ # Assert char_to_word
+ self.assertEqual(encoding.char_to_word(0), 0)
+ self.assertEqual(encoding.char_to_word(0, 0), 0)
+ self.assertEqual(encoding.char_to_word(last_char_index), last_word_index)
+ self.assertEqual(encoding.char_to_word(0, last_char_index), last_word_index)
+ self.assertEqual(batch_encoding.char_to_word(1, 0), 0)
+ self.assertEqual(batch_encoding.char_to_word(0, last_char_index), last_word_index)
+ self.assertEqual(batch_encoding.char_to_word(last_batch_index, last_char_index), last_word_index)
- # Assert token_to_chars
- self.assertEqual(encoding.token_to_chars(0).start, 0)
- self.assertEqual(encoding.token_to_chars(0, 0).start, 0)
- self.assertEqual(encoding.token_to_chars(last_token_index).end, last_char_index + 1)
- self.assertEqual(encoding.token_to_chars(0, last_token_index).end, last_char_index + 1)
- self.assertEqual(batch_encoding.token_to_chars(1, 0).start, 0)
- self.assertEqual(batch_encoding.token_to_chars(0, last_token_index).end, last_char_index + 1)
- self.assertEqual(batch_encoding.token_to_chars(last_batch_index, last_token_index).end, last_char_index + 1)
+ # Assert word_to_chars
+ self.assertEqual(encoding.word_to_chars(0).start, 0)
+ self.assertEqual(encoding.word_to_chars(0, 0).start, 0)
+ self.assertEqual(encoding.word_to_chars(last_word_index).end, last_char_index + 1)
+ self.assertEqual(encoding.word_to_chars(0, last_word_index).end, last_char_index + 1)
+ self.assertEqual(batch_encoding.word_to_chars(1, 0).start, 0)
+ self.assertEqual(batch_encoding.word_to_chars(0, last_word_index).end, last_char_index + 1)
+ self.assertEqual(
+ batch_encoding.word_to_chars(last_batch_index, last_word_index).end, last_char_index + 1
+ )
- # Assert char_to_token
- self.assertEqual(encoding.char_to_token(0), 0)
- self.assertEqual(encoding.char_to_token(0, 0), 0)
- self.assertEqual(encoding.char_to_token(last_char_index), last_token_index)
- self.assertEqual(encoding.char_to_token(0, last_char_index), last_token_index)
- self.assertEqual(batch_encoding.char_to_token(1, 0), 0)
- self.assertEqual(batch_encoding.char_to_token(0, last_char_index), last_token_index)
- self.assertEqual(batch_encoding.char_to_token(last_batch_index, last_char_index), last_token_index)
+ def test_tokenization_python_rust_equals(self):
+ for tok_case, pretrained_name, kwargs in self.tokenizers_list:
+ with self.subTest("{} ({})".format(tok_case.name, pretrained_name)):
+ tokenizer_r = tok_case.rust_cls.from_pretrained(pretrained_name, **kwargs)
+ tokenizer_p = tok_case.python_cls.from_pretrained(pretrained_name, **kwargs)
- # Assert char_to_word
- self.assertEqual(encoding.char_to_word(0), 0)
- self.assertEqual(encoding.char_to_word(0, 0), 0)
- self.assertEqual(encoding.char_to_word(last_char_index), last_word_index)
- self.assertEqual(encoding.char_to_word(0, last_char_index), last_word_index)
- self.assertEqual(batch_encoding.char_to_word(1, 0), 0)
- self.assertEqual(batch_encoding.char_to_word(0, last_char_index), last_word_index)
- self.assertEqual(batch_encoding.char_to_word(last_batch_index, last_char_index), last_word_index)
+ # Ensure basic input match
+ input_p = tokenizer_p.encode_plus(self._data)
+ input_r = tokenizer_r.encode_plus(self._data)
- # Assert word_to_chars
- self.assertEqual(encoding.word_to_chars(0).start, 0)
- self.assertEqual(encoding.word_to_chars(0, 0).start, 0)
- self.assertEqual(encoding.word_to_chars(last_word_index).end, last_char_index + 1)
- self.assertEqual(encoding.word_to_chars(0, last_word_index).end, last_char_index + 1)
- self.assertEqual(batch_encoding.word_to_chars(1, 0).start, 0)
- self.assertEqual(batch_encoding.word_to_chars(0, last_word_index).end, last_char_index + 1)
- self.assertEqual(batch_encoding.word_to_chars(last_batch_index, last_word_index).end, last_char_index + 1)
+ for key in filter(lambda x: x in ["input_ids", "token_type_ids", "attention_mask"], input_p.keys()):
+ self.assertSequenceEqual(input_p[key], input_r[key])
- def assert_tokenization_python_rust_equals(self, tokenizer_r, tokenizer_p):
- # Ensure basic input match
- input_p = tokenizer_p.encode_plus(self._data)
- input_r = tokenizer_r.encode_plus(self._data)
+ input_pairs_p = tokenizer_p.encode_plus(self._data, self._data)
+ input_pairs_r = tokenizer_r.encode_plus(self._data, self._data)
- for key in filter(lambda x: x in ["input_ids", "token_type_ids", "attention_mask"], input_p.keys()):
- self.assertSequenceEqual(input_p[key], input_r[key])
+ for key in filter(lambda x: x in ["input_ids", "token_type_ids", "attention_mask"], input_p.keys()):
+ self.assertSequenceEqual(input_pairs_p[key], input_pairs_r[key])
- input_pairs_p = tokenizer_p.encode_plus(self._data, self._data)
- input_pairs_r = tokenizer_r.encode_plus(self._data, self._data)
+ # Ensure truncation match
+ input_p = tokenizer_p.encode_plus(self._data, max_length=512, truncation=True)
+ input_r = tokenizer_r.encode_plus(self._data, max_length=512, truncation=True)
- for key in filter(lambda x: x in ["input_ids", "token_type_ids", "attention_mask"], input_p.keys()):
- self.assertSequenceEqual(input_pairs_p[key], input_pairs_r[key])
+ for key in filter(lambda x: x in ["input_ids", "token_type_ids", "attention_mask"], input_p.keys()):
+ self.assertSequenceEqual(input_p[key], input_r[key])
- # Ensure truncation match
- input_p = tokenizer_p.encode_plus(self._data, max_length=512, truncation=True)
- input_r = tokenizer_r.encode_plus(self._data, max_length=512, truncation=True)
+ # Ensure truncation with stride match
+ input_p = tokenizer_p.encode_plus(
+ self._data, max_length=512, truncation=True, stride=3, return_overflowing_tokens=True
+ )
+ input_r = tokenizer_r.encode_plus(
+ self._data, max_length=512, truncation=True, stride=3, return_overflowing_tokens=True
+ )
- for key in filter(lambda x: x in ["input_ids", "token_type_ids", "attention_mask"], input_p.keys()):
- self.assertSequenceEqual(input_p[key], input_r[key])
+ for key in filter(lambda x: x in ["input_ids", "token_type_ids", "attention_mask"], input_p.keys()):
+ self.assertSequenceEqual(input_p[key], input_r[key][0])
- # Ensure truncation with stride match
- input_p = tokenizer_p.encode_plus(
- self._data, max_length=512, truncation=True, stride=3, return_overflowing_tokens=True
- )
- input_r = tokenizer_r.encode_plus(
- self._data, max_length=512, truncation=True, stride=3, return_overflowing_tokens=True
- )
+ def test_num_special_tokens_to_add_equal(self):
+ for tok_case, pretrained_name, kwargs in self.tokenizers_list:
+ with self.subTest("{} ({})".format(tok_case.name, pretrained_name)):
+ tokenizer_r = tok_case.rust_cls.from_pretrained(pretrained_name, **kwargs)
+ tokenizer_p = tok_case.python_cls.from_pretrained(pretrained_name, **kwargs)
- for key in filter(lambda x: x in ["input_ids", "token_type_ids", "attention_mask"], input_p.keys()):
- self.assertSequenceEqual(input_p[key], input_r[key][0])
+ # Check we have the same number of added_tokens for both pair and non-pair inputs.
+ self.assertEqual(
+ tokenizer_r.num_special_tokens_to_add(False), tokenizer_p.num_special_tokens_to_add(False)
+ )
+ self.assertEqual(
+ tokenizer_r.num_special_tokens_to_add(True), tokenizer_p.num_special_tokens_to_add(True)
+ )
- def assert_num_special_tokens_to_add_equal(self, tokenizer_r, tokenizer_p):
- # Check we have the same number of added_tokens for both pair and non-pair inputs.
- self.assertEqual(tokenizer_r.num_special_tokens_to_add(False), tokenizer_p.num_special_tokens_to_add(False))
- self.assertEqual(tokenizer_r.num_special_tokens_to_add(True), tokenizer_p.num_special_tokens_to_add(True))
+ def test_max_length_equal(self):
+ for tok_case, pretrained_name, kwargs in self.tokenizers_list:
+ with self.subTest("{} ({})".format(tok_case.name, pretrained_name)):
+ tokenizer_r = tok_case.rust_cls.from_pretrained(pretrained_name, **kwargs)
+ tokenizer_p = tok_case.python_cls.from_pretrained(pretrained_name, **kwargs)
- def assert_max_length_equal(self, tokenizer_r, tokenizer_p):
- # Check we have the correct max_length for both pair and non-pair inputs.
- self.assertEqual(tokenizer_r.max_len_single_sentence, tokenizer_p.max_len_single_sentence)
- self.assertEqual(tokenizer_r.max_len_sentences_pair, tokenizer_p.max_len_sentences_pair)
+ # Check we have the correct max_length for both pair and non-pair inputs.
+ self.assertEqual(tokenizer_r.max_len_single_sentence, tokenizer_p.max_len_single_sentence)
+ self.assertEqual(tokenizer_r.max_len_sentences_pair, tokenizer_p.max_len_sentences_pair)
- def assert_special_tokens_map_equal(self, tokenizer_r, tokenizer_p):
- # Assert the set of special tokens match.
- self.assertSequenceEqual(
- tokenizer_p.special_tokens_map.items(),
- tokenizer_r.special_tokens_map.items(),
- )
+ def test_special_tokens_map_equal(self):
+ for tok_case, pretrained_name, kwargs in self.tokenizers_list:
+ with self.subTest("{} ({})".format(tok_case.name, pretrained_name)):
+ tokenizer_r = tok_case.rust_cls.from_pretrained(pretrained_name, **kwargs)
+ tokenizer_p = tok_case.python_cls.from_pretrained(pretrained_name, **kwargs)
- def assert_add_tokens(self, tokenizer_r):
- vocab_size = tokenizer_r.vocab_size
- self.assertEqual(tokenizer_r.add_tokens(""), 0)
- self.assertEqual(tokenizer_r.add_tokens("testoken"), 1)
- self.assertEqual(tokenizer_r.add_tokens(["testoken1", "testtoken2"]), 2)
- self.assertEqual(len(tokenizer_r), vocab_size + 3)
+ # Assert the set of special tokens match.
+ self.assertSequenceEqual(
+ tokenizer_p.special_tokens_map.items(),
+ tokenizer_r.special_tokens_map.items(),
+ )
- self.assertEqual(tokenizer_r.add_special_tokens({}), 0)
- self.assertEqual(tokenizer_r.add_special_tokens({"bos_token": "[BOS]", "eos_token": "[EOS]"}), 2)
- self.assertRaises(
- AssertionError, tokenizer_r.add_special_tokens, {"additional_special_tokens": ""}
- )
- self.assertEqual(tokenizer_r.add_special_tokens({"additional_special_tokens": [""]}), 1)
- self.assertEqual(
- tokenizer_r.add_special_tokens({"additional_special_tokens": ["", ""]}), 2
- )
- self.assertEqual(len(tokenizer_r), vocab_size + 8)
+ def test_add_tokens(self):
+ for tok_case, pretrained_name, kwargs in self.tokenizers_list:
+ with self.subTest("{} ({})".format(tok_case.name, pretrained_name)):
+ tokenizer_r = tok_case.rust_cls.from_pretrained(pretrained_name, **kwargs)
- def assert_offsets_mapping(self, tokenizer_r):
- text = "Wonderful no inspiration example with subtoken"
- pair = "Along with an awesome pair"
+ vocab_size = len(tokenizer_r)
+ self.assertEqual(tokenizer_r.add_tokens(""), 0)
+ self.assertEqual(tokenizer_r.add_tokens("testoken"), 1)
+ self.assertEqual(tokenizer_r.add_tokens(["testoken1", "testtoken2"]), 2)
+ self.assertEqual(len(tokenizer_r), vocab_size + 3)
- # No pair
- tokens_with_offsets = tokenizer_r.encode_plus(
- text, return_special_tokens_mask=True, return_offsets_mapping=True, add_special_tokens=True
- )
- added_tokens = tokenizer_r.num_special_tokens_to_add(False)
- offsets = tokens_with_offsets["offset_mapping"]
+ self.assertEqual(tokenizer_r.add_special_tokens({}), 0)
+ self.assertEqual(tokenizer_r.add_special_tokens({"bos_token": "[BOS]", "eos_token": "[EOS]"}), 2)
+ self.assertRaises(
+ AssertionError, tokenizer_r.add_special_tokens, {"additional_special_tokens": ""}
+ )
+ self.assertEqual(tokenizer_r.add_special_tokens({"additional_special_tokens": [""]}), 1)
+ self.assertEqual(
+ tokenizer_r.add_special_tokens({"additional_special_tokens": ["", ""]}), 2
+ )
+ self.assertEqual(len(tokenizer_r), vocab_size + 8)
- # Assert there is the same number of tokens and offsets
- self.assertEqual(len(offsets), len(tokens_with_offsets["input_ids"]))
+ def test_offsets_mapping(self):
+ for tok_case, pretrained_name, kwargs in self.tokenizers_list:
+ with self.subTest("{} ({})".format(tok_case.name, pretrained_name)):
+ tokenizer_r = tok_case.rust_cls.from_pretrained(pretrained_name, **kwargs)
- # Assert there is online added_tokens special_tokens
- self.assertEqual(sum(tokens_with_offsets["special_tokens_mask"]), added_tokens)
+ text = "Wonderful no inspiration example with subtoken"
+ pair = "Along with an awesome pair"
- # Pairs
- tokens_with_offsets = tokenizer_r.encode_plus(
- text, pair, return_special_tokens_mask=True, return_offsets_mapping=True, add_special_tokens=True
- )
- added_tokens = tokenizer_r.num_special_tokens_to_add(True)
- offsets = tokens_with_offsets["offset_mapping"]
+ # No pair
+ tokens_with_offsets = tokenizer_r.encode_plus(
+ text, return_special_tokens_mask=True, return_offsets_mapping=True, add_special_tokens=True
+ )
+ added_tokens = tokenizer_r.num_special_tokens_to_add(False)
+ offsets = tokens_with_offsets["offset_mapping"]
- # Assert there is the same number of tokens and offsets
- self.assertEqual(len(offsets), len(tokens_with_offsets["input_ids"]))
+ # Assert there is the same number of tokens and offsets
+ self.assertEqual(len(offsets), len(tokens_with_offsets["input_ids"]))
- # Assert there is online added_tokens special_tokens
- self.assertEqual(sum(tokens_with_offsets["special_tokens_mask"]), added_tokens)
+ # Assert there is online added_tokens special_tokens
+ self.assertEqual(sum(tokens_with_offsets["special_tokens_mask"]), added_tokens)
- def assert_batch_encode_dynamic_overflowing(self, tokenizer: PreTrainedTokenizer):
+ # Pairs
+ tokens_with_offsets = tokenizer_r.encode_plus(
+ text, pair, return_special_tokens_mask=True, return_offsets_mapping=True, add_special_tokens=True
+ )
+ added_tokens = tokenizer_r.num_special_tokens_to_add(True)
+ offsets = tokens_with_offsets["offset_mapping"]
+
+ # Assert there is the same number of tokens and offsets
+ self.assertEqual(len(offsets), len(tokens_with_offsets["input_ids"]))
+
+ # Assert there is online added_tokens special_tokens
+ self.assertEqual(sum(tokens_with_offsets["special_tokens_mask"]), added_tokens)
+
+ def test_batch_encode_dynamic_overflowing(self):
"""
When calling batch_encode with multiple sequence it can returns different number of
overflowing encoding for each sequence:
@@ -289,437 +334,515 @@ class CommonFastTokenizerTest(unittest.TestCase):
]
This needs to be padded so that it can represented as a tensor
"""
- returned_tensor = "pt" if is_torch_available() else "tf"
+ for tok_case, pretrained_name, kwargs in self.tokenizers_list:
+ tokenizer = tok_case.rust_cls.from_pretrained(pretrained_name, **kwargs)
- if not tokenizer.pad_token or tokenizer.pad_token_id < 0:
- return
+ with self.subTest("{} ({}, {})".format(tok_case.name, pretrained_name, tokenizer.__class__.__name__)):
- tokens = tokenizer.encode_plus(
- "HuggingFace is solving NLP one commit at a time",
- max_length=6,
- padding=True,
- truncation=True,
- return_tensors=returned_tensor,
- return_overflowing_tokens=True,
- )
+ returned_tensor = "pt" if is_torch_available() else "tf"
- for key in filter(lambda x: "overflow_to_sample_mapping" not in x, tokens.keys()):
- self.assertEqual(len(tokens[key].shape), 2)
+ if not tokenizer.pad_token or tokenizer.pad_token_id < 0:
+ return
- # Mono sample
- tokens = tokenizer.batch_encode_plus(
- ["HuggingFace is solving NLP one commit at a time"],
- max_length=6,
- padding=True,
- truncation="only_first",
- return_tensors=returned_tensor,
- return_overflowing_tokens=True,
- )
-
- for key in filter(lambda x: "overflow_to_sample_mapping" not in x, tokens.keys()):
- self.assertEqual(len(tokens[key].shape), 2)
- self.assertEqual(tokens[key].shape[-1], 6)
-
- # Multi sample
- tokens = tokenizer.batch_encode_plus(
- ["HuggingFace is solving NLP one commit at a time", "Very tiny input"],
- max_length=6,
- padding=True,
- truncation="only_first",
- return_tensors=returned_tensor,
- return_overflowing_tokens=True,
- )
-
- for key in filter(lambda x: "overflow_to_sample_mapping" not in x, tokens.keys()):
- self.assertEqual(len(tokens[key].shape), 2)
- self.assertEqual(tokens[key].shape[-1], 6)
-
- def assert_pretokenized_inputs(self, tokenizer_r, tokenizer_p):
- # Input string
- pretokenized_input_simple = "This is a sample input".split()
- pretokenized_input_pair = "This is a sample pair".split()
-
- # Test encode for pretokenized inputs
- output_r = tokenizer_r.encode(pretokenized_input_simple, is_split_into_words=True)
- output_p = tokenizer_p.encode(pretokenized_input_simple, is_split_into_words=True)
- self.assertEqual(output_p, output_r)
-
- kwargs = {
- "is_split_into_words": True,
- "return_token_type_ids": True,
- "return_attention_mask": True,
- "return_overflowing_tokens": False,
- "return_special_tokens_mask": True,
- "return_offsets_mapping": False, # Not implemented in python tokenizers
- }
- batch_kwargs = {
- "is_split_into_words": True,
- "return_token_type_ids": True,
- "return_attention_mask": True, # we have an 's' here
- "return_overflowing_tokens": False,
- "return_special_tokens_mask": True, # we have an 's' here
- "return_offsets_mapping": False, # Not implemented in python tokenizers
- }
- # Test encode_plus for pretokenized inputs
- output_r = tokenizer_r.encode_plus(pretokenized_input_simple, **kwargs)
- output_p = tokenizer_p.encode_plus(pretokenized_input_simple, **kwargs)
- for key in output_p.keys():
- self.assertEqual(output_p[key], output_r[key])
-
- # Test batch_encode_plus for pretokenized inputs
- input_batch = ([pretokenized_input_simple] * 2) + [pretokenized_input_simple + pretokenized_input_pair]
- output_r = tokenizer_r.batch_encode_plus(input_batch, **batch_kwargs)
- output_p = tokenizer_p.batch_encode_plus(input_batch, **batch_kwargs)
- for key in output_p.keys():
- self.assertEqual(output_p[key], output_r[key])
-
- # Test encode for pretokenized inputs pairs
- output_r = tokenizer_r.encode(pretokenized_input_simple, pretokenized_input_pair, is_split_into_words=True)
- output_p = tokenizer_p.encode(pretokenized_input_simple, pretokenized_input_pair, is_split_into_words=True)
- self.assertEqual(output_p, output_r)
-
- # Test encode_plus for pretokenized inputs
- output_r = tokenizer_r.encode_plus(pretokenized_input_simple, pretokenized_input_pair, **kwargs)
- output_p = tokenizer_p.encode_plus(pretokenized_input_simple, pretokenized_input_pair, **kwargs)
- for key in output_p.keys():
- self.assertEqual(output_p[key], output_r[key])
-
- # Test batch_encode_plus for pretokenized inputs
- input_batch_pair = ([pretokenized_input_simple, pretokenized_input_pair] * 2) + [
- pretokenized_input_simple + pretokenized_input_pair,
- pretokenized_input_pair,
- ]
- output_r = tokenizer_r.batch_encode_plus(input_batch_pair, **batch_kwargs)
- output_p = tokenizer_p.batch_encode_plus(input_batch_pair, **batch_kwargs)
- for key in output_p.keys():
- self.assertEqual(output_p[key], output_r[key])
-
- def assert_create_token_type_ids(self, tokenizer_r, tokenizer_p):
- input_simple = [1, 2, 3]
- input_pair = [1, 2, 3]
-
- # Generate output
- output_r = tokenizer_r.create_token_type_ids_from_sequences(input_simple)
- output_p = tokenizer_p.create_token_type_ids_from_sequences(input_simple)
- self.assertEqual(output_p, output_r)
-
- # Generate pair output
- output_r = tokenizer_r.create_token_type_ids_from_sequences(input_simple, input_pair)
- output_p = tokenizer_p.create_token_type_ids_from_sequences(input_simple, input_pair)
- self.assertEqual(output_p, output_r)
-
- def assert_build_inputs_with_special_tokens(self, tokenizer_r, tokenizer_p):
- # Input string
- input_simple = tokenizer_p.tokenize("This is a sample input")
- input_pair = tokenizer_p.tokenize("This is a sample pair")
-
- # Generate output
- output_r = tokenizer_r.build_inputs_with_special_tokens(input_simple)
- output_p = tokenizer_p.build_inputs_with_special_tokens(input_simple)
- self.assertEqual(output_p, output_r)
-
- # Generate pair output
- output_r = tokenizer_r.build_inputs_with_special_tokens(input_simple, input_pair)
- output_p = tokenizer_p.build_inputs_with_special_tokens(input_simple, input_pair)
- self.assertEqual(output_p, output_r)
-
- # Input tokens id
- input_simple = tokenizer_p.encode("This is a sample input")
- input_pair = tokenizer_p.encode("This is a sample pair")
-
- # Generate output
- output_r = tokenizer_r.build_inputs_with_special_tokens(input_simple)
- output_p = tokenizer_p.build_inputs_with_special_tokens(input_simple)
- self.assertEqual(output_p, output_r)
-
- # Generate pair output
- output_r = tokenizer_r.build_inputs_with_special_tokens(input_simple, input_pair)
- output_p = tokenizer_p.build_inputs_with_special_tokens(input_simple, input_pair)
- self.assertEqual(output_p, output_r)
-
- def assert_padding(self, tokenizer_r, tokenizer_p, max_length=15):
- def assert_padded_input_match(input_r: list, input_p: list, max_length: int):
-
- # Ensure we match max_length
- self.assertEqual(len(input_r), max_length)
- self.assertEqual(len(input_p), max_length)
-
- # Ensure the number of padded tokens is the same
- padded_tokens_r = list(takewhile(lambda i: i == tokenizer_r.pad_token_id, reversed(input_r)))
- padded_tokens_p = list(takewhile(lambda i: i == tokenizer_p.pad_token_id, reversed(input_p)))
- self.assertSequenceEqual(padded_tokens_r, padded_tokens_p)
-
- def assert_batch_padded_input_match(input_r: dict, input_p: dict, max_length: int):
- for i_r in input_r.values():
- self.assertEqual(len(i_r), 2), self.assertEqual(len(i_r[0]), max_length), self.assertEqual(
- len(i_r[1]), max_length
- )
- self.assertEqual(len(i_r), 2), self.assertEqual(len(i_r[0]), max_length), self.assertEqual(
- len(i_r[1]), max_length
+ tokens = tokenizer.encode_plus(
+ "HuggingFace is solving NLP one commit at a time",
+ max_length=6,
+ padding=True,
+ truncation=True,
+ return_tensors=returned_tensor,
+ return_overflowing_tokens=True,
)
- for i_r, i_p in zip(input_r["input_ids"], input_p["input_ids"]):
- assert_padded_input_match(i_r, i_p, max_length)
+ for key in filter(lambda x: "overflow_to_sample_mapping" not in x, tokens.keys()):
+ self.assertEqual(len(tokens[key].shape), 2)
- for i_r, i_p in zip(input_r["attention_mask"], input_p["attention_mask"]):
- self.assertSequenceEqual(i_r, i_p)
-
- # Encode - Simple input
- input_r = tokenizer_r.encode("This is a simple input", max_length=max_length, pad_to_max_length=True)
- input_p = tokenizer_p.encode("This is a simple input", max_length=max_length, pad_to_max_length=True)
- assert_padded_input_match(input_r, input_p, max_length)
- input_r = tokenizer_r.encode("This is a simple input", max_length=max_length, padding="max_length")
- input_p = tokenizer_p.encode("This is a simple input", max_length=max_length, padding="max_length")
- assert_padded_input_match(input_r, input_p, max_length)
-
- input_r = tokenizer_r.encode("This is a simple input", padding="longest")
- input_p = tokenizer_p.encode("This is a simple input", padding=True)
- assert_padded_input_match(input_r, input_p, len(input_r))
-
- # Encode - Pair input
- input_r = tokenizer_r.encode(
- "This is a simple input", "This is a pair", max_length=max_length, pad_to_max_length=True
- )
- input_p = tokenizer_p.encode(
- "This is a simple input", "This is a pair", max_length=max_length, pad_to_max_length=True
- )
- assert_padded_input_match(input_r, input_p, max_length)
- input_r = tokenizer_r.encode(
- "This is a simple input", "This is a pair", max_length=max_length, padding="max_length"
- )
- input_p = tokenizer_p.encode(
- "This is a simple input", "This is a pair", max_length=max_length, padding="max_length"
- )
- assert_padded_input_match(input_r, input_p, max_length)
- input_r = tokenizer_r.encode("This is a simple input", "This is a pair", padding=True)
- input_p = tokenizer_p.encode("This is a simple input", "This is a pair", padding="longest")
- assert_padded_input_match(input_r, input_p, len(input_r))
-
- # Encode_plus - Simple input
- input_r = tokenizer_r.encode_plus("This is a simple input", max_length=max_length, pad_to_max_length=True)
- input_p = tokenizer_p.encode_plus("This is a simple input", max_length=max_length, pad_to_max_length=True)
- assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length)
- self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
- input_r = tokenizer_r.encode_plus("This is a simple input", max_length=max_length, padding="max_length")
- input_p = tokenizer_p.encode_plus("This is a simple input", max_length=max_length, padding="max_length")
- assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length)
- self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
-
- input_r = tokenizer_r.encode_plus("This is a simple input", padding="longest")
- input_p = tokenizer_p.encode_plus("This is a simple input", padding=True)
- assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], len(input_r["input_ids"]))
-
- self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
-
- # Encode_plus - Pair input
- input_r = tokenizer_r.encode_plus(
- "This is a simple input", "This is a pair", max_length=max_length, pad_to_max_length=True
- )
- input_p = tokenizer_p.encode_plus(
- "This is a simple input", "This is a pair", max_length=max_length, pad_to_max_length=True
- )
- assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length)
- self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
- input_r = tokenizer_r.encode_plus(
- "This is a simple input", "This is a pair", max_length=max_length, padding="max_length"
- )
- input_p = tokenizer_p.encode_plus(
- "This is a simple input", "This is a pair", max_length=max_length, padding="max_length"
- )
- assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length)
- self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
- input_r = tokenizer_r.encode_plus("This is a simple input", "This is a pair", padding="longest")
- input_p = tokenizer_p.encode_plus("This is a simple input", "This is a pair", padding=True)
- assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], len(input_r["input_ids"]))
- self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
-
- # Batch_encode_plus - Simple input
- input_r = tokenizer_r.batch_encode_plus(
- ["This is a simple input 1", "This is a simple input 2"], max_length=max_length, pad_to_max_length=True
- )
- input_p = tokenizer_p.batch_encode_plus(
- ["This is a simple input 1", "This is a simple input 2"], max_length=max_length, pad_to_max_length=True
- )
- assert_batch_padded_input_match(input_r, input_p, max_length)
-
- input_r = tokenizer_r.batch_encode_plus(
- ["This is a simple input 1", "This is a simple input 2"],
- max_length=max_length,
- padding="max_length",
- )
- input_p = tokenizer_p.batch_encode_plus(
- ["This is a simple input 1", "This is a simple input 2"],
- max_length=max_length,
- padding="max_length",
- )
- assert_batch_padded_input_match(input_r, input_p, max_length)
-
- input_r = tokenizer_r.batch_encode_plus(
- ["This is a simple input 1", "This is a simple input 2"],
- max_length=max_length,
- padding="longest",
- )
- input_p = tokenizer_p.batch_encode_plus(
- ["This is a simple input 1", "This is a simple input 2"],
- max_length=max_length,
- padding=True,
- )
- assert_batch_padded_input_match(input_r, input_p, len(input_r["input_ids"][0]))
-
- input_r = tokenizer_r.batch_encode_plus(
- ["This is a simple input 1", "This is a simple input 2"], padding="longest"
- )
- input_p = tokenizer_p.batch_encode_plus(["This is a simple input 1", "This is a simple input 2"], padding=True)
- assert_batch_padded_input_match(input_r, input_p, len(input_r["input_ids"][0]))
-
- # Batch_encode_plus - Pair input
- input_r = tokenizer_r.batch_encode_plus(
- [
- ("This is a simple input 1", "This is a simple input 2"),
- ("This is a simple pair 1", "This is a simple pair 2"),
- ],
- max_length=max_length,
- truncation=True,
- padding="max_length",
- )
- input_p = tokenizer_p.batch_encode_plus(
- [
- ("This is a simple input 1", "This is a simple input 2"),
- ("This is a simple pair 1", "This is a simple pair 2"),
- ],
- max_length=max_length,
- truncation=True,
- padding="max_length",
- )
- assert_batch_padded_input_match(input_r, input_p, max_length)
-
- input_r = tokenizer_r.batch_encode_plus(
- [
- ("This is a simple input 1", "This is a simple input 2"),
- ("This is a simple pair 1", "This is a simple pair 2"),
- ],
- padding=True,
- )
- input_p = tokenizer_p.batch_encode_plus(
- [
- ("This is a simple input 1", "This is a simple input 2"),
- ("This is a simple pair 1", "This is a simple pair 2"),
- ],
- padding="longest",
- )
- assert_batch_padded_input_match(input_r, input_p, len(input_r["input_ids"][0]))
-
- # Using pad on single examples after tokenization
- input_r = tokenizer_r.encode_plus("This is a input 1")
- input_r = tokenizer_r.pad(input_r)
-
- input_p = tokenizer_r.encode_plus("This is a input 1")
- input_p = tokenizer_r.pad(input_p)
-
- assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], len(input_r["input_ids"]))
-
- # Using pad on single examples after tokenization
- input_r = tokenizer_r.encode_plus("This is a input 1")
- input_r = tokenizer_r.pad(input_r, max_length=max_length, padding="max_length")
-
- input_p = tokenizer_r.encode_plus("This is a input 1")
- input_p = tokenizer_r.pad(input_p, max_length=max_length, padding="max_length")
-
- assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length)
-
- # Using pad after tokenization
- input_r = tokenizer_r.batch_encode_plus(
- ["This is a input 1", "This is a much longer input whilch should be padded"]
- )
- input_r = tokenizer_r.pad(input_r)
-
- input_p = tokenizer_r.batch_encode_plus(
- ["This is a input 1", "This is a much longer input whilch should be padded"]
- )
- input_p = tokenizer_r.pad(input_p)
-
- assert_batch_padded_input_match(input_r, input_p, len(input_r["input_ids"][0]))
-
- # Using pad after tokenization
- input_r = tokenizer_r.batch_encode_plus(
- ["This is a input 1", "This is a much longer input whilch should be padded"]
- )
- input_r = tokenizer_r.pad(input_r, max_length=max_length, padding="max_length")
-
- input_p = tokenizer_r.batch_encode_plus(
- ["This is a input 1", "This is a much longer input whilch should be padded"]
- )
- input_p = tokenizer_r.pad(input_p, max_length=max_length, padding="max_length")
-
- assert_batch_padded_input_match(input_r, input_p, max_length)
-
- def assert_save_pretrained(self, tokenizer_r, tokenizer_p):
- # Checks it save with the same files
- self.assertSequenceEqual(tokenizer_r.save_vocabulary("."), tokenizer_p.save_vocabulary("."))
-
- # Checks everything loads correctly in the same way
- tokenizer_rp, tokenizer_pp = tokenizer_r.from_pretrained("."), tokenizer_p.from_pretrained(".")
-
- # Check special tokens are set accordingly on Rust and Python
- for key in tokenizer_pp.special_tokens_map:
- self.assertTrue(hasattr(tokenizer_rp, key))
- # self.assertEqual(getattr(tokenizer_rp, key), getattr(tokenizer_pp, key))
- # self.assertEqual(getattr(tokenizer_rp, key + "_id"), getattr(tokenizer_pp, key + "_id"))
-
- def assert_embeded_special_tokens(self, tokenizer_r, tokenizer_p):
- sentence = "A, AllenNLP sentence."
- tokens_r = tokenizer_r.encode_plus(
- sentence, add_special_tokens=True, return_attention_mask=False, return_token_type_ids=True
- )
- tokens_p = tokenizer_p.encode_plus(
- sentence, add_special_tokens=True, return_attention_mask=False, return_token_type_ids=True
- )
-
- for key in tokens_p.keys():
- self.assertEqual(tokens_r[key], tokens_p[key])
-
- self.assertEqual(sum(tokens_r["token_type_ids"]), 0)
- self.assertEqual(sum(tokens_p["token_type_ids"]), 0)
-
- tokens_r = tokenizer_r.convert_ids_to_tokens(tokens_r["input_ids"])
- tokens_p = tokenizer_p.convert_ids_to_tokens(tokens_p["input_ids"])
- self.assertSequenceEqual(tokens_r, tokens_p)
-
- def assert_add_special_tokens(self, tokenizer_r):
- simple_num_special_tokens_to_add = tokenizer_r.num_special_tokens_to_add(pair=False)
- # pair_num_special_tokens_to_add = tokenizer_r.num_special_tokens_to_add(pair=True)
-
- for text in ["", " "]:
- # tokenize()
- no_special_tokens = tokenizer_r.tokenize(text, add_special_tokens=False)
- with_special_tokens = tokenizer_r.tokenize(text, add_special_tokens=True)
- self.assertEqual(len(no_special_tokens), len(with_special_tokens) - simple_num_special_tokens_to_add)
-
- # encode()
- no_special_tokens = tokenizer_r.encode(text, add_special_tokens=False)
- with_special_tokens = tokenizer_r.encode(text, add_special_tokens=True)
- self.assertEqual(len(no_special_tokens), len(with_special_tokens) - simple_num_special_tokens_to_add)
-
- # encode_plus()
- no_special_tokens = tokenizer_r.encode_plus(text, add_special_tokens=False)
- with_special_tokens = tokenizer_r.encode_plus(text, add_special_tokens=True)
- for key in no_special_tokens.keys():
- self.assertEqual(
- len(no_special_tokens[key]), len(with_special_tokens[key]) - simple_num_special_tokens_to_add
+ # Mono sample
+ tokens = tokenizer.batch_encode_plus(
+ ["HuggingFace is solving NLP one commit at a time"],
+ max_length=6,
+ padding=True,
+ truncation="only_first",
+ return_tensors=returned_tensor,
+ return_overflowing_tokens=True,
)
- # # batch_encode_plus
- no_special_tokens = tokenizer_r.batch_encode_plus([text, text], add_special_tokens=False)
- with_special_tokens = tokenizer_r.batch_encode_plus([text, text], add_special_tokens=True)
- for key in no_special_tokens.keys():
- for i_no, i_with in zip(no_special_tokens[key], with_special_tokens[key]):
- self.assertEqual(len(i_no), len(i_with) - simple_num_special_tokens_to_add)
+ for key in filter(lambda x: "overflow_to_sample_mapping" not in x, tokens.keys()):
+ self.assertEqual(len(tokens[key].shape), 2)
+ self.assertEqual(tokens[key].shape[-1], 6)
- def assert_prepare_for_model(self, tokenizer_r, tokenizer_p):
- string_sequence = "Asserting that both tokenizers are equal"
- python_output = tokenizer_p.prepare_for_model(tokenizer_p.encode(string_sequence))
- rust_output = tokenizer_r.prepare_for_model(tokenizer_r.encode(string_sequence))
- self.assertEqual(python_output, rust_output)
+ # Multi sample
+ tokens = tokenizer.batch_encode_plus(
+ ["HuggingFace is solving NLP one commit at a time", "Very tiny input"],
+ max_length=6,
+ padding=True,
+ truncation="only_first",
+ return_tensors=returned_tensor,
+ return_overflowing_tokens=True,
+ )
+
+ for key in filter(lambda x: "overflow_to_sample_mapping" not in x, tokens.keys()):
+ self.assertEqual(len(tokens[key].shape), 2)
+ self.assertEqual(tokens[key].shape[-1], 6)
+
+ def test_pretokenized_inputs(self):
+ for tok_case, pretrained_name, kwargs in self.tokenizers_list:
+ with self.subTest("{} ({})".format(tok_case.name, pretrained_name)):
+ tokenizer_r = tok_case.rust_cls.from_pretrained(pretrained_name, **kwargs)
+ tokenizer_p = tok_case.python_cls.from_pretrained(pretrained_name, **kwargs)
+ # Input string
+ pretokenized_input_simple = "This is a sample input".split()
+ pretokenized_input_pair = "This is a sample pair".split()
+
+ # Test encode for pretokenized inputs
+ output_r = tokenizer_r.encode(
+ pretokenized_input_simple, is_split_into_words=True, add_special_tokens=False
+ )
+ output_p = tokenizer_p.encode(
+ pretokenized_input_simple, is_split_into_words=True, add_special_tokens=False
+ )
+ self.assertEqual(output_p, output_r)
+
+ kwargs = {
+ "is_split_into_words": True,
+ # "return_token_type_ids": True, # Use the defaults for each tokenizers
+ # "return_attention_mask": True, # Use the defaults for each tokenizers
+ "return_overflowing_tokens": False,
+ "return_special_tokens_mask": True,
+ "return_offsets_mapping": False, # Not implemented in python tokenizers
+ # "add_special_tokens": False,
+ }
+ batch_kwargs = {
+ "is_split_into_words": True,
+ # "return_token_type_ids": True, # Use the defaults for each tokenizers
+ # "return_attention_mask": True, # Use the defaults for each tokenizers
+ "return_overflowing_tokens": False,
+ "return_special_tokens_mask": True,
+ "return_offsets_mapping": False, # Not implemented in python tokenizers
+ # "add_special_tokens": False,
+ }
+ # Test encode_plus for pretokenized inputs
+ output_r = tokenizer_r.encode_plus(pretokenized_input_simple, **kwargs)
+ output_p = tokenizer_p.encode_plus(pretokenized_input_simple, **kwargs)
+ for key in output_p.keys():
+ self.assertEqual(output_p[key], output_r[key])
+
+ # Test batch_encode_plus for pretokenized inputs
+ input_batch = ([pretokenized_input_simple] * 2) + [pretokenized_input_simple + pretokenized_input_pair]
+ output_r = tokenizer_r.batch_encode_plus(input_batch, **batch_kwargs)
+ output_p = tokenizer_p.batch_encode_plus(input_batch, **batch_kwargs)
+ for key in output_p.keys():
+ self.assertEqual(output_p[key], output_r[key])
+
+ # Test encode for pretokenized inputs pairs
+ output_r = tokenizer_r.encode(
+ pretokenized_input_simple, pretokenized_input_pair, is_split_into_words=True
+ )
+ output_p = tokenizer_p.encode(
+ pretokenized_input_simple, pretokenized_input_pair, is_split_into_words=True
+ )
+ self.assertEqual(output_p, output_r)
+
+ # Test encode_plus for pretokenized inputs
+ output_r = tokenizer_r.encode_plus(pretokenized_input_simple, pretokenized_input_pair, **kwargs)
+ output_p = tokenizer_p.encode_plus(pretokenized_input_simple, pretokenized_input_pair, **kwargs)
+ for key in output_p.keys():
+ self.assertEqual(output_p[key], output_r[key])
+
+ # Test batch_encode_plus for pretokenized inputs
+ input_batch_pair = ([pretokenized_input_simple, pretokenized_input_pair] * 2) + [
+ pretokenized_input_simple + pretokenized_input_pair,
+ pretokenized_input_pair,
+ ]
+ output_r = tokenizer_r.batch_encode_plus(input_batch_pair, **batch_kwargs)
+ output_p = tokenizer_p.batch_encode_plus(input_batch_pair, **batch_kwargs)
+ for key in output_p.keys():
+ self.assertEqual(output_p[key], output_r[key])
+
+ def test_create_token_type_ids(self):
+ for tok_case, pretrained_name, kwargs in self.tokenizers_list:
+ with self.subTest("{} ({})".format(tok_case.name, pretrained_name)):
+ tokenizer_r = tok_case.rust_cls.from_pretrained(pretrained_name, **kwargs)
+ tokenizer_p = tok_case.python_cls.from_pretrained(pretrained_name, **kwargs)
+ input_simple = [1, 2, 3]
+ input_pair = [1, 2, 3]
+
+ # Generate output
+ output_r = tokenizer_r.create_token_type_ids_from_sequences(input_simple)
+ output_p = tokenizer_p.create_token_type_ids_from_sequences(input_simple)
+ self.assertEqual(output_p, output_r)
+
+ # Generate pair output
+ output_r = tokenizer_r.create_token_type_ids_from_sequences(input_simple, input_pair)
+ output_p = tokenizer_p.create_token_type_ids_from_sequences(input_simple, input_pair)
+ self.assertEqual(output_p, output_r)
+
+ def test_build_inputs_with_special_tokens(self):
+ for tok_case, pretrained_name, kwargs in self.tokenizers_list:
+ with self.subTest("{} ({})".format(tok_case.name, pretrained_name)):
+ tokenizer_r = tok_case.rust_cls.from_pretrained(pretrained_name, **kwargs)
+ tokenizer_p = tok_case.python_cls.from_pretrained(pretrained_name, **kwargs)
+ # # Input string
+ # input_simple = tokenizer_p.tokenize("This is a sample input", add_special_tokens=False)
+ # input_pair = tokenizer_p.tokenize("This is a sample pair", add_special_tokens=False)
+
+ # # Generate output
+ # output_r = tokenizer_r.build_inputs_with_special_tokens(input_simple)
+ # output_p = tokenizer_p.build_inputs_with_special_tokens(input_simple)
+ # self.assertEqual(output_p, output_r)
+
+ # # Generate pair output
+ # output_r = tokenizer_r.build_inputs_with_special_tokens(input_simple, input_pair)
+ # output_p = tokenizer_p.build_inputs_with_special_tokens(input_simple, input_pair)
+ # self.assertEqual(output_p, output_r)
+
+ # Input tokens id
+ input_simple = tokenizer_p.encode("This is a sample input", add_special_tokens=False)
+ input_pair = tokenizer_p.encode("This is a sample pair", add_special_tokens=False)
+
+ # Generate output
+ output_r = tokenizer_r.build_inputs_with_special_tokens(input_simple)
+ output_p = tokenizer_p.build_inputs_with_special_tokens(input_simple)
+ self.assertEqual(output_p, output_r)
+
+ # Generate pair output
+ output_r = tokenizer_r.build_inputs_with_special_tokens(input_simple, input_pair)
+ output_p = tokenizer_p.build_inputs_with_special_tokens(input_simple, input_pair)
+ self.assertEqual(output_p, output_r)
+
+ def test_padding(self, max_length=50):
+ for tok_case, pretrained_name, kwargs in self.tokenizers_list:
+ with self.subTest("{} ({})".format(tok_case.name, pretrained_name)):
+ tokenizer_r = tok_case.rust_cls.from_pretrained(pretrained_name, **kwargs)
+ tokenizer_p = tok_case.python_cls.from_pretrained(pretrained_name, **kwargs)
+
+ def assert_padded_input_match(input_r: list, input_p: list, max_length: int):
+
+ # Ensure we match max_length
+ self.assertEqual(len(input_r), max_length)
+ self.assertEqual(len(input_p), max_length)
+
+ # Ensure the number of padded tokens is the same
+ padded_tokens_r = list(takewhile(lambda i: i == tokenizer_r.pad_token_id, reversed(input_r)))
+ padded_tokens_p = list(takewhile(lambda i: i == tokenizer_p.pad_token_id, reversed(input_p)))
+ self.assertSequenceEqual(padded_tokens_r, padded_tokens_p)
+
+ def assert_batch_padded_input_match(input_r: dict, input_p: dict, max_length: int):
+ for i_r in input_r.values():
+ self.assertEqual(len(i_r), 2), self.assertEqual(len(i_r[0]), max_length), self.assertEqual(
+ len(i_r[1]), max_length
+ )
+ self.assertEqual(len(i_r), 2), self.assertEqual(len(i_r[0]), max_length), self.assertEqual(
+ len(i_r[1]), max_length
+ )
+
+ for i_r, i_p in zip(input_r["input_ids"], input_p["input_ids"]):
+ assert_padded_input_match(i_r, i_p, max_length)
+
+ for i_r, i_p in zip(input_r["attention_mask"], input_p["attention_mask"]):
+ self.assertSequenceEqual(i_r, i_p)
+
+ # Encode - Simple input
+ input_r = tokenizer_r.encode("This is a simple input", max_length=max_length, pad_to_max_length=True)
+ input_p = tokenizer_p.encode("This is a simple input", max_length=max_length, pad_to_max_length=True)
+ assert_padded_input_match(input_r, input_p, max_length)
+ input_r = tokenizer_r.encode("This is a simple input", max_length=max_length, padding="max_length")
+ input_p = tokenizer_p.encode("This is a simple input", max_length=max_length, padding="max_length")
+ assert_padded_input_match(input_r, input_p, max_length)
+
+ input_r = tokenizer_r.encode("This is a simple input", padding="longest")
+ input_p = tokenizer_p.encode("This is a simple input", padding=True)
+ assert_padded_input_match(input_r, input_p, len(input_r))
+
+ # Encode - Pair input
+ input_r = tokenizer_r.encode(
+ "This is a simple input", "This is a pair", max_length=max_length, pad_to_max_length=True
+ )
+ input_p = tokenizer_p.encode(
+ "This is a simple input", "This is a pair", max_length=max_length, pad_to_max_length=True
+ )
+ assert_padded_input_match(input_r, input_p, max_length)
+ input_r = tokenizer_r.encode(
+ "This is a simple input", "This is a pair", max_length=max_length, padding="max_length"
+ )
+ input_p = tokenizer_p.encode(
+ "This is a simple input", "This is a pair", max_length=max_length, padding="max_length"
+ )
+ assert_padded_input_match(input_r, input_p, max_length)
+ input_r = tokenizer_r.encode("This is a simple input", "This is a pair", padding=True)
+ input_p = tokenizer_p.encode("This is a simple input", "This is a pair", padding="longest")
+ assert_padded_input_match(input_r, input_p, len(input_r))
+
+ # Encode_plus - Simple input
+ input_r = tokenizer_r.encode_plus(
+ "This is a simple input", max_length=max_length, pad_to_max_length=True
+ )
+ input_p = tokenizer_p.encode_plus(
+ "This is a simple input", max_length=max_length, pad_to_max_length=True
+ )
+ assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length)
+ self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
+ input_r = tokenizer_r.encode_plus(
+ "This is a simple input", max_length=max_length, padding="max_length"
+ )
+ input_p = tokenizer_p.encode_plus(
+ "This is a simple input", max_length=max_length, padding="max_length"
+ )
+ assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length)
+ self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
+
+ input_r = tokenizer_r.encode_plus("This is a simple input", padding="longest")
+ input_p = tokenizer_p.encode_plus("This is a simple input", padding=True)
+ assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], len(input_r["input_ids"]))
+
+ self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
+
+ # Encode_plus - Pair input
+ input_r = tokenizer_r.encode_plus(
+ "This is a simple input", "This is a pair", max_length=max_length, pad_to_max_length=True
+ )
+ input_p = tokenizer_p.encode_plus(
+ "This is a simple input", "This is a pair", max_length=max_length, pad_to_max_length=True
+ )
+ assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length)
+ self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
+ input_r = tokenizer_r.encode_plus(
+ "This is a simple input", "This is a pair", max_length=max_length, padding="max_length"
+ )
+ input_p = tokenizer_p.encode_plus(
+ "This is a simple input", "This is a pair", max_length=max_length, padding="max_length"
+ )
+ assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length)
+ self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
+ input_r = tokenizer_r.encode_plus("This is a simple input", "This is a pair", padding="longest")
+ input_p = tokenizer_p.encode_plus("This is a simple input", "This is a pair", padding=True)
+ assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], len(input_r["input_ids"]))
+ self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
+
+ # Batch_encode_plus - Simple input
+ input_r = tokenizer_r.batch_encode_plus(
+ ["This is a simple input 1", "This is a simple input 2"],
+ max_length=max_length,
+ pad_to_max_length=True,
+ )
+ input_p = tokenizer_p.batch_encode_plus(
+ ["This is a simple input 1", "This is a simple input 2"],
+ max_length=max_length,
+ pad_to_max_length=True,
+ )
+ assert_batch_padded_input_match(input_r, input_p, max_length)
+
+ input_r = tokenizer_r.batch_encode_plus(
+ ["This is a simple input 1", "This is a simple input 2"],
+ max_length=max_length,
+ padding="max_length",
+ )
+ input_p = tokenizer_p.batch_encode_plus(
+ ["This is a simple input 1", "This is a simple input 2"],
+ max_length=max_length,
+ padding="max_length",
+ )
+ assert_batch_padded_input_match(input_r, input_p, max_length)
+
+ input_r = tokenizer_r.batch_encode_plus(
+ ["This is a simple input 1", "This is a simple input 2"],
+ max_length=max_length,
+ padding="longest",
+ )
+ input_p = tokenizer_p.batch_encode_plus(
+ ["This is a simple input 1", "This is a simple input 2"],
+ max_length=max_length,
+ padding=True,
+ )
+ assert_batch_padded_input_match(input_r, input_p, len(input_r["input_ids"][0]))
+
+ input_r = tokenizer_r.batch_encode_plus(
+ ["This is a simple input 1", "This is a simple input 2"], padding="longest"
+ )
+ input_p = tokenizer_p.batch_encode_plus(
+ ["This is a simple input 1", "This is a simple input 2"], padding=True
+ )
+ assert_batch_padded_input_match(input_r, input_p, len(input_r["input_ids"][0]))
+
+ # Batch_encode_plus - Pair input
+ input_r = tokenizer_r.batch_encode_plus(
+ [
+ ("This is a simple input 1", "This is a simple input 2"),
+ ("This is a simple pair 1", "This is a simple pair 2"),
+ ],
+ max_length=max_length,
+ truncation=True,
+ padding="max_length",
+ )
+ input_p = tokenizer_p.batch_encode_plus(
+ [
+ ("This is a simple input 1", "This is a simple input 2"),
+ ("This is a simple pair 1", "This is a simple pair 2"),
+ ],
+ max_length=max_length,
+ truncation=True,
+ padding="max_length",
+ )
+ assert_batch_padded_input_match(input_r, input_p, max_length)
+
+ input_r = tokenizer_r.batch_encode_plus(
+ [
+ ("This is a simple input 1", "This is a simple input 2"),
+ ("This is a simple pair 1", "This is a simple pair 2"),
+ ],
+ padding=True,
+ )
+ input_p = tokenizer_p.batch_encode_plus(
+ [
+ ("This is a simple input 1", "This is a simple input 2"),
+ ("This is a simple pair 1", "This is a simple pair 2"),
+ ],
+ padding="longest",
+ )
+ assert_batch_padded_input_match(input_r, input_p, len(input_r["input_ids"][0]))
+
+ # Using pad on single examples after tokenization
+ input_r = tokenizer_r.encode_plus("This is a input 1")
+ input_r = tokenizer_r.pad(input_r)
+
+ input_p = tokenizer_r.encode_plus("This is a input 1")
+ input_p = tokenizer_r.pad(input_p)
+
+ assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], len(input_r["input_ids"]))
+
+ # Using pad on single examples after tokenization
+ input_r = tokenizer_r.encode_plus("This is a input 1")
+ input_r = tokenizer_r.pad(input_r, max_length=max_length, padding="max_length")
+
+ input_p = tokenizer_r.encode_plus("This is a input 1")
+ input_p = tokenizer_r.pad(input_p, max_length=max_length, padding="max_length")
+
+ assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length)
+
+ # Using pad after tokenization
+ input_r = tokenizer_r.batch_encode_plus(
+ ["This is a input 1", "This is a much longer input whilch should be padded"]
+ )
+ input_r = tokenizer_r.pad(input_r)
+
+ input_p = tokenizer_r.batch_encode_plus(
+ ["This is a input 1", "This is a much longer input whilch should be padded"]
+ )
+ input_p = tokenizer_r.pad(input_p)
+
+ assert_batch_padded_input_match(input_r, input_p, len(input_r["input_ids"][0]))
+
+ # Using pad after tokenization
+ input_r = tokenizer_r.batch_encode_plus(
+ ["This is a input 1", "This is a much longer input whilch should be padded"]
+ )
+ input_r = tokenizer_r.pad(input_r, max_length=max_length, padding="max_length")
+
+ input_p = tokenizer_r.batch_encode_plus(
+ ["This is a input 1", "This is a much longer input whilch should be padded"]
+ )
+ input_p = tokenizer_r.pad(input_p, max_length=max_length, padding="max_length")
+
+ assert_batch_padded_input_match(input_r, input_p, max_length)
+
+ def test_save_pretrained(self):
+ for tok_case, pretrained_name, kwargs in self.tokenizers_list:
+ with self.subTest("{} ({})".format(tok_case.name, pretrained_name)):
+ tokenizer_r = tok_case.rust_cls.from_pretrained(pretrained_name, **kwargs)
+ tokenizer_p = tok_case.python_cls.from_pretrained(pretrained_name, **kwargs)
+ # Checks it save with the same files
+ self.assertSequenceEqual(
+ tokenizer_r.save_vocabulary(self.tmpdirname), tokenizer_p.save_vocabulary(self.tmpdirname)
+ )
+
+ # Checks everything loads correctly in the same way
+ tokenizer_rp, tokenizer_pp = tokenizer_r.from_pretrained(self.tmpdirname), tokenizer_p.from_pretrained(
+ self.tmpdirname
+ )
+
+ # Check special tokens are set accordingly on Rust and Python
+ for key in tokenizer_pp.special_tokens_map:
+ self.assertTrue(hasattr(tokenizer_rp, key))
+ # self.assertEqual(getattr(tokenizer_rp, key), getattr(tokenizer_pp, key))
+ # self.assertEqual(getattr(tokenizer_rp, key + "_id"), getattr(tokenizer_pp, key + "_id"))
+
+ def test_embeded_special_tokens(self):
+ for tok_case, pretrained_name, kwargs in self.tokenizers_list:
+ with self.subTest("{} ({})".format(tok_case.name, pretrained_name)):
+ tokenizer_r = tok_case.rust_cls.from_pretrained(pretrained_name, **kwargs)
+ tokenizer_p = tok_case.python_cls.from_pretrained(pretrained_name, **kwargs)
+ sentence = "A, AllenNLP sentence."
+ tokens_r = tokenizer_r.encode_plus(
+ sentence,
+ add_special_tokens=True,
+ )
+ tokens_p = tokenizer_p.encode_plus(
+ sentence,
+ add_special_tokens=True,
+ )
+
+ for key in tokens_p.keys():
+ self.assertEqual(tokens_r[key], tokens_p[key])
+
+ if "token_type_ids" in tokens_r:
+ self.assertEqual(sum(tokens_r["token_type_ids"]), sum(tokens_p["token_type_ids"]))
+
+ tokens_r = tokenizer_r.convert_ids_to_tokens(tokens_r["input_ids"])
+ tokens_p = tokenizer_p.convert_ids_to_tokens(tokens_p["input_ids"])
+ self.assertSequenceEqual(tokens_r, tokens_p)
+
+ def test_add_special_tokens(self):
+ for tok_case, pretrained_name, kwargs in self.tokenizers_list:
+ with self.subTest("{} ({})".format(tok_case.name, pretrained_name)):
+ tokenizer_r = tok_case.rust_cls.from_pretrained(pretrained_name, **kwargs)
+
+ simple_num_special_tokens_to_add = tokenizer_r.num_special_tokens_to_add(pair=False)
+ # pair_num_special_tokens_to_add = tokenizer_r.num_special_tokens_to_add(pair=True)
+
+ for text in ["", " "]:
+ # tokenize()
+ no_special_tokens = tokenizer_r.tokenize(text, add_special_tokens=False)
+ with_special_tokens = tokenizer_r.tokenize(text, add_special_tokens=True)
+ self.assertEqual(
+ len(no_special_tokens), len(with_special_tokens) - simple_num_special_tokens_to_add
+ )
+
+ # encode()
+ no_special_tokens = tokenizer_r.encode(text, add_special_tokens=False)
+ with_special_tokens = tokenizer_r.encode(text, add_special_tokens=True)
+ self.assertEqual(
+ len(no_special_tokens), len(with_special_tokens) - simple_num_special_tokens_to_add
+ )
+
+ # encode_plus()
+ no_special_tokens = tokenizer_r.encode_plus(text, add_special_tokens=False)
+ with_special_tokens = tokenizer_r.encode_plus(text, add_special_tokens=True)
+ for key in no_special_tokens.keys():
+ self.assertEqual(
+ len(no_special_tokens[key]),
+ len(with_special_tokens[key]) - simple_num_special_tokens_to_add,
+ )
+
+ # # batch_encode_plus
+ no_special_tokens = tokenizer_r.batch_encode_plus([text, text], add_special_tokens=False)
+ with_special_tokens = tokenizer_r.batch_encode_plus([text, text], add_special_tokens=True)
+ for key in no_special_tokens.keys():
+ for i_no, i_with in zip(no_special_tokens[key], with_special_tokens[key]):
+ self.assertEqual(len(i_no), len(i_with) - simple_num_special_tokens_to_add)
+
+ def test_prepare_for_model(self):
+ for tok_case, pretrained_name, kwargs in self.tokenizers_list:
+ with self.subTest("{} ({})".format(tok_case.name, pretrained_name)):
+ tokenizer_r = tok_case.rust_cls.from_pretrained(pretrained_name, **kwargs)
+ tokenizer_p = tok_case.python_cls.from_pretrained(pretrained_name, **kwargs)
+ string_sequence = "Asserting that both tokenizers are equal"
+ python_output = tokenizer_p.prepare_for_model(
+ tokenizer_p.encode(string_sequence, add_special_tokens=False)
+ )
+ rust_output = tokenizer_r.prepare_for_model(
+ tokenizer_r.encode(string_sequence, add_special_tokens=False)
+ )
+ for key in python_output:
+ self.assertEqual(python_output[key], rust_output[key])
class WordPieceFastTokenizerTest(CommonFastTokenizerTest):
@@ -733,61 +856,86 @@ class WordPieceFastTokenizerTest(CommonFastTokenizerTest):
Tokenizer(
"DistilBert", DistilBertTokenizerFast, DistilBertTokenizer, "vocab_file", filter_non_english, None
),
+ Tokenizer(
+ "DPRReaderTokenizer",
+ DPRReaderTokenizerFast,
+ DPRReaderTokenizer,
+ "vocab_file",
+ filter_non_english,
+ None,
+ ),
+ Tokenizer(
+ "DPRQuestionEncoderTokenizer",
+ DPRQuestionEncoderTokenizerFast,
+ DPRQuestionEncoderTokenizer,
+ "vocab_file",
+ filter_non_english,
+ None,
+ ),
+ Tokenizer(
+ "DPRContextEncoderTokenizer",
+ DPRContextEncoderTokenizerFast,
+ DPRContextEncoderTokenizer,
+ "vocab_file",
+ filter_non_english,
+ None,
+ ),
+ Tokenizer("FunnelTokenizer", FunnelTokenizerFast, FunnelTokenizer, "vocab_file", filter_non_english, None),
+ Tokenizer("LxmertTokenizer", LxmertTokenizerFast, LxmertTokenizer, "vocab_file", filter_non_english, None),
]
)
- def fast_only(self, tokenizer_r):
- super().fast_only(tokenizer_r)
- self.assert_offsets_with_special_characters(tokenizer_r)
+ def test_offsets_with_special_characters(self):
+ for tok_case, pretrained_name, kwargs in self.tokenizers_list:
+ with self.subTest("{} ({})".format(tok_case.name, pretrained_name)):
+ tokenizer_r = tok_case.rust_cls.from_pretrained(pretrained_name, **kwargs)
- def assert_add_special_tokens(self, tokenizer_r):
- super().assert_add_special_tokens(tokenizer_r)
+ sentence = f"A, naïve {tokenizer_r.mask_token} AllenNLP sentence."
+ tokens = tokenizer_r.encode_plus(
+ sentence,
+ return_attention_mask=False,
+ return_token_type_ids=False,
+ return_offsets_mapping=True,
+ add_special_tokens=True,
+ )
- def assert_offsets_with_special_characters(self, tokenizer_r):
- sentence = "A, naïve [MASK] AllenNLP sentence."
- tokens = tokenizer_r.encode_plus(
- sentence,
- return_attention_mask=False,
- return_token_type_ids=False,
- return_offsets_mapping=True,
- add_special_tokens=True,
- )
+ do_lower_case = tokenizer_r.do_lower_case if hasattr(tokenizer_r, "do_lower_case") else False
+ expected_results = (
+ [
+ ((0, 0), tokenizer_r.cls_token),
+ ((0, 1), "A"),
+ ((1, 2), ","),
+ ((3, 5), "na"),
+ ((5, 6), "##ï"),
+ ((6, 8), "##ve"),
+ ((9, 15), tokenizer_r.mask_token),
+ ((16, 21), "Allen"),
+ ((21, 23), "##NL"),
+ ((23, 24), "##P"),
+ ((25, 33), "sentence"),
+ ((33, 34), "."),
+ ((0, 0), tokenizer_r.sep_token),
+ ]
+ if not do_lower_case
+ else [
+ ((0, 0), tokenizer_r.cls_token),
+ ((0, 1), "a"),
+ ((1, 2), ","),
+ ((3, 8), "naive"),
+ ((9, 15), tokenizer_r.mask_token),
+ ((16, 21), "allen"),
+ ((21, 23), "##nl"),
+ ((23, 24), "##p"),
+ ((25, 33), "sentence"),
+ ((33, 34), "."),
+ ((0, 0), tokenizer_r.sep_token),
+ ]
+ )
- do_lower_case = tokenizer_r.init_kwargs.get("do_lower_case")
- expected_results = (
- [
- ((0, 0), "[CLS]"),
- ((0, 1), "A"),
- ((1, 2), ","),
- ((3, 5), "na"),
- ((5, 6), "##ï"),
- ((6, 8), "##ve"),
- ((9, 15), "[MASK]"),
- ((16, 21), "Allen"),
- ((21, 23), "##NL"),
- ((23, 24), "##P"),
- ((25, 33), "sentence"),
- ((33, 34), "."),
- ((0, 0), "[SEP]"),
- ]
- if not do_lower_case
- else [
- ((0, 0), "[CLS]"),
- ((0, 1), "a"),
- ((1, 2), ","),
- ((3, 8), "naive"),
- ((9, 15), "[MASK]"),
- ((16, 21), "allen"),
- ((21, 23), "##nl"),
- ((23, 24), "##p"),
- ((25, 33), "sentence"),
- ((33, 34), "."),
- ((0, 0), "[SEP]"),
- ]
- )
-
- self.assertEqual([e[1] for e in expected_results], tokenizer_r.convert_ids_to_tokens(tokens["input_ids"]))
- self.assertEqual([e[0] for e in expected_results], tokens["offset_mapping"])
+ self.assertEqual(
+ [e[1] for e in expected_results], tokenizer_r.convert_ids_to_tokens(tokens["input_ids"])
+ )
+ self.assertEqual([e[0] for e in expected_results], tokens["offset_mapping"])
class RobertaFastTokenizerTest(CommonFastTokenizerTest):
@@ -800,32 +948,52 @@ class RobertaFastTokenizerTest(CommonFastTokenizerTest):
"vocab_file",
filter_roberta_detectors,
(("cls_token", ""),),
- )
+ ),
+ Tokenizer(
+ "Bart",
+ BartTokenizerFast,
+ BartTokenizer,
+ "vocab_file",
+ None,
+ None,
+ ),
]
)
- def assert_embeded_special_tokens(self, tokenizer_r, tokenizer_p):
- sentence = "A, AllenNLP sentence."
- tokens_r = tokenizer_r.encode_plus(sentence, add_special_tokens=True, return_token_type_ids=True)
- tokens_p = tokenizer_p.encode_plus(sentence, add_special_tokens=True, return_token_type_ids=True)
+ def test_pretokenized_inputs(self):
+ pass
- # Rust correctly handles the space before the mask while python doesnt
- self.assertSequenceEqual(tokens_r["input_ids"], [0, 250, 6, 50264, 3823, 487, 21992, 3645, 4, 2])
- self.assertSequenceEqual(tokens_p["input_ids"], [0, 250, 6, 50264, 3823, 487, 21992, 3645, 4, 2])
+ def test_embeded_special_tokens(self):
+ for tok_case, pretrained_name, kwargs in self.tokenizers_list:
+ with self.subTest("{} ({})".format(tok_case.name, pretrained_name)):
+ tokenizer_r = tok_case.rust_cls.from_pretrained(pretrained_name, **kwargs)
+ tokenizer_p = tok_case.python_cls.from_pretrained(pretrained_name, **kwargs)
+ sentence = "A, AllenNLP sentence."
+ tokens_r = tokenizer_r.encode_plus(sentence, add_special_tokens=True, return_token_type_ids=True)
+ tokens_p = tokenizer_p.encode_plus(sentence, add_special_tokens=True, return_token_type_ids=True)
- # token_type_ids should put 0 everywhere
- self.assertEqual(sum(tokens_r["token_type_ids"]), sum(tokens_p["token_type_ids"]))
+ # token_type_ids should put 0 everywhere
+ self.assertEqual(sum(tokens_r["token_type_ids"]), sum(tokens_p["token_type_ids"]))
- # attention_mask should put 1 everywhere, so sum over length should be 1
- self.assertEqual(
- sum(tokens_r["attention_mask"]) / len(tokens_r["attention_mask"]),
- sum(tokens_p["attention_mask"]) / len(tokens_p["attention_mask"]),
- )
+ # attention_mask should put 1 everywhere, so sum over length should be 1
+ self.assertEqual(
+ sum(tokens_r["attention_mask"]) / len(tokens_r["attention_mask"]),
+ sum(tokens_p["attention_mask"]) / len(tokens_p["attention_mask"]),
+ )
- tokens_r = tokenizer_r.convert_ids_to_tokens(tokens_r["input_ids"])
- tokens_p = tokenizer_p.convert_ids_to_tokens(tokens_p["input_ids"])
- self.assertSequenceEqual(tokens_r, ["", "A", ",", "", "ĠAllen", "N", "LP", "Ġsentence", ".", ""])
- self.assertSequenceEqual(tokens_p, ["", "A", ",", "", "ĠAllen", "N", "LP", "Ġsentence", ".", ""])
+ tokens_r_str = tokenizer_r.convert_ids_to_tokens(tokens_r["input_ids"])
+ tokens_p_str = tokenizer_p.convert_ids_to_tokens(tokens_p["input_ids"])
+
+ # Rust correctly handles the space before the mask while python doesnt
+ self.assertSequenceEqual(tokens_p["input_ids"], [0, 250, 6, 50264, 3823, 487, 21992, 3645, 4, 2])
+ self.assertSequenceEqual(tokens_r["input_ids"], [0, 250, 6, 50264, 3823, 487, 21992, 3645, 4, 2])
+
+ self.assertSequenceEqual(
+ tokens_p_str, ["", "A", ",", "", "ĠAllen", "N", "LP", "Ġsentence", ".", ""]
+ )
+ self.assertSequenceEqual(
+ tokens_r_str, ["", "A", ",", "", "ĠAllen", "N", "LP", "Ġsentence", ".", ""]
+ )
class NoPaddingTokenFastTokenizerMatchingTest(CommonFastTokenizerTest):
@@ -834,62 +1002,75 @@ class NoPaddingTokenFastTokenizerMatchingTest(CommonFastTokenizerTest):
Tokenizer("GPT2", GPT2TokenizerFast, GPT2Tokenizer, "vocab_file", None, [("add_prefix_space", True)]),
]
- def fast_align_python(self, tokenizer_r, tokenizer_p, tok_case, pretrained_name):
- # Check is_fast is set correctly
- self.assertFalse(tokenizer_p.is_fast)
- self.assertTrue(tokenizer_r.is_fast)
+ def test_pretokenized_inputs(self):
+ pass
- # Check that Rust and Python align
- self.assert_tokenization_python_rust_equals(tokenizer_r, tokenizer_p)
- self.assert_num_special_tokens_to_add_equal(tokenizer_r, tokenizer_p)
- self.assert_max_length_equal(tokenizer_r, tokenizer_p)
- self.assert_special_tokens_map_equal(tokenizer_r, tokenizer_p)
- self.assert_embeded_special_tokens(tokenizer_r, tokenizer_p)
- self.assert_padding(tokenizer_r, tokenizer_p)
+ def test_padding(self, max_length=15):
+ for tok_case, pretrained_name, kwargs in self.tokenizers_list:
+ with self.subTest("{} ({})".format(tok_case.name, pretrained_name)):
+ tokenizer_r = tok_case.rust_cls.from_pretrained(pretrained_name, **kwargs)
- # Specific for
- kwargs = {}
- if tok_case.kwargs is not None:
- kwargs = dict(tok_case.kwargs)
- tokenizer_r = tok_case.rust_cls.from_pretrained(pretrained_name, **kwargs)
- self.assert_pretokenized_inputs(tokenizer_r, tokenizer_p)
+ # Simple input
+ s = "This is a simple input"
+ s2 = ["This is a simple input 1", "This is a simple input 2"]
+ p = ("This is a simple input", "This is a pair")
+ p2 = [
+ ("This is a simple input 1", "This is a simple input 2"),
+ ("This is a simple pair 1", "This is a simple pair 2"),
+ ]
- def assert_padding(self, tokenizer_r, tokenizer_p, max_length=15):
- # Simple input
- s = "This is a simple input"
- s2 = ["This is a simple input 1", "This is a simple input 2"]
- p = ("This is a simple input", "This is a pair")
- p2 = [
- ("This is a simple input 1", "This is a simple input 2"),
- ("This is a simple pair 1", "This is a simple pair 2"),
+ # Simple input tests
+ self.assertRaises(ValueError, tokenizer_r.encode, s, max_length=max_length, padding="max_length")
+
+ # Simple input
+ self.assertRaises(ValueError, tokenizer_r.encode_plus, s, max_length=max_length, padding="max_length")
+
+ # Simple input
+ self.assertRaises(
+ ValueError,
+ tokenizer_r.batch_encode_plus,
+ s2,
+ max_length=max_length,
+ padding="max_length",
+ )
+
+ # Pair input
+ self.assertRaises(ValueError, tokenizer_r.encode, p, max_length=max_length, padding="max_length")
+
+ # Pair input
+ self.assertRaises(ValueError, tokenizer_r.encode_plus, p, max_length=max_length, padding="max_length")
+
+ # Pair input
+ self.assertRaises(
+ ValueError,
+ tokenizer_r.batch_encode_plus,
+ p2,
+ max_length=max_length,
+ padding="max_length",
+ )
+
+
+class SentencePieceFastTokenizerTest(CommonFastTokenizerTest):
+ """
+ Override specific methods to test SentencePiece behavior
+ """
+
+ TOKENIZERS_CLASSES = frozenset(
+ [
+ Tokenizer("Albert", AlbertTokenizerFast, AlbertTokenizer, "vocab_file", None, None),
+ Tokenizer("Camembert", CamembertTokenizerFast, CamembertTokenizer, "vocab_file", None, None),
+ Tokenizer("T5", T5TokenizerFast, T5Tokenizer, "vocab_file", None, None),
+ Tokenizer(
+ "MBart",
+ MBartTokenizerFast,
+ MBartTokenizer,
+ "vocab_file",
+ None,
+ None,
+ ),
+ Tokenizer("Pegasus", PegasusTokenizerFast, PegasusTokenizer, "vocab_file", None, None),
+ Tokenizer("Reformer", ReformerTokenizerFast, ReformerTokenizer, "vocab_file", None, None),
+ Tokenizer("XLMRoberta", XLMRobertaTokenizerFast, XLMRobertaTokenizer, "vocab_file", None, None),
+ Tokenizer("XLNet", XLNetTokenizerFast, XLNetTokenizer, "vocab_file", None, None),
]
-
- # Simple input tests
- self.assertRaises(ValueError, tokenizer_r.encode, s, max_length=max_length, padding="max_length")
-
- # Simple input
- self.assertRaises(ValueError, tokenizer_r.encode_plus, s, max_length=max_length, padding="max_length")
-
- # Simple input
- self.assertRaises(
- ValueError,
- tokenizer_r.batch_encode_plus,
- s2,
- max_length=max_length,
- padding="max_length",
- )
-
- # Pair input
- self.assertRaises(ValueError, tokenizer_r.encode, p, max_length=max_length, padding="max_length")
-
- # Pair input
- self.assertRaises(ValueError, tokenizer_r.encode_plus, p, max_length=max_length, padding="max_length")
-
- # Pair input
- self.assertRaises(
- ValueError,
- tokenizer_r.batch_encode_plus,
- p2,
- max_length=max_length,
- padding="max_length",
- )
+ )
diff --git a/tests/test_tokenization_funnel.py b/tests/test_tokenization_funnel.py
index 8c9ce7a3d4..11945ffc52 100644
--- a/tests/test_tokenization_funnel.py
+++ b/tests/test_tokenization_funnel.py
@@ -26,6 +26,7 @@ class FunnelTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = FunnelTokenizer
test_rust_tokenizer = True
+ space_between_special_tokens = True
def setUp(self):
super().setUp()
diff --git a/tests/test_tokenization_gpt2.py b/tests/test_tokenization_gpt2.py
index ad23b6f8fc..29420d0b03 100644
--- a/tests/test_tokenization_gpt2.py
+++ b/tests/test_tokenization_gpt2.py
@@ -26,6 +26,7 @@ from .test_tokenization_common import TokenizerTesterMixin
class GPT2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = GPT2Tokenizer
+ rust_tokenizer_class = GPT2TokenizerFast
test_rust_tokenizer = True
def setUp(self):
diff --git a/tests/test_tokenization_lxmert.py b/tests/test_tokenization_lxmert.py
index e3c157568c..953bca4832 100644
--- a/tests/test_tokenization_lxmert.py
+++ b/tests/test_tokenization_lxmert.py
@@ -18,7 +18,7 @@ import os
import unittest
from transformers.tokenization_bert import VOCAB_FILES_NAMES
-from transformers.tokenization_lxmert import LxmertTokenizer
+from transformers.tokenization_lxmert import LxmertTokenizer, LxmertTokenizerFast
from .test_tokenization_common import TokenizerTesterMixin
@@ -26,6 +26,9 @@ from .test_tokenization_common import TokenizerTesterMixin
class LxmertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = LxmertTokenizer
+ rust_tokenizer_class = LxmertTokenizerFast
+ test_rust_tokenizer = True
+ space_between_special_tokens = True
def setUp(self):
super().setUp()
@@ -49,9 +52,6 @@ class LxmertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
- def get_tokenizer(self, **kwargs):
- return LxmertTokenizer.from_pretrained(self.tmpdirname, **kwargs)
-
def get_input_output_texts(self, tokenizer):
input_text = "UNwant\u00E9d,running"
output_text = "unwanted, running"
@@ -63,3 +63,25 @@ class LxmertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokens = tokenizer.tokenize("UNwant\u00E9d,running")
self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
+
+ def test_rust_and_python_full_tokenizers(self):
+ if not self.test_rust_tokenizer:
+ return
+
+ tokenizer = self.get_tokenizer()
+ rust_tokenizer = self.get_rust_tokenizer()
+
+ sequence = "I was born in 92000, and this is falsé."
+
+ tokens = tokenizer.tokenize(sequence)
+ rust_tokens = rust_tokenizer.tokenize(sequence)
+ self.assertListEqual(tokens, rust_tokens)
+
+ ids = tokenizer.encode(sequence, add_special_tokens=False)
+ rust_ids = rust_tokenizer.encode(sequence, add_special_tokens=False)
+ self.assertListEqual(ids, rust_ids)
+
+ rust_tokenizer = self.get_rust_tokenizer()
+ ids = tokenizer.encode(sequence)
+ rust_ids = rust_tokenizer.encode(sequence)
+ self.assertListEqual(ids, rust_ids)
diff --git a/tests/test_tokenization_marian.py b/tests/test_tokenization_marian.py
index 4948dffb18..23bf4bd519 100644
--- a/tests/test_tokenization_marian.py
+++ b/tests/test_tokenization_marian.py
@@ -38,6 +38,7 @@ FRAMEWORK = "pt" if _torch_available else "tf"
class MarianTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = MarianTokenizer
+ test_rust_tokenizer = False
def setUp(self):
super().setUp()
diff --git a/tests/test_tokenization_mbart.py b/tests/test_tokenization_mbart.py
index e6f77d7514..71b84077d8 100644
--- a/tests/test_tokenization_mbart.py
+++ b/tests/test_tokenization_mbart.py
@@ -1,7 +1,7 @@
import tempfile
import unittest
-from transformers import AutoTokenizer, BatchEncoding, MBartTokenizer, is_torch_available
+from transformers import AutoTokenizer, BatchEncoding, MBartTokenizer, MBartTokenizerFast, is_torch_available
from transformers.testing_utils import require_torch
from .test_tokenization_common import TokenizerTesterMixin
@@ -17,6 +17,8 @@ RO_CODE = 250020
class MBartTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = MBartTokenizer
+ rust_tokenizer_class = MBartTokenizerFast
+ test_rust_tokenizer = True
def setUp(self):
super().setUp()
diff --git a/tests/test_tokenization_openai.py b/tests/test_tokenization_openai.py
index 62e80ca4a1..88f253d0ab 100644
--- a/tests/test_tokenization_openai.py
+++ b/tests/test_tokenization_openai.py
@@ -18,7 +18,7 @@ import json
import os
import unittest
-from transformers.tokenization_openai import VOCAB_FILES_NAMES, OpenAIGPTTokenizer
+from transformers.tokenization_openai import VOCAB_FILES_NAMES, OpenAIGPTTokenizer, OpenAIGPTTokenizerFast
from .test_tokenization_common import TokenizerTesterMixin
@@ -26,6 +26,8 @@ from .test_tokenization_common import TokenizerTesterMixin
class OpenAIGPTTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = OpenAIGPTTokenizer
+ rust_tokenizer_class = OpenAIGPTTokenizerFast
+ test_rust_tokenizer = True
def setUp(self):
super().setUp()
diff --git a/tests/test_tokenization_pegasus.py b/tests/test_tokenization_pegasus.py
index 88a0a1bed4..3943322bfe 100644
--- a/tests/test_tokenization_pegasus.py
+++ b/tests/test_tokenization_pegasus.py
@@ -3,7 +3,7 @@ from pathlib import Path
from transformers.file_utils import cached_property
from transformers.testing_utils import require_torch
-from transformers.tokenization_pegasus import PegasusTokenizer
+from transformers.tokenization_pegasus import PegasusTokenizer, PegasusTokenizerFast
from .test_tokenization_common import TokenizerTesterMixin
@@ -11,6 +11,8 @@ from .test_tokenization_common import TokenizerTesterMixin
class PegasusTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = PegasusTokenizer
+ rust_tokenizer_class = PegasusTokenizerFast
+ test_rust_tokenizer = True
def setUp(self):
super().setUp()
diff --git a/tests/test_tokenization_reformer.py b/tests/test_tokenization_reformer.py
index a5f5509f3d..f134958e81 100644
--- a/tests/test_tokenization_reformer.py
+++ b/tests/test_tokenization_reformer.py
@@ -19,7 +19,7 @@ import unittest
from transformers.file_utils import cached_property
from transformers.testing_utils import require_torch, slow
-from transformers.tokenization_reformer import SPIECE_UNDERLINE, ReformerTokenizer
+from transformers.tokenization_reformer import SPIECE_UNDERLINE, ReformerTokenizer, ReformerTokenizerFast
from .test_tokenization_common import TokenizerTesterMixin
@@ -30,6 +30,8 @@ SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixture
class ReformerTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = ReformerTokenizer
+ rust_tokenizer_class = ReformerTokenizerFast
+ test_rust_tokenizer = True
def setUp(self):
super().setUp()
@@ -37,6 +39,28 @@ class ReformerTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer = ReformerTokenizer(SAMPLE_VOCAB, keep_accents=True)
tokenizer.save_pretrained(self.tmpdirname)
+ def test_rust_and_python_full_tokenizers(self):
+ if not self.test_rust_tokenizer:
+ return
+
+ tokenizer = self.get_tokenizer()
+ rust_tokenizer = self.get_rust_tokenizer()
+
+ sequence = "I was born in 92000, and this is falsé."
+
+ tokens = tokenizer.tokenize(sequence)
+ rust_tokens = rust_tokenizer.tokenize(sequence)
+ self.assertListEqual(tokens, rust_tokens)
+
+ ids = tokenizer.encode(sequence, add_special_tokens=False)
+ rust_ids = rust_tokenizer.encode(sequence, add_special_tokens=False)
+ self.assertListEqual(ids, rust_ids)
+
+ rust_tokenizer = self.get_rust_tokenizer()
+ ids = tokenizer.encode(sequence)
+ rust_ids = rust_tokenizer.encode(sequence)
+ self.assertListEqual(ids, rust_ids)
+
def test_full_tokenizer(self):
tokenizer = ReformerTokenizer(SAMPLE_VOCAB, keep_accents=True)
diff --git a/tests/test_tokenization_roberta.py b/tests/test_tokenization_roberta.py
index cbe37f21f1..e96fa58fb9 100644
--- a/tests/test_tokenization_roberta.py
+++ b/tests/test_tokenization_roberta.py
@@ -26,6 +26,8 @@ from .test_tokenization_common import TokenizerTesterMixin
class RobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = RobertaTokenizer
+ rust_tokenizer_class = RobertaTokenizerFast
+ test_rust_tokenizer = True
def setUp(self):
super().setUp()
diff --git a/tests/test_tokenization_t5.py b/tests/test_tokenization_t5.py
index d14d3de53a..234e9f91f1 100644
--- a/tests/test_tokenization_t5.py
+++ b/tests/test_tokenization_t5.py
@@ -20,13 +20,12 @@ import unittest
from transformers import BatchEncoding
from transformers.file_utils import cached_property
from transformers.testing_utils import _torch_available
-from transformers.tokenization_t5 import T5Tokenizer
+from transformers.tokenization_t5 import T5Tokenizer, T5TokenizerFast
+from transformers.tokenization_xlnet import SPIECE_UNDERLINE
from .test_tokenization_common import TokenizerTesterMixin
-SPIECE_UNDERLINE = "▁"
-
SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model")
FRAMEWORK = "pt" if _torch_available else "tf"
@@ -35,6 +34,8 @@ FRAMEWORK = "pt" if _torch_available else "tf"
class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = T5Tokenizer
+ rust_tokenizer_class = T5TokenizerFast
+ test_rust_tokenizer = True
def setUp(self):
super().setUp()
@@ -113,6 +114,38 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
def t5_base_tokenizer(self):
return T5Tokenizer.from_pretrained("t5-base")
+ @cached_property
+ def t5_base_tokenizer_fast(self):
+ return T5TokenizerFast.from_pretrained("t5-base")
+
+ def get_tokenizer(self, **kwargs) -> T5Tokenizer:
+ return self.tokenizer_class.from_pretrained(self.tmpdirname, pad_token=None, **kwargs)
+
+ def get_rust_tokenizer(self, **kwargs) -> T5TokenizerFast:
+ return self.rust_tokenizer_class.from_pretrained(self.tmpdirname, pad_token=None, **kwargs)
+
+ def test_rust_and_python_full_tokenizers(self):
+ if not self.test_rust_tokenizer:
+ return
+
+ tokenizer = self.get_tokenizer()
+ rust_tokenizer = self.get_rust_tokenizer()
+
+ sequence = "I was born in 92000, and this is falsé."
+
+ tokens = tokenizer.tokenize(sequence)
+ rust_tokens = rust_tokenizer.tokenize(sequence)
+ self.assertListEqual(tokens, rust_tokens)
+
+ ids = tokenizer.encode(sequence, add_special_tokens=False)
+ rust_ids = rust_tokenizer.encode(sequence, add_special_tokens=False)
+ self.assertListEqual(ids, rust_ids)
+
+ rust_tokenizer = self.get_rust_tokenizer()
+ ids = tokenizer.encode(sequence)
+ rust_ids = rust_tokenizer.encode(sequence)
+ self.assertListEqual(ids, rust_ids)
+
def test_eos_treatment(self):
tokenizer = self.t5_base_tokenizer
batch_with_eos_added = tokenizer(["hi", "I went to the gym", ""])
diff --git a/tests/test_tokenization_transfo_xl.py b/tests/test_tokenization_transfo_xl.py
index 1688a9f3a6..7e51327742 100644
--- a/tests/test_tokenization_transfo_xl.py
+++ b/tests/test_tokenization_transfo_xl.py
@@ -17,20 +17,15 @@
import os
import unittest
-from transformers import is_torch_available
-from transformers.testing_utils import require_torch
+from transformers.tokenization_transfo_xl import VOCAB_FILES_NAMES, TransfoXLTokenizer
from .test_tokenization_common import TokenizerTesterMixin
-if is_torch_available():
- from transformers.tokenization_transfo_xl import VOCAB_FILES_NAMES, TransfoXLTokenizer
-
-
-@require_torch
class TransfoXLTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
- tokenizer_class = TransfoXLTokenizer if is_torch_available() else None
+ tokenizer_class = TransfoXLTokenizer
+ test_rust_tokenizer = False
def setUp(self):
super().setUp()
diff --git a/tests/test_tokenization_xlm.py b/tests/test_tokenization_xlm.py
index 8e9d8946f2..4bd40635f3 100644
--- a/tests/test_tokenization_xlm.py
+++ b/tests/test_tokenization_xlm.py
@@ -27,6 +27,7 @@ from .test_tokenization_common import TokenizerTesterMixin
class XLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = XLMTokenizer
+ test_rust_tokenizer = False
def setUp(self):
super().setUp()
diff --git a/tests/test_tokenization_xlm_roberta.py b/tests/test_tokenization_xlm_roberta.py
index c67e9e2f24..1b64e0091e 100644
--- a/tests/test_tokenization_xlm_roberta.py
+++ b/tests/test_tokenization_xlm_roberta.py
@@ -19,7 +19,7 @@ import unittest
from transformers.file_utils import cached_property
from transformers.testing_utils import slow
-from transformers.tokenization_xlm_roberta import SPIECE_UNDERLINE, XLMRobertaTokenizer
+from transformers.tokenization_xlm_roberta import SPIECE_UNDERLINE, XLMRobertaTokenizer, XLMRobertaTokenizerFast
from .test_tokenization_common import TokenizerTesterMixin
@@ -30,6 +30,8 @@ SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixture
class XLMRobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = XLMRobertaTokenizer
+ rust_tokenizer_class = XLMRobertaTokenizerFast
+ test_rust_tokenizer = True
def setUp(self):
super().setUp()
@@ -118,6 +120,28 @@ class XLMRobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
def big_tokenizer(self):
return XLMRobertaTokenizer.from_pretrained("xlm-roberta-base")
+ def test_rust_and_python_full_tokenizers(self):
+ if not self.test_rust_tokenizer:
+ return
+
+ tokenizer = self.get_tokenizer()
+ rust_tokenizer = self.get_rust_tokenizer()
+
+ sequence = "I was born in 92000, and this is falsé."
+
+ tokens = tokenizer.tokenize(sequence)
+ rust_tokens = rust_tokenizer.tokenize(sequence)
+ self.assertListEqual(tokens, rust_tokens)
+
+ ids = tokenizer.encode(sequence, add_special_tokens=False)
+ rust_ids = rust_tokenizer.encode(sequence, add_special_tokens=False)
+ self.assertListEqual(ids, rust_ids)
+
+ rust_tokenizer = self.get_rust_tokenizer()
+ ids = tokenizer.encode(sequence)
+ rust_ids = rust_tokenizer.encode(sequence)
+ self.assertListEqual(ids, rust_ids)
+
@slow
def test_tokenization_base_easy_symbols(self):
symbols = "Hello World!"
diff --git a/tests/test_tokenization_xlnet.py b/tests/test_tokenization_xlnet.py
index 9f92d0a05b..d0ee0da26f 100644
--- a/tests/test_tokenization_xlnet.py
+++ b/tests/test_tokenization_xlnet.py
@@ -18,7 +18,7 @@ import os
import unittest
from transformers.testing_utils import slow
-from transformers.tokenization_xlnet import SPIECE_UNDERLINE, XLNetTokenizer
+from transformers.tokenization_xlnet import SPIECE_UNDERLINE, XLNetTokenizer, XLNetTokenizerFast
from .test_tokenization_common import TokenizerTesterMixin
@@ -29,12 +29,15 @@ SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixture
class XLNetTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = XLNetTokenizer
+ rust_tokenizer_class = XLNetTokenizerFast
+ test_rust_tokenizer = True
def setUp(self):
super().setUp()
# We have a SentencePiece fixture for testing
tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True)
+ tokenizer.sanitize_special_tokens()
tokenizer.save_pretrained(self.tmpdirname)
def test_full_tokenizer(self):