[optim] implement AdafactorSchedule (#12123)

* implement AdafactorSchedule

* typo

* fix

* Update src/transformers/optimization.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Stas Bekman
2021-06-14 09:43:48 -07:00
committed by GitHub
parent fe3576488a
commit ff7c81687a
2 changed files with 74 additions and 0 deletions

View File

@@ -589,6 +589,25 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertFalse(torch.allclose(trainer.model.b, b))
self.assertEqual(trainer.optimizer.state_dict()["param_groups"][0]["lr"], 1.0)
@require_torch
def test_adafactor_lr_none(self):
# test the special case where lr=None, since Trainer can't not have lr_scheduler
from transformers.optimization import Adafactor, AdafactorSchedule
train_dataset = RegressionDataset()
args = TrainingArguments("./regression")
model = RegressionModel()
optimizer = Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None)
lr_scheduler = AdafactorSchedule(optimizer)
trainer = Trainer(model, args, train_dataset=train_dataset, optimizers=(optimizer, lr_scheduler))
trainer.train()
(a, b) = self.default_trained_model
self.assertFalse(torch.allclose(trainer.model.a, a))
self.assertFalse(torch.allclose(trainer.model.b, b))
self.assertGreater(trainer.optimizer.state_dict()["param_groups"][0]["lr"], 0)
def test_model_init(self):
train_dataset = RegressionDataset()
args = TrainingArguments("./regression", learning_rate=0.1)