From 315f464b0aadf0ef052eb33560d37182cc8c0e2b Mon Sep 17 00:00:00 2001 From: Thomas Wolf Date: Thu, 25 Jun 2020 22:17:14 +0200 Subject: [PATCH] [tokenizers] Several small improvements and bug fixes (#5287) * avoid recursion in id checks for fast tokenizers * better typings and fix #5232 * align slow and fast tokenizers behaviors for Roberta and GPT2 * style and quality * fix tests - improve typings --- src/transformers/tokenization_gpt2.py | 5 ++- src/transformers/tokenization_roberta.py | 34 ++++++++++----------- src/transformers/tokenization_utils_base.py | 26 +++++++++++++--- src/transformers/tokenization_utils_fast.py | 13 +++----- tests/test_tokenization_fast.py | 16 ++++++++-- tests/test_tokenization_roberta.py | 6 ++-- 6 files changed, 64 insertions(+), 36 deletions(-) diff --git a/src/transformers/tokenization_gpt2.py b/src/transformers/tokenization_gpt2.py index f863dca6ce..f5da631ae1 100644 --- a/src/transformers/tokenization_gpt2.py +++ b/src/transformers/tokenization_gpt2.py @@ -23,7 +23,7 @@ from functools import lru_cache import regex as re from tokenizers import ByteLevelBPETokenizer -from .tokenization_utils import PreTrainedTokenizer +from .tokenization_utils import AddedToken, PreTrainedTokenizer from .tokenization_utils_base import BatchEncoding from .tokenization_utils_fast import PreTrainedTokenizerFast @@ -149,6 +149,9 @@ class GPT2Tokenizer(PreTrainedTokenizer): add_prefix_space=False, **kwargs ): + bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token + eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token + unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token super().__init__(bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs) with open(vocab_file, encoding="utf-8") as vocab_handle: diff --git a/src/transformers/tokenization_roberta.py b/src/transformers/tokenization_roberta.py index 65c8bf72b7..f5ec1f3a00 100644 --- a/src/transformers/tokenization_roberta.py +++ b/src/transformers/tokenization_roberta.py @@ -21,7 +21,7 @@ from typing import List, Optional from tokenizers.processors import RobertaProcessing from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast -from .tokenization_utils import AddedToken, PreTrainedTokenizer +from .tokenization_utils import AddedToken logger = logging.getLogger(__name__) @@ -137,6 +137,16 @@ class RobertaTokenizer(GPT2Tokenizer): add_prefix_space=False, **kwargs ): + bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token + eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token + sep_token = AddedToken(sep_token, lstrip=False, rstrip=False) if isinstance(sep_token, str) else sep_token + cls_token = AddedToken(cls_token, lstrip=False, rstrip=False) if isinstance(cls_token, str) else cls_token + unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token + pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token + + # Mask token behave like a normal word, i.e. include the space before it + mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token + super().__init__( vocab_file=vocab_file, merges_file=merges_file, @@ -152,13 +162,6 @@ class RobertaTokenizer(GPT2Tokenizer): **kwargs, ) - @PreTrainedTokenizer.mask_token.setter - def mask_token(self, value): - if not isinstance(value, AddedToken): - value = AddedToken(value, lstrip=True) - - self._mask_token = value - def build_inputs_with_special_tokens( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None ) -> List[int]: @@ -309,6 +312,9 @@ class RobertaTokenizerFast(GPT2TokenizerFast): trim_offsets=True, **kwargs ): + # Mask token behave like a normal word, i.e. include the space before it + mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token + kwargs.setdefault("pad_token", pad_token) kwargs.setdefault("sep_token", sep_token) kwargs.setdefault("cls_token", cls_token) @@ -325,6 +331,9 @@ class RobertaTokenizerFast(GPT2TokenizerFast): **kwargs, ) + # This will add the necessary special tokens to the vocabulary if needed + self.sanitize_special_tokens() + self.backend_tokenizer._tokenizer.post_processor = RobertaProcessing( sep=(sep_token, self.sep_token_id), cls=(cls_token, self.cls_token_id), @@ -332,15 +341,6 @@ class RobertaTokenizerFast(GPT2TokenizerFast): trim_offsets=trim_offsets, ) - self.sanitize_special_tokens() # This will add the necessary special tokens to the vocabulary if needed. - - @PreTrainedTokenizer.mask_token.setter - def mask_token(self, value): - if not isinstance(value, AddedToken): - value = AddedToken(value, lstrip=True) - - self._mask_token = value - def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): output = [self.bos_token_id] + token_ids_0 + [self.eos_token_id] if token_ids_1 is None: diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 3181517cd8..2d5c62aa0e 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -607,7 +607,7 @@ class SpecialTokensMixin: "special token {} has to be either str or AddedToken but got: {}".format(key, type(value)) ) - def sanitize_special_tokens(self): + def sanitize_special_tokens(self) -> int: """ Make sure that all the special tokens attributes of the tokenizer (tokenizer.mask_token, tokenizer.cls_token, ...) are in the vocabulary. Add the missing ones to the vocabulary if needed. @@ -616,7 +616,7 @@ class SpecialTokensMixin: """ return self.add_tokens(self.all_special_tokens_extended, special_tokens=True) - def add_special_tokens(self, special_tokens_dict): + def add_special_tokens(self, special_tokens_dict: Dict[str, Union[str, AddedToken]]) -> int: """ Add a dictionary of special tokens (eos, pad, cls...) to the encoder and link them to class attributes. If special tokens are NOT in the vocabulary, they are added @@ -665,10 +665,14 @@ class SpecialTokensMixin: setattr(self, key, value) if key == "additional_special_tokens": - assert isinstance(value, (list, tuple)) and all(isinstance(t, str) for t in value) + assert isinstance(value, (list, tuple)) and all( + isinstance(t, (str, AddedToken)) for t in value + ), f"Tokens {value} for key {key} should all be str or AddedToken instances" added_tokens += self.add_tokens(value, special_tokens=True) else: - assert isinstance(value, str) + assert isinstance( + value, (str, AddedToken) + ), f"Token {value} for key {key} should be a str or an AddedToken instance" added_tokens += self.add_tokens([value], special_tokens=True) return added_tokens @@ -809,26 +813,36 @@ class SpecialTokensMixin: @property def bos_token_id(self): """ Id of the beginning of sentence token in the vocabulary. Log an error if used while not having been set. """ + if self._bos_token is None: + return None return self.convert_tokens_to_ids(self.bos_token) @property def eos_token_id(self): """ Id of the end of sentence token in the vocabulary. Log an error if used while not having been set. """ + if self._eos_token is None: + return None return self.convert_tokens_to_ids(self.eos_token) @property def unk_token_id(self): """ Id of the unknown token in the vocabulary. Log an error if used while not having been set. """ + if self._unk_token is None: + return None return self.convert_tokens_to_ids(self.unk_token) @property def sep_token_id(self): """ Id of the separation token in the vocabulary. E.g. separate context and query in an input sequence. Log an error if used while not having been set. """ + if self._sep_token is None: + return None return self.convert_tokens_to_ids(self.sep_token) @property def pad_token_id(self): """ Id of the padding token in the vocabulary. Log an error if used while not having been set. """ + if self._pad_token is None: + return None return self.convert_tokens_to_ids(self.pad_token) @property @@ -839,11 +853,15 @@ class SpecialTokensMixin: @property def cls_token_id(self): """ Id of the classification token in the vocabulary. E.g. to extract a summary of an input sequence leveraging self-attention along the full depth of the model. Log an error if used while not having been set. """ + if self._cls_token is None: + return None return self.convert_tokens_to_ids(self.cls_token) @property def mask_token_id(self): """ Id of the mask token in the vocabulary. E.g. when training a model with masked-language modeling. Log an error if used while not having been set. """ + if self._mask_token is None: + return None return self.convert_tokens_to_ids(self.mask_token) @property diff --git a/src/transformers/tokenization_utils_fast.py b/src/transformers/tokenization_utils_fast.py index f2aaee3837..199445d2d2 100644 --- a/src/transformers/tokenization_utils_fast.py +++ b/src/transformers/tokenization_utils_fast.py @@ -185,7 +185,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase): return encoding_dict - def convert_tokens_to_ids(self, tokens): + def convert_tokens_to_ids(self, tokens: Union[str, List[str]]) -> Union[int, List[int]]: """ Converts a token string (or a sequence of tokens) in a single integer id (or a sequence of ids), using the vocabulary. """ @@ -200,7 +200,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase): ids.append(self._convert_token_to_id_with_added_voc(token)) return ids - def _convert_token_to_id_with_added_voc(self, token: int) -> str: + def _convert_token_to_id_with_added_voc(self, token: str) -> int: index = self._tokenizer.token_to_id(token) if index is None: return self.unk_token_id @@ -209,9 +209,6 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase): def _convert_id_to_token(self, index: int) -> Optional[str]: return self._tokenizer.id_to_token(int(index)) - def convert_tokens_to_string(self, tokens: List[int], skip_special_tokens: bool = False) -> str: - return self._tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens) - def _add_tokens(self, new_tokens: List[Union[str, AddedToken]], special_tokens=False) -> int: if special_tokens: return self._tokenizer.add_special_tokens(new_tokens) @@ -223,7 +220,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase): def convert_ids_to_tokens( self, ids: Union[int, List[int]], skip_special_tokens: bool = False - ) -> Union[int, List[int]]: + ) -> Union[str, List[str]]: """ Converts a single index or a sequence of indices (integers) in a token " (resp.) a sequence of tokens (str), using the vocabulary and added tokens. @@ -240,9 +237,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase): tokens.append(self._tokenizer.id_to_token(index)) return tokens - def tokenize( - self, text: TextInput, pair: Optional[TextInput] = None, add_special_tokens: bool = False - ) -> List[str]: + def tokenize(self, text: str, pair: Optional[str] = None, add_special_tokens: bool = False) -> List[str]: return self._tokenizer.encode(text, pair, add_special_tokens=add_special_tokens).tokens def set_truncation_and_padding( diff --git a/tests/test_tokenization_fast.py b/tests/test_tokenization_fast.py index 7f2f662c75..010f29d0fc 100644 --- a/tests/test_tokenization_fast.py +++ b/tests/test_tokenization_fast.py @@ -54,9 +54,10 @@ class CommonFastTokenizerTest(unittest.TestCase): if tok_case.filter is None or ( tok_case.filter is not None and tok_case.filter(tok_case, pretrained_name) ): + kwargs = dict(t for t in tok_case.kwargs) if tok_case.kwargs else {} with self.subTest("{} ({})".format(tok_case.name, pretrained_name)): - tokenizer_r = tok_case.rust_cls.from_pretrained(pretrained_name) - tokenizer_p = tok_case.python_cls.from_pretrained(pretrained_name) + tokenizer_r = tok_case.rust_cls.from_pretrained(pretrained_name, **kwargs) + tokenizer_p = tok_case.python_cls.from_pretrained(pretrained_name, **kwargs) self.fast_align_python(tokenizer_r, tokenizer_p, tok_case, pretrained_name) self.fast_only(tokenizer_r) @@ -767,7 +768,16 @@ class WordPieceFastTokenizerTest(CommonFastTokenizerTest): class RobertaFastTokenizerTest(CommonFastTokenizerTest): TOKENIZERS_CLASSES = frozenset( - [Tokenizer("Roberta", RobertaTokenizerFast, RobertaTokenizer, "vocab_file", filter_roberta_detectors, None)] + [ + Tokenizer( + "Roberta", + RobertaTokenizerFast, + RobertaTokenizer, + "vocab_file", + filter_roberta_detectors, + (("cls_token", ""),), + ) + ] ) def assert_embeded_special_tokens(self, tokenizer_r, tokenizer_p): diff --git a/tests/test_tokenization_roberta.py b/tests/test_tokenization_roberta.py index 9ec4c02cff..c48603203b 100644 --- a/tests/test_tokenization_roberta.py +++ b/tests/test_tokenization_roberta.py @@ -18,7 +18,7 @@ import json import os import unittest -from transformers.tokenization_roberta import VOCAB_FILES_NAMES, RobertaTokenizer, RobertaTokenizerFast +from transformers.tokenization_roberta import VOCAB_FILES_NAMES, AddedToken, RobertaTokenizer, RobertaTokenizerFast from .test_tokenization_common import TokenizerTesterMixin from .utils import slow @@ -139,7 +139,9 @@ class RobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase): # Testing spaces after special tokenss mask = "" - tokenizer.add_special_tokens({"mask_token": mask}) + tokenizer.add_special_tokens( + {"mask_token": AddedToken(mask, lstrip=True, rstrip=False)} + ) # mask token has a left space mask_ind = tokenizer.convert_tokens_to_ids(mask) sequence = "Encode sequence"