MBART: support summarization tasks where max_src_len > max_tgt_len (#6003)
* MBART: support summarization tasks * fix test * Style * add tokenizer test
This commit is contained in:
@@ -300,14 +300,17 @@ def test_mbart_dataset_truncation():
|
||||
tmp_dir = make_test_data_dir()
|
||||
max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
|
||||
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
|
||||
trunc = 4
|
||||
max_src_len = 4
|
||||
max_tgt_len = 8
|
||||
assert max_len_target > max_src_len # Truncated
|
||||
assert max_len_source > max_src_len
|
||||
src_lang, tgt_lang = "ro_RO", "de_DE" # NOT WHAT IT WAS TRAINED ON
|
||||
train_dataset = MBartDataset(
|
||||
tokenizer,
|
||||
data_dir=tmp_dir,
|
||||
type_path="train",
|
||||
max_source_length=trunc,
|
||||
max_target_length=1000, # ignored
|
||||
max_source_length=max_src_len,
|
||||
max_target_length=max_tgt_len, # ignored
|
||||
src_lang=src_lang,
|
||||
tgt_lang=tgt_lang,
|
||||
)
|
||||
@@ -316,17 +319,15 @@ def test_mbart_dataset_truncation():
|
||||
assert isinstance(batch, dict)
|
||||
assert batch["attention_mask"].shape == batch["input_ids"].shape
|
||||
# show that articles were trimmed.
|
||||
assert batch["input_ids"].shape[1] == trunc
|
||||
assert batch["input_ids"].shape[1] == max_src_len
|
||||
# show that targets are the same len
|
||||
assert batch["decoder_input_ids"].shape[1] == trunc
|
||||
assert batch["decoder_input_ids"].shape[1] == max_tgt_len
|
||||
# check language codes in correct place
|
||||
assert batch["decoder_input_ids"][0, 0].item() == tokenizer.lang_code_to_id[tgt_lang]
|
||||
assert batch["decoder_input_ids"][0, -1].item() == tokenizer.eos_token_id
|
||||
assert batch["input_ids"][0, -2].item() == tokenizer.eos_token_id
|
||||
assert batch["input_ids"][0, -1].item() == tokenizer.lang_code_to_id[src_lang]
|
||||
|
||||
assert max_len_target > trunc # Truncated
|
||||
assert max_len_source > trunc
|
||||
break # No need to test every batch
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user