Merge pull request #1811 from huggingface/special-tokens
Fix special tokens addition in decoder #1807
This commit is contained in:
@@ -160,6 +160,26 @@ class CommonTestCases:
|
|||||||
self.assertEqual(tokens[0], tokenizer.eos_token_id)
|
self.assertEqual(tokens[0], tokenizer.eos_token_id)
|
||||||
self.assertEqual(tokens[-2], tokenizer.pad_token_id)
|
self.assertEqual(tokens[-2], tokenizer.pad_token_id)
|
||||||
|
|
||||||
|
def test_add_special_tokens(self):
|
||||||
|
tokenizer = self.get_tokenizer()
|
||||||
|
input_text, output_text = self.get_input_output_texts()
|
||||||
|
|
||||||
|
special_token = "[SPECIAL TOKEN]"
|
||||||
|
|
||||||
|
tokenizer.add_special_tokens({"cls_token": special_token})
|
||||||
|
encoded_special_token = tokenizer.encode(special_token, add_special_tokens=False)
|
||||||
|
assert len(encoded_special_token) == 1
|
||||||
|
|
||||||
|
text = " ".join([input_text, special_token, output_text])
|
||||||
|
encoded = tokenizer.encode(text, add_special_tokens=False)
|
||||||
|
|
||||||
|
input_encoded = tokenizer.encode(input_text, add_special_tokens=False)
|
||||||
|
output_encoded = tokenizer.encode(output_text, add_special_tokens=False)
|
||||||
|
special_token_id = tokenizer.encode(special_token, add_special_tokens=False)
|
||||||
|
assert encoded == input_encoded + special_token_id + output_encoded
|
||||||
|
|
||||||
|
decoded = tokenizer.decode(encoded, skip_special_tokens=True)
|
||||||
|
assert special_token not in decoded
|
||||||
|
|
||||||
def test_required_methods_tokenizer(self):
|
def test_required_methods_tokenizer(self):
|
||||||
tokenizer = self.get_tokenizer()
|
tokenizer = self.get_tokenizer()
|
||||||
|
|||||||
@@ -1057,7 +1057,7 @@ class PreTrainedTokenizer(object):
|
|||||||
class attributes (cls_token, unk_token...).
|
class attributes (cls_token, unk_token...).
|
||||||
"""
|
"""
|
||||||
all_toks = self.all_special_tokens
|
all_toks = self.all_special_tokens
|
||||||
all_ids = list(self._convert_token_to_id(t) for t in all_toks)
|
all_ids = self.convert_tokens_to_ids(all_toks)
|
||||||
return all_ids
|
return all_ids
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
Reference in New Issue
Block a user