prepare_seq2seq_batch makes labels/ decoder_input_ids made later. (#6654)
* broken test * batch parity * tests pass * boom boom * boom boom * split out bart tokenizer tests * fix tests * boom boom * Fixed dataset bug * Fix marian * Undo extra * Get marian working * Fix t5 tok tests * Test passing * Cleanup * better assert msg * require torch * Fix mbart tests * undo extra decoder_attn_mask change * Fix import * pegasus tokenizer can ignore src_lang kwargs * unused kwarg test cov * boom boom * add todo for pegasus issue * cover one word translation edge case * Cleanup * doc
This commit is contained in:
@@ -63,7 +63,6 @@ class PegasusTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
batch = self.pegasus_large_tokenizer.prepare_seq2seq_batch(src_texts, tgt_texts=tgt_texts, max_target_length=5)
|
||||
assert batch.input_ids.shape == (2, 1024)
|
||||
assert batch.attention_mask.shape == (2, 1024)
|
||||
assert "decoder_input_ids" in batch # because tgt_texts was specified
|
||||
assert batch.decoder_input_ids.shape == (2, 5)
|
||||
assert batch.decoder_attention_mask.shape == (2, 5)
|
||||
assert len(batch) == 4 # no extra keys
|
||||
assert "labels" in batch # because tgt_texts was specified
|
||||
assert batch.labels.shape == (2, 5)
|
||||
assert len(batch) == 3 # input_ids, attention_mask, labels. Other things make by BartModel
|
||||
|
||||
Reference in New Issue
Block a user