prepare_seq2seq_batch makes labels/ decoder_input_ids made later. (#6654)
* broken test * batch parity * tests pass * boom boom * boom boom * split out bart tokenizer tests * fix tests * boom boom * Fixed dataset bug * Fix marian * Undo extra * Get marian working * Fix t5 tok tests * Test passing * Cleanup * better assert msg * require torch * Fix mbart tests * undo extra decoder_attn_mask change * Fix import * pegasus tokenizer can ignore src_lang kwargs * unused kwarg test cov * boom boom * add todo for pegasus issue * cover one word translation edge case * Cleanup * doc
This commit is contained in:
@@ -1555,14 +1555,19 @@ class TokenizerTesterMixin:
|
||||
"vor face decât să înrăutăţească violenţele şi mizeria pentru milioane de oameni.",
|
||||
]
|
||||
batch = tokenizer.prepare_seq2seq_batch(
|
||||
src_texts=src_text, tgt_texts=tgt_text, max_length=3, max_target_length=10, return_tensors="pt"
|
||||
src_texts=src_text,
|
||||
tgt_texts=tgt_text,
|
||||
max_length=3,
|
||||
max_target_length=10,
|
||||
return_tensors="pt",
|
||||
src_lang="en_XX", # this should be ignored (for all but mbart) but not cause an error
|
||||
)
|
||||
self.assertEqual(batch.input_ids.shape[1], 3)
|
||||
self.assertEqual(batch.decoder_input_ids.shape[1], 10)
|
||||
self.assertEqual(batch.labels.shape[1], 10)
|
||||
# max_target_length will default to max_length if not specified
|
||||
batch = tokenizer.prepare_seq2seq_batch(src_text, tgt_texts=tgt_text, max_length=3)
|
||||
self.assertEqual(batch.input_ids.shape[1], 3)
|
||||
self.assertEqual(batch.decoder_input_ids.shape[1], 3)
|
||||
self.assertEqual(batch.labels.shape[1], 3)
|
||||
|
||||
batch_encoder_only = tokenizer.prepare_seq2seq_batch(
|
||||
src_texts=src_text, max_length=3, max_target_length=10, return_tensors="pt"
|
||||
|
||||
Reference in New Issue
Block a user