T5Tokenizer adds EOS token if not already added (#5866)

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Sam Shleifer
2020-08-25 14:56:08 -04:00
committed by GitHub
parent e11d923bfc
commit 624495706c
3 changed files with 117 additions and 50 deletions

View File

@@ -18,6 +18,7 @@ import os
import unittest
from transformers import BatchEncoding
from transformers.file_utils import cached_property
from transformers.testing_utils import _torch_available
from transformers.tokenization_t5 import T5Tokenizer
from transformers.tokenization_xlnet import SPIECE_UNDERLINE
@@ -107,28 +108,37 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
],
)
@cached_property
def t5_base_tokenizer(self):
return T5Tokenizer.from_pretrained("t5-base")
def test_eos_treatment(self):
tokenizer = self.t5_base_tokenizer
batch_with_eos_added = tokenizer(["hi</s>", "I went to the gym</s>", "</s>"])
batch_without_eos_added = tokenizer(["hi", "I went to the gym", ""])
self.assertListEqual(batch_with_eos_added["input_ids"], batch_without_eos_added["input_ids"])
def test_prepare_seq2seq_batch(self):
tokenizer = T5Tokenizer.from_pretrained("t5-small")
tokenizer = self.t5_base_tokenizer
src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."]
tgt_text = [
"Summary of the text.",
"Another summary.",
]
expected_src_tokens = [71, 307, 8986, 21, 4505, 51, 52, 1707, 5]
batch = tokenizer.prepare_seq2seq_batch(
src_text, tgt_texts=tgt_text, max_length=len(expected_src_tokens), return_tensors=FRAMEWORK
)
expected_src_tokens = [71, 307, 8986, 21, 4505, 51, 52, 1707, 5, tokenizer.eos_token_id]
batch = tokenizer.prepare_seq2seq_batch(src_text, tgt_texts=tgt_text, return_tensors=FRAMEWORK,)
self.assertIsInstance(batch, BatchEncoding)
self.assertEqual((2, 9), batch.input_ids.shape)
self.assertEqual((2, 9), batch.attention_mask.shape)
result = list(batch.input_ids.numpy()[0])
self.assertListEqual(expected_src_tokens, result)
self.assertEqual((2, 10), batch.input_ids.shape)
self.assertEqual((2, 10), batch.attention_mask.shape)
# Test that special tokens are reset
self.assertEqual(tokenizer.prefix_tokens, [])
def test_empty_target_text(self):
tokenizer = T5Tokenizer.from_pretrained("t5-small")
tokenizer = self.t5_base_tokenizer
src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."]
batch = tokenizer.prepare_seq2seq_batch(src_text, return_tensors=FRAMEWORK)
# check if input_ids are returned and no decoder_input_ids
@@ -138,7 +148,7 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
self.assertNotIn("decoder_attention_mask", batch)
def test_max_target_length(self):
tokenizer = T5Tokenizer.from_pretrained("t5-small")
tokenizer = self.t5_base_tokenizer
src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."]
tgt_text = [
"Summary of the text.",
@@ -158,7 +168,7 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
self.assertEqual(32, batch["decoder_attention_mask"].shape[1])
def test_outputs_not_longer_than_maxlen(self):
tokenizer = T5Tokenizer.from_pretrained("t5-small")
tokenizer = self.t5_base_tokenizer
batch = tokenizer.prepare_seq2seq_batch(
["I am a small frog" * 1000, "I am a small frog"], return_tensors=FRAMEWORK
@@ -167,7 +177,7 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
self.assertEqual(batch.input_ids.shape, (2, 512))
def test_eos_in_input(self):
tokenizer = T5Tokenizer.from_pretrained("t5-small")
tokenizer = self.t5_base_tokenizer
src_text = ["A long paragraph for summrization. </s>"]
tgt_text = ["Summary of the text. </s>"]
expected_src_tokens = [71, 307, 8986, 21, 4505, 51, 52, 1707, 5, 1]