Fix special tokens addition in decoder
This commit is contained in:
@@ -160,6 +160,26 @@ class CommonTestCases:
|
||||
self.assertEqual(tokens[0], tokenizer.eos_token_id)
|
||||
self.assertEqual(tokens[-2], tokenizer.pad_token_id)
|
||||
|
||||
def test_add_special_tokens(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
input_text, output_text = self.get_input_output_texts()
|
||||
|
||||
special_token = "[SPECIAL TOKEN]"
|
||||
|
||||
tokenizer.add_special_tokens({"cls_token": special_token})
|
||||
encoded_special_token = tokenizer.encode(special_token, add_special_tokens=False)
|
||||
assert len(encoded_special_token) == 1
|
||||
|
||||
text = " ".join([input_text, special_token, output_text])
|
||||
encoded = tokenizer.encode(text, add_special_tokens=False)
|
||||
|
||||
input_encoded = tokenizer.encode(input_text, add_special_tokens=False)
|
||||
output_encoded = tokenizer.encode(output_text, add_special_tokens=False)
|
||||
special_token_id = tokenizer.encode(special_token, add_special_tokens=False)
|
||||
assert encoded == input_encoded + special_token_id + output_encoded
|
||||
|
||||
decoded = tokenizer.decode(encoded, skip_special_tokens=True)
|
||||
assert special_token not in decoded
|
||||
|
||||
def test_required_methods_tokenizer(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
Reference in New Issue
Block a user