FEAT / Trainer: LOMO optimizer support (#30178)

* add V1 - adalomo not working yet

* add todo docs + refactor from comments

* adjust LR

* add docs

* add more elaborated test

* Apply suggestions from code review

Co-authored-by: Zach Mueller <muellerzr@gmail.com>

* fix

* push

* add accelerate check

* fix DDP case

* Apply suggestions from code review

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* fix

* init kwargs

* safely add attribute

* revert to enum logic

* Update src/transformers/trainer.py

---------

Co-authored-by: Zach Mueller <muellerzr@gmail.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Younes Belkada
2024-05-21 10:16:37 +02:00
committed by GitHub
parent c876d12127
commit 8871b26150
7 changed files with 149 additions and 4 deletions

View File

@@ -63,6 +63,7 @@ from transformers.testing_utils import (
require_deepspeed,
require_galore_torch,
require_intel_extension_for_pytorch,
require_lomo,
require_optuna,
require_peft,
require_ray,
@@ -1229,6 +1230,49 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
trainer.train()
trainer.evaluate()
@require_lomo
@require_torch_gpu
def test_lomo(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
previous_params = {n: p.clone() for n, p in tiny_llama.named_parameters()}
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)
with tempfile.TemporaryDirectory() as tmpdir:
# Trainer without inf/nan filter
args = TrainingArguments(tmpdir, learning_rate=1e-2, logging_steps=5, optim="lomo", max_steps=20)
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
# Check this works
_ = trainer.train()
for name, param in tiny_llama.named_parameters():
self.assertFalse(torch.allclose(param, previous_params[name].to(param.device), rtol=1e-12, atol=1e-12))
@require_lomo
@require_torch_gpu
def test_adalomo(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)
with tempfile.TemporaryDirectory() as tmpdir:
# Trainer without inf/nan filter
args = TrainingArguments(
tmpdir,
learning_rate=1e-9,
logging_steps=5,
optim="adalomo",
)
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
# Check this works
_ = trainer.train()
def test_galore_matched_modules(self):
regex_patterns = [r".*.attn.*", r".*.mlp.*"]