layernorm_decay_fix (#35927)
* layernorm_decay_fix * W293 fix * ruff format fix * black format * ruff format * erase last layer * add test_get_parameter_names_rmsnorm * rmsnorm fix
This commit is contained in:
@@ -680,8 +680,7 @@ def main():
|
||||
# Instantiate custom data collator
|
||||
data_collator = DataCollatorCTCWithPadding(processor=processor)
|
||||
|
||||
decay_parameters = get_parameter_names(model, [torch.nn.LayerNorm])
|
||||
decay_parameters = [name for name in decay_parameters if "bias" not in name]
|
||||
decay_parameters = get_parameter_names(model, [torch.nn.LayerNorm], ["bias", "layernorm", "rmsnorm"])
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if n in decay_parameters],
|
||||
|
||||
Reference in New Issue
Block a user