[tokenizers] Several small improvements and bug fixes (#5287)
* avoid recursion in id checks for fast tokenizers * better typings and fix #5232 * align slow and fast tokenizers behaviors for Roberta and GPT2 * style and quality * fix tests - improve typings
This commit is contained in:
@@ -23,7 +23,7 @@ from functools import lru_cache
|
||||
import regex as re
|
||||
from tokenizers import ByteLevelBPETokenizer
|
||||
|
||||
from .tokenization_utils import PreTrainedTokenizer
|
||||
from .tokenization_utils import AddedToken, PreTrainedTokenizer
|
||||
from .tokenization_utils_base import BatchEncoding
|
||||
from .tokenization_utils_fast import PreTrainedTokenizerFast
|
||||
|
||||
@@ -149,6 +149,9 @@ class GPT2Tokenizer(PreTrainedTokenizer):
|
||||
add_prefix_space=False,
|
||||
**kwargs
|
||||
):
|
||||
bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token
|
||||
eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
|
||||
unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
|
||||
super().__init__(bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs)
|
||||
|
||||
with open(vocab_file, encoding="utf-8") as vocab_handle:
|
||||
|
||||
@@ -21,7 +21,7 @@ from typing import List, Optional
|
||||
from tokenizers.processors import RobertaProcessing
|
||||
|
||||
from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast
|
||||
from .tokenization_utils import AddedToken, PreTrainedTokenizer
|
||||
from .tokenization_utils import AddedToken
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -137,6 +137,16 @@ class RobertaTokenizer(GPT2Tokenizer):
|
||||
add_prefix_space=False,
|
||||
**kwargs
|
||||
):
|
||||
bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token
|
||||
eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
|
||||
sep_token = AddedToken(sep_token, lstrip=False, rstrip=False) if isinstance(sep_token, str) else sep_token
|
||||
cls_token = AddedToken(cls_token, lstrip=False, rstrip=False) if isinstance(cls_token, str) else cls_token
|
||||
unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
|
||||
pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token
|
||||
|
||||
# Mask token behave like a normal word, i.e. include the space before it
|
||||
mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
|
||||
|
||||
super().__init__(
|
||||
vocab_file=vocab_file,
|
||||
merges_file=merges_file,
|
||||
@@ -152,13 +162,6 @@ class RobertaTokenizer(GPT2Tokenizer):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@PreTrainedTokenizer.mask_token.setter
|
||||
def mask_token(self, value):
|
||||
if not isinstance(value, AddedToken):
|
||||
value = AddedToken(value, lstrip=True)
|
||||
|
||||
self._mask_token = value
|
||||
|
||||
def build_inputs_with_special_tokens(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||
) -> List[int]:
|
||||
@@ -309,6 +312,9 @@ class RobertaTokenizerFast(GPT2TokenizerFast):
|
||||
trim_offsets=True,
|
||||
**kwargs
|
||||
):
|
||||
# Mask token behave like a normal word, i.e. include the space before it
|
||||
mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
|
||||
|
||||
kwargs.setdefault("pad_token", pad_token)
|
||||
kwargs.setdefault("sep_token", sep_token)
|
||||
kwargs.setdefault("cls_token", cls_token)
|
||||
@@ -325,6 +331,9 @@ class RobertaTokenizerFast(GPT2TokenizerFast):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# This will add the necessary special tokens to the vocabulary if needed
|
||||
self.sanitize_special_tokens()
|
||||
|
||||
self.backend_tokenizer._tokenizer.post_processor = RobertaProcessing(
|
||||
sep=(sep_token, self.sep_token_id),
|
||||
cls=(cls_token, self.cls_token_id),
|
||||
@@ -332,15 +341,6 @@ class RobertaTokenizerFast(GPT2TokenizerFast):
|
||||
trim_offsets=trim_offsets,
|
||||
)
|
||||
|
||||
self.sanitize_special_tokens() # This will add the necessary special tokens to the vocabulary if needed.
|
||||
|
||||
@PreTrainedTokenizer.mask_token.setter
|
||||
def mask_token(self, value):
|
||||
if not isinstance(value, AddedToken):
|
||||
value = AddedToken(value, lstrip=True)
|
||||
|
||||
self._mask_token = value
|
||||
|
||||
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
|
||||
output = [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
|
||||
if token_ids_1 is None:
|
||||
|
||||
@@ -607,7 +607,7 @@ class SpecialTokensMixin:
|
||||
"special token {} has to be either str or AddedToken but got: {}".format(key, type(value))
|
||||
)
|
||||
|
||||
def sanitize_special_tokens(self):
|
||||
def sanitize_special_tokens(self) -> int:
|
||||
""" Make sure that all the special tokens attributes of the tokenizer (tokenizer.mask_token, tokenizer.cls_token, ...)
|
||||
are in the vocabulary. Add the missing ones to the vocabulary if needed.
|
||||
|
||||
@@ -616,7 +616,7 @@ class SpecialTokensMixin:
|
||||
"""
|
||||
return self.add_tokens(self.all_special_tokens_extended, special_tokens=True)
|
||||
|
||||
def add_special_tokens(self, special_tokens_dict):
|
||||
def add_special_tokens(self, special_tokens_dict: Dict[str, Union[str, AddedToken]]) -> int:
|
||||
"""
|
||||
Add a dictionary of special tokens (eos, pad, cls...) to the encoder and link them
|
||||
to class attributes. If special tokens are NOT in the vocabulary, they are added
|
||||
@@ -665,10 +665,14 @@ class SpecialTokensMixin:
|
||||
setattr(self, key, value)
|
||||
|
||||
if key == "additional_special_tokens":
|
||||
assert isinstance(value, (list, tuple)) and all(isinstance(t, str) for t in value)
|
||||
assert isinstance(value, (list, tuple)) and all(
|
||||
isinstance(t, (str, AddedToken)) for t in value
|
||||
), f"Tokens {value} for key {key} should all be str or AddedToken instances"
|
||||
added_tokens += self.add_tokens(value, special_tokens=True)
|
||||
else:
|
||||
assert isinstance(value, str)
|
||||
assert isinstance(
|
||||
value, (str, AddedToken)
|
||||
), f"Token {value} for key {key} should be a str or an AddedToken instance"
|
||||
added_tokens += self.add_tokens([value], special_tokens=True)
|
||||
|
||||
return added_tokens
|
||||
@@ -809,26 +813,36 @@ class SpecialTokensMixin:
|
||||
@property
|
||||
def bos_token_id(self):
|
||||
""" Id of the beginning of sentence token in the vocabulary. Log an error if used while not having been set. """
|
||||
if self._bos_token is None:
|
||||
return None
|
||||
return self.convert_tokens_to_ids(self.bos_token)
|
||||
|
||||
@property
|
||||
def eos_token_id(self):
|
||||
""" Id of the end of sentence token in the vocabulary. Log an error if used while not having been set. """
|
||||
if self._eos_token is None:
|
||||
return None
|
||||
return self.convert_tokens_to_ids(self.eos_token)
|
||||
|
||||
@property
|
||||
def unk_token_id(self):
|
||||
""" Id of the unknown token in the vocabulary. Log an error if used while not having been set. """
|
||||
if self._unk_token is None:
|
||||
return None
|
||||
return self.convert_tokens_to_ids(self.unk_token)
|
||||
|
||||
@property
|
||||
def sep_token_id(self):
|
||||
""" Id of the separation token in the vocabulary. E.g. separate context and query in an input sequence. Log an error if used while not having been set. """
|
||||
if self._sep_token is None:
|
||||
return None
|
||||
return self.convert_tokens_to_ids(self.sep_token)
|
||||
|
||||
@property
|
||||
def pad_token_id(self):
|
||||
""" Id of the padding token in the vocabulary. Log an error if used while not having been set. """
|
||||
if self._pad_token is None:
|
||||
return None
|
||||
return self.convert_tokens_to_ids(self.pad_token)
|
||||
|
||||
@property
|
||||
@@ -839,11 +853,15 @@ class SpecialTokensMixin:
|
||||
@property
|
||||
def cls_token_id(self):
|
||||
""" Id of the classification token in the vocabulary. E.g. to extract a summary of an input sequence leveraging self-attention along the full depth of the model. Log an error if used while not having been set. """
|
||||
if self._cls_token is None:
|
||||
return None
|
||||
return self.convert_tokens_to_ids(self.cls_token)
|
||||
|
||||
@property
|
||||
def mask_token_id(self):
|
||||
""" Id of the mask token in the vocabulary. E.g. when training a model with masked-language modeling. Log an error if used while not having been set. """
|
||||
if self._mask_token is None:
|
||||
return None
|
||||
return self.convert_tokens_to_ids(self.mask_token)
|
||||
|
||||
@property
|
||||
|
||||
@@ -185,7 +185,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
|
||||
|
||||
return encoding_dict
|
||||
|
||||
def convert_tokens_to_ids(self, tokens):
|
||||
def convert_tokens_to_ids(self, tokens: Union[str, List[str]]) -> Union[int, List[int]]:
|
||||
""" Converts a token string (or a sequence of tokens) in a single integer id
|
||||
(or a sequence of ids), using the vocabulary.
|
||||
"""
|
||||
@@ -200,7 +200,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
|
||||
ids.append(self._convert_token_to_id_with_added_voc(token))
|
||||
return ids
|
||||
|
||||
def _convert_token_to_id_with_added_voc(self, token: int) -> str:
|
||||
def _convert_token_to_id_with_added_voc(self, token: str) -> int:
|
||||
index = self._tokenizer.token_to_id(token)
|
||||
if index is None:
|
||||
return self.unk_token_id
|
||||
@@ -209,9 +209,6 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
|
||||
def _convert_id_to_token(self, index: int) -> Optional[str]:
|
||||
return self._tokenizer.id_to_token(int(index))
|
||||
|
||||
def convert_tokens_to_string(self, tokens: List[int], skip_special_tokens: bool = False) -> str:
|
||||
return self._tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
|
||||
|
||||
def _add_tokens(self, new_tokens: List[Union[str, AddedToken]], special_tokens=False) -> int:
|
||||
if special_tokens:
|
||||
return self._tokenizer.add_special_tokens(new_tokens)
|
||||
@@ -223,7 +220,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
|
||||
|
||||
def convert_ids_to_tokens(
|
||||
self, ids: Union[int, List[int]], skip_special_tokens: bool = False
|
||||
) -> Union[int, List[int]]:
|
||||
) -> Union[str, List[str]]:
|
||||
""" Converts a single index or a sequence of indices (integers) in a token "
|
||||
(resp.) a sequence of tokens (str), using the vocabulary and added tokens.
|
||||
|
||||
@@ -240,9 +237,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
|
||||
tokens.append(self._tokenizer.id_to_token(index))
|
||||
return tokens
|
||||
|
||||
def tokenize(
|
||||
self, text: TextInput, pair: Optional[TextInput] = None, add_special_tokens: bool = False
|
||||
) -> List[str]:
|
||||
def tokenize(self, text: str, pair: Optional[str] = None, add_special_tokens: bool = False) -> List[str]:
|
||||
return self._tokenizer.encode(text, pair, add_special_tokens=add_special_tokens).tokens
|
||||
|
||||
def set_truncation_and_padding(
|
||||
|
||||
@@ -54,9 +54,10 @@ class CommonFastTokenizerTest(unittest.TestCase):
|
||||
if tok_case.filter is None or (
|
||||
tok_case.filter is not None and tok_case.filter(tok_case, pretrained_name)
|
||||
):
|
||||
kwargs = dict(t for t in tok_case.kwargs) if tok_case.kwargs else {}
|
||||
with self.subTest("{} ({})".format(tok_case.name, pretrained_name)):
|
||||
tokenizer_r = tok_case.rust_cls.from_pretrained(pretrained_name)
|
||||
tokenizer_p = tok_case.python_cls.from_pretrained(pretrained_name)
|
||||
tokenizer_r = tok_case.rust_cls.from_pretrained(pretrained_name, **kwargs)
|
||||
tokenizer_p = tok_case.python_cls.from_pretrained(pretrained_name, **kwargs)
|
||||
|
||||
self.fast_align_python(tokenizer_r, tokenizer_p, tok_case, pretrained_name)
|
||||
self.fast_only(tokenizer_r)
|
||||
@@ -767,7 +768,16 @@ class WordPieceFastTokenizerTest(CommonFastTokenizerTest):
|
||||
|
||||
class RobertaFastTokenizerTest(CommonFastTokenizerTest):
|
||||
TOKENIZERS_CLASSES = frozenset(
|
||||
[Tokenizer("Roberta", RobertaTokenizerFast, RobertaTokenizer, "vocab_file", filter_roberta_detectors, None)]
|
||||
[
|
||||
Tokenizer(
|
||||
"Roberta",
|
||||
RobertaTokenizerFast,
|
||||
RobertaTokenizer,
|
||||
"vocab_file",
|
||||
filter_roberta_detectors,
|
||||
(("cls_token", "<s>"),),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
def assert_embeded_special_tokens(self, tokenizer_r, tokenizer_p):
|
||||
|
||||
@@ -18,7 +18,7 @@ import json
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from transformers.tokenization_roberta import VOCAB_FILES_NAMES, RobertaTokenizer, RobertaTokenizerFast
|
||||
from transformers.tokenization_roberta import VOCAB_FILES_NAMES, AddedToken, RobertaTokenizer, RobertaTokenizerFast
|
||||
|
||||
from .test_tokenization_common import TokenizerTesterMixin
|
||||
from .utils import slow
|
||||
@@ -139,7 +139,9 @@ class RobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
|
||||
# Testing spaces after special tokenss
|
||||
mask = "<mask>"
|
||||
tokenizer.add_special_tokens({"mask_token": mask})
|
||||
tokenizer.add_special_tokens(
|
||||
{"mask_token": AddedToken(mask, lstrip=True, rstrip=False)}
|
||||
) # mask token has a left space
|
||||
mask_ind = tokenizer.convert_tokens_to_ids(mask)
|
||||
|
||||
sequence = "Encode <mask> sequence"
|
||||
|
||||
Reference in New Issue
Block a user