[tokenizers] Several small improvements and bug fixes (#5287)
* avoid recursion in id checks for fast tokenizers * better typings and fix #5232 * align slow and fast tokenizers behaviors for Roberta and GPT2 * style and quality * fix tests - improve typings
This commit is contained in:
@@ -607,7 +607,7 @@ class SpecialTokensMixin:
|
||||
"special token {} has to be either str or AddedToken but got: {}".format(key, type(value))
|
||||
)
|
||||
|
||||
def sanitize_special_tokens(self):
|
||||
def sanitize_special_tokens(self) -> int:
|
||||
""" Make sure that all the special tokens attributes of the tokenizer (tokenizer.mask_token, tokenizer.cls_token, ...)
|
||||
are in the vocabulary. Add the missing ones to the vocabulary if needed.
|
||||
|
||||
@@ -616,7 +616,7 @@ class SpecialTokensMixin:
|
||||
"""
|
||||
return self.add_tokens(self.all_special_tokens_extended, special_tokens=True)
|
||||
|
||||
def add_special_tokens(self, special_tokens_dict):
|
||||
def add_special_tokens(self, special_tokens_dict: Dict[str, Union[str, AddedToken]]) -> int:
|
||||
"""
|
||||
Add a dictionary of special tokens (eos, pad, cls...) to the encoder and link them
|
||||
to class attributes. If special tokens are NOT in the vocabulary, they are added
|
||||
@@ -665,10 +665,14 @@ class SpecialTokensMixin:
|
||||
setattr(self, key, value)
|
||||
|
||||
if key == "additional_special_tokens":
|
||||
assert isinstance(value, (list, tuple)) and all(isinstance(t, str) for t in value)
|
||||
assert isinstance(value, (list, tuple)) and all(
|
||||
isinstance(t, (str, AddedToken)) for t in value
|
||||
), f"Tokens {value} for key {key} should all be str or AddedToken instances"
|
||||
added_tokens += self.add_tokens(value, special_tokens=True)
|
||||
else:
|
||||
assert isinstance(value, str)
|
||||
assert isinstance(
|
||||
value, (str, AddedToken)
|
||||
), f"Token {value} for key {key} should be a str or an AddedToken instance"
|
||||
added_tokens += self.add_tokens([value], special_tokens=True)
|
||||
|
||||
return added_tokens
|
||||
@@ -809,26 +813,36 @@ class SpecialTokensMixin:
|
||||
@property
|
||||
def bos_token_id(self):
|
||||
""" Id of the beginning of sentence token in the vocabulary. Log an error if used while not having been set. """
|
||||
if self._bos_token is None:
|
||||
return None
|
||||
return self.convert_tokens_to_ids(self.bos_token)
|
||||
|
||||
@property
|
||||
def eos_token_id(self):
|
||||
""" Id of the end of sentence token in the vocabulary. Log an error if used while not having been set. """
|
||||
if self._eos_token is None:
|
||||
return None
|
||||
return self.convert_tokens_to_ids(self.eos_token)
|
||||
|
||||
@property
|
||||
def unk_token_id(self):
|
||||
""" Id of the unknown token in the vocabulary. Log an error if used while not having been set. """
|
||||
if self._unk_token is None:
|
||||
return None
|
||||
return self.convert_tokens_to_ids(self.unk_token)
|
||||
|
||||
@property
|
||||
def sep_token_id(self):
|
||||
""" Id of the separation token in the vocabulary. E.g. separate context and query in an input sequence. Log an error if used while not having been set. """
|
||||
if self._sep_token is None:
|
||||
return None
|
||||
return self.convert_tokens_to_ids(self.sep_token)
|
||||
|
||||
@property
|
||||
def pad_token_id(self):
|
||||
""" Id of the padding token in the vocabulary. Log an error if used while not having been set. """
|
||||
if self._pad_token is None:
|
||||
return None
|
||||
return self.convert_tokens_to_ids(self.pad_token)
|
||||
|
||||
@property
|
||||
@@ -839,11 +853,15 @@ class SpecialTokensMixin:
|
||||
@property
|
||||
def cls_token_id(self):
|
||||
""" Id of the classification token in the vocabulary. E.g. to extract a summary of an input sequence leveraging self-attention along the full depth of the model. Log an error if used while not having been set. """
|
||||
if self._cls_token is None:
|
||||
return None
|
||||
return self.convert_tokens_to_ids(self.cls_token)
|
||||
|
||||
@property
|
||||
def mask_token_id(self):
|
||||
""" Id of the mask token in the vocabulary. E.g. when training a model with masked-language modeling. Log an error if used while not having been set. """
|
||||
if self._mask_token is None:
|
||||
return None
|
||||
return self.convert_tokens_to_ids(self.mask_token)
|
||||
|
||||
@property
|
||||
|
||||
Reference in New Issue
Block a user