add RAdamScheduleFree optimizer (#35313)
* add RAdamScheduleFree optimizer * revert schedulefree version to the minimum requirement * refine is_schedulefree_available so that it can take min_version * refine documents --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -1865,14 +1865,38 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
x = torch.randint(0, 100, (128,))
|
||||
train_dataset = RepeatDataset(x)
|
||||
|
||||
# Trainer without inf/nan filter
|
||||
args = TrainingArguments(
|
||||
self.get_auto_remove_tmp_dir(),
|
||||
learning_rate=1e-9,
|
||||
logging_steps=5,
|
||||
optim="schedule_free_adamw",
|
||||
)
|
||||
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Trainer without inf/nan filter
|
||||
args = TrainingArguments(
|
||||
tmpdir,
|
||||
learning_rate=1e-9,
|
||||
logging_steps=5,
|
||||
optim="schedule_free_adamw",
|
||||
lr_scheduler_type="constant",
|
||||
)
|
||||
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
|
||||
|
||||
# Check this works
|
||||
_ = trainer.train()
|
||||
|
||||
@require_schedulefree
|
||||
@require_torch_gpu
|
||||
def test_schedulefree_radam(self):
|
||||
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||
tiny_llama = LlamaForCausalLM(config)
|
||||
x = torch.randint(0, 100, (128,))
|
||||
train_dataset = RepeatDataset(x)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Trainer without inf/nan filter
|
||||
args = TrainingArguments(
|
||||
tmpdir,
|
||||
learning_rate=1e-9,
|
||||
logging_steps=5,
|
||||
lr_scheduler_type="constant",
|
||||
optim="schedule_free_radam",
|
||||
)
|
||||
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
|
||||
|
||||
# Check this works
|
||||
_ = trainer.train()
|
||||
|
||||
Reference in New Issue
Block a user