add RAdamScheduleFree optimizer (#35313)

* add RAdamScheduleFree optimizer

* revert schedulefree version to the minimum requirement

* refine is_schedulefree_available so that it can take min_version

* refine documents

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
nhamanasu
2025-02-12 19:31:51 +09:00
committed by GitHub
parent f5fff672db
commit 377d8e2b9c
5 changed files with 67 additions and 20 deletions

View File

@@ -1865,14 +1865,38 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)
# Trainer without inf/nan filter
args = TrainingArguments(
self.get_auto_remove_tmp_dir(),
learning_rate=1e-9,
logging_steps=5,
optim="schedule_free_adamw",
)
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
with tempfile.TemporaryDirectory() as tmpdir:
# Trainer without inf/nan filter
args = TrainingArguments(
tmpdir,
learning_rate=1e-9,
logging_steps=5,
optim="schedule_free_adamw",
lr_scheduler_type="constant",
)
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
# Check this works
_ = trainer.train()
@require_schedulefree
@require_torch_gpu
def test_schedulefree_radam(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)
with tempfile.TemporaryDirectory() as tmpdir:
# Trainer without inf/nan filter
args = TrainingArguments(
tmpdir,
learning_rate=1e-9,
logging_steps=5,
lr_scheduler_type="constant",
optim="schedule_free_radam",
)
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
# Check this works
_ = trainer.train()