Add AdaFactor optimizer from fairseq (#6722)
* AdaFactor optimizer ported from fairseq. Tested for T5 finetuning and MLM -- reduced memory consumption compared to ADAM. * update PR fixes, add basic test * bug -- incorrect params in test * bugfix -- import Adafactor into test * bugfix -- removed accidental T5 include * resetting T5 to master * bugfix -- include Adafactor in __init__ * longer loop for adafactor test * remove double error class declare * lint * black * isort * Update src/transformers/optimization.py Co-authored-by: Sam Shleifer <sshleifer@gmail.com> * single docstring * Cleanup docstring Co-authored-by: Nikolai Y <nikolai.yakovenko@point72.com> Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
committed by
GitHub
parent
4bd7be9a42
commit
971d1802d0
@@ -26,6 +26,7 @@ if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
Adafactor,
|
||||
AdamW,
|
||||
get_constant_schedule,
|
||||
get_constant_schedule_with_warmup,
|
||||
@@ -80,6 +81,31 @@ class OptimizationTest(unittest.TestCase):
|
||||
w.grad.zero_()
|
||||
self.assertListAlmostEqual(w.tolist(), [0.4, 0.2, -0.5], tol=1e-2)
|
||||
|
||||
def test_adafactor(self):
|
||||
w = torch.tensor([0.1, -0.2, -0.1], requires_grad=True)
|
||||
target = torch.tensor([0.4, 0.2, -0.5])
|
||||
criterion = torch.nn.MSELoss()
|
||||
# No warmup, constant schedule, no gradient clipping
|
||||
optimizer = Adafactor(
|
||||
params=[w],
|
||||
lr=1e-2,
|
||||
eps=(1e-30, 1e-3),
|
||||
clip_threshold=1.0,
|
||||
decay_rate=-0.8,
|
||||
beta1=None,
|
||||
weight_decay=0.0,
|
||||
relative_step=False,
|
||||
scale_parameter=False,
|
||||
warmup_init=False,
|
||||
)
|
||||
for _ in range(1000):
|
||||
loss = criterion(w, target)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
w.grad.detach_() # No zero_grad() function on simple tensors. we do it ourselves.
|
||||
w.grad.zero_()
|
||||
self.assertListAlmostEqual(w.tolist(), [0.4, 0.2, -0.5], tol=1e-2)
|
||||
|
||||
|
||||
@require_torch
|
||||
class ScheduleInitTest(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user