rename prepare_translation_batch -> prepare_seq2seq_batch (#6103)

This commit is contained in:
Sam Shleifer
2020-08-11 15:57:07 -04:00
committed by GitHub
parent 66fa8ceaea
commit be1520d3a3
14 changed files with 208 additions and 123 deletions

View File

@@ -123,8 +123,8 @@ class MBartEnroIntegrationTest(unittest.TestCase):
self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["en_EN"], 250004)
self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["ro_RO"], 250020)
def test_enro_tokenizer_prepare_translation_batch(self):
batch = self.tokenizer.prepare_translation_batch(
def test_enro_tokenizer_prepare_seq2seq_batch(self):
batch = self.tokenizer.prepare_seq2seq_batch(
self.src_text, tgt_texts=self.tgt_text, max_length=len(self.expected_src_tokens),
)
self.assertIsInstance(batch, BatchEncoding)
@@ -140,13 +140,13 @@ class MBartEnroIntegrationTest(unittest.TestCase):
def test_max_target_length(self):
batch = self.tokenizer.prepare_translation_batch(
batch = self.tokenizer.prepare_seq2seq_batch(
self.src_text, tgt_texts=self.tgt_text, max_length=3, max_target_length=10
)
self.assertEqual(batch.input_ids.shape[1], 3)
self.assertEqual(batch.decoder_input_ids.shape[1], 10)
# max_target_length will default to max_length if not specified
batch = self.tokenizer.prepare_translation_batch(self.src_text, tgt_texts=self.tgt_text, max_length=3)
batch = self.tokenizer.prepare_seq2seq_batch(self.src_text, tgt_texts=self.tgt_text, max_length=3)
self.assertEqual(batch.input_ids.shape[1], 3)
self.assertEqual(batch.decoder_input_ids.shape[1], 3)
@@ -166,7 +166,7 @@ class MBartEnroIntegrationTest(unittest.TestCase):
src_text = ["this is gunna be a long sentence " * 20]
assert isinstance(src_text[0], str)
desired_max_length = 10
ids = self.tokenizer.prepare_translation_batch(
ids = self.tokenizer.prepare_seq2seq_batch(
src_text, return_tensors=None, max_length=desired_max_length
).input_ids[0]
self.assertEqual(ids[-2], 2)