rename prepare_translation_batch -> prepare_seq2seq_batch (#6103)
This commit is contained in:
@@ -97,7 +97,7 @@ class MarianIntegrationTest(unittest.TestCase):
|
||||
self.assertListEqual(self.expected_text, generated_words)
|
||||
|
||||
def translate_src_text(self, **tokenizer_kwargs):
|
||||
model_inputs = self.tokenizer.prepare_translation_batch(src_texts=self.src_text, **tokenizer_kwargs).to(
|
||||
model_inputs = self.tokenizer.prepare_seq2seq_batch(src_texts=self.src_text, **tokenizer_kwargs).to(
|
||||
torch_device
|
||||
)
|
||||
self.assertEqual(self.model.device, model_inputs.input_ids.device)
|
||||
@@ -114,7 +114,7 @@ class TestMarian_EN_DE_More(MarianIntegrationTest):
|
||||
src, tgt = ["I am a small frog"], ["Ich bin ein kleiner Frosch."]
|
||||
expected_ids = [38, 121, 14, 697, 38848, 0]
|
||||
|
||||
model_inputs: dict = self.tokenizer.prepare_translation_batch(src, tgt_texts=tgt).to(torch_device)
|
||||
model_inputs: dict = self.tokenizer.prepare_seq2seq_batch(src, tgt_texts=tgt).to(torch_device)
|
||||
self.assertListEqual(expected_ids, model_inputs.input_ids[0].tolist())
|
||||
|
||||
desired_keys = {
|
||||
@@ -131,12 +131,12 @@ class TestMarian_EN_DE_More(MarianIntegrationTest):
|
||||
|
||||
def test_unk_support(self):
|
||||
t = self.tokenizer
|
||||
ids = t.prepare_translation_batch(["||"]).to(torch_device).input_ids[0].tolist()
|
||||
ids = t.prepare_seq2seq_batch(["||"]).to(torch_device).input_ids[0].tolist()
|
||||
expected = [t.unk_token_id, t.unk_token_id, t.eos_token_id]
|
||||
self.assertEqual(expected, ids)
|
||||
|
||||
def test_pad_not_split(self):
|
||||
input_ids_w_pad = self.tokenizer.prepare_translation_batch(["I am a small frog <pad>"]).input_ids[0].tolist()
|
||||
input_ids_w_pad = self.tokenizer.prepare_seq2seq_batch(["I am a small frog <pad>"]).input_ids[0].tolist()
|
||||
expected_w_pad = [38, 121, 14, 697, 38848, self.tokenizer.pad_token_id, 0] # pad
|
||||
self.assertListEqual(expected_w_pad, input_ids_w_pad)
|
||||
|
||||
@@ -229,7 +229,7 @@ class TestMarian_en_ROMANCE(MarianIntegrationTest):
|
||||
normalized = self.tokenizer.normalize("")
|
||||
self.assertIsInstance(normalized, str)
|
||||
with self.assertRaises(ValueError):
|
||||
self.tokenizer.prepare_translation_batch([""])
|
||||
self.tokenizer.prepare_seq2seq_batch([""])
|
||||
|
||||
def test_pipeline(self):
|
||||
device = 0 if torch_device == "cuda" else -1
|
||||
|
||||
@@ -82,7 +82,7 @@ class MBartEnroIntegrationTest(AbstractSeq2SeqIntegrationTest):
|
||||
|
||||
@slow
|
||||
def test_enro_generate(self):
|
||||
batch: BatchEncoding = self.tokenizer.prepare_translation_batch(self.src_text).to(torch_device)
|
||||
batch: BatchEncoding = self.tokenizer.prepare_seq2seq_batch(self.src_text).to(torch_device)
|
||||
translated_tokens = self.model.generate(**batch)
|
||||
decoded = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
|
||||
self.assertEqual(self.tgt_text[0], decoded[0])
|
||||
@@ -134,7 +134,7 @@ class MBartCC25IntegrationTest(AbstractSeq2SeqIntegrationTest):
|
||||
|
||||
@unittest.skip("This test is broken, still generates english")
|
||||
def test_cc25_generate(self):
|
||||
inputs = self.tokenizer.prepare_translation_batch([self.src_text[0]]).to(torch_device)
|
||||
inputs = self.tokenizer.prepare_seq2seq_batch([self.src_text[0]]).to(torch_device)
|
||||
translated_tokens = self.model.generate(
|
||||
input_ids=inputs["input_ids"].to(torch_device),
|
||||
decoder_start_token_id=self.tokenizer.lang_code_to_id["ro_RO"],
|
||||
@@ -144,7 +144,7 @@ class MBartCC25IntegrationTest(AbstractSeq2SeqIntegrationTest):
|
||||
|
||||
@slow
|
||||
def test_fill_mask(self):
|
||||
inputs = self.tokenizer.prepare_translation_batch(["One of the best <mask> I ever read!"]).to(torch_device)
|
||||
inputs = self.tokenizer.prepare_seq2seq_batch(["One of the best <mask> I ever read!"]).to(torch_device)
|
||||
outputs = self.model.generate(
|
||||
inputs["input_ids"], decoder_start_token_id=self.tokenizer.lang_code_to_id["en_XX"], num_beams=1
|
||||
)
|
||||
|
||||
@@ -1522,3 +1522,37 @@ class TokenizerTesterMixin:
|
||||
|
||||
if batch_encoded_sequence_fast is None:
|
||||
raise ValueError("Cannot convert list to numpy tensor on batch_encode_plus() (fast)")
|
||||
|
||||
@require_torch
|
||||
def test_prepare_seq2seq_batch(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
if not hasattr(tokenizer, "prepare_seq2seq_batch"):
|
||||
return
|
||||
# Longer text that will definitely require truncation.
|
||||
src_text = [
|
||||
" UN Chief Says There Is No Military Solution in Syria",
|
||||
" Secretary-General Ban Ki-moon says his response to Russia's stepped up military support for Syria is that 'there is no military solution' to the nearly five-year conflict and more weapons will only worsen the violence and misery for millions of people.",
|
||||
]
|
||||
tgt_text = [
|
||||
"Şeful ONU declară că nu există o soluţie militară în Siria",
|
||||
"Secretarul General Ban Ki-moon declară că răspunsul său la intensificarea sprijinului militar al Rusiei "
|
||||
'pentru Siria este că "nu există o soluţie militară" la conflictul de aproape cinci ani şi că noi arme nu '
|
||||
"vor face decât să înrăutăţească violenţele şi mizeria pentru milioane de oameni.",
|
||||
]
|
||||
batch = tokenizer.prepare_seq2seq_batch(
|
||||
src_texts=src_text, tgt_texts=tgt_text, max_length=3, max_target_length=10, return_tensors="pt"
|
||||
)
|
||||
self.assertEqual(batch.input_ids.shape[1], 3)
|
||||
self.assertEqual(batch.decoder_input_ids.shape[1], 10)
|
||||
# max_target_length will default to max_length if not specified
|
||||
batch = tokenizer.prepare_seq2seq_batch(src_text, tgt_texts=tgt_text, max_length=3)
|
||||
self.assertEqual(batch.input_ids.shape[1], 3)
|
||||
self.assertEqual(batch.decoder_input_ids.shape[1], 3)
|
||||
|
||||
batch_encoder_only = tokenizer.prepare_seq2seq_batch(
|
||||
src_texts=src_text, max_length=3, max_target_length=10, return_tensors="pt"
|
||||
)
|
||||
self.assertEqual(batch_encoder_only.input_ids.shape[1], 3)
|
||||
self.assertEqual(batch_encoder_only.attention_mask.shape[1], 3)
|
||||
self.assertNotIn("decoder_input_ids", batch_encoder_only)
|
||||
|
||||
@@ -64,7 +64,7 @@ class MarianTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
|
||||
def test_tokenizer_equivalence_en_de(self):
|
||||
en_de_tokenizer = MarianTokenizer.from_pretrained(f"{ORG_NAME}opus-mt-en-de")
|
||||
batch = en_de_tokenizer.prepare_translation_batch(["I am a small frog"], return_tensors=None)
|
||||
batch = en_de_tokenizer.prepare_seq2seq_batch(["I am a small frog"], return_tensors=None)
|
||||
self.assertIsInstance(batch, BatchEncoding)
|
||||
expected = [38, 121, 14, 697, 38848, 0]
|
||||
self.assertListEqual(expected, batch.input_ids[0])
|
||||
@@ -78,16 +78,12 @@ class MarianTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
def test_outputs_not_longer_than_maxlen(self):
|
||||
tok = self.get_tokenizer()
|
||||
|
||||
batch = tok.prepare_translation_batch(
|
||||
["I am a small frog" * 1000, "I am a small frog"], return_tensors=FRAMEWORK
|
||||
)
|
||||
batch = tok.prepare_seq2seq_batch(["I am a small frog" * 1000, "I am a small frog"], return_tensors=FRAMEWORK)
|
||||
self.assertIsInstance(batch, BatchEncoding)
|
||||
self.assertEqual(batch.input_ids.shape, (2, 512))
|
||||
|
||||
def test_outputs_can_be_shorter(self):
|
||||
tok = self.get_tokenizer()
|
||||
batch_smaller = tok.prepare_translation_batch(
|
||||
["I am a tiny frog", "I am a small frog"], return_tensors=FRAMEWORK
|
||||
)
|
||||
batch_smaller = tok.prepare_seq2seq_batch(["I am a tiny frog", "I am a small frog"], return_tensors=FRAMEWORK)
|
||||
self.assertIsInstance(batch_smaller, BatchEncoding)
|
||||
self.assertEqual(batch_smaller.input_ids.shape, (2, 10))
|
||||
|
||||
@@ -123,8 +123,8 @@ class MBartEnroIntegrationTest(unittest.TestCase):
|
||||
self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["en_EN"], 250004)
|
||||
self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["ro_RO"], 250020)
|
||||
|
||||
def test_enro_tokenizer_prepare_translation_batch(self):
|
||||
batch = self.tokenizer.prepare_translation_batch(
|
||||
def test_enro_tokenizer_prepare_seq2seq_batch(self):
|
||||
batch = self.tokenizer.prepare_seq2seq_batch(
|
||||
self.src_text, tgt_texts=self.tgt_text, max_length=len(self.expected_src_tokens),
|
||||
)
|
||||
self.assertIsInstance(batch, BatchEncoding)
|
||||
@@ -140,13 +140,13 @@ class MBartEnroIntegrationTest(unittest.TestCase):
|
||||
|
||||
def test_max_target_length(self):
|
||||
|
||||
batch = self.tokenizer.prepare_translation_batch(
|
||||
batch = self.tokenizer.prepare_seq2seq_batch(
|
||||
self.src_text, tgt_texts=self.tgt_text, max_length=3, max_target_length=10
|
||||
)
|
||||
self.assertEqual(batch.input_ids.shape[1], 3)
|
||||
self.assertEqual(batch.decoder_input_ids.shape[1], 10)
|
||||
# max_target_length will default to max_length if not specified
|
||||
batch = self.tokenizer.prepare_translation_batch(self.src_text, tgt_texts=self.tgt_text, max_length=3)
|
||||
batch = self.tokenizer.prepare_seq2seq_batch(self.src_text, tgt_texts=self.tgt_text, max_length=3)
|
||||
self.assertEqual(batch.input_ids.shape[1], 3)
|
||||
self.assertEqual(batch.decoder_input_ids.shape[1], 3)
|
||||
|
||||
@@ -166,7 +166,7 @@ class MBartEnroIntegrationTest(unittest.TestCase):
|
||||
src_text = ["this is gunna be a long sentence " * 20]
|
||||
assert isinstance(src_text[0], str)
|
||||
desired_max_length = 10
|
||||
ids = self.tokenizer.prepare_translation_batch(
|
||||
ids = self.tokenizer.prepare_seq2seq_batch(
|
||||
src_text, return_tensors=None, max_length=desired_max_length
|
||||
).input_ids[0]
|
||||
self.assertEqual(ids[-2], 2)
|
||||
|
||||
Reference in New Issue
Block a user