prepare_seq2seq_batch makes labels/ decoder_input_ids made later. (#6654)

* broken test

* batch parity

* tests pass

* boom boom

* boom boom

* split out bart tokenizer tests

* fix tests

* boom boom

* Fixed dataset bug

* Fix marian

* Undo extra

* Get marian working

* Fix t5 tok tests

* Test passing

* Cleanup

* better assert msg

* require torch

* Fix mbart tests

* undo extra decoder_attn_mask change

* Fix import

* pegasus tokenizer can ignore src_lang kwargs

* unused kwarg test cov

* boom boom

* add todo for pegasus issue

* cover one word translation edge case

* Cleanup

* doc
This commit is contained in:
Sam Shleifer
2020-08-28 11:15:17 -04:00
committed by GitHub
parent cb276b41de
commit 9336086ab5
20 changed files with 429 additions and 290 deletions

View File

@@ -1555,14 +1555,19 @@ class TokenizerTesterMixin:
"vor face decât să înrăutăţească violenţele şi mizeria pentru milioane de oameni.",
]
batch = tokenizer.prepare_seq2seq_batch(
src_texts=src_text, tgt_texts=tgt_text, max_length=3, max_target_length=10, return_tensors="pt"
src_texts=src_text,
tgt_texts=tgt_text,
max_length=3,
max_target_length=10,
return_tensors="pt",
src_lang="en_XX", # this should be ignored (for all but mbart) but not cause an error
)
self.assertEqual(batch.input_ids.shape[1], 3)
self.assertEqual(batch.decoder_input_ids.shape[1], 10)
self.assertEqual(batch.labels.shape[1], 10)
# max_target_length will default to max_length if not specified
batch = tokenizer.prepare_seq2seq_batch(src_text, tgt_texts=tgt_text, max_length=3)
self.assertEqual(batch.input_ids.shape[1], 3)
self.assertEqual(batch.decoder_input_ids.shape[1], 3)
self.assertEqual(batch.labels.shape[1], 3)
batch_encoder_only = tokenizer.prepare_seq2seq_batch(
src_texts=src_text, max_length=3, max_target_length=10, return_tensors="pt"