From e6ec24fa881446e7c06fd5ab2cbc461899428c54 Mon Sep 17 00:00:00 2001 From: Anthony MOI Date: Thu, 26 Dec 2019 16:49:48 -0500 Subject: [PATCH] Better added_tokens handling --- src/transformers/tokenization_utils.py | 51 +++++++++++++++++++++++++- 1 file changed, 50 insertions(+), 1 deletion(-) diff --git a/src/transformers/tokenization_utils.py b/src/transformers/tokenization_utils.py index 210e47e752..8fa85a2f7c 100644 --- a/src/transformers/tokenization_utils.py +++ b/src/transformers/tokenization_utils.py @@ -1413,6 +1413,9 @@ class PreTrainedTokenizer(object): class PreTrainedTokenizerFast(PreTrainedTokenizer): + _tokenizer = None + _decoder = None + def __init__(self, **kwargs): super(PreTrainedTokenizerFast, self).__init__(**kwargs) @@ -1435,8 +1438,49 @@ class PreTrainedTokenizerFast(PreTrainedTokenizer): def __len__(self): return self.tokenizer.get_vocab_size(with_added_tokens=True) + @PreTrainedTokenizer.bos_token.setter + def bos_token(self, value): + self._bos_token = value + self._update_special_tokens() + + @PreTrainedTokenizer.eos_token.setter + def eos_token(self, value): + self._eos_token = value + self._update_special_tokens() + + @PreTrainedTokenizer.unk_token.setter + def unk_token(self, value): + self._unk_token = value + self._update_special_tokens() + + @PreTrainedTokenizer.sep_token.setter + def sep_token(self, value): + self._sep_token = value + self._update_special_tokens() + + @PreTrainedTokenizer.pad_token.setter + def pad_token(self, value): + self._pad_token = value + self._update_special_tokens() + + @PreTrainedTokenizer.cls_token.setter + def cls_token(self, value): + self._cls_token = value + self._update_special_tokens() + + @PreTrainedTokenizer.mask_token.setter + def mask_token(self, value): + self._mask_token = value + self._update_special_tokens() + + @PreTrainedTokenizer.additional_special_tokens.setter + def additional_special_tokens(self, value): + self._additional_special_tokens = value + self._update_special_tokens() + def _update_special_tokens(self): - self.tokenizer.add_special_tokens(self.all_special_tokens) + if self._tokenizer is not None: + self._tokenizer.add_special_tokens(self.all_special_tokens) @staticmethod def _convert_encoding( @@ -1522,6 +1566,11 @@ class PreTrainedTokenizerFast(PreTrainedTokenizer): def add_tokens(self, new_tokens): self.tokenizer.add_tokens(new_tokens) + def add_special_tokens(self, special_tokens_dict): + added = super().add_special_tokens(special_tokens_dict) + self._update_special_tokens() + return added + def encode_batch( self, texts,