Add support for GrokAdamW optimizer (#32521)
* add grokadamw * reformat * code review feedback, unit test * reformat * reformat
This commit is contained in:
@@ -62,6 +62,7 @@ from transformers.testing_utils import (
|
||||
require_bitsandbytes,
|
||||
require_deepspeed,
|
||||
require_galore_torch,
|
||||
require_grokadamw,
|
||||
require_intel_extension_for_pytorch,
|
||||
require_lomo,
|
||||
require_optuna,
|
||||
@@ -1366,6 +1367,28 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
# Check this works
|
||||
_ = trainer.train()
|
||||
|
||||
@require_grokadamw
|
||||
@require_torch_gpu
|
||||
def test_grokadamw():
|
||||
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||
tiny_llama = LlamaForCausalLM(config)
|
||||
x = torch.randint(0, 100, (128,))
|
||||
train_dataset = RepeatDataset(x)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Trainer without inf/nan filter
|
||||
args = TrainingArguments(
|
||||
tmpdir,
|
||||
learning_rate=2e-5,
|
||||
logging_steps=5,
|
||||
optim="grokadamw",
|
||||
max_steps=20,
|
||||
)
|
||||
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
|
||||
|
||||
# Check this works
|
||||
_ = trainer.train()
|
||||
|
||||
def test_galore_matched_modules(self):
|
||||
regex_patterns = [r".*.attn.*", r".*.mlp.*"]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user