Fixing camembert tokenization
This commit is contained in:
@@ -51,7 +51,7 @@ class CamembertTokenizer(PreTrainedTokenizer):
|
|||||||
|
|
||||||
def __init__(self, vocab_file, bos_token="<s>", eos_token="</s>", sep_token="</s>",
|
def __init__(self, vocab_file, bos_token="<s>", eos_token="</s>", sep_token="</s>",
|
||||||
cls_token="<s>", unk_token="<unk>", pad_token='<pad>', mask_token='<mask>',
|
cls_token="<s>", unk_token="<unk>", pad_token='<pad>', mask_token='<mask>',
|
||||||
additional_special_tokens=['<s>NOTUSED', '<s>NOTUSED'], **kwargs):
|
additional_special_tokens=['<s>NOTUSED', '</s>NOTUSED'], **kwargs):
|
||||||
super(CamembertTokenizer, self).__init__(max_len=512, bos_token=bos_token, eos_token=eos_token, unk_token=unk_token,
|
super(CamembertTokenizer, self).__init__(max_len=512, bos_token=bos_token, eos_token=eos_token, unk_token=unk_token,
|
||||||
sep_token=sep_token, cls_token=cls_token, pad_token=pad_token,
|
sep_token=sep_token, cls_token=cls_token, pad_token=pad_token,
|
||||||
mask_token=mask_token, additional_special_tokens=additional_special_tokens,
|
mask_token=mask_token, additional_special_tokens=additional_special_tokens,
|
||||||
@@ -125,7 +125,7 @@ class CamembertTokenizer(PreTrainedTokenizer):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def vocab_size(self):
|
def vocab_size(self):
|
||||||
return self.fairseq_offset + len(self.sp_model)
|
return len(self.fairseq_tokens_to_ids) + len(self.sp_model)
|
||||||
|
|
||||||
def _tokenize(self, text):
|
def _tokenize(self, text):
|
||||||
return self.sp_model.EncodeAsPieces(text)
|
return self.sp_model.EncodeAsPieces(text)
|
||||||
@@ -134,6 +134,9 @@ class CamembertTokenizer(PreTrainedTokenizer):
|
|||||||
""" Converts a token (str/unicode) in an id using the vocab. """
|
""" Converts a token (str/unicode) in an id using the vocab. """
|
||||||
if token in self.fairseq_tokens_to_ids:
|
if token in self.fairseq_tokens_to_ids:
|
||||||
return self.fairseq_tokens_to_ids[token]
|
return self.fairseq_tokens_to_ids[token]
|
||||||
|
elif self.sp_model.PieceToId(token) == 0:
|
||||||
|
# Convert sentence piece unk token to fairseq unk token index
|
||||||
|
return self.unk_token_id
|
||||||
return self.fairseq_offset + self.sp_model.PieceToId(token)
|
return self.fairseq_offset + self.sp_model.PieceToId(token)
|
||||||
|
|
||||||
def _convert_id_to_token(self, index):
|
def _convert_id_to_token(self, index):
|
||||||
|
|||||||
Reference in New Issue
Block a user