correctly handle mt5 (#9879)

This commit is contained in:
Stas Bekman
2021-01-29 08:11:22 -08:00
committed by GitHub
parent 7eadfe166e
commit 6bf94bc0b6

View File

@@ -563,7 +563,7 @@ def freeze_embeds(model):
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
model_type = model.config.model_type
if model_type == "t5":
if model_type in ["t5", "mt5"]:
freeze_params(model.shared)
for d in [model.encoder, model.decoder]:
freeze_params(d.embed_tokens)