Add mbart-large-cc25, support translation finetuning (#5129)

improve unittests for finetuning, especially w.r.t testing frozen parameters
fix freeze_embeds for T5
add streamlit setup.cfg
This commit is contained in:
Sam Shleifer
2020-07-07 13:23:01 -04:00
committed by GitHub
parent 141492448b
commit 353b8f1e7a
14 changed files with 521 additions and 204 deletions

View File

@@ -223,10 +223,30 @@ def test_finetune(model):
output_dir=output_dir,
do_predict=True,
task=task,
src_lang="en_XX",
tgt_lang="ro_RO",
freeze_encoder=True,
freeze_embeds=True,
)
assert "n_train" in args_d
args = argparse.Namespace(**args_d)
main(args)
module = main(args)
input_embeds = module.model.get_input_embeddings()
assert not input_embeds.weight.requires_grad
if model == T5_TINY:
lm_head = module.model.lm_head
assert not lm_head.weight.requires_grad
assert (lm_head.weight == input_embeds.weight).all().item()
else:
bart = module.model.model
embed_pos = bart.decoder.embed_positions
assert not embed_pos.weight.requires_grad
assert not bart.shared.weight.requires_grad
# check that embeds are the same
assert bart.decoder.embed_tokens == bart.encoder.embed_tokens
assert bart.decoder.embed_tokens == bart.shared
@pytest.mark.parametrize(
@@ -239,7 +259,12 @@ def test_dataset(tok):
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
trunc_target = 4
train_dataset = SummarizationDataset(
tokenizer, data_dir=tmp_dir, type_path="train", max_source_length=20, max_target_length=trunc_target,
tokenizer,
data_dir=tmp_dir,
type_path="train",
max_source_length=20,
max_target_length=trunc_target,
tgt_lang="ro_RO",
)
dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn)
for batch in dataloader: