layernorm_decay_fix (#35927)

* layernorm_decay_fix

* W293 fix

* ruff format fix

* black format

* ruff format

* erase last layer

* add test_get_parameter_names_rmsnorm

* rmsnorm fix
This commit is contained in:
Ryoo Kwangrok
2025-02-04 19:01:49 +09:00
committed by GitHub
parent 2ba040a71f
commit b1954fd64a
6 changed files with 45 additions and 15 deletions

View File

@@ -237,8 +237,7 @@ from transformers.trainer_pt_utils import get_parameter_names
training_args = TrainingArguments(per_device_train_batch_size=4, **default_args)
decay_parameters = get_parameter_names(model, [nn.LayerNorm])
decay_parameters = [name for name in decay_parameters if "bias" not in name]
decay_parameters = get_parameter_names(model, [nn.LayerNorm], ["bias", "layernorm", "rmsnorm"])
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if n in decay_parameters],