layernorm_decay_fix (#35927)

* layernorm_decay_fix

* W293 fix

* ruff format fix

* black format

* ruff format

* erase last layer

* add test_get_parameter_names_rmsnorm

* rmsnorm fix
This commit is contained in:
Ryoo Kwangrok
2025-02-04 19:01:49 +09:00
committed by GitHub
parent 2ba040a71f
commit b1954fd64a
6 changed files with 45 additions and 15 deletions

View File

@@ -680,8 +680,7 @@ def main():
# Instantiate custom data collator
data_collator = DataCollatorCTCWithPadding(processor=processor)
decay_parameters = get_parameter_names(model, [torch.nn.LayerNorm])
decay_parameters = [name for name in decay_parameters if "bias" not in name]
decay_parameters = get_parameter_names(model, [torch.nn.LayerNorm], ["bias", "layernorm", "rmsnorm"])
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if n in decay_parameters],