FEAT / Trainer: LOMO optimizer support (#30178)
* add V1 - adalomo not working yet * add todo docs + refactor from comments * adjust LR * add docs * add more elaborated test * Apply suggestions from code review Co-authored-by: Zach Mueller <muellerzr@gmail.com> * fix * push * add accelerate check * fix DDP case * Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * fix * init kwargs * safely add attribute * revert to enum logic * Update src/transformers/trainer.py --------- Co-authored-by: Zach Mueller <muellerzr@gmail.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -63,6 +63,7 @@ from transformers.testing_utils import (
|
||||
require_deepspeed,
|
||||
require_galore_torch,
|
||||
require_intel_extension_for_pytorch,
|
||||
require_lomo,
|
||||
require_optuna,
|
||||
require_peft,
|
||||
require_ray,
|
||||
@@ -1229,6 +1230,49 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
trainer.train()
|
||||
trainer.evaluate()
|
||||
|
||||
@require_lomo
|
||||
@require_torch_gpu
|
||||
def test_lomo(self):
|
||||
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||
tiny_llama = LlamaForCausalLM(config)
|
||||
|
||||
previous_params = {n: p.clone() for n, p in tiny_llama.named_parameters()}
|
||||
|
||||
x = torch.randint(0, 100, (128,))
|
||||
train_dataset = RepeatDataset(x)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Trainer without inf/nan filter
|
||||
args = TrainingArguments(tmpdir, learning_rate=1e-2, logging_steps=5, optim="lomo", max_steps=20)
|
||||
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
|
||||
|
||||
# Check this works
|
||||
_ = trainer.train()
|
||||
|
||||
for name, param in tiny_llama.named_parameters():
|
||||
self.assertFalse(torch.allclose(param, previous_params[name].to(param.device), rtol=1e-12, atol=1e-12))
|
||||
|
||||
@require_lomo
|
||||
@require_torch_gpu
|
||||
def test_adalomo(self):
|
||||
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||
tiny_llama = LlamaForCausalLM(config)
|
||||
x = torch.randint(0, 100, (128,))
|
||||
train_dataset = RepeatDataset(x)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Trainer without inf/nan filter
|
||||
args = TrainingArguments(
|
||||
tmpdir,
|
||||
learning_rate=1e-9,
|
||||
logging_steps=5,
|
||||
optim="adalomo",
|
||||
)
|
||||
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
|
||||
|
||||
# Check this works
|
||||
_ = trainer.train()
|
||||
|
||||
def test_galore_matched_modules(self):
|
||||
regex_patterns = [r".*.attn.*", r".*.mlp.*"]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user