Add StableAdamW Optimizer (#39446)
* Added StableAdamW as an optimizer option for Trainer. Also wrote tests to verify its behaviour. * Fixed issue with * Added docs for StableAdamW. Also fixed a typo in schedule free optimizers --------- Co-authored-by: Gautham Krithiwas <gauthamkrithiwas2003@gmail.com>
This commit is contained in:
@@ -164,7 +164,7 @@ args = TrainingArguments(
|
||||
output_dir="./test-schedulefree",
|
||||
max_steps=1000,
|
||||
per_device_train_batch_size=4,
|
||||
+ optim="schedule_free_radamw,
|
||||
+ optim="schedule_free_radamw",
|
||||
+ lr_scheduler_type="constant",
|
||||
gradient_checkpointing=True,
|
||||
logging_strategy="steps",
|
||||
@@ -174,3 +174,29 @@ args = TrainingArguments(
|
||||
run_name="sfo",
|
||||
)
|
||||
```
|
||||
|
||||
## StableAdamW
|
||||
|
||||
```bash
|
||||
pip install torch-optimi
|
||||
```
|
||||
|
||||
[StableAdamW](https://arxiv.org/pdf/2304.13013) is a hybrid between AdamW and AdaFactor. It ports AdaFactor's update clipping into AdamW, which removes the need for gradient clipping. Otherwise, it behaves as a drop-in replacement for AdamW.
|
||||
|
||||
> [!TIP]
|
||||
> If training on large batch sizes or still observing training loss spikes, consider reducing beta_2 between [0.95, 0.99].
|
||||
|
||||
```diff
|
||||
args = TrainingArguments(
|
||||
output_dir="./test-stable-adamw",
|
||||
max_steps=1000,
|
||||
per_device_train_batch_size=4,
|
||||
+ optim="stable_adamw",
|
||||
gradient_checkpointing=True,
|
||||
logging_strategy="steps",
|
||||
logging_steps=1,
|
||||
learning_rate=2e-6,
|
||||
save_strategy="no",
|
||||
run_name="stable-adamw",
|
||||
)
|
||||
```
|
||||
Reference in New Issue
Block a user