Better added_tokens handling
This commit is contained in:
@@ -1413,6 +1413,9 @@ class PreTrainedTokenizer(object):
|
|||||||
|
|
||||||
|
|
||||||
class PreTrainedTokenizerFast(PreTrainedTokenizer):
|
class PreTrainedTokenizerFast(PreTrainedTokenizer):
|
||||||
|
_tokenizer = None
|
||||||
|
_decoder = None
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
super(PreTrainedTokenizerFast, self).__init__(**kwargs)
|
super(PreTrainedTokenizerFast, self).__init__(**kwargs)
|
||||||
|
|
||||||
@@ -1435,8 +1438,49 @@ class PreTrainedTokenizerFast(PreTrainedTokenizer):
|
|||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.tokenizer.get_vocab_size(with_added_tokens=True)
|
return self.tokenizer.get_vocab_size(with_added_tokens=True)
|
||||||
|
|
||||||
|
@PreTrainedTokenizer.bos_token.setter
|
||||||
|
def bos_token(self, value):
|
||||||
|
self._bos_token = value
|
||||||
|
self._update_special_tokens()
|
||||||
|
|
||||||
|
@PreTrainedTokenizer.eos_token.setter
|
||||||
|
def eos_token(self, value):
|
||||||
|
self._eos_token = value
|
||||||
|
self._update_special_tokens()
|
||||||
|
|
||||||
|
@PreTrainedTokenizer.unk_token.setter
|
||||||
|
def unk_token(self, value):
|
||||||
|
self._unk_token = value
|
||||||
|
self._update_special_tokens()
|
||||||
|
|
||||||
|
@PreTrainedTokenizer.sep_token.setter
|
||||||
|
def sep_token(self, value):
|
||||||
|
self._sep_token = value
|
||||||
|
self._update_special_tokens()
|
||||||
|
|
||||||
|
@PreTrainedTokenizer.pad_token.setter
|
||||||
|
def pad_token(self, value):
|
||||||
|
self._pad_token = value
|
||||||
|
self._update_special_tokens()
|
||||||
|
|
||||||
|
@PreTrainedTokenizer.cls_token.setter
|
||||||
|
def cls_token(self, value):
|
||||||
|
self._cls_token = value
|
||||||
|
self._update_special_tokens()
|
||||||
|
|
||||||
|
@PreTrainedTokenizer.mask_token.setter
|
||||||
|
def mask_token(self, value):
|
||||||
|
self._mask_token = value
|
||||||
|
self._update_special_tokens()
|
||||||
|
|
||||||
|
@PreTrainedTokenizer.additional_special_tokens.setter
|
||||||
|
def additional_special_tokens(self, value):
|
||||||
|
self._additional_special_tokens = value
|
||||||
|
self._update_special_tokens()
|
||||||
|
|
||||||
def _update_special_tokens(self):
|
def _update_special_tokens(self):
|
||||||
self.tokenizer.add_special_tokens(self.all_special_tokens)
|
if self._tokenizer is not None:
|
||||||
|
self._tokenizer.add_special_tokens(self.all_special_tokens)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _convert_encoding(
|
def _convert_encoding(
|
||||||
@@ -1522,6 +1566,11 @@ class PreTrainedTokenizerFast(PreTrainedTokenizer):
|
|||||||
def add_tokens(self, new_tokens):
|
def add_tokens(self, new_tokens):
|
||||||
self.tokenizer.add_tokens(new_tokens)
|
self.tokenizer.add_tokens(new_tokens)
|
||||||
|
|
||||||
|
def add_special_tokens(self, special_tokens_dict):
|
||||||
|
added = super().add_special_tokens(special_tokens_dict)
|
||||||
|
self._update_special_tokens()
|
||||||
|
return added
|
||||||
|
|
||||||
def encode_batch(
|
def encode_batch(
|
||||||
self,
|
self,
|
||||||
texts,
|
texts,
|
||||||
|
|||||||
Reference in New Issue
Block a user