Simplify and update trl examples (#38772)
* Simplify and update trl examples * Remove optim_args from SFTConfig in Trainer documentation * Update docs/source/en/trainer.md * Apply suggestions from code review * Update docs/source/en/trainer.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --------- Co-authored-by: Quentin Gallouédec <qgallouedec@Quentins-MacBook-Pro.local> Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
de24fb63ed
commit
c989ddd294
@@ -392,15 +392,15 @@ training_args = TrainingArguments(
|
||||
|
||||
[Gradient Low-Rank Projection (GaLore)](https://hf.co/papers/2403.03507) significantly reduces memory usage when training large language models (LLMs). One of GaLores key benefits is *full-parameter* learning, unlike low-rank adaptation methods like [LoRA](https://hf.co/papers/2106.09685), which produces better model performance.
|
||||
|
||||
Install the [GaLore](https://github.com/jiaweizzhao/GaLore) library, [TRL](https://hf.co/docs/trl/index), and [Datasets](https://hf.co/docs/datasets/index).
|
||||
Install the [GaLore](https://github.com/jiaweizzhao/GaLore) and [TRL](https://hf.co/docs/trl/index) libraries.
|
||||
|
||||
```bash
|
||||
pip install galore-torch trl datasets
|
||||
pip install galore-torch trl
|
||||
```
|
||||
|
||||
Pick a GaLore optimizer (`"galore_adamw"`, `"galore_adafactor"`, `"galore_adamw_8bit`") and pass it to the `optim` parameter in [`TrainingArguments`]. Use the `optim_target_modules` parameter to specify which modules to adapt (can be a list of strings, regex, or a full path).
|
||||
Pick a GaLore optimizer (`"galore_adamw"`, `"galore_adafactor"`, `"galore_adamw_8bit`") and pass it to the `optim` parameter in [`trl.SFTConfig`]. Use the `optim_target_modules` parameter to specify which modules to adapt (can be a list of strings, regex, or a full path).
|
||||
|
||||
Extra parameters supported by GaLore, `rank`, `update_proj_gap`, and `scale`, should be passed to the `optim_args` parameter in [`TrainingArguments`].
|
||||
Extra parameters supported by GaLore, `rank`, `update_proj_gap`, and `scale`, should be passed to the `optim_args` parameter in [`trl.SFTConfig`].
|
||||
|
||||
The example below enables GaLore with [`~trl.SFTTrainer`] that targets the `attn` and `mlp` layers with regex.
|
||||
|
||||
@@ -411,29 +411,22 @@ The example below enables GaLore with [`~trl.SFTTrainer`] that targets the `attn
|
||||
<hfoption id="GaLore optimizer">
|
||||
|
||||
```py
|
||||
import torch
|
||||
import datasets
|
||||
import trl
|
||||
from transformers import TrainingArguments, AutoConfig, AutoTokenizer, AutoModelForCausalLM
|
||||
from trl import SFTConfig, SFTTrainer
|
||||
|
||||
train_dataset = datasets.load_dataset('imdb', split='train')
|
||||
args = TrainingArguments(
|
||||
args = SFTConfig(
|
||||
output_dir="./test-galore",
|
||||
max_steps=100,
|
||||
per_device_train_batch_size=2,
|
||||
optim="galore_adamw",
|
||||
optim_target_modules=[r".*.attn.*", r".*.mlp.*"],
|
||||
optim_args="rank=64, update_proj_gap=100, scale=0.10",
|
||||
gradient_checkpointing=True,
|
||||
)
|
||||
config = AutoConfig.from_pretrained("google/gemma-2b")
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
|
||||
model = AutoModelForCausalLM.from_config("google/gemma-2b").to(0)
|
||||
trainer = trl.SFTTrainer(
|
||||
model=model,
|
||||
trainer = SFTTrainer(
|
||||
model="google/gemma-2b",
|
||||
args=args,
|
||||
train_dataset=train_dataset,
|
||||
dataset_text_field='text',
|
||||
max_seq_length=512,
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
@@ -444,29 +437,22 @@ trainer.train()
|
||||
Append `layerwise` to the optimizer name to enable layerwise optimization. For example, `"galore_adamw"` becomes `"galore_adamw_layerwise"`. This feature is still experimental and does not support Distributed Data Parallel (DDP). The code below can only be run on a [single GPU](https://github.com/jiaweizzhao/GaLore?tab=readme-ov-file#train-7b-model-with-a-single-gpu-with-24gb-memory). Other features like gradient clipping and DeepSpeed may not be available out of the box. Feel free to open an [issue](https://github.com/huggingface/transformers/issues) if you encounter any problems!
|
||||
|
||||
```py
|
||||
import torch
|
||||
import datasets
|
||||
import trl
|
||||
from transformers import TrainingArguments, AutoConfig, AutoTokenizer, AutoModelForCausalLM
|
||||
from trl import SFTConfig, SFTTrainer
|
||||
|
||||
train_dataset = datasets.load_dataset('imdb', split='train')
|
||||
args = TrainingArguments(
|
||||
args = SFTConfig(
|
||||
output_dir="./test-galore",
|
||||
max_steps=100,
|
||||
per_device_train_batch_size=2,
|
||||
optim="galore_adamw_layerwise",
|
||||
optim_target_modules=[r".*.attn.*", r".*.mlp.*"],
|
||||
optim_args="rank=64, update_proj_gap=100, scale=0.10",
|
||||
gradient_checkpointing=True,
|
||||
)
|
||||
config = AutoConfig.from_pretrained("google/gemma-2b")
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
|
||||
model = AutoModelForCausalLM.from_config("google/gemma-2b").to(0)
|
||||
trainer = trl.SFTTrainer(
|
||||
model=model,
|
||||
trainer = SFTTrainer(
|
||||
model="google/gemma-2b",
|
||||
args=args,
|
||||
train_dataset=train_dataset,
|
||||
dataset_text_field='text',
|
||||
max_seq_length=512,
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
Reference in New Issue
Block a user