rename prepare_translation_batch -> prepare_seq2seq_batch (#6103)

This commit is contained in:
Sam Shleifer
2020-08-11 15:57:07 -04:00
committed by GitHub
parent 66fa8ceaea
commit be1520d3a3
14 changed files with 208 additions and 123 deletions

View File

@@ -64,7 +64,7 @@ class MarianTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
def test_tokenizer_equivalence_en_de(self):
en_de_tokenizer = MarianTokenizer.from_pretrained(f"{ORG_NAME}opus-mt-en-de")
batch = en_de_tokenizer.prepare_translation_batch(["I am a small frog"], return_tensors=None)
batch = en_de_tokenizer.prepare_seq2seq_batch(["I am a small frog"], return_tensors=None)
self.assertIsInstance(batch, BatchEncoding)
expected = [38, 121, 14, 697, 38848, 0]
self.assertListEqual(expected, batch.input_ids[0])
@@ -78,16 +78,12 @@ class MarianTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
def test_outputs_not_longer_than_maxlen(self):
tok = self.get_tokenizer()
batch = tok.prepare_translation_batch(
["I am a small frog" * 1000, "I am a small frog"], return_tensors=FRAMEWORK
)
batch = tok.prepare_seq2seq_batch(["I am a small frog" * 1000, "I am a small frog"], return_tensors=FRAMEWORK)
self.assertIsInstance(batch, BatchEncoding)
self.assertEqual(batch.input_ids.shape, (2, 512))
def test_outputs_can_be_shorter(self):
tok = self.get_tokenizer()
batch_smaller = tok.prepare_translation_batch(
["I am a tiny frog", "I am a small frog"], return_tensors=FRAMEWORK
)
batch_smaller = tok.prepare_seq2seq_batch(["I am a tiny frog", "I am a small frog"], return_tensors=FRAMEWORK)
self.assertIsInstance(batch_smaller, BatchEncoding)
self.assertEqual(batch_smaller.input_ids.shape, (2, 10))