From 293991d44bdd5a71cbae68965ea130893fd770ef Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Wed, 30 Nov 2022 12:56:32 +0100 Subject: [PATCH] Make `add_special_tokens` more clear (#20424) * make add_special_tokens more clear Co-authored-by: ydshieh --- src/transformers/tokenization_utils_base.py | 26 +++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index e3721f8131..4e8aa6f3fb 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -841,7 +841,9 @@ class SpecialTokensMixin: """ return self.add_tokens(self.all_special_tokens_extended, special_tokens=True) - def add_special_tokens(self, special_tokens_dict: Dict[str, Union[str, AddedToken]]) -> int: + def add_special_tokens( + self, special_tokens_dict: Dict[str, Union[str, AddedToken]], replace_additional_special_tokens=True + ) -> int: """ Add a dictionary of special tokens (eos, pad, cls, etc.) to the encoder and link them to class attributes. If special tokens are NOT in the vocabulary, they are added to it (indexed starting from the last index of the @@ -869,6 +871,11 @@ class SpecialTokensMixin: Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer assign the index of the `unk_token` to them). + replace_additional_special_tokens (`bool`, *optional*,, defaults to `True`): + If `True`, the existing list of additional special tokens will be replaced by the one specified in + `special_tokens_dict`. Otherwise, `self._additional_special_tokens` is updated. In the former case, the + tokens will NOT be removed from the tokenizer's full vocabulary - they are only being flagged as + non-special tokens. Returns: `int`: Number of tokens added to the vocabulary. @@ -898,17 +905,32 @@ class SpecialTokensMixin: if self.verbose: logger.info(f"Assigning {value} to the {key} key of the tokenizer") - setattr(self, key, value) if key == "additional_special_tokens": assert isinstance(value, (list, tuple)) and all( isinstance(t, (str, AddedToken)) for t in value ), f"Tokens {value} for key {key} should all be str or AddedToken instances" + + if replace_additional_special_tokens: + setattr(self, key, value) + else: + # This is a copy of `self._additional_special_tokens` + additional_special_tokens = getattr(self, key) + additional_special_tokens_set = set(additional_special_tokens) + to_add = [] + for token in value: + if str(token) not in additional_special_tokens_set and str(token) not in to_add: + to_add.append(token) + # update the property + additional_special_tokens.extend(to_add) + self.additional_special_tokens = additional_special_tokens + added_tokens += self.add_tokens(value, special_tokens=True) else: assert isinstance( value, (str, AddedToken) ), f"Token {value} for key {key} should be a str or an AddedToken instance" + setattr(self, key, value) added_tokens += self.add_tokens([value], special_tokens=True) return added_tokens