layernorm_decay_fix (#35927)
* layernorm_decay_fix * W293 fix * ruff format fix * black format * ruff format * erase last layer * add test_get_parameter_names_rmsnorm * rmsnorm fix
This commit is contained in:
@@ -237,8 +237,7 @@ from transformers.trainer_pt_utils import get_parameter_names
|
||||
|
||||
training_args = TrainingArguments(per_device_train_batch_size=4, **default_args)
|
||||
|
||||
decay_parameters = get_parameter_names(model, [nn.LayerNorm])
|
||||
decay_parameters = [name for name in decay_parameters if "bias" not in name]
|
||||
decay_parameters = get_parameter_names(model, [nn.LayerNorm], ["bias", "layernorm", "rmsnorm"])
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if n in decay_parameters],
|
||||
|
||||
Reference in New Issue
Block a user