fix t5 special tokens (#8435)
This commit is contained in:
committed by
GitHub
parent
cace39af97
commit
b93569457f
@@ -249,8 +249,17 @@ class T5Tokenizer(PreTrainedTokenizer):
|
|||||||
|
|
||||||
def convert_tokens_to_string(self, tokens):
|
def convert_tokens_to_string(self, tokens):
|
||||||
""" Converts a sequence of tokens (string) in a single string. """
|
""" Converts a sequence of tokens (string) in a single string. """
|
||||||
out_string = self.sp_model.decode_pieces(tokens)
|
current_sub_tokens = []
|
||||||
return out_string
|
out_string = ""
|
||||||
|
for token in tokens:
|
||||||
|
# make sure that special tokens are not decoded using sentencepiece model
|
||||||
|
if token in self.all_special_tokens:
|
||||||
|
out_string += self.sp_model.decode_pieces(current_sub_tokens) + token + " "
|
||||||
|
current_sub_tokens = []
|
||||||
|
else:
|
||||||
|
current_sub_tokens.append(token)
|
||||||
|
out_string += self.sp_model.decode_pieces(current_sub_tokens)
|
||||||
|
return out_string.strip()
|
||||||
|
|
||||||
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||||
if not os.path.isdir(save_directory):
|
if not os.path.isdir(save_directory):
|
||||||
|
|||||||
@@ -222,3 +222,18 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
self.assertEqual(expected_src_tokens, src_ids)
|
self.assertEqual(expected_src_tokens, src_ids)
|
||||||
self.assertEqual(expected_tgt_tokens, tgt_ids)
|
self.assertEqual(expected_tgt_tokens, tgt_ids)
|
||||||
|
|
||||||
|
def test_fast_and_slow_same_result(self):
|
||||||
|
src_text = "<pad> Today is <unk> nice day </s>"
|
||||||
|
tgt_ids = [0, 1960, 19, 2, 1245, 239, 1]
|
||||||
|
tgt_text = "<pad> Today is<unk> nice day</s>"
|
||||||
|
|
||||||
|
fast_ids = self.t5_base_tokenizer_fast(src_text, add_special_tokens=False).input_ids
|
||||||
|
slow_ids = self.t5_base_tokenizer(src_text, add_special_tokens=False).input_ids
|
||||||
|
self.assertEqual(tgt_ids, fast_ids)
|
||||||
|
self.assertEqual(tgt_ids, slow_ids)
|
||||||
|
|
||||||
|
fast_text = self.t5_base_tokenizer_fast.decode(fast_ids)
|
||||||
|
slow_text = self.t5_base_tokenizer.decode(fast_ids)
|
||||||
|
self.assertEqual(tgt_text, fast_text)
|
||||||
|
self.assertEqual(tgt_text, slow_text)
|
||||||
|
|||||||
Reference in New Issue
Block a user