From 3ced9b3eb946f4338ace7da66fc2fcdcc2705080 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Mon, 8 Mar 2021 16:40:11 -0500 Subject: [PATCH] Check layer types for Optimizer construction (#10598) * Check layer types for Optimizer construction * Duplicate class --- src/transformers/trainer.py | 8 +++++--- src/transformers/trainer_pt_utils.py | 16 ++++++++++++++++ tests/test_trainer.py | 26 ++++++++++++++++++++++++++ tests/test_trainer_utils.py | 24 ++++++++++++++++++++++++ 4 files changed, 71 insertions(+), 3 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 0fa496dcc7..aaf9c1e627 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -80,6 +80,7 @@ from .trainer_pt_utils import ( SequentialDistributedSampler, distributed_broadcast_scalars, distributed_concat, + get_parameter_names, nested_concat, nested_detach, nested_numpify, @@ -613,14 +614,15 @@ class Trainer: Trainer's init through :obj:`optimizers`, or subclass and override this method in a subclass. """ if self.optimizer is None: - no_decay = ["bias", "LayerNorm.weight"] + decay_parameters = get_parameter_names(self.model, [torch.nn.LayerNorm]) + decay_parameters = [name for name in decay_parameters if "bias" not in name] optimizer_grouped_parameters = [ { - "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)], + "params": [p for n, p in self.model.named_parameters() if n in decay_parameters], "weight_decay": self.args.weight_decay, }, { - "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)], + "params": [p for n, p in self.model.named_parameters() if n not in decay_parameters], "weight_decay": 0.0, }, ] diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py index ed92222612..ae8e249490 100644 --- a/src/transformers/trainer_pt_utils.py +++ b/src/transformers/trainer_pt_utils.py @@ -672,3 +672,19 @@ def save_state(self): path = os.path.join(self.args.output_dir, "trainer_state.json") self.state.save_to_json(path) + + +def get_parameter_names(model, forbidden_layer_types): + """ + Returns the names of the model parameters that are not inside a forbidden layer. + """ + result = [] + for name, child in model.named_children(): + result += [ + f"{name}.{n}" + for n in get_parameter_names(child, forbidden_layer_types) + if not isinstance(child, tuple(forbidden_layer_types)) + ] + # Add model specific parameters (defined with nn.Parameter) since they are not in any child. + result += list(model._parameters.keys()) + return result diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 09801dd6aa..2742c2b4dc 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -193,6 +193,20 @@ if is_torch_available(): loss = torch.nn.functional.mse_loss(y, labels) return (loss, y, y) if self.double_output else (loss, y) + class TstLayer(torch.nn.Module): + def __init__(self, hidden_size): + super().__init__() + self.linear1 = torch.nn.Linear(hidden_size, hidden_size) + self.ln1 = torch.nn.LayerNorm(hidden_size) + self.linear2 = torch.nn.Linear(hidden_size, hidden_size) + self.ln2 = torch.nn.LayerNorm(hidden_size) + self.bias = torch.nn.Parameter(torch.zeros(hidden_size)) + + def forward(self, x): + h = self.ln1(torch.nn.functional.relu(self.linear1(x))) + h = torch.nn.functional.relu(self.linear2(x)) + return self.ln2(x + h + self.bias) + def get_regression_trainer(a=0, b=0, double_output=False, train_len=64, eval_len=64, pretrained=True, **kwargs): label_names = kwargs.get("label_names", None) train_dataset = RegressionDataset(length=train_len, label_names=label_names) @@ -991,6 +1005,18 @@ class TrainerIntegrationTest(unittest.TestCase): # perfect world: fp32_init/2 == fp16_eval self.assertAlmostEqual(fp16_eval, fp32_init / 2, delta=5_000) + def test_no_wd_param_group(self): + model = torch.nn.Sequential(TstLayer(128), torch.nn.ModuleList([TstLayer(128), TstLayer(128)])) + trainer = Trainer(model=model) + trainer.create_optimizer_and_scheduler(10) + # fmt: off + wd_names = ['0.linear1.weight', '0.linear2.weight', '1.0.linear1.weight', '1.0.linear2.weight', '1.1.linear1.weight', '1.1.linear2.weight'] + # fmt: on + wd_params = [p for n, p in model.named_parameters() if n in wd_names] + no_wd_params = [p for n, p in model.named_parameters() if n not in wd_names] + self.assertListEqual(trainer.optimizer.param_groups[0]["params"], wd_params) + self.assertListEqual(trainer.optimizer.param_groups[1]["params"], no_wd_params) + @require_torch @require_optuna diff --git a/tests/test_trainer_utils.py b/tests/test_trainer_utils.py index 19dfa9b1d1..f56ef140e8 100644 --- a/tests/test_trainer_utils.py +++ b/tests/test_trainer_utils.py @@ -30,8 +30,23 @@ if is_torch_available(): DistributedTensorGatherer, LabelSmoother, LengthGroupedSampler, + get_parameter_names, ) + class TstLayer(torch.nn.Module): + def __init__(self, hidden_size): + super().__init__() + self.linear1 = torch.nn.Linear(hidden_size, hidden_size) + self.ln1 = torch.nn.LayerNorm(hidden_size) + self.linear2 = torch.nn.Linear(hidden_size, hidden_size) + self.ln2 = torch.nn.LayerNorm(hidden_size) + self.bias = torch.nn.Parameter(torch.zeros(hidden_size)) + + def forward(self, x): + h = self.ln1(torch.nn.functional.relu(self.linear1(x))) + h = torch.nn.functional.relu(self.linear2(x)) + return self.ln2(x + h + self.bias) + @require_torch class TrainerUtilsTest(unittest.TestCase): @@ -117,3 +132,12 @@ class TrainerUtilsTest(unittest.TestCase): self.assertEqual(lengths[indices_process_0[0]], 50) # The indices should be a permutation of range(100) self.assertEqual(list(sorted(indices_process_0 + indices_process_1)), list(range(100))) + + def test_get_parameter_names(self): + model = torch.nn.Sequential(TstLayer(128), torch.nn.ModuleList([TstLayer(128), TstLayer(128)])) + # fmt: off + self.assertEqual( + get_parameter_names(model, [torch.nn.LayerNorm]), + ['0.linear1.weight', '0.linear1.bias', '0.linear2.weight', '0.linear2.bias', '0.bias', '1.0.linear1.weight', '1.0.linear1.bias', '1.0.linear2.weight', '1.0.linear2.bias', '1.0.bias', '1.1.linear1.weight', '1.1.linear1.bias', '1.1.linear2.weight', '1.1.linear2.bias', '1.1.bias'] + ) + # fmt: on