[s2s] adjust finetune + test to work with fsmt (#7263)
This commit is contained in:
@@ -103,6 +103,7 @@ T5_TINY = "patrickvonplaten/t5-tiny-random"
|
||||
BART_TINY = "sshleifer/bart-tiny-random"
|
||||
MBART_TINY = "sshleifer/tiny-mbart"
|
||||
MARIAN_TINY = "sshleifer/tiny-marian-en-de"
|
||||
FSMT_TINY = "stas/tiny-wmt19-en-de"
|
||||
|
||||
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
@@ -374,11 +375,11 @@ def test_run_eval_search(model):
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[T5_TINY, BART_TINY, MBART_TINY, MARIAN_TINY],
|
||||
[T5_TINY, BART_TINY, MBART_TINY, MARIAN_TINY, FSMT_TINY],
|
||||
)
|
||||
def test_finetune(model):
|
||||
args_d: dict = CHEAP_ARGS.copy()
|
||||
task = "translation" if model in [MBART_TINY, MARIAN_TINY] else "summarization"
|
||||
task = "translation" if model in [MBART_TINY, MARIAN_TINY, FSMT_TINY] else "summarization"
|
||||
args_d["label_smoothing"] = 0.1 if task == "translation" else 0
|
||||
|
||||
tmp_dir = make_test_data_dir()
|
||||
@@ -407,7 +408,13 @@ def test_finetune(model):
|
||||
lm_head = module.model.lm_head
|
||||
assert not lm_head.weight.requires_grad
|
||||
assert (lm_head.weight == input_embeds.weight).all().item()
|
||||
|
||||
elif model == FSMT_TINY:
|
||||
fsmt = module.model.model
|
||||
embed_pos = fsmt.decoder.embed_positions
|
||||
assert not embed_pos.weight.requires_grad
|
||||
assert not fsmt.decoder.embed_tokens.weight.requires_grad
|
||||
# check that embeds are not the same
|
||||
assert fsmt.decoder.embed_tokens != fsmt.encoder.embed_tokens
|
||||
else:
|
||||
bart = module.model.model
|
||||
embed_pos = bart.decoder.embed_positions
|
||||
|
||||
Reference in New Issue
Block a user