[tests] fix typos in inputs (#6818)
This commit is contained in:
@@ -69,12 +69,12 @@ class TestTokenizationBart(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_prepare_seq2seq_batch(self):
|
def test_prepare_seq2seq_batch(self):
|
||||||
src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."]
|
src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."]
|
||||||
tgt_text = [
|
tgt_text = [
|
||||||
"Summary of the text.",
|
"Summary of the text.",
|
||||||
"Another summary.",
|
"Another summary.",
|
||||||
]
|
]
|
||||||
expected_src_tokens = [0, 250, 251, 17818, 13, 32933, 21645, 1258, 4, 2]
|
expected_src_tokens = [0, 250, 251, 17818, 13, 39186, 1938, 4, 2]
|
||||||
|
|
||||||
for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]:
|
for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]:
|
||||||
batch = tokenizer.prepare_seq2seq_batch(
|
batch = tokenizer.prepare_seq2seq_batch(
|
||||||
@@ -82,8 +82,8 @@ class TestTokenizationBart(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
)
|
)
|
||||||
self.assertIsInstance(batch, BatchEncoding)
|
self.assertIsInstance(batch, BatchEncoding)
|
||||||
|
|
||||||
self.assertEqual((2, 10), batch.input_ids.shape)
|
self.assertEqual((2, 9), batch.input_ids.shape)
|
||||||
self.assertEqual((2, 10), batch.attention_mask.shape)
|
self.assertEqual((2, 9), batch.attention_mask.shape)
|
||||||
result = batch.input_ids.tolist()[0]
|
result = batch.input_ids.tolist()[0]
|
||||||
self.assertListEqual(expected_src_tokens, result)
|
self.assertListEqual(expected_src_tokens, result)
|
||||||
# Test that special tokens are reset
|
# Test that special tokens are reset
|
||||||
@@ -91,7 +91,7 @@ class TestTokenizationBart(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
# Test Prepare Seq
|
# Test Prepare Seq
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_seq2seq_batch_empty_target_text(self):
|
def test_seq2seq_batch_empty_target_text(self):
|
||||||
src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."]
|
src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."]
|
||||||
for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]:
|
for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]:
|
||||||
batch = tokenizer.prepare_seq2seq_batch(src_text, return_tensors="pt")
|
batch = tokenizer.prepare_seq2seq_batch(src_text, return_tensors="pt")
|
||||||
# check if input_ids are returned and no labels
|
# check if input_ids are returned and no labels
|
||||||
@@ -102,7 +102,7 @@ class TestTokenizationBart(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_seq2seq_batch_max_target_length(self):
|
def test_seq2seq_batch_max_target_length(self):
|
||||||
src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."]
|
src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."]
|
||||||
tgt_text = [
|
tgt_text = [
|
||||||
"Summary of the text.",
|
"Summary of the text.",
|
||||||
"Another summary.",
|
"Another summary.",
|
||||||
@@ -131,7 +131,7 @@ class TestTokenizationBart(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
@require_torch
|
@require_torch
|
||||||
def test_special_tokens(self):
|
def test_special_tokens(self):
|
||||||
|
|
||||||
src_text = ["A long paragraph for summrization."]
|
src_text = ["A long paragraph for summarization."]
|
||||||
tgt_text = [
|
tgt_text = [
|
||||||
"Summary of the text.",
|
"Summary of the text.",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -120,12 +120,12 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
def test_prepare_seq2seq_batch(self):
|
def test_prepare_seq2seq_batch(self):
|
||||||
tokenizer = self.t5_base_tokenizer
|
tokenizer = self.t5_base_tokenizer
|
||||||
src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."]
|
src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."]
|
||||||
tgt_text = [
|
tgt_text = [
|
||||||
"Summary of the text.",
|
"Summary of the text.",
|
||||||
"Another summary.",
|
"Another summary.",
|
||||||
]
|
]
|
||||||
expected_src_tokens = [71, 307, 8986, 21, 4505, 51, 52, 1707, 5, tokenizer.eos_token_id]
|
expected_src_tokens = [71, 307, 8986, 21, 4505, 1635, 1707, 5, tokenizer.eos_token_id]
|
||||||
batch = tokenizer.prepare_seq2seq_batch(
|
batch = tokenizer.prepare_seq2seq_batch(
|
||||||
src_text,
|
src_text,
|
||||||
tgt_texts=tgt_text,
|
tgt_texts=tgt_text,
|
||||||
@@ -135,15 +135,15 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
result = list(batch.input_ids.numpy()[0])
|
result = list(batch.input_ids.numpy()[0])
|
||||||
self.assertListEqual(expected_src_tokens, result)
|
self.assertListEqual(expected_src_tokens, result)
|
||||||
|
|
||||||
self.assertEqual((2, 10), batch.input_ids.shape)
|
self.assertEqual((2, 9), batch.input_ids.shape)
|
||||||
self.assertEqual((2, 10), batch.attention_mask.shape)
|
self.assertEqual((2, 9), batch.attention_mask.shape)
|
||||||
|
|
||||||
# Test that special tokens are reset
|
# Test that special tokens are reset
|
||||||
self.assertEqual(tokenizer.prefix_tokens, [])
|
self.assertEqual(tokenizer.prefix_tokens, [])
|
||||||
|
|
||||||
def test_empty_target_text(self):
|
def test_empty_target_text(self):
|
||||||
tokenizer = self.t5_base_tokenizer
|
tokenizer = self.t5_base_tokenizer
|
||||||
src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."]
|
src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."]
|
||||||
batch = tokenizer.prepare_seq2seq_batch(src_text, return_tensors=FRAMEWORK)
|
batch = tokenizer.prepare_seq2seq_batch(src_text, return_tensors=FRAMEWORK)
|
||||||
# check if input_ids are returned and no decoder_input_ids
|
# check if input_ids are returned and no decoder_input_ids
|
||||||
self.assertIn("input_ids", batch)
|
self.assertIn("input_ids", batch)
|
||||||
@@ -153,7 +153,7 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
def test_max_target_length(self):
|
def test_max_target_length(self):
|
||||||
tokenizer = self.t5_base_tokenizer
|
tokenizer = self.t5_base_tokenizer
|
||||||
src_text = ["A short paragraph for summrization.", "Another short paragraph for summrization."]
|
src_text = ["A short paragraph for summarization.", "Another short paragraph for summarization."]
|
||||||
tgt_text = [
|
tgt_text = [
|
||||||
"Summary of the text.",
|
"Summary of the text.",
|
||||||
"Another summary.",
|
"Another summary.",
|
||||||
@@ -180,9 +180,9 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
def test_eos_in_input(self):
|
def test_eos_in_input(self):
|
||||||
tokenizer = self.t5_base_tokenizer
|
tokenizer = self.t5_base_tokenizer
|
||||||
src_text = ["A long paragraph for summrization. </s>"]
|
src_text = ["A long paragraph for summarization. </s>"]
|
||||||
tgt_text = ["Summary of the text. </s>"]
|
tgt_text = ["Summary of the text. </s>"]
|
||||||
expected_src_tokens = [71, 307, 8986, 21, 4505, 51, 52, 1707, 5, 1]
|
expected_src_tokens = [71, 307, 8986, 21, 4505, 1635, 1707, 5, 1]
|
||||||
expected_tgt_tokens = [0, 20698, 13, 8, 1499, 5, 1]
|
expected_tgt_tokens = [0, 20698, 13, 8, 1499, 5, 1]
|
||||||
|
|
||||||
batch = tokenizer.prepare_seq2seq_batch(src_text, tgt_texts=tgt_text, return_tensors=FRAMEWORK)
|
batch = tokenizer.prepare_seq2seq_batch(src_text, tgt_texts=tgt_text, return_tensors=FRAMEWORK)
|
||||||
|
|||||||
Reference in New Issue
Block a user