schedulefree optimizers (#30079)
* schedulefree optimizers * fix train instead of eval for optimizer * fixes and update docs * chore: lint * add tests and drop overly-verbose _32bit suffix * chore: lint * fix for docs * fix code review issues * use duck-typing to avoid per-optimizer patches * fixup style * fixup style * warn if incorrect accelerate version with schedule free Co-authored-by: Aman Gupta Karmani <aman@tmm1.net> --------- Co-authored-by: Aman Karmani <aman@tmm1.net>
This commit is contained in:
@@ -70,6 +70,7 @@ from transformers.testing_utils import (
|
||||
require_peft,
|
||||
require_ray,
|
||||
require_safetensors,
|
||||
require_schedulefree,
|
||||
require_sentencepiece,
|
||||
require_sigopt,
|
||||
require_tensorboard,
|
||||
@@ -1442,6 +1443,27 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
# Check this works
|
||||
_ = trainer.train()
|
||||
|
||||
@require_schedulefree
|
||||
@require_torch_gpu
|
||||
def test_schedulefree_adam(self):
|
||||
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||
tiny_llama = LlamaForCausalLM(config)
|
||||
x = torch.randint(0, 100, (128,))
|
||||
train_dataset = RepeatDataset(x)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Trainer without inf/nan filter
|
||||
args = TrainingArguments(
|
||||
tmpdir,
|
||||
learning_rate=1e-9,
|
||||
logging_steps=5,
|
||||
optim="schedule_free_adamw",
|
||||
)
|
||||
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
|
||||
|
||||
# Check this works
|
||||
_ = trainer.train()
|
||||
|
||||
def test_galore_matched_modules(self):
|
||||
regex_patterns = [r".*.attn.*", r".*.mlp.*"]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user