Add support for GrokAdamW optimizer (#32521)

* add grokadamw

* reformat

* code review feedback, unit test

* reformat

* reformat
This commit is contained in:
Eric Hartford
2024-08-13 08:20:28 -04:00
committed by GitHub
parent b5016d5de7
commit 481e15604a
7 changed files with 107 additions and 0 deletions

View File

@@ -62,6 +62,7 @@ from transformers.testing_utils import (
require_bitsandbytes,
require_deepspeed,
require_galore_torch,
require_grokadamw,
require_intel_extension_for_pytorch,
require_lomo,
require_optuna,
@@ -1366,6 +1367,28 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
# Check this works
_ = trainer.train()
@require_grokadamw
@require_torch_gpu
def test_grokadamw():
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)
with tempfile.TemporaryDirectory() as tmpdir:
# Trainer without inf/nan filter
args = TrainingArguments(
tmpdir,
learning_rate=2e-5,
logging_steps=5,
optim="grokadamw",
max_steps=20,
)
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
# Check this works
_ = trainer.train()
def test_galore_matched_modules(self):
regex_patterns = [r".*.attn.*", r".*.mlp.*"]