From bfc9ddf5c6243c8f8a9615051436a70078f73943 Mon Sep 17 00:00:00 2001 From: Marc Sun <57196510+SunMarc@users.noreply.github.com> Date: Wed, 16 Jul 2025 13:35:53 +0200 Subject: [PATCH] Add StableAdamW Optimizer (#39446) * Added StableAdamW as an optimizer option for Trainer. Also wrote tests to verify its behaviour. * Fixed issue with * Added docs for StableAdamW. Also fixed a typo in schedule free optimizers --------- Co-authored-by: Gautham Krithiwas --- docs/source/en/optimizers.md | 28 +++++- src/transformers/testing_utils.py | 9 ++ src/transformers/trainer.py | 26 ++++++ src/transformers/training_args.py | 1 + src/transformers/utils/__init__.py | 1 + src/transformers/utils/import_utils.py | 5 ++ tests/trainer/test_trainer.py | 118 +++++++++++++++++++++++++ 7 files changed, 187 insertions(+), 1 deletion(-) diff --git a/docs/source/en/optimizers.md b/docs/source/en/optimizers.md index a02b02c359..20bb956ad6 100644 --- a/docs/source/en/optimizers.md +++ b/docs/source/en/optimizers.md @@ -164,7 +164,7 @@ args = TrainingArguments( output_dir="./test-schedulefree", max_steps=1000, per_device_train_batch_size=4, -+ optim="schedule_free_radamw, ++ optim="schedule_free_radamw", + lr_scheduler_type="constant", gradient_checkpointing=True, logging_strategy="steps", @@ -174,3 +174,29 @@ args = TrainingArguments( run_name="sfo", ) ``` + +## StableAdamW + +```bash +pip install torch-optimi +``` + +[StableAdamW](https://arxiv.org/pdf/2304.13013) is a hybrid between AdamW and AdaFactor. It ports AdaFactor's update clipping into AdamW, which removes the need for gradient clipping. Otherwise, it behaves as a drop-in replacement for AdamW. + +> [!TIP] +> If training on large batch sizes or still observing training loss spikes, consider reducing beta_2 between [0.95, 0.99]. + +```diff +args = TrainingArguments( + output_dir="./test-stable-adamw", + max_steps=1000, + per_device_train_batch_size=4, ++ optim="stable_adamw", + gradient_checkpointing=True, + logging_strategy="steps", + logging_steps=1, + learning_rate=2e-6, + save_strategy="no", + run_name="stable-adamw", +) +``` \ No newline at end of file diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 38b4905b37..c582f0e4fb 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -152,6 +152,7 @@ from .utils import ( is_torch_mlu_available, is_torch_neuroncore_available, is_torch_npu_available, + is_torch_optimi_available, is_torch_sdpa_available, is_torch_tensorrt_fx_available, is_torch_tf32_available, @@ -379,6 +380,14 @@ def require_apollo_torch(test_case): return unittest.skipUnless(is_apollo_torch_available(), "test requires APOLLO")(test_case) +def require_torch_optimi(test_case): + """ + Decorator marking a test that requires torch-optimi. These tests are skipped when torch-optimi isn't installed. + https://github.com/jxnl/torch-optimi + """ + return unittest.skipUnless(is_torch_optimi_available(), "test requires torch-optimi")(test_case) + + def require_lomo(test_case): """ Decorator marking a test that requires LOMO. These tests are skipped when LOMO-optim isn't installed. diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index ddc447b6a4..3db5940a8d 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -169,6 +169,7 @@ from .utils import ( is_torch_musa_available, is_torch_neuroncore_available, is_torch_npu_available, + is_torch_optimi_available, is_torch_xla_available, is_torch_xpu_available, is_torchao_available, @@ -1723,6 +1724,31 @@ class Trainer: } ) optimizer_kwargs.update(additional_optim_kwargs) + elif args.optim == OptimizerNames.STABLE_ADAMW: + if not is_torch_optimi_available(): + raise ImportError( + "You need to install `torch-optimi` in order to use stable_adamw optimizers. " + "Install it with `pip install torch-optimi`." + ) + from optimi import StableAdamW + + max_lr = optim_args.pop("max_lr", None) + if max_lr is not None: + max_lr = float(max_lr) + + kahan_sum = optim_args.pop("kahan_sum", None) + if kahan_sum is not None: + kahan_sum = bool(kahan_sum) + + stable_adamw_kwargs = { + "decouple_lr": bool(optim_args.pop("decouple_lr", False)), + "max_lr": max_lr, + "kahan_sum": kahan_sum, + } + + optimizer_cls = StableAdamW + optimizer_kwargs.update(adam_kwargs) + optimizer_kwargs.update(stable_adamw_kwargs) else: raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}") return optimizer_cls, optimizer_kwargs diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 4ae3983c13..e7cf36c68d 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -186,6 +186,7 @@ class OptimizerNames(ExplicitEnum): SCHEDULE_FREE_SGD = "schedule_free_sgd" APOLLO_ADAMW = "apollo_adamw" APOLLO_ADAMW_LAYERWISE = "apollo_adamw_layerwise" + STABLE_ADAMW = "stable_adamw" def _convert_str_dict(passed_value: dict): diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 2220281cdb..9c1132ec5c 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -249,6 +249,7 @@ from .import_utils import ( is_torch_musa_available, is_torch_neuroncore_available, is_torch_npu_available, + is_torch_optimi_available, is_torch_sdpa_available, is_torch_tensorrt_fx_available, is_torch_tf32_available, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 93622f6c3c..c20d3d36f5 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -127,6 +127,7 @@ _galore_torch_available = _is_package_available("galore_torch") _lomo_available = _is_package_available("lomo_optim") _grokadamw_available = _is_package_available("grokadamw") _schedulefree_available, _schedulefree_version = _is_package_available("schedulefree", return_version=True) +_torch_optimi_available = importlib.util.find_spec("optimi") is not None # `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed. _bs4_available = importlib.util.find_spec("bs4") is not None _coloredlogs_available = _is_package_available("coloredlogs") @@ -474,6 +475,10 @@ def is_apollo_torch_available(): return _apollo_torch_available +def is_torch_optimi_available(): + return _torch_optimi_available + + def is_lomo_available(): return _lomo_available diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index e91ff1e21d..9ad624cd08 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -99,6 +99,7 @@ from transformers.testing_utils import ( require_torch_multi_accelerator, require_torch_non_multi_accelerator, require_torch_non_multi_gpu, + require_torch_optimi, require_torch_tensorrt_fx, require_torch_tf32, require_torch_up_to_2_accelerators, @@ -2518,6 +2519,123 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): # warm up steps << total steps self.assertTrue(len(decreasing_lrs) > len(increasing_lrs)) + @require_torch_optimi + @require_torch_gpu + def test_stable_adamw(self): + config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4) + tiny_llama = LlamaForCausalLM(config) + x = torch.randint(0, 100, (128,)) + train_dataset = RepeatDataset(x) + + # Trainer without inf/nan filter + args = TrainingArguments( + self.get_auto_remove_tmp_dir(), + learning_rate=1e-9, + logging_steps=5, + optim="stable_adamw", + optim_target_modules=[r".*attn.*", r".*mlp.*"], + ) + trainer = Trainer(tiny_llama, args, train_dataset=train_dataset) + _ = trainer.train() + + @require_torch_optimi + @require_torch_gpu + def test_stable_adamw_extra_args(self): + config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4) + tiny_llama = LlamaForCausalLM(config) + x = torch.randint(0, 100, (128,)) + train_dataset = RepeatDataset(x) + + # Trainer without inf/nan filter + args = TrainingArguments( + self.get_auto_remove_tmp_dir(), + learning_rate=1e-9, + logging_steps=5, + optim="stable_adamw", + optim_args="decouple_lr=True,max_lr=1e-3,kahan_sum=True", + optim_target_modules=[r".*attn.*", r".*mlp.*"], + ) + trainer = Trainer(tiny_llama, args, train_dataset=train_dataset) + + # Check this works + _ = trainer.train() + + @require_torch_optimi + @require_torch_gpu + def test_stable_adamw_lr_display_without_scheduler(self): + config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4) + tiny_llama = LlamaForCausalLM(config) + x = torch.randint(0, 100, (128,)) + train_dataset = RepeatDataset(x) + + learning_rate = 1e-9 + num_steps = 10 + + # Trainer without inf/nan filter + args = TrainingArguments( + self.get_auto_remove_tmp_dir(), + learning_rate=learning_rate, + logging_steps=5, + optim="stable_adamw", + optim_target_modules=[r".*attn.*", r".*mlp.*"], + ) + trainer = Trainer(tiny_llama, args, train_dataset=train_dataset) + trainer.create_optimizer_and_scheduler(num_training_steps=num_steps) + + # reflects displayed lr in trainer + self.assertEqual(trainer.get_learning_rates(), [learning_rate, learning_rate]) + + @require_torch_optimi + @require_torch_gpu + def test_stable_adamw_lr_display_with_scheduler(self): + config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4) + tiny_llama = LlamaForCausalLM(config) + x = torch.randint(0, 100, (128,)) + train_dataset = RepeatDataset(x) + + learning_rate = 2e-4 + num_train_epochs = 10 + num_warmup_steps = 5 + + # Trainer without inf/nan filter + args = TrainingArguments( + self.get_auto_remove_tmp_dir(), + num_train_epochs=num_train_epochs, + learning_rate=learning_rate, + warmup_steps=num_warmup_steps, + lr_scheduler_type="cosine", + logging_steps=1, + optim="stable_adamw", + optim_target_modules=[r".*attn.*", r".*mlp.*"], + ) + trainer = Trainer(tiny_llama, args, train_dataset=train_dataset) + + # creating log history of trainer, results don't matter + trainer.train() + logs = trainer.state.log_history[1:][:-1] + + # reach given learning rate peak and end with 0 lr + self.assertTrue(logs[num_warmup_steps - 2]["learning_rate"] == learning_rate) + self.assertTrue(logs[-1]["learning_rate"] == 0) + + # increasing and decreasing pattern of lrs + increasing_lrs = [ + logs[i]["learning_rate"] < logs[i + 1]["learning_rate"] + for i in range(len(logs)) + if i < num_warmup_steps - 2 + ] + decreasing_lrs = [ + logs[i]["learning_rate"] > logs[i + 1]["learning_rate"] + for i in range(len(logs) - 1) + if i >= num_warmup_steps - 2 + ] + + self.assertTrue(all(increasing_lrs)) + self.assertTrue(all(decreasing_lrs)) + + # warm up steps << total steps + self.assertTrue(len(decreasing_lrs) > len(increasing_lrs)) + @require_torch_multi_accelerator def test_data_is_not_parallelized_when_model_is_parallel(self): model = RegressionModel()