[optim] implement AdafactorSchedule (#12123)
* implement AdafactorSchedule * typo * fix * Update src/transformers/optimization.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -420,6 +420,12 @@ class Adafactor(Optimizer):
|
||||
|
||||
Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None)
|
||||
|
||||
When using ``lr=None`` with :class:`~transformers.Trainer` you will most likely need to use :class:`~transformers.optimization.AdafactorSchedule` scheduler as following::
|
||||
|
||||
from transformers.optimization import Adafactor, AdafactorSchedule
|
||||
optimizer = Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None)
|
||||
lr_scheduler = AdafactorSchedule(optimizer)
|
||||
trainer = Trainer(..., optimizers=(optimizer, lr_scheduler))
|
||||
|
||||
Usage::
|
||||
|
||||
@@ -588,3 +594,52 @@ class Adafactor(Optimizer):
|
||||
p.data.copy_(p_data_fp32)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
class AdafactorSchedule(LambdaLR):
|
||||
"""
|
||||
Since :class:`~transformers.optimization.Adafactor` performs its own scheduling, if the training loop relies on a
|
||||
scheduler (e.g., for logging), this class creates a proxy object that retrieves the current lr values from the
|
||||
optimizer.
|
||||
|
||||
It returns ``initial_lr`` during startup and the actual ``lr`` during stepping.
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer, initial_lr=0.0):
|
||||
def lr_lambda(_):
|
||||
return initial_lr
|
||||
|
||||
for group in optimizer.param_groups:
|
||||
group["initial_lr"] = initial_lr
|
||||
super().__init__(optimizer, lr_lambda)
|
||||
for group in optimizer.param_groups:
|
||||
del group["initial_lr"]
|
||||
|
||||
def get_lr(self):
|
||||
opt = self.optimizer
|
||||
lrs = [
|
||||
opt._get_lr(group, opt.state[group["params"][0]])
|
||||
for group in opt.param_groups
|
||||
if group["params"][0].grad is not None
|
||||
]
|
||||
if len(lrs) == 0:
|
||||
lrs = self.base_lrs # if called before stepping
|
||||
return lrs
|
||||
|
||||
|
||||
def get_adafactor_schedule(optimizer, initial_lr=0.0):
|
||||
"""
|
||||
Get a proxy schedule for :class:`~transformers.optimization.Adafactor`
|
||||
|
||||
Args:
|
||||
optimizer (:class:`~torch.optim.Optimizer`):
|
||||
The optimizer for which to schedule the learning rate.
|
||||
initial_lr (:obj:`float`, `optional`, defaults to 0.0):
|
||||
Initial lr
|
||||
|
||||
Return:
|
||||
:class:`~transformers.optimization.Adafactor` proxy schedule object.
|
||||
|
||||
|
||||
"""
|
||||
return AdafactorSchedule(optimizer, initial_lr)
|
||||
|
||||
@@ -589,6 +589,25 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
self.assertFalse(torch.allclose(trainer.model.b, b))
|
||||
self.assertEqual(trainer.optimizer.state_dict()["param_groups"][0]["lr"], 1.0)
|
||||
|
||||
@require_torch
|
||||
def test_adafactor_lr_none(self):
|
||||
# test the special case where lr=None, since Trainer can't not have lr_scheduler
|
||||
|
||||
from transformers.optimization import Adafactor, AdafactorSchedule
|
||||
|
||||
train_dataset = RegressionDataset()
|
||||
args = TrainingArguments("./regression")
|
||||
model = RegressionModel()
|
||||
optimizer = Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None)
|
||||
lr_scheduler = AdafactorSchedule(optimizer)
|
||||
trainer = Trainer(model, args, train_dataset=train_dataset, optimizers=(optimizer, lr_scheduler))
|
||||
trainer.train()
|
||||
|
||||
(a, b) = self.default_trained_model
|
||||
self.assertFalse(torch.allclose(trainer.model.a, a))
|
||||
self.assertFalse(torch.allclose(trainer.model.b, b))
|
||||
self.assertGreater(trainer.optimizer.state_dict()["param_groups"][0]["lr"], 0)
|
||||
|
||||
def test_model_init(self):
|
||||
train_dataset = RegressionDataset()
|
||||
args = TrainingArguments("./regression", learning_rate=0.1)
|
||||
|
||||
Reference in New Issue
Block a user