Better added_tokens handling
This commit is contained in:
@@ -1413,6 +1413,9 @@ class PreTrainedTokenizer(object):
|
||||
|
||||
|
||||
class PreTrainedTokenizerFast(PreTrainedTokenizer):
|
||||
_tokenizer = None
|
||||
_decoder = None
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(PreTrainedTokenizerFast, self).__init__(**kwargs)
|
||||
|
||||
@@ -1435,8 +1438,49 @@ class PreTrainedTokenizerFast(PreTrainedTokenizer):
|
||||
def __len__(self):
|
||||
return self.tokenizer.get_vocab_size(with_added_tokens=True)
|
||||
|
||||
@PreTrainedTokenizer.bos_token.setter
|
||||
def bos_token(self, value):
|
||||
self._bos_token = value
|
||||
self._update_special_tokens()
|
||||
|
||||
@PreTrainedTokenizer.eos_token.setter
|
||||
def eos_token(self, value):
|
||||
self._eos_token = value
|
||||
self._update_special_tokens()
|
||||
|
||||
@PreTrainedTokenizer.unk_token.setter
|
||||
def unk_token(self, value):
|
||||
self._unk_token = value
|
||||
self._update_special_tokens()
|
||||
|
||||
@PreTrainedTokenizer.sep_token.setter
|
||||
def sep_token(self, value):
|
||||
self._sep_token = value
|
||||
self._update_special_tokens()
|
||||
|
||||
@PreTrainedTokenizer.pad_token.setter
|
||||
def pad_token(self, value):
|
||||
self._pad_token = value
|
||||
self._update_special_tokens()
|
||||
|
||||
@PreTrainedTokenizer.cls_token.setter
|
||||
def cls_token(self, value):
|
||||
self._cls_token = value
|
||||
self._update_special_tokens()
|
||||
|
||||
@PreTrainedTokenizer.mask_token.setter
|
||||
def mask_token(self, value):
|
||||
self._mask_token = value
|
||||
self._update_special_tokens()
|
||||
|
||||
@PreTrainedTokenizer.additional_special_tokens.setter
|
||||
def additional_special_tokens(self, value):
|
||||
self._additional_special_tokens = value
|
||||
self._update_special_tokens()
|
||||
|
||||
def _update_special_tokens(self):
|
||||
self.tokenizer.add_special_tokens(self.all_special_tokens)
|
||||
if self._tokenizer is not None:
|
||||
self._tokenizer.add_special_tokens(self.all_special_tokens)
|
||||
|
||||
@staticmethod
|
||||
def _convert_encoding(
|
||||
@@ -1522,6 +1566,11 @@ class PreTrainedTokenizerFast(PreTrainedTokenizer):
|
||||
def add_tokens(self, new_tokens):
|
||||
self.tokenizer.add_tokens(new_tokens)
|
||||
|
||||
def add_special_tokens(self, special_tokens_dict):
|
||||
added = super().add_special_tokens(special_tokens_dict)
|
||||
self._update_special_tokens()
|
||||
return added
|
||||
|
||||
def encode_batch(
|
||||
self,
|
||||
texts,
|
||||
|
||||
Reference in New Issue
Block a user