Deprecate prepare_seq2seq_batch (#10287)

* Deprecate prepare_seq2seq_batch

* Fix last tests

* Apply suggestions from code review

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Co-authored-by: Suraj Patil <surajp815@gmail.com>

* More review comments

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
Sylvain Gugger
2021-02-22 12:36:16 -05:00
committed by GitHub
parent e73a3e1891
commit 9e147d31f6
31 changed files with 325 additions and 320 deletions

View File

@@ -86,11 +86,13 @@ class PegasusTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
def test_large_seq2seq_truncation(self):
src_texts = ["This is going to be way too long." * 150, "short example"]
tgt_texts = ["not super long but more than 5 tokens", "tiny"]
batch = self._large_tokenizer.prepare_seq2seq_batch(
src_texts, tgt_texts=tgt_texts, max_target_length=5, return_tensors="pt"
)
batch = self._large_tokenizer(src_texts, padding=True, truncation=True, return_tensors="pt")
with self._large_tokenizer.as_target_tokenizer():
targets = self._large_tokenizer(
tgt_texts, max_length=5, padding=True, truncation=True, return_tensors="pt"
)
assert batch.input_ids.shape == (2, 1024)
assert batch.attention_mask.shape == (2, 1024)
assert "labels" in batch # because tgt_texts was specified
assert batch.labels.shape == (2, 5)
assert len(batch) == 3 # input_ids, attention_mask, labels. Other things make by BartModel
assert targets["input_ids"].shape == (2, 5)
assert len(batch) == 2 # input_ids, attention_mask.