Add StableAdamW Optimizer (#39446)
* Added StableAdamW as an optimizer option for Trainer. Also wrote tests to verify its behaviour. * Fixed issue with * Added docs for StableAdamW. Also fixed a typo in schedule free optimizers --------- Co-authored-by: Gautham Krithiwas <gauthamkrithiwas2003@gmail.com>
This commit is contained in:
@@ -164,7 +164,7 @@ args = TrainingArguments(
|
|||||||
output_dir="./test-schedulefree",
|
output_dir="./test-schedulefree",
|
||||||
max_steps=1000,
|
max_steps=1000,
|
||||||
per_device_train_batch_size=4,
|
per_device_train_batch_size=4,
|
||||||
+ optim="schedule_free_radamw,
|
+ optim="schedule_free_radamw",
|
||||||
+ lr_scheduler_type="constant",
|
+ lr_scheduler_type="constant",
|
||||||
gradient_checkpointing=True,
|
gradient_checkpointing=True,
|
||||||
logging_strategy="steps",
|
logging_strategy="steps",
|
||||||
@@ -174,3 +174,29 @@ args = TrainingArguments(
|
|||||||
run_name="sfo",
|
run_name="sfo",
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## StableAdamW
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install torch-optimi
|
||||||
|
```
|
||||||
|
|
||||||
|
[StableAdamW](https://arxiv.org/pdf/2304.13013) is a hybrid between AdamW and AdaFactor. It ports AdaFactor's update clipping into AdamW, which removes the need for gradient clipping. Otherwise, it behaves as a drop-in replacement for AdamW.
|
||||||
|
|
||||||
|
> [!TIP]
|
||||||
|
> If training on large batch sizes or still observing training loss spikes, consider reducing beta_2 between [0.95, 0.99].
|
||||||
|
|
||||||
|
```diff
|
||||||
|
args = TrainingArguments(
|
||||||
|
output_dir="./test-stable-adamw",
|
||||||
|
max_steps=1000,
|
||||||
|
per_device_train_batch_size=4,
|
||||||
|
+ optim="stable_adamw",
|
||||||
|
gradient_checkpointing=True,
|
||||||
|
logging_strategy="steps",
|
||||||
|
logging_steps=1,
|
||||||
|
learning_rate=2e-6,
|
||||||
|
save_strategy="no",
|
||||||
|
run_name="stable-adamw",
|
||||||
|
)
|
||||||
|
```
|
||||||
@@ -152,6 +152,7 @@ from .utils import (
|
|||||||
is_torch_mlu_available,
|
is_torch_mlu_available,
|
||||||
is_torch_neuroncore_available,
|
is_torch_neuroncore_available,
|
||||||
is_torch_npu_available,
|
is_torch_npu_available,
|
||||||
|
is_torch_optimi_available,
|
||||||
is_torch_sdpa_available,
|
is_torch_sdpa_available,
|
||||||
is_torch_tensorrt_fx_available,
|
is_torch_tensorrt_fx_available,
|
||||||
is_torch_tf32_available,
|
is_torch_tf32_available,
|
||||||
@@ -379,6 +380,14 @@ def require_apollo_torch(test_case):
|
|||||||
return unittest.skipUnless(is_apollo_torch_available(), "test requires APOLLO")(test_case)
|
return unittest.skipUnless(is_apollo_torch_available(), "test requires APOLLO")(test_case)
|
||||||
|
|
||||||
|
|
||||||
|
def require_torch_optimi(test_case):
|
||||||
|
"""
|
||||||
|
Decorator marking a test that requires torch-optimi. These tests are skipped when torch-optimi isn't installed.
|
||||||
|
https://github.com/jxnl/torch-optimi
|
||||||
|
"""
|
||||||
|
return unittest.skipUnless(is_torch_optimi_available(), "test requires torch-optimi")(test_case)
|
||||||
|
|
||||||
|
|
||||||
def require_lomo(test_case):
|
def require_lomo(test_case):
|
||||||
"""
|
"""
|
||||||
Decorator marking a test that requires LOMO. These tests are skipped when LOMO-optim isn't installed.
|
Decorator marking a test that requires LOMO. These tests are skipped when LOMO-optim isn't installed.
|
||||||
|
|||||||
@@ -169,6 +169,7 @@ from .utils import (
|
|||||||
is_torch_musa_available,
|
is_torch_musa_available,
|
||||||
is_torch_neuroncore_available,
|
is_torch_neuroncore_available,
|
||||||
is_torch_npu_available,
|
is_torch_npu_available,
|
||||||
|
is_torch_optimi_available,
|
||||||
is_torch_xla_available,
|
is_torch_xla_available,
|
||||||
is_torch_xpu_available,
|
is_torch_xpu_available,
|
||||||
is_torchao_available,
|
is_torchao_available,
|
||||||
@@ -1723,6 +1724,31 @@ class Trainer:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
optimizer_kwargs.update(additional_optim_kwargs)
|
optimizer_kwargs.update(additional_optim_kwargs)
|
||||||
|
elif args.optim == OptimizerNames.STABLE_ADAMW:
|
||||||
|
if not is_torch_optimi_available():
|
||||||
|
raise ImportError(
|
||||||
|
"You need to install `torch-optimi` in order to use stable_adamw optimizers. "
|
||||||
|
"Install it with `pip install torch-optimi`."
|
||||||
|
)
|
||||||
|
from optimi import StableAdamW
|
||||||
|
|
||||||
|
max_lr = optim_args.pop("max_lr", None)
|
||||||
|
if max_lr is not None:
|
||||||
|
max_lr = float(max_lr)
|
||||||
|
|
||||||
|
kahan_sum = optim_args.pop("kahan_sum", None)
|
||||||
|
if kahan_sum is not None:
|
||||||
|
kahan_sum = bool(kahan_sum)
|
||||||
|
|
||||||
|
stable_adamw_kwargs = {
|
||||||
|
"decouple_lr": bool(optim_args.pop("decouple_lr", False)),
|
||||||
|
"max_lr": max_lr,
|
||||||
|
"kahan_sum": kahan_sum,
|
||||||
|
}
|
||||||
|
|
||||||
|
optimizer_cls = StableAdamW
|
||||||
|
optimizer_kwargs.update(adam_kwargs)
|
||||||
|
optimizer_kwargs.update(stable_adamw_kwargs)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}")
|
raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}")
|
||||||
return optimizer_cls, optimizer_kwargs
|
return optimizer_cls, optimizer_kwargs
|
||||||
|
|||||||
@@ -186,6 +186,7 @@ class OptimizerNames(ExplicitEnum):
|
|||||||
SCHEDULE_FREE_SGD = "schedule_free_sgd"
|
SCHEDULE_FREE_SGD = "schedule_free_sgd"
|
||||||
APOLLO_ADAMW = "apollo_adamw"
|
APOLLO_ADAMW = "apollo_adamw"
|
||||||
APOLLO_ADAMW_LAYERWISE = "apollo_adamw_layerwise"
|
APOLLO_ADAMW_LAYERWISE = "apollo_adamw_layerwise"
|
||||||
|
STABLE_ADAMW = "stable_adamw"
|
||||||
|
|
||||||
|
|
||||||
def _convert_str_dict(passed_value: dict):
|
def _convert_str_dict(passed_value: dict):
|
||||||
|
|||||||
@@ -249,6 +249,7 @@ from .import_utils import (
|
|||||||
is_torch_musa_available,
|
is_torch_musa_available,
|
||||||
is_torch_neuroncore_available,
|
is_torch_neuroncore_available,
|
||||||
is_torch_npu_available,
|
is_torch_npu_available,
|
||||||
|
is_torch_optimi_available,
|
||||||
is_torch_sdpa_available,
|
is_torch_sdpa_available,
|
||||||
is_torch_tensorrt_fx_available,
|
is_torch_tensorrt_fx_available,
|
||||||
is_torch_tf32_available,
|
is_torch_tf32_available,
|
||||||
|
|||||||
@@ -127,6 +127,7 @@ _galore_torch_available = _is_package_available("galore_torch")
|
|||||||
_lomo_available = _is_package_available("lomo_optim")
|
_lomo_available = _is_package_available("lomo_optim")
|
||||||
_grokadamw_available = _is_package_available("grokadamw")
|
_grokadamw_available = _is_package_available("grokadamw")
|
||||||
_schedulefree_available, _schedulefree_version = _is_package_available("schedulefree", return_version=True)
|
_schedulefree_available, _schedulefree_version = _is_package_available("schedulefree", return_version=True)
|
||||||
|
_torch_optimi_available = importlib.util.find_spec("optimi") is not None
|
||||||
# `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed.
|
# `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed.
|
||||||
_bs4_available = importlib.util.find_spec("bs4") is not None
|
_bs4_available = importlib.util.find_spec("bs4") is not None
|
||||||
_coloredlogs_available = _is_package_available("coloredlogs")
|
_coloredlogs_available = _is_package_available("coloredlogs")
|
||||||
@@ -474,6 +475,10 @@ def is_apollo_torch_available():
|
|||||||
return _apollo_torch_available
|
return _apollo_torch_available
|
||||||
|
|
||||||
|
|
||||||
|
def is_torch_optimi_available():
|
||||||
|
return _torch_optimi_available
|
||||||
|
|
||||||
|
|
||||||
def is_lomo_available():
|
def is_lomo_available():
|
||||||
return _lomo_available
|
return _lomo_available
|
||||||
|
|
||||||
|
|||||||
@@ -99,6 +99,7 @@ from transformers.testing_utils import (
|
|||||||
require_torch_multi_accelerator,
|
require_torch_multi_accelerator,
|
||||||
require_torch_non_multi_accelerator,
|
require_torch_non_multi_accelerator,
|
||||||
require_torch_non_multi_gpu,
|
require_torch_non_multi_gpu,
|
||||||
|
require_torch_optimi,
|
||||||
require_torch_tensorrt_fx,
|
require_torch_tensorrt_fx,
|
||||||
require_torch_tf32,
|
require_torch_tf32,
|
||||||
require_torch_up_to_2_accelerators,
|
require_torch_up_to_2_accelerators,
|
||||||
@@ -2518,6 +2519,123 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
# warm up steps << total steps
|
# warm up steps << total steps
|
||||||
self.assertTrue(len(decreasing_lrs) > len(increasing_lrs))
|
self.assertTrue(len(decreasing_lrs) > len(increasing_lrs))
|
||||||
|
|
||||||
|
@require_torch_optimi
|
||||||
|
@require_torch_gpu
|
||||||
|
def test_stable_adamw(self):
|
||||||
|
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||||
|
tiny_llama = LlamaForCausalLM(config)
|
||||||
|
x = torch.randint(0, 100, (128,))
|
||||||
|
train_dataset = RepeatDataset(x)
|
||||||
|
|
||||||
|
# Trainer without inf/nan filter
|
||||||
|
args = TrainingArguments(
|
||||||
|
self.get_auto_remove_tmp_dir(),
|
||||||
|
learning_rate=1e-9,
|
||||||
|
logging_steps=5,
|
||||||
|
optim="stable_adamw",
|
||||||
|
optim_target_modules=[r".*attn.*", r".*mlp.*"],
|
||||||
|
)
|
||||||
|
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
|
||||||
|
_ = trainer.train()
|
||||||
|
|
||||||
|
@require_torch_optimi
|
||||||
|
@require_torch_gpu
|
||||||
|
def test_stable_adamw_extra_args(self):
|
||||||
|
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||||
|
tiny_llama = LlamaForCausalLM(config)
|
||||||
|
x = torch.randint(0, 100, (128,))
|
||||||
|
train_dataset = RepeatDataset(x)
|
||||||
|
|
||||||
|
# Trainer without inf/nan filter
|
||||||
|
args = TrainingArguments(
|
||||||
|
self.get_auto_remove_tmp_dir(),
|
||||||
|
learning_rate=1e-9,
|
||||||
|
logging_steps=5,
|
||||||
|
optim="stable_adamw",
|
||||||
|
optim_args="decouple_lr=True,max_lr=1e-3,kahan_sum=True",
|
||||||
|
optim_target_modules=[r".*attn.*", r".*mlp.*"],
|
||||||
|
)
|
||||||
|
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
|
||||||
|
|
||||||
|
# Check this works
|
||||||
|
_ = trainer.train()
|
||||||
|
|
||||||
|
@require_torch_optimi
|
||||||
|
@require_torch_gpu
|
||||||
|
def test_stable_adamw_lr_display_without_scheduler(self):
|
||||||
|
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||||
|
tiny_llama = LlamaForCausalLM(config)
|
||||||
|
x = torch.randint(0, 100, (128,))
|
||||||
|
train_dataset = RepeatDataset(x)
|
||||||
|
|
||||||
|
learning_rate = 1e-9
|
||||||
|
num_steps = 10
|
||||||
|
|
||||||
|
# Trainer without inf/nan filter
|
||||||
|
args = TrainingArguments(
|
||||||
|
self.get_auto_remove_tmp_dir(),
|
||||||
|
learning_rate=learning_rate,
|
||||||
|
logging_steps=5,
|
||||||
|
optim="stable_adamw",
|
||||||
|
optim_target_modules=[r".*attn.*", r".*mlp.*"],
|
||||||
|
)
|
||||||
|
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
|
||||||
|
trainer.create_optimizer_and_scheduler(num_training_steps=num_steps)
|
||||||
|
|
||||||
|
# reflects displayed lr in trainer
|
||||||
|
self.assertEqual(trainer.get_learning_rates(), [learning_rate, learning_rate])
|
||||||
|
|
||||||
|
@require_torch_optimi
|
||||||
|
@require_torch_gpu
|
||||||
|
def test_stable_adamw_lr_display_with_scheduler(self):
|
||||||
|
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||||
|
tiny_llama = LlamaForCausalLM(config)
|
||||||
|
x = torch.randint(0, 100, (128,))
|
||||||
|
train_dataset = RepeatDataset(x)
|
||||||
|
|
||||||
|
learning_rate = 2e-4
|
||||||
|
num_train_epochs = 10
|
||||||
|
num_warmup_steps = 5
|
||||||
|
|
||||||
|
# Trainer without inf/nan filter
|
||||||
|
args = TrainingArguments(
|
||||||
|
self.get_auto_remove_tmp_dir(),
|
||||||
|
num_train_epochs=num_train_epochs,
|
||||||
|
learning_rate=learning_rate,
|
||||||
|
warmup_steps=num_warmup_steps,
|
||||||
|
lr_scheduler_type="cosine",
|
||||||
|
logging_steps=1,
|
||||||
|
optim="stable_adamw",
|
||||||
|
optim_target_modules=[r".*attn.*", r".*mlp.*"],
|
||||||
|
)
|
||||||
|
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
|
||||||
|
|
||||||
|
# creating log history of trainer, results don't matter
|
||||||
|
trainer.train()
|
||||||
|
logs = trainer.state.log_history[1:][:-1]
|
||||||
|
|
||||||
|
# reach given learning rate peak and end with 0 lr
|
||||||
|
self.assertTrue(logs[num_warmup_steps - 2]["learning_rate"] == learning_rate)
|
||||||
|
self.assertTrue(logs[-1]["learning_rate"] == 0)
|
||||||
|
|
||||||
|
# increasing and decreasing pattern of lrs
|
||||||
|
increasing_lrs = [
|
||||||
|
logs[i]["learning_rate"] < logs[i + 1]["learning_rate"]
|
||||||
|
for i in range(len(logs))
|
||||||
|
if i < num_warmup_steps - 2
|
||||||
|
]
|
||||||
|
decreasing_lrs = [
|
||||||
|
logs[i]["learning_rate"] > logs[i + 1]["learning_rate"]
|
||||||
|
for i in range(len(logs) - 1)
|
||||||
|
if i >= num_warmup_steps - 2
|
||||||
|
]
|
||||||
|
|
||||||
|
self.assertTrue(all(increasing_lrs))
|
||||||
|
self.assertTrue(all(decreasing_lrs))
|
||||||
|
|
||||||
|
# warm up steps << total steps
|
||||||
|
self.assertTrue(len(decreasing_lrs) > len(increasing_lrs))
|
||||||
|
|
||||||
@require_torch_multi_accelerator
|
@require_torch_multi_accelerator
|
||||||
def test_data_is_not_parallelized_when_model_is_parallel(self):
|
def test_data_is_not_parallelized_when_model_is_parallel(self):
|
||||||
model = RegressionModel()
|
model = RegressionModel()
|
||||||
|
|||||||
Reference in New Issue
Block a user