correctly handle mt5 (#9879)
This commit is contained in:
@@ -563,7 +563,7 @@ def freeze_embeds(model):
|
|||||||
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
|
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
|
||||||
model_type = model.config.model_type
|
model_type = model.config.model_type
|
||||||
|
|
||||||
if model_type == "t5":
|
if model_type in ["t5", "mt5"]:
|
||||||
freeze_params(model.shared)
|
freeze_params(model.shared)
|
||||||
for d in [model.encoder, model.decoder]:
|
for d in [model.encoder, model.decoder]:
|
||||||
freeze_params(d.embed_tokens)
|
freeze_params(d.embed_tokens)
|
||||||
|
|||||||
Reference in New Issue
Block a user