[s2sTrainer] test + code cleanup (#7467)

This commit is contained in:
Sam Shleifer
2020-10-01 00:33:01 -04:00
committed by GitHub
parent 097049b81b
commit 48f23f92a8
5 changed files with 102 additions and 116 deletions

View File

@@ -441,6 +441,25 @@ def freeze_params(model: nn.Module):
par.requires_grad = False
def freeze_embeds(model):
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
model_type = model.config.model_type
if model_type == "t5":
freeze_params(model.shared)
for d in [model.encoder, model.decoder]:
freeze_params(d.embed_tokens)
elif model_type == "fsmt":
for d in [model.model.encoder, model.model.decoder]:
freeze_params(d.embed_positions)
freeze_params(d.embed_tokens)
else:
freeze_params(model.model.shared)
for d in [model.model.encoder, model.model.decoder]:
freeze_params(d.embed_positions)
freeze_params(d.embed_tokens)
def grad_status(model: nn.Module) -> Iterable:
return (par.requires_grad for par in model.parameters())