schedulefree optimizers (#30079)
* schedulefree optimizers * fix train instead of eval for optimizer * fixes and update docs * chore: lint * add tests and drop overly-verbose _32bit suffix * chore: lint * fix for docs * fix code review issues * use duck-typing to avoid per-optimizer patches * fixup style * fixup style * warn if incorrect accelerate version with schedule free Co-authored-by: Aman Gupta Karmani <aman@tmm1.net> --------- Co-authored-by: Aman Karmani <aman@tmm1.net>
This commit is contained in:
@@ -518,6 +518,51 @@ trainer.train()
|
||||
|
||||
This script demonstrates how to fine-tune the `google/gemma-2b` model on the IMDB dataset using the GrokAdamW optimizer. The `TrainingArguments` are configured to use GrokAdamW, and the dataset is passed to the `Trainer` for training.
|
||||
|
||||
## Schedule Free Optimizer
|
||||
|
||||
The Schedule Free optimizers have been introduced in [The Road Less Scheduled](https://hf.co/papers/2405.15682).
|
||||
Schedule-Free learning replaces the momentum of the base optimizer with a combination of averaging and interpolation, to completely remove the need to anneal the learning rate with a traditional schedule.
|
||||
Supported optimizers for SFO are `"schedule_free_adamw"` and `"schedule_free_sgd"`. First install schedulefree from pypi `pip install schedulefree`.
|
||||
|
||||
Below is a simple script to demonstrate how to fine-tune [google/gemma-2b](https://huggingface.co/google/gemma-2b) on IMDB dataset in full precision:
|
||||
|
||||
```python
|
||||
import torch
|
||||
import datasets
|
||||
from transformers import TrainingArguments, AutoTokenizer, AutoModelForCausalLM
|
||||
import trl
|
||||
|
||||
train_dataset = datasets.load_dataset('imdb', split='train')
|
||||
|
||||
args = TrainingArguments(
|
||||
output_dir="./test-schedulefree",
|
||||
max_steps=1000,
|
||||
per_device_train_batch_size=4,
|
||||
optim="schedule_free_adamw",
|
||||
gradient_checkpointing=True,
|
||||
logging_strategy="steps",
|
||||
logging_steps=1,
|
||||
learning_rate=2e-6,
|
||||
save_strategy="no",
|
||||
run_name="sfo-imdb",
|
||||
)
|
||||
|
||||
model_id = "google/gemma-2b"
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True).to(0)
|
||||
|
||||
trainer = trl.SFTTrainer(
|
||||
model=model,
|
||||
args=args,
|
||||
train_dataset=train_dataset,
|
||||
dataset_text_field='text',
|
||||
max_seq_length=1024,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
## Accelerate and Trainer
|
||||
|
||||
The [`Trainer`] class is powered by [Accelerate](https://hf.co/docs/accelerate), a library for easily training PyTorch models in distributed environments with support for integrations such as [FullyShardedDataParallel (FSDP)](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/) and [DeepSpeed](https://www.deepspeed.ai/).
|
||||
|
||||
Reference in New Issue
Block a user