[tokenizers] Several small improvements and bug fixes (#5287)

* avoid recursion in id checks for fast tokenizers

* better typings and fix #5232

* align slow and fast tokenizers behaviors for Roberta and GPT2

* style and quality

* fix tests - improve typings
This commit is contained in:
Thomas Wolf
2020-06-25 22:17:14 +02:00
committed by GitHub
parent 24f46ea3f3
commit 315f464b0a
6 changed files with 64 additions and 36 deletions

View File

@@ -607,7 +607,7 @@ class SpecialTokensMixin:
"special token {} has to be either str or AddedToken but got: {}".format(key, type(value))
)
def sanitize_special_tokens(self):
def sanitize_special_tokens(self) -> int:
""" Make sure that all the special tokens attributes of the tokenizer (tokenizer.mask_token, tokenizer.cls_token, ...)
are in the vocabulary. Add the missing ones to the vocabulary if needed.
@@ -616,7 +616,7 @@ class SpecialTokensMixin:
"""
return self.add_tokens(self.all_special_tokens_extended, special_tokens=True)
def add_special_tokens(self, special_tokens_dict):
def add_special_tokens(self, special_tokens_dict: Dict[str, Union[str, AddedToken]]) -> int:
"""
Add a dictionary of special tokens (eos, pad, cls...) to the encoder and link them
to class attributes. If special tokens are NOT in the vocabulary, they are added
@@ -665,10 +665,14 @@ class SpecialTokensMixin:
setattr(self, key, value)
if key == "additional_special_tokens":
assert isinstance(value, (list, tuple)) and all(isinstance(t, str) for t in value)
assert isinstance(value, (list, tuple)) and all(
isinstance(t, (str, AddedToken)) for t in value
), f"Tokens {value} for key {key} should all be str or AddedToken instances"
added_tokens += self.add_tokens(value, special_tokens=True)
else:
assert isinstance(value, str)
assert isinstance(
value, (str, AddedToken)
), f"Token {value} for key {key} should be a str or an AddedToken instance"
added_tokens += self.add_tokens([value], special_tokens=True)
return added_tokens
@@ -809,26 +813,36 @@ class SpecialTokensMixin:
@property
def bos_token_id(self):
""" Id of the beginning of sentence token in the vocabulary. Log an error if used while not having been set. """
if self._bos_token is None:
return None
return self.convert_tokens_to_ids(self.bos_token)
@property
def eos_token_id(self):
""" Id of the end of sentence token in the vocabulary. Log an error if used while not having been set. """
if self._eos_token is None:
return None
return self.convert_tokens_to_ids(self.eos_token)
@property
def unk_token_id(self):
""" Id of the unknown token in the vocabulary. Log an error if used while not having been set. """
if self._unk_token is None:
return None
return self.convert_tokens_to_ids(self.unk_token)
@property
def sep_token_id(self):
""" Id of the separation token in the vocabulary. E.g. separate context and query in an input sequence. Log an error if used while not having been set. """
if self._sep_token is None:
return None
return self.convert_tokens_to_ids(self.sep_token)
@property
def pad_token_id(self):
""" Id of the padding token in the vocabulary. Log an error if used while not having been set. """
if self._pad_token is None:
return None
return self.convert_tokens_to_ids(self.pad_token)
@property
@@ -839,11 +853,15 @@ class SpecialTokensMixin:
@property
def cls_token_id(self):
""" Id of the classification token in the vocabulary. E.g. to extract a summary of an input sequence leveraging self-attention along the full depth of the model. Log an error if used while not having been set. """
if self._cls_token is None:
return None
return self.convert_tokens_to_ids(self.cls_token)
@property
def mask_token_id(self):
""" Id of the mask token in the vocabulary. E.g. when training a model with masked-language modeling. Log an error if used while not having been set. """
if self._mask_token is None:
return None
return self.convert_tokens_to_ids(self.mask_token)
@property