fix t5 token type ids (#8437)
This commit is contained in:
committed by
GitHub
parent
9fd1f56236
commit
70708cca1a
@@ -223,6 +223,20 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
self.assertEqual(expected_src_tokens, src_ids)
|
||||
self.assertEqual(expected_tgt_tokens, tgt_ids)
|
||||
|
||||
def test_token_type_ids(self):
|
||||
src_text_1 = ["A first paragraph for summarization."]
|
||||
src_text_2 = ["A second paragraph for summarization."]
|
||||
|
||||
fast_token_type_ids = self.t5_base_tokenizer_fast(
|
||||
src_text_1, src_text_2, add_special_tokens=True, return_token_type_ids=True
|
||||
).token_type_ids
|
||||
slow_token_type_ids = self.t5_base_tokenizer(
|
||||
src_text_1, src_text_2, add_special_tokens=True, return_token_type_ids=True
|
||||
).token_type_ids
|
||||
|
||||
self.assertEqual(slow_token_type_ids, fast_token_type_ids)
|
||||
self.assertEqual(len(slow_token_type_ids[0]), 18)
|
||||
|
||||
def test_fast_and_slow_same_result(self):
|
||||
src_text = "<pad> Today is <unk> nice day </s>"
|
||||
tgt_ids = [0, 1960, 19, 2, 1245, 239, 1]
|
||||
|
||||
Reference in New Issue
Block a user