Make add_special_tokens more clear (#20424)
* make add_special_tokens more clear Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -841,7 +841,9 @@ class SpecialTokensMixin:
|
|||||||
"""
|
"""
|
||||||
return self.add_tokens(self.all_special_tokens_extended, special_tokens=True)
|
return self.add_tokens(self.all_special_tokens_extended, special_tokens=True)
|
||||||
|
|
||||||
def add_special_tokens(self, special_tokens_dict: Dict[str, Union[str, AddedToken]]) -> int:
|
def add_special_tokens(
|
||||||
|
self, special_tokens_dict: Dict[str, Union[str, AddedToken]], replace_additional_special_tokens=True
|
||||||
|
) -> int:
|
||||||
"""
|
"""
|
||||||
Add a dictionary of special tokens (eos, pad, cls, etc.) to the encoder and link them to class attributes. If
|
Add a dictionary of special tokens (eos, pad, cls, etc.) to the encoder and link them to class attributes. If
|
||||||
special tokens are NOT in the vocabulary, they are added to it (indexed starting from the last index of the
|
special tokens are NOT in the vocabulary, they are added to it (indexed starting from the last index of the
|
||||||
@@ -869,6 +871,11 @@ class SpecialTokensMixin:
|
|||||||
|
|
||||||
Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer
|
Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer
|
||||||
assign the index of the `unk_token` to them).
|
assign the index of the `unk_token` to them).
|
||||||
|
replace_additional_special_tokens (`bool`, *optional*,, defaults to `True`):
|
||||||
|
If `True`, the existing list of additional special tokens will be replaced by the one specified in
|
||||||
|
`special_tokens_dict`. Otherwise, `self._additional_special_tokens` is updated. In the former case, the
|
||||||
|
tokens will NOT be removed from the tokenizer's full vocabulary - they are only being flagged as
|
||||||
|
non-special tokens.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
`int`: Number of tokens added to the vocabulary.
|
`int`: Number of tokens added to the vocabulary.
|
||||||
@@ -898,17 +905,32 @@ class SpecialTokensMixin:
|
|||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
logger.info(f"Assigning {value} to the {key} key of the tokenizer")
|
logger.info(f"Assigning {value} to the {key} key of the tokenizer")
|
||||||
setattr(self, key, value)
|
|
||||||
|
|
||||||
if key == "additional_special_tokens":
|
if key == "additional_special_tokens":
|
||||||
assert isinstance(value, (list, tuple)) and all(
|
assert isinstance(value, (list, tuple)) and all(
|
||||||
isinstance(t, (str, AddedToken)) for t in value
|
isinstance(t, (str, AddedToken)) for t in value
|
||||||
), f"Tokens {value} for key {key} should all be str or AddedToken instances"
|
), f"Tokens {value} for key {key} should all be str or AddedToken instances"
|
||||||
|
|
||||||
|
if replace_additional_special_tokens:
|
||||||
|
setattr(self, key, value)
|
||||||
|
else:
|
||||||
|
# This is a copy of `self._additional_special_tokens`
|
||||||
|
additional_special_tokens = getattr(self, key)
|
||||||
|
additional_special_tokens_set = set(additional_special_tokens)
|
||||||
|
to_add = []
|
||||||
|
for token in value:
|
||||||
|
if str(token) not in additional_special_tokens_set and str(token) not in to_add:
|
||||||
|
to_add.append(token)
|
||||||
|
# update the property
|
||||||
|
additional_special_tokens.extend(to_add)
|
||||||
|
self.additional_special_tokens = additional_special_tokens
|
||||||
|
|
||||||
added_tokens += self.add_tokens(value, special_tokens=True)
|
added_tokens += self.add_tokens(value, special_tokens=True)
|
||||||
else:
|
else:
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
value, (str, AddedToken)
|
value, (str, AddedToken)
|
||||||
), f"Token {value} for key {key} should be a str or an AddedToken instance"
|
), f"Token {value} for key {key} should be a str or an AddedToken instance"
|
||||||
|
setattr(self, key, value)
|
||||||
added_tokens += self.add_tokens([value], special_tokens=True)
|
added_tokens += self.add_tokens([value], special_tokens=True)
|
||||||
|
|
||||||
return added_tokens
|
return added_tokens
|
||||||
|
|||||||
Reference in New Issue
Block a user