From 2ca73e5ee320078a275e40a95ad32f040a389d39 Mon Sep 17 00:00:00 2001 From: Charbel Abi Daher <45701489+CharbelAD@users.noreply.github.com> Date: Tue, 28 Nov 2023 08:33:45 +0100 Subject: [PATCH] Fixed passing scheduler-specific kwargs via TrainingArguments lr_scheduler_kwargs (#27595) * Fix passing scheduler-specific kwargs through TrainingArguments `lr_scheduler_kwargs` * Added test for lr_scheduler_kwargs --- src/transformers/trainer.py | 2 +- tests/trainer/test_trainer.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index bd6dd6fe95..8d85e50772 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1111,7 +1111,7 @@ class Trainer: optimizer=self.optimizer if optimizer is None else optimizer, num_warmup_steps=self.args.get_warmup_steps(num_training_steps), num_training_steps=num_training_steps, - **self.args.lr_scheduler_kwargs, + scheduler_specific_kwargs=self.args.lr_scheduler_kwargs, ) self._created_lr_scheduler = True return self.lr_scheduler diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 5d4ae34161..305ccb35d5 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -39,6 +39,7 @@ from transformers import ( IntervalStrategy, PretrainedConfig, TrainingArguments, + get_polynomial_decay_schedule_with_warmup, is_torch_available, logging, ) @@ -643,6 +644,33 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon): self.assertFalse(torch.allclose(trainer.model.b, b)) self.assertEqual(trainer.optimizer.state_dict()["param_groups"][0]["lr"], 1.0) + def test_lr_scheduler_kwargs(self): + # test scheduler kwargs passed via TrainingArguments + train_dataset = RegressionDataset() + model = RegressionModel() + num_steps, num_warmup_steps = 10, 2 + extra_kwargs = {"power": 5.0, "lr_end": 1e-5} # Non-default arguments + args = TrainingArguments( + "./regression", + lr_scheduler_type="polynomial", + lr_scheduler_kwargs=extra_kwargs, + learning_rate=0.2, + warmup_steps=num_warmup_steps, + ) + trainer = Trainer(model, args, train_dataset=train_dataset) + trainer.create_optimizer_and_scheduler(num_training_steps=num_steps) + + # Checking that the scheduler was created + self.assertIsNotNone(trainer.lr_scheduler) + + # Checking that the correct args were passed + sched1 = trainer.lr_scheduler + sched2 = get_polynomial_decay_schedule_with_warmup( + trainer.optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_steps, **extra_kwargs + ) + self.assertEqual(sched1.lr_lambdas[0].args, sched2.lr_lambdas[0].args) + self.assertEqual(sched1.lr_lambdas[0].keywords, sched2.lr_lambdas[0].keywords) + def test_reduce_lr_on_plateau_args(self): # test passed arguments for a custom ReduceLROnPlateau scheduler train_dataset = RegressionDataset(length=64)