layernorm_decay_fix (#35927)
* layernorm_decay_fix * W293 fix * ruff format fix * black format * ruff format * erase last layer * add test_get_parameter_names_rmsnorm * rmsnorm fix
This commit is contained in:
@@ -244,6 +244,33 @@ class TrainerUtilsTest(unittest.TestCase):
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
def test_get_parameter_names_rmsnorm(self):
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.bias = nn.Parameter(torch.zeros(hidden_size))
|
||||
|
||||
class ModelWithRMSNorm(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(128, 128)
|
||||
self.rmsnorm = RMSNorm(128)
|
||||
self.bias = nn.Parameter(torch.zeros(128))
|
||||
|
||||
model = ModelWithRMSNorm()
|
||||
# Test both type-based and name-based filtering
|
||||
decay_parameters = get_parameter_names(model, [], ["bias", "rmsnorm"])
|
||||
|
||||
# Parameters that should be in weight decay
|
||||
self.assertIn("linear.weight", decay_parameters)
|
||||
|
||||
# Parameters that should NOT be in weight decay
|
||||
self.assertNotIn("linear.bias", decay_parameters)
|
||||
self.assertNotIn("rmsnorm.weight", decay_parameters)
|
||||
self.assertNotIn("rmsnorm.bias", decay_parameters)
|
||||
self.assertNotIn("bias", decay_parameters)
|
||||
|
||||
def test_distributed_sampler_with_loop(self):
|
||||
batch_size = 16
|
||||
for length in [23, 64, 123]:
|
||||
|
||||
Reference in New Issue
Block a user