[tokenizers] Several small improvements and bug fixes (#5287)
* avoid recursion in id checks for fast tokenizers * better typings and fix #5232 * align slow and fast tokenizers behaviors for Roberta and GPT2 * style and quality * fix tests - improve typings
This commit is contained in:
@@ -23,7 +23,7 @@ from functools import lru_cache
|
|||||||
import regex as re
|
import regex as re
|
||||||
from tokenizers import ByteLevelBPETokenizer
|
from tokenizers import ByteLevelBPETokenizer
|
||||||
|
|
||||||
from .tokenization_utils import PreTrainedTokenizer
|
from .tokenization_utils import AddedToken, PreTrainedTokenizer
|
||||||
from .tokenization_utils_base import BatchEncoding
|
from .tokenization_utils_base import BatchEncoding
|
||||||
from .tokenization_utils_fast import PreTrainedTokenizerFast
|
from .tokenization_utils_fast import PreTrainedTokenizerFast
|
||||||
|
|
||||||
@@ -149,6 +149,9 @@ class GPT2Tokenizer(PreTrainedTokenizer):
|
|||||||
add_prefix_space=False,
|
add_prefix_space=False,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
|
bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token
|
||||||
|
eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
|
||||||
|
unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
|
||||||
super().__init__(bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs)
|
super().__init__(bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs)
|
||||||
|
|
||||||
with open(vocab_file, encoding="utf-8") as vocab_handle:
|
with open(vocab_file, encoding="utf-8") as vocab_handle:
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ from typing import List, Optional
|
|||||||
from tokenizers.processors import RobertaProcessing
|
from tokenizers.processors import RobertaProcessing
|
||||||
|
|
||||||
from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast
|
from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast
|
||||||
from .tokenization_utils import AddedToken, PreTrainedTokenizer
|
from .tokenization_utils import AddedToken
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -137,6 +137,16 @@ class RobertaTokenizer(GPT2Tokenizer):
|
|||||||
add_prefix_space=False,
|
add_prefix_space=False,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
|
bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token
|
||||||
|
eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
|
||||||
|
sep_token = AddedToken(sep_token, lstrip=False, rstrip=False) if isinstance(sep_token, str) else sep_token
|
||||||
|
cls_token = AddedToken(cls_token, lstrip=False, rstrip=False) if isinstance(cls_token, str) else cls_token
|
||||||
|
unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
|
||||||
|
pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token
|
||||||
|
|
||||||
|
# Mask token behave like a normal word, i.e. include the space before it
|
||||||
|
mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
vocab_file=vocab_file,
|
vocab_file=vocab_file,
|
||||||
merges_file=merges_file,
|
merges_file=merges_file,
|
||||||
@@ -152,13 +162,6 @@ class RobertaTokenizer(GPT2Tokenizer):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@PreTrainedTokenizer.mask_token.setter
|
|
||||||
def mask_token(self, value):
|
|
||||||
if not isinstance(value, AddedToken):
|
|
||||||
value = AddedToken(value, lstrip=True)
|
|
||||||
|
|
||||||
self._mask_token = value
|
|
||||||
|
|
||||||
def build_inputs_with_special_tokens(
|
def build_inputs_with_special_tokens(
|
||||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||||
) -> List[int]:
|
) -> List[int]:
|
||||||
@@ -309,6 +312,9 @@ class RobertaTokenizerFast(GPT2TokenizerFast):
|
|||||||
trim_offsets=True,
|
trim_offsets=True,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
|
# Mask token behave like a normal word, i.e. include the space before it
|
||||||
|
mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
|
||||||
|
|
||||||
kwargs.setdefault("pad_token", pad_token)
|
kwargs.setdefault("pad_token", pad_token)
|
||||||
kwargs.setdefault("sep_token", sep_token)
|
kwargs.setdefault("sep_token", sep_token)
|
||||||
kwargs.setdefault("cls_token", cls_token)
|
kwargs.setdefault("cls_token", cls_token)
|
||||||
@@ -325,6 +331,9 @@ class RobertaTokenizerFast(GPT2TokenizerFast):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# This will add the necessary special tokens to the vocabulary if needed
|
||||||
|
self.sanitize_special_tokens()
|
||||||
|
|
||||||
self.backend_tokenizer._tokenizer.post_processor = RobertaProcessing(
|
self.backend_tokenizer._tokenizer.post_processor = RobertaProcessing(
|
||||||
sep=(sep_token, self.sep_token_id),
|
sep=(sep_token, self.sep_token_id),
|
||||||
cls=(cls_token, self.cls_token_id),
|
cls=(cls_token, self.cls_token_id),
|
||||||
@@ -332,15 +341,6 @@ class RobertaTokenizerFast(GPT2TokenizerFast):
|
|||||||
trim_offsets=trim_offsets,
|
trim_offsets=trim_offsets,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.sanitize_special_tokens() # This will add the necessary special tokens to the vocabulary if needed.
|
|
||||||
|
|
||||||
@PreTrainedTokenizer.mask_token.setter
|
|
||||||
def mask_token(self, value):
|
|
||||||
if not isinstance(value, AddedToken):
|
|
||||||
value = AddedToken(value, lstrip=True)
|
|
||||||
|
|
||||||
self._mask_token = value
|
|
||||||
|
|
||||||
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
|
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
|
||||||
output = [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
|
output = [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
|
||||||
if token_ids_1 is None:
|
if token_ids_1 is None:
|
||||||
|
|||||||
@@ -607,7 +607,7 @@ class SpecialTokensMixin:
|
|||||||
"special token {} has to be either str or AddedToken but got: {}".format(key, type(value))
|
"special token {} has to be either str or AddedToken but got: {}".format(key, type(value))
|
||||||
)
|
)
|
||||||
|
|
||||||
def sanitize_special_tokens(self):
|
def sanitize_special_tokens(self) -> int:
|
||||||
""" Make sure that all the special tokens attributes of the tokenizer (tokenizer.mask_token, tokenizer.cls_token, ...)
|
""" Make sure that all the special tokens attributes of the tokenizer (tokenizer.mask_token, tokenizer.cls_token, ...)
|
||||||
are in the vocabulary. Add the missing ones to the vocabulary if needed.
|
are in the vocabulary. Add the missing ones to the vocabulary if needed.
|
||||||
|
|
||||||
@@ -616,7 +616,7 @@ class SpecialTokensMixin:
|
|||||||
"""
|
"""
|
||||||
return self.add_tokens(self.all_special_tokens_extended, special_tokens=True)
|
return self.add_tokens(self.all_special_tokens_extended, special_tokens=True)
|
||||||
|
|
||||||
def add_special_tokens(self, special_tokens_dict):
|
def add_special_tokens(self, special_tokens_dict: Dict[str, Union[str, AddedToken]]) -> int:
|
||||||
"""
|
"""
|
||||||
Add a dictionary of special tokens (eos, pad, cls...) to the encoder and link them
|
Add a dictionary of special tokens (eos, pad, cls...) to the encoder and link them
|
||||||
to class attributes. If special tokens are NOT in the vocabulary, they are added
|
to class attributes. If special tokens are NOT in the vocabulary, they are added
|
||||||
@@ -665,10 +665,14 @@ class SpecialTokensMixin:
|
|||||||
setattr(self, key, value)
|
setattr(self, key, value)
|
||||||
|
|
||||||
if key == "additional_special_tokens":
|
if key == "additional_special_tokens":
|
||||||
assert isinstance(value, (list, tuple)) and all(isinstance(t, str) for t in value)
|
assert isinstance(value, (list, tuple)) and all(
|
||||||
|
isinstance(t, (str, AddedToken)) for t in value
|
||||||
|
), f"Tokens {value} for key {key} should all be str or AddedToken instances"
|
||||||
added_tokens += self.add_tokens(value, special_tokens=True)
|
added_tokens += self.add_tokens(value, special_tokens=True)
|
||||||
else:
|
else:
|
||||||
assert isinstance(value, str)
|
assert isinstance(
|
||||||
|
value, (str, AddedToken)
|
||||||
|
), f"Token {value} for key {key} should be a str or an AddedToken instance"
|
||||||
added_tokens += self.add_tokens([value], special_tokens=True)
|
added_tokens += self.add_tokens([value], special_tokens=True)
|
||||||
|
|
||||||
return added_tokens
|
return added_tokens
|
||||||
@@ -809,26 +813,36 @@ class SpecialTokensMixin:
|
|||||||
@property
|
@property
|
||||||
def bos_token_id(self):
|
def bos_token_id(self):
|
||||||
""" Id of the beginning of sentence token in the vocabulary. Log an error if used while not having been set. """
|
""" Id of the beginning of sentence token in the vocabulary. Log an error if used while not having been set. """
|
||||||
|
if self._bos_token is None:
|
||||||
|
return None
|
||||||
return self.convert_tokens_to_ids(self.bos_token)
|
return self.convert_tokens_to_ids(self.bos_token)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def eos_token_id(self):
|
def eos_token_id(self):
|
||||||
""" Id of the end of sentence token in the vocabulary. Log an error if used while not having been set. """
|
""" Id of the end of sentence token in the vocabulary. Log an error if used while not having been set. """
|
||||||
|
if self._eos_token is None:
|
||||||
|
return None
|
||||||
return self.convert_tokens_to_ids(self.eos_token)
|
return self.convert_tokens_to_ids(self.eos_token)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def unk_token_id(self):
|
def unk_token_id(self):
|
||||||
""" Id of the unknown token in the vocabulary. Log an error if used while not having been set. """
|
""" Id of the unknown token in the vocabulary. Log an error if used while not having been set. """
|
||||||
|
if self._unk_token is None:
|
||||||
|
return None
|
||||||
return self.convert_tokens_to_ids(self.unk_token)
|
return self.convert_tokens_to_ids(self.unk_token)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def sep_token_id(self):
|
def sep_token_id(self):
|
||||||
""" Id of the separation token in the vocabulary. E.g. separate context and query in an input sequence. Log an error if used while not having been set. """
|
""" Id of the separation token in the vocabulary. E.g. separate context and query in an input sequence. Log an error if used while not having been set. """
|
||||||
|
if self._sep_token is None:
|
||||||
|
return None
|
||||||
return self.convert_tokens_to_ids(self.sep_token)
|
return self.convert_tokens_to_ids(self.sep_token)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def pad_token_id(self):
|
def pad_token_id(self):
|
||||||
""" Id of the padding token in the vocabulary. Log an error if used while not having been set. """
|
""" Id of the padding token in the vocabulary. Log an error if used while not having been set. """
|
||||||
|
if self._pad_token is None:
|
||||||
|
return None
|
||||||
return self.convert_tokens_to_ids(self.pad_token)
|
return self.convert_tokens_to_ids(self.pad_token)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -839,11 +853,15 @@ class SpecialTokensMixin:
|
|||||||
@property
|
@property
|
||||||
def cls_token_id(self):
|
def cls_token_id(self):
|
||||||
""" Id of the classification token in the vocabulary. E.g. to extract a summary of an input sequence leveraging self-attention along the full depth of the model. Log an error if used while not having been set. """
|
""" Id of the classification token in the vocabulary. E.g. to extract a summary of an input sequence leveraging self-attention along the full depth of the model. Log an error if used while not having been set. """
|
||||||
|
if self._cls_token is None:
|
||||||
|
return None
|
||||||
return self.convert_tokens_to_ids(self.cls_token)
|
return self.convert_tokens_to_ids(self.cls_token)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def mask_token_id(self):
|
def mask_token_id(self):
|
||||||
""" Id of the mask token in the vocabulary. E.g. when training a model with masked-language modeling. Log an error if used while not having been set. """
|
""" Id of the mask token in the vocabulary. E.g. when training a model with masked-language modeling. Log an error if used while not having been set. """
|
||||||
|
if self._mask_token is None:
|
||||||
|
return None
|
||||||
return self.convert_tokens_to_ids(self.mask_token)
|
return self.convert_tokens_to_ids(self.mask_token)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@@ -185,7 +185,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
|
|||||||
|
|
||||||
return encoding_dict
|
return encoding_dict
|
||||||
|
|
||||||
def convert_tokens_to_ids(self, tokens):
|
def convert_tokens_to_ids(self, tokens: Union[str, List[str]]) -> Union[int, List[int]]:
|
||||||
""" Converts a token string (or a sequence of tokens) in a single integer id
|
""" Converts a token string (or a sequence of tokens) in a single integer id
|
||||||
(or a sequence of ids), using the vocabulary.
|
(or a sequence of ids), using the vocabulary.
|
||||||
"""
|
"""
|
||||||
@@ -200,7 +200,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
|
|||||||
ids.append(self._convert_token_to_id_with_added_voc(token))
|
ids.append(self._convert_token_to_id_with_added_voc(token))
|
||||||
return ids
|
return ids
|
||||||
|
|
||||||
def _convert_token_to_id_with_added_voc(self, token: int) -> str:
|
def _convert_token_to_id_with_added_voc(self, token: str) -> int:
|
||||||
index = self._tokenizer.token_to_id(token)
|
index = self._tokenizer.token_to_id(token)
|
||||||
if index is None:
|
if index is None:
|
||||||
return self.unk_token_id
|
return self.unk_token_id
|
||||||
@@ -209,9 +209,6 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
|
|||||||
def _convert_id_to_token(self, index: int) -> Optional[str]:
|
def _convert_id_to_token(self, index: int) -> Optional[str]:
|
||||||
return self._tokenizer.id_to_token(int(index))
|
return self._tokenizer.id_to_token(int(index))
|
||||||
|
|
||||||
def convert_tokens_to_string(self, tokens: List[int], skip_special_tokens: bool = False) -> str:
|
|
||||||
return self._tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
|
|
||||||
|
|
||||||
def _add_tokens(self, new_tokens: List[Union[str, AddedToken]], special_tokens=False) -> int:
|
def _add_tokens(self, new_tokens: List[Union[str, AddedToken]], special_tokens=False) -> int:
|
||||||
if special_tokens:
|
if special_tokens:
|
||||||
return self._tokenizer.add_special_tokens(new_tokens)
|
return self._tokenizer.add_special_tokens(new_tokens)
|
||||||
@@ -223,7 +220,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
|
|||||||
|
|
||||||
def convert_ids_to_tokens(
|
def convert_ids_to_tokens(
|
||||||
self, ids: Union[int, List[int]], skip_special_tokens: bool = False
|
self, ids: Union[int, List[int]], skip_special_tokens: bool = False
|
||||||
) -> Union[int, List[int]]:
|
) -> Union[str, List[str]]:
|
||||||
""" Converts a single index or a sequence of indices (integers) in a token "
|
""" Converts a single index or a sequence of indices (integers) in a token "
|
||||||
(resp.) a sequence of tokens (str), using the vocabulary and added tokens.
|
(resp.) a sequence of tokens (str), using the vocabulary and added tokens.
|
||||||
|
|
||||||
@@ -240,9 +237,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
|
|||||||
tokens.append(self._tokenizer.id_to_token(index))
|
tokens.append(self._tokenizer.id_to_token(index))
|
||||||
return tokens
|
return tokens
|
||||||
|
|
||||||
def tokenize(
|
def tokenize(self, text: str, pair: Optional[str] = None, add_special_tokens: bool = False) -> List[str]:
|
||||||
self, text: TextInput, pair: Optional[TextInput] = None, add_special_tokens: bool = False
|
|
||||||
) -> List[str]:
|
|
||||||
return self._tokenizer.encode(text, pair, add_special_tokens=add_special_tokens).tokens
|
return self._tokenizer.encode(text, pair, add_special_tokens=add_special_tokens).tokens
|
||||||
|
|
||||||
def set_truncation_and_padding(
|
def set_truncation_and_padding(
|
||||||
|
|||||||
@@ -54,9 +54,10 @@ class CommonFastTokenizerTest(unittest.TestCase):
|
|||||||
if tok_case.filter is None or (
|
if tok_case.filter is None or (
|
||||||
tok_case.filter is not None and tok_case.filter(tok_case, pretrained_name)
|
tok_case.filter is not None and tok_case.filter(tok_case, pretrained_name)
|
||||||
):
|
):
|
||||||
|
kwargs = dict(t for t in tok_case.kwargs) if tok_case.kwargs else {}
|
||||||
with self.subTest("{} ({})".format(tok_case.name, pretrained_name)):
|
with self.subTest("{} ({})".format(tok_case.name, pretrained_name)):
|
||||||
tokenizer_r = tok_case.rust_cls.from_pretrained(pretrained_name)
|
tokenizer_r = tok_case.rust_cls.from_pretrained(pretrained_name, **kwargs)
|
||||||
tokenizer_p = tok_case.python_cls.from_pretrained(pretrained_name)
|
tokenizer_p = tok_case.python_cls.from_pretrained(pretrained_name, **kwargs)
|
||||||
|
|
||||||
self.fast_align_python(tokenizer_r, tokenizer_p, tok_case, pretrained_name)
|
self.fast_align_python(tokenizer_r, tokenizer_p, tok_case, pretrained_name)
|
||||||
self.fast_only(tokenizer_r)
|
self.fast_only(tokenizer_r)
|
||||||
@@ -767,7 +768,16 @@ class WordPieceFastTokenizerTest(CommonFastTokenizerTest):
|
|||||||
|
|
||||||
class RobertaFastTokenizerTest(CommonFastTokenizerTest):
|
class RobertaFastTokenizerTest(CommonFastTokenizerTest):
|
||||||
TOKENIZERS_CLASSES = frozenset(
|
TOKENIZERS_CLASSES = frozenset(
|
||||||
[Tokenizer("Roberta", RobertaTokenizerFast, RobertaTokenizer, "vocab_file", filter_roberta_detectors, None)]
|
[
|
||||||
|
Tokenizer(
|
||||||
|
"Roberta",
|
||||||
|
RobertaTokenizerFast,
|
||||||
|
RobertaTokenizer,
|
||||||
|
"vocab_file",
|
||||||
|
filter_roberta_detectors,
|
||||||
|
(("cls_token", "<s>"),),
|
||||||
|
)
|
||||||
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
def assert_embeded_special_tokens(self, tokenizer_r, tokenizer_p):
|
def assert_embeded_special_tokens(self, tokenizer_r, tokenizer_p):
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ import json
|
|||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers.tokenization_roberta import VOCAB_FILES_NAMES, RobertaTokenizer, RobertaTokenizerFast
|
from transformers.tokenization_roberta import VOCAB_FILES_NAMES, AddedToken, RobertaTokenizer, RobertaTokenizerFast
|
||||||
|
|
||||||
from .test_tokenization_common import TokenizerTesterMixin
|
from .test_tokenization_common import TokenizerTesterMixin
|
||||||
from .utils import slow
|
from .utils import slow
|
||||||
@@ -139,7 +139,9 @@ class RobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
# Testing spaces after special tokenss
|
# Testing spaces after special tokenss
|
||||||
mask = "<mask>"
|
mask = "<mask>"
|
||||||
tokenizer.add_special_tokens({"mask_token": mask})
|
tokenizer.add_special_tokens(
|
||||||
|
{"mask_token": AddedToken(mask, lstrip=True, rstrip=False)}
|
||||||
|
) # mask token has a left space
|
||||||
mask_ind = tokenizer.convert_tokens_to_ids(mask)
|
mask_ind = tokenizer.convert_tokens_to_ids(mask)
|
||||||
|
|
||||||
sequence = "Encode <mask> sequence"
|
sequence = "Encode <mask> sequence"
|
||||||
|
|||||||
Reference in New Issue
Block a user