Make add_special_tokens more clear (#20424)
* make add_special_tokens more clear Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -841,7 +841,9 @@ class SpecialTokensMixin:
|
||||
"""
|
||||
return self.add_tokens(self.all_special_tokens_extended, special_tokens=True)
|
||||
|
||||
def add_special_tokens(self, special_tokens_dict: Dict[str, Union[str, AddedToken]]) -> int:
|
||||
def add_special_tokens(
|
||||
self, special_tokens_dict: Dict[str, Union[str, AddedToken]], replace_additional_special_tokens=True
|
||||
) -> int:
|
||||
"""
|
||||
Add a dictionary of special tokens (eos, pad, cls, etc.) to the encoder and link them to class attributes. If
|
||||
special tokens are NOT in the vocabulary, they are added to it (indexed starting from the last index of the
|
||||
@@ -869,6 +871,11 @@ class SpecialTokensMixin:
|
||||
|
||||
Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer
|
||||
assign the index of the `unk_token` to them).
|
||||
replace_additional_special_tokens (`bool`, *optional*,, defaults to `True`):
|
||||
If `True`, the existing list of additional special tokens will be replaced by the one specified in
|
||||
`special_tokens_dict`. Otherwise, `self._additional_special_tokens` is updated. In the former case, the
|
||||
tokens will NOT be removed from the tokenizer's full vocabulary - they are only being flagged as
|
||||
non-special tokens.
|
||||
|
||||
Returns:
|
||||
`int`: Number of tokens added to the vocabulary.
|
||||
@@ -898,17 +905,32 @@ class SpecialTokensMixin:
|
||||
|
||||
if self.verbose:
|
||||
logger.info(f"Assigning {value} to the {key} key of the tokenizer")
|
||||
setattr(self, key, value)
|
||||
|
||||
if key == "additional_special_tokens":
|
||||
assert isinstance(value, (list, tuple)) and all(
|
||||
isinstance(t, (str, AddedToken)) for t in value
|
||||
), f"Tokens {value} for key {key} should all be str or AddedToken instances"
|
||||
|
||||
if replace_additional_special_tokens:
|
||||
setattr(self, key, value)
|
||||
else:
|
||||
# This is a copy of `self._additional_special_tokens`
|
||||
additional_special_tokens = getattr(self, key)
|
||||
additional_special_tokens_set = set(additional_special_tokens)
|
||||
to_add = []
|
||||
for token in value:
|
||||
if str(token) not in additional_special_tokens_set and str(token) not in to_add:
|
||||
to_add.append(token)
|
||||
# update the property
|
||||
additional_special_tokens.extend(to_add)
|
||||
self.additional_special_tokens = additional_special_tokens
|
||||
|
||||
added_tokens += self.add_tokens(value, special_tokens=True)
|
||||
else:
|
||||
assert isinstance(
|
||||
value, (str, AddedToken)
|
||||
), f"Token {value} for key {key} should be a str or an AddedToken instance"
|
||||
setattr(self, key, value)
|
||||
added_tokens += self.add_tokens([value], special_tokens=True)
|
||||
|
||||
return added_tokens
|
||||
|
||||
Reference in New Issue
Block a user