prepare_seq2seq_batch makes labels/ decoder_input_ids made later. (#6654)

* broken test

* batch parity

* tests pass

* boom boom

* boom boom

* split out bart tokenizer tests

* fix tests

* boom boom

* Fixed dataset bug

* Fix marian

* Undo extra

* Get marian working

* Fix t5 tok tests

* Test passing

* Cleanup

* better assert msg

* require torch

* Fix mbart tests

* undo extra decoder_attn_mask change

* Fix import

* pegasus tokenizer can ignore src_lang kwargs

* unused kwarg test cov

* boom boom

* add todo for pegasus issue

* cover one word translation edge case

* Cleanup

* doc
This commit is contained in:
Sam Shleifer
2020-08-28 11:15:17 -04:00
committed by GitHub
parent cb276b41de
commit 9336086ab5
20 changed files with 429 additions and 290 deletions

View File

@@ -63,7 +63,6 @@ class PegasusTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
batch = self.pegasus_large_tokenizer.prepare_seq2seq_batch(src_texts, tgt_texts=tgt_texts, max_target_length=5)
assert batch.input_ids.shape == (2, 1024)
assert batch.attention_mask.shape == (2, 1024)
assert "decoder_input_ids" in batch # because tgt_texts was specified
assert batch.decoder_input_ids.shape == (2, 5)
assert batch.decoder_attention_mask.shape == (2, 5)
assert len(batch) == 4 # no extra keys
assert "labels" in batch # because tgt_texts was specified
assert batch.labels.shape == (2, 5)
assert len(batch) == 3 # input_ids, attention_mask, labels. Other things make by BartModel