layernorm_decay_fix (#35927)

* layernorm_decay_fix

* W293 fix

* ruff format fix

* black format

* ruff format

* erase last layer

* add test_get_parameter_names_rmsnorm

* rmsnorm fix
This commit is contained in:
Ryoo Kwangrok
2025-02-04 19:01:49 +09:00
committed by GitHub
parent 2ba040a71f
commit b1954fd64a
6 changed files with 45 additions and 15 deletions

View File

@@ -244,6 +244,33 @@ class TrainerUtilsTest(unittest.TestCase):
)
# fmt: on
def test_get_parameter_names_rmsnorm(self):
class RMSNorm(nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.bias = nn.Parameter(torch.zeros(hidden_size))
class ModelWithRMSNorm(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(128, 128)
self.rmsnorm = RMSNorm(128)
self.bias = nn.Parameter(torch.zeros(128))
model = ModelWithRMSNorm()
# Test both type-based and name-based filtering
decay_parameters = get_parameter_names(model, [], ["bias", "rmsnorm"])
# Parameters that should be in weight decay
self.assertIn("linear.weight", decay_parameters)
# Parameters that should NOT be in weight decay
self.assertNotIn("linear.bias", decay_parameters)
self.assertNotIn("rmsnorm.weight", decay_parameters)
self.assertNotIn("rmsnorm.bias", decay_parameters)
self.assertNotIn("bias", decay_parameters)
def test_distributed_sampler_with_loop(self):
batch_size = 16
for length in [23, 64, 123]: