rename prepare_translation_batch -> prepare_seq2seq_batch (#6103)

This commit is contained in:
Sam Shleifer
2020-08-11 15:57:07 -04:00
committed by GitHub
parent 66fa8ceaea
commit be1520d3a3
14 changed files with 208 additions and 123 deletions

View File

@@ -82,7 +82,7 @@ class MBartEnroIntegrationTest(AbstractSeq2SeqIntegrationTest):
@slow
def test_enro_generate(self):
batch: BatchEncoding = self.tokenizer.prepare_translation_batch(self.src_text).to(torch_device)
batch: BatchEncoding = self.tokenizer.prepare_seq2seq_batch(self.src_text).to(torch_device)
translated_tokens = self.model.generate(**batch)
decoded = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
self.assertEqual(self.tgt_text[0], decoded[0])
@@ -134,7 +134,7 @@ class MBartCC25IntegrationTest(AbstractSeq2SeqIntegrationTest):
@unittest.skip("This test is broken, still generates english")
def test_cc25_generate(self):
inputs = self.tokenizer.prepare_translation_batch([self.src_text[0]]).to(torch_device)
inputs = self.tokenizer.prepare_seq2seq_batch([self.src_text[0]]).to(torch_device)
translated_tokens = self.model.generate(
input_ids=inputs["input_ids"].to(torch_device),
decoder_start_token_id=self.tokenizer.lang_code_to_id["ro_RO"],
@@ -144,7 +144,7 @@ class MBartCC25IntegrationTest(AbstractSeq2SeqIntegrationTest):
@slow
def test_fill_mask(self):
inputs = self.tokenizer.prepare_translation_batch(["One of the best <mask> I ever read!"]).to(torch_device)
inputs = self.tokenizer.prepare_seq2seq_batch(["One of the best <mask> I ever read!"]).to(torch_device)
outputs = self.model.generate(
inputs["input_ids"], decoder_start_token_id=self.tokenizer.lang_code_to_id["en_XX"], num_beams=1
)