[s2s] add support for overriding config params (#6149)

This commit is contained in:
Stas Bekman
2020-07-29 22:09:46 -07:00
committed by GitHub
parent 54f9fbeff8
commit 3212b8850d
4 changed files with 106 additions and 17 deletions

View File

@@ -277,6 +277,55 @@ def test_finetune(model):
assert bart.decoder.embed_tokens == bart.shared
def test_finetune_extra_model_args():
args_d: dict = CHEAP_ARGS.copy()
task = "summarization"
tmp_dir = make_test_data_dir()
args_d.update(
data_dir=tmp_dir,
tokenizer_name=None,
train_batch_size=2,
eval_batch_size=2,
do_predict=False,
task=task,
src_lang="en_XX",
tgt_lang="ro_RO",
freeze_encoder=True,
freeze_embeds=True,
)
# test models whose config includes the extra_model_args
model = BART_TINY
output_dir = tempfile.mkdtemp(prefix="output_1_")
args_d1 = args_d.copy()
args_d1.update(
model_name_or_path=model, output_dir=output_dir,
)
extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "dropout", "attention_dropout")
for p in extra_model_params:
args_d1[p] = 0.5
args = argparse.Namespace(**args_d1)
model = main(args)
for p in extra_model_params:
assert getattr(model.config, p) == 0.5, f"failed to override the model config for param {p}"
# test models whose config doesn't include the extra_model_args
model = T5_TINY
output_dir = tempfile.mkdtemp(prefix="output_2_")
args_d2 = args_d.copy()
args_d2.update(
model_name_or_path=model, output_dir=output_dir,
)
unsupported_param = "encoder_layerdrop"
args_d2[unsupported_param] = 0.5
args = argparse.Namespace(**args_d2)
with pytest.raises(Exception) as excinfo:
model = main(args)
assert str(excinfo.value) == f"model config doesn't have a `{unsupported_param}` attribute"
def test_pack_dataset():
tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25")