From 70708cca1a2f3e01d9d72a3aaa7ab078dfef639e Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 10 Nov 2020 20:21:54 +0100 Subject: [PATCH] fix t5 token type ids (#8437) --- src/transformers/tokenization_t5.py | 22 ++++++++++++++++++++++ src/transformers/tokenization_t5_fast.py | 22 ++++++++++++++++++++++ tests/test_tokenization_t5.py | 14 ++++++++++++++ 3 files changed, 58 insertions(+) diff --git a/src/transformers/tokenization_t5.py b/src/transformers/tokenization_t5.py index 7a5e7fd587..781791b5ba 100644 --- a/src/transformers/tokenization_t5.py +++ b/src/transformers/tokenization_t5.py @@ -187,6 +187,28 @@ class T5Tokenizer(PreTrainedTokenizer): else: return token_ids + [self.eos_token_id] + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make + use of token type ids, therefore a list of zeros is returned. + + Args: + token_ids_0 (:obj:`List[int]`): + List of IDs. + token_ids_1 (:obj:`List[int]`, `optional`): + Optional second list of IDs for sequence pairs. + + Returns: + :obj:`List[int]`: List of zeros. + """ + eos = [self.eos_token_id] + + if token_ids_1 is None: + return len(token_ids_0 + eos) * [0] + return len(token_ids_0 + eos + token_ids_1 + eos) * [0] + def build_inputs_with_special_tokens( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None ) -> List[int]: diff --git a/src/transformers/tokenization_t5_fast.py b/src/transformers/tokenization_t5_fast.py index e64d8ca724..0aba4763df 100644 --- a/src/transformers/tokenization_t5_fast.py +++ b/src/transformers/tokenization_t5_fast.py @@ -191,6 +191,28 @@ class T5TokenizerFast(PreTrainedTokenizerFast): token_ids_1 = token_ids_1 + [self.eos_token_id] return self.prefix_tokens + token_ids_0 + token_ids_1 + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make + use of token type ids, therefore a list of zeros is returned. + + Args: + token_ids_0 (:obj:`List[int]`): + List of IDs. + token_ids_1 (:obj:`List[int]`, `optional`): + Optional second list of IDs for sequence pairs. + + Returns: + :obj:`List[int]`: List of zeros. + """ + eos = [self.eos_token_id] + + if token_ids_1 is None: + return len(token_ids_0 + eos) * [0] + return len(token_ids_0 + eos + token_ids_1 + eos) * [0] + @add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING) def prepare_seq2seq_batch( self, diff --git a/tests/test_tokenization_t5.py b/tests/test_tokenization_t5.py index 05d45d9b69..7ef4b931bf 100644 --- a/tests/test_tokenization_t5.py +++ b/tests/test_tokenization_t5.py @@ -223,6 +223,20 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase): self.assertEqual(expected_src_tokens, src_ids) self.assertEqual(expected_tgt_tokens, tgt_ids) + def test_token_type_ids(self): + src_text_1 = ["A first paragraph for summarization."] + src_text_2 = ["A second paragraph for summarization."] + + fast_token_type_ids = self.t5_base_tokenizer_fast( + src_text_1, src_text_2, add_special_tokens=True, return_token_type_ids=True + ).token_type_ids + slow_token_type_ids = self.t5_base_tokenizer( + src_text_1, src_text_2, add_special_tokens=True, return_token_type_ids=True + ).token_type_ids + + self.assertEqual(slow_token_type_ids, fast_token_type_ids) + self.assertEqual(len(slow_token_type_ids[0]), 18) + def test_fast_and_slow_same_result(self): src_text = " Today is nice day " tgt_ids = [0, 1960, 19, 2, 1245, 239, 1]