Deprecate prepare_seq2seq_batch (#10287)
* Deprecate prepare_seq2seq_batch * Fix last tests * Apply suggestions from code review Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: Suraj Patil <surajp815@gmail.com> * More review comments Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
@@ -86,11 +86,13 @@ class PegasusTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
def test_large_seq2seq_truncation(self):
|
||||
src_texts = ["This is going to be way too long." * 150, "short example"]
|
||||
tgt_texts = ["not super long but more than 5 tokens", "tiny"]
|
||||
batch = self._large_tokenizer.prepare_seq2seq_batch(
|
||||
src_texts, tgt_texts=tgt_texts, max_target_length=5, return_tensors="pt"
|
||||
)
|
||||
batch = self._large_tokenizer(src_texts, padding=True, truncation=True, return_tensors="pt")
|
||||
with self._large_tokenizer.as_target_tokenizer():
|
||||
targets = self._large_tokenizer(
|
||||
tgt_texts, max_length=5, padding=True, truncation=True, return_tensors="pt"
|
||||
)
|
||||
|
||||
assert batch.input_ids.shape == (2, 1024)
|
||||
assert batch.attention_mask.shape == (2, 1024)
|
||||
assert "labels" in batch # because tgt_texts was specified
|
||||
assert batch.labels.shape == (2, 5)
|
||||
assert len(batch) == 3 # input_ids, attention_mask, labels. Other things make by BartModel
|
||||
assert targets["input_ids"].shape == (2, 5)
|
||||
assert len(batch) == 2 # input_ids, attention_mask.
|
||||
|
||||
Reference in New Issue
Block a user