MBART: support summarization tasks where max_src_len > max_tgt_len (#6003)

* MBART: support summarization tasks

* fix test

* Style

* add tokenizer test
This commit is contained in:
Sam Shleifer
2020-07-28 08:18:11 -04:00
committed by GitHub
parent 842eb45606
commit 3c7fbf35a6
7 changed files with 38 additions and 15 deletions

View File

@@ -300,14 +300,17 @@ def test_mbart_dataset_truncation():
tmp_dir = make_test_data_dir()
max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
trunc = 4
max_src_len = 4
max_tgt_len = 8
assert max_len_target > max_src_len # Truncated
assert max_len_source > max_src_len
src_lang, tgt_lang = "ro_RO", "de_DE" # NOT WHAT IT WAS TRAINED ON
train_dataset = MBartDataset(
tokenizer,
data_dir=tmp_dir,
type_path="train",
max_source_length=trunc,
max_target_length=1000, # ignored
max_source_length=max_src_len,
max_target_length=max_tgt_len, # ignored
src_lang=src_lang,
tgt_lang=tgt_lang,
)
@@ -316,17 +319,15 @@ def test_mbart_dataset_truncation():
assert isinstance(batch, dict)
assert batch["attention_mask"].shape == batch["input_ids"].shape
# show that articles were trimmed.
assert batch["input_ids"].shape[1] == trunc
assert batch["input_ids"].shape[1] == max_src_len
# show that targets are the same len
assert batch["decoder_input_ids"].shape[1] == trunc
assert batch["decoder_input_ids"].shape[1] == max_tgt_len
# check language codes in correct place
assert batch["decoder_input_ids"][0, 0].item() == tokenizer.lang_code_to_id[tgt_lang]
assert batch["decoder_input_ids"][0, -1].item() == tokenizer.eos_token_id
assert batch["input_ids"][0, -2].item() == tokenizer.eos_token_id
assert batch["input_ids"][0, -1].item() == tokenizer.lang_code_to_id[src_lang]
assert max_len_target > trunc # Truncated
assert max_len_source > trunc
break # No need to test every batch