fix t5 token type ids (#8437)

This commit is contained in:
Patrick von Platen
2020-11-10 20:21:54 +01:00
committed by GitHub
parent 9fd1f56236
commit 70708cca1a
3 changed files with 58 additions and 0 deletions

View File

@@ -223,6 +223,20 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
self.assertEqual(expected_src_tokens, src_ids)
self.assertEqual(expected_tgt_tokens, tgt_ids)
def test_token_type_ids(self):
src_text_1 = ["A first paragraph for summarization."]
src_text_2 = ["A second paragraph for summarization."]
fast_token_type_ids = self.t5_base_tokenizer_fast(
src_text_1, src_text_2, add_special_tokens=True, return_token_type_ids=True
).token_type_ids
slow_token_type_ids = self.t5_base_tokenizer(
src_text_1, src_text_2, add_special_tokens=True, return_token_type_ids=True
).token_type_ids
self.assertEqual(slow_token_type_ids, fast_token_type_ids)
self.assertEqual(len(slow_token_type_ids[0]), 18)
def test_fast_and_slow_same_result(self):
src_text = "<pad> Today is <unk> nice day </s>"
tgt_ids = [0, 1960, 19, 2, 1245, 239, 1]