Deprecate prepare_seq2seq_batch (#10287)

* Deprecate prepare_seq2seq_batch

* Fix last tests

* Apply suggestions from code review

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Co-authored-by: Suraj Patil <surajp815@gmail.com>

* More review comments

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
Sylvain Gugger
2021-02-22 12:36:16 -05:00
committed by GitHub
parent e73a3e1891
commit 9e147d31f6
31 changed files with 325 additions and 320 deletions

View File

@@ -349,7 +349,7 @@ class MBartEnroIntegrationTest(AbstractSeq2SeqIntegrationTest):
@slow
def test_enro_generate_one(self):
batch: BatchEncoding = self.tokenizer.prepare_seq2seq_batch(
batch: BatchEncoding = self.tokenizer(
["UN Chief Says There Is No Military Solution in Syria"], return_tensors="pt"
).to(torch_device)
translated_tokens = self.model.generate(**batch)
@@ -359,9 +359,7 @@ class MBartEnroIntegrationTest(AbstractSeq2SeqIntegrationTest):
@slow
def test_enro_generate_batch(self):
batch: BatchEncoding = self.tokenizer.prepare_seq2seq_batch(self.src_text, return_tensors="pt").to(
torch_device
)
batch: BatchEncoding = self.tokenizer(self.src_text, return_tensors="pt").to(torch_device)
translated_tokens = self.model.generate(**batch)
decoded = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
assert self.tgt_text == decoded
@@ -412,7 +410,7 @@ class MBartCC25IntegrationTest(AbstractSeq2SeqIntegrationTest):
@unittest.skip("This test is broken, still generates english")
def test_cc25_generate(self):
inputs = self.tokenizer.prepare_seq2seq_batch([self.src_text[0]], return_tensors="pt").to(torch_device)
inputs = self.tokenizer([self.src_text[0]], return_tensors="pt").to(torch_device)
translated_tokens = self.model.generate(
input_ids=inputs["input_ids"].to(torch_device),
decoder_start_token_id=self.tokenizer.lang_code_to_id["ro_RO"],
@@ -422,9 +420,7 @@ class MBartCC25IntegrationTest(AbstractSeq2SeqIntegrationTest):
@slow
def test_fill_mask(self):
inputs = self.tokenizer.prepare_seq2seq_batch(["One of the best <mask> I ever read!"], return_tensors="pt").to(
torch_device
)
inputs = self.tokenizer(["One of the best <mask> I ever read!"], return_tensors="pt").to(torch_device)
outputs = self.model.generate(
inputs["input_ids"], decoder_start_token_id=self.tokenizer.lang_code_to_id["en_XX"], num_beams=1
)