Deprecate prepare_seq2seq_batch (#10287)

* Deprecate prepare_seq2seq_batch

* Fix last tests

* Apply suggestions from code review

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Co-authored-by: Suraj Patil <surajp815@gmail.com>

* More review comments

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
Sylvain Gugger
2021-02-22 12:36:16 -05:00
committed by GitHub
parent e73a3e1891
commit 9e147d31f6
31 changed files with 325 additions and 320 deletions

View File

@@ -38,16 +38,12 @@ class BarthezTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
self.tokenizer = tokenizer
@require_torch
def test_prepare_seq2seq_batch(self):
def test_prepare_batch(self):
src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."]
tgt_text = [
"Summary of the text.",
"Another summary.",
]
expected_src_tokens = [0, 57, 3018, 70307, 91, 2]
batch = self.tokenizer.prepare_seq2seq_batch(
src_text, tgt_texts=tgt_text, max_length=len(expected_src_tokens), return_tensors="pt"
batch = self.tokenizer(
src_text, max_length=len(expected_src_tokens), padding=True, truncation=True, return_tensors="pt"
)
self.assertIsInstance(batch, BatchEncoding)