XLM tokenizer should encode with bos token (#3791)
* XLM tokenizer should encode with bos token * Update tests
This commit is contained in:
@@ -873,11 +873,12 @@ class XLMTokenizer(PreTrainedTokenizer):
|
|||||||
:obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
|
:obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if token_ids_1 is None:
|
bos = [self.bos_token_id]
|
||||||
return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
|
|
||||||
sep = [self.sep_token_id]
|
sep = [self.sep_token_id]
|
||||||
cls = [self.cls_token_id]
|
|
||||||
return cls + token_ids_0 + sep + token_ids_1 + sep
|
if token_ids_1 is None:
|
||||||
|
return bos + token_ids_0 + sep
|
||||||
|
return bos + token_ids_0 + sep + token_ids_1 + sep
|
||||||
|
|
||||||
def get_special_tokens_mask(
|
def get_special_tokens_mask(
|
||||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
||||||
|
|||||||
@@ -96,5 +96,5 @@ class XLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
encoded_sentence = tokenizer.build_inputs_with_special_tokens(text)
|
encoded_sentence = tokenizer.build_inputs_with_special_tokens(text)
|
||||||
encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2)
|
encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2)
|
||||||
|
|
||||||
assert encoded_sentence == [1] + text + [1]
|
assert encoded_sentence == [0] + text + [1]
|
||||||
assert encoded_pair == [1] + text + [1] + text_2 + [1]
|
assert encoded_pair == [0] + text + [1] + text_2 + [1]
|
||||||
|
|||||||
Reference in New Issue
Block a user