Add StableAdamW Optimizer (#39446)

* Added StableAdamW as an optimizer option for Trainer. Also wrote tests to verify its behaviour.

* Fixed issue with

* Added docs for StableAdamW. Also fixed a typo in schedule free optimizers

---------

Co-authored-by: Gautham Krithiwas <gauthamkrithiwas2003@gmail.com>
This commit is contained in:
Marc Sun
2025-07-16 13:35:53 +02:00
committed by GitHub
parent b9ee528246
commit bfc9ddf5c6
7 changed files with 187 additions and 1 deletions

View File

@@ -99,6 +99,7 @@ from transformers.testing_utils import (
require_torch_multi_accelerator,
require_torch_non_multi_accelerator,
require_torch_non_multi_gpu,
require_torch_optimi,
require_torch_tensorrt_fx,
require_torch_tf32,
require_torch_up_to_2_accelerators,
@@ -2518,6 +2519,123 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
# warm up steps << total steps
self.assertTrue(len(decreasing_lrs) > len(increasing_lrs))
@require_torch_optimi
@require_torch_gpu
def test_stable_adamw(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)
# Trainer without inf/nan filter
args = TrainingArguments(
self.get_auto_remove_tmp_dir(),
learning_rate=1e-9,
logging_steps=5,
optim="stable_adamw",
optim_target_modules=[r".*attn.*", r".*mlp.*"],
)
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
_ = trainer.train()
@require_torch_optimi
@require_torch_gpu
def test_stable_adamw_extra_args(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)
# Trainer without inf/nan filter
args = TrainingArguments(
self.get_auto_remove_tmp_dir(),
learning_rate=1e-9,
logging_steps=5,
optim="stable_adamw",
optim_args="decouple_lr=True,max_lr=1e-3,kahan_sum=True",
optim_target_modules=[r".*attn.*", r".*mlp.*"],
)
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
# Check this works
_ = trainer.train()
@require_torch_optimi
@require_torch_gpu
def test_stable_adamw_lr_display_without_scheduler(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)
learning_rate = 1e-9
num_steps = 10
# Trainer without inf/nan filter
args = TrainingArguments(
self.get_auto_remove_tmp_dir(),
learning_rate=learning_rate,
logging_steps=5,
optim="stable_adamw",
optim_target_modules=[r".*attn.*", r".*mlp.*"],
)
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
trainer.create_optimizer_and_scheduler(num_training_steps=num_steps)
# reflects displayed lr in trainer
self.assertEqual(trainer.get_learning_rates(), [learning_rate, learning_rate])
@require_torch_optimi
@require_torch_gpu
def test_stable_adamw_lr_display_with_scheduler(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)
learning_rate = 2e-4
num_train_epochs = 10
num_warmup_steps = 5
# Trainer without inf/nan filter
args = TrainingArguments(
self.get_auto_remove_tmp_dir(),
num_train_epochs=num_train_epochs,
learning_rate=learning_rate,
warmup_steps=num_warmup_steps,
lr_scheduler_type="cosine",
logging_steps=1,
optim="stable_adamw",
optim_target_modules=[r".*attn.*", r".*mlp.*"],
)
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
# creating log history of trainer, results don't matter
trainer.train()
logs = trainer.state.log_history[1:][:-1]
# reach given learning rate peak and end with 0 lr
self.assertTrue(logs[num_warmup_steps - 2]["learning_rate"] == learning_rate)
self.assertTrue(logs[-1]["learning_rate"] == 0)
# increasing and decreasing pattern of lrs
increasing_lrs = [
logs[i]["learning_rate"] < logs[i + 1]["learning_rate"]
for i in range(len(logs))
if i < num_warmup_steps - 2
]
decreasing_lrs = [
logs[i]["learning_rate"] > logs[i + 1]["learning_rate"]
for i in range(len(logs) - 1)
if i >= num_warmup_steps - 2
]
self.assertTrue(all(increasing_lrs))
self.assertTrue(all(decreasing_lrs))
# warm up steps << total steps
self.assertTrue(len(decreasing_lrs) > len(increasing_lrs))
@require_torch_multi_accelerator
def test_data_is_not_parallelized_when_model_is_parallel(self):
model = RegressionModel()