[optim] implement AdafactorSchedule (#12123)
* implement AdafactorSchedule * typo * fix * Update src/transformers/optimization.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -589,6 +589,25 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
self.assertFalse(torch.allclose(trainer.model.b, b))
|
||||
self.assertEqual(trainer.optimizer.state_dict()["param_groups"][0]["lr"], 1.0)
|
||||
|
||||
@require_torch
|
||||
def test_adafactor_lr_none(self):
|
||||
# test the special case where lr=None, since Trainer can't not have lr_scheduler
|
||||
|
||||
from transformers.optimization import Adafactor, AdafactorSchedule
|
||||
|
||||
train_dataset = RegressionDataset()
|
||||
args = TrainingArguments("./regression")
|
||||
model = RegressionModel()
|
||||
optimizer = Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None)
|
||||
lr_scheduler = AdafactorSchedule(optimizer)
|
||||
trainer = Trainer(model, args, train_dataset=train_dataset, optimizers=(optimizer, lr_scheduler))
|
||||
trainer.train()
|
||||
|
||||
(a, b) = self.default_trained_model
|
||||
self.assertFalse(torch.allclose(trainer.model.a, a))
|
||||
self.assertFalse(torch.allclose(trainer.model.b, b))
|
||||
self.assertGreater(trainer.optimizer.state_dict()["param_groups"][0]["lr"], 0)
|
||||
|
||||
def test_model_init(self):
|
||||
train_dataset = RegressionDataset()
|
||||
args = TrainingArguments("./regression", learning_rate=0.1)
|
||||
|
||||
Reference in New Issue
Block a user