schedulefree optimizers (#30079)

* schedulefree optimizers

* fix train instead of eval for optimizer

* fixes and update docs

* chore: lint

* add tests and drop overly-verbose _32bit suffix

* chore: lint

* fix for docs

* fix code review issues

* use duck-typing to avoid per-optimizer patches

* fixup style

* fixup style

* warn if incorrect accelerate version with schedule free

Co-authored-by: Aman Gupta Karmani <aman@tmm1.net>

---------

Co-authored-by: Aman Karmani <aman@tmm1.net>
This commit is contained in:
Wing Lian
2024-09-09 03:51:39 -04:00
committed by GitHub
parent 60226fdc1d
commit 62aecd85ff
9 changed files with 124 additions and 0 deletions

View File

@@ -70,6 +70,7 @@ from transformers.testing_utils import (
require_peft,
require_ray,
require_safetensors,
require_schedulefree,
require_sentencepiece,
require_sigopt,
require_tensorboard,
@@ -1442,6 +1443,27 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
# Check this works
_ = trainer.train()
@require_schedulefree
@require_torch_gpu
def test_schedulefree_adam(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)
with tempfile.TemporaryDirectory() as tmpdir:
# Trainer without inf/nan filter
args = TrainingArguments(
tmpdir,
learning_rate=1e-9,
logging_steps=5,
optim="schedule_free_adamw",
)
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
# Check this works
_ = trainer.train()
def test_galore_matched_modules(self):
regex_patterns = [r".*.attn.*", r".*.mlp.*"]