Simplify and update trl examples (#38772)

* Simplify and update trl examples

* Remove optim_args from SFTConfig in Trainer documentation

* Update docs/source/en/trainer.md

* Apply suggestions from code review

* Update docs/source/en/trainer.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

---------

Co-authored-by: Quentin Gallouédec <qgallouedec@Quentins-MacBook-Pro.local>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
This commit is contained in:
Quentin Gallouédec
2025-06-13 14:03:49 +02:00
committed by GitHub
parent de24fb63ed
commit c989ddd294
7 changed files with 114 additions and 347 deletions

View File

@@ -97,39 +97,22 @@ print(tokenizer.decode(output[0], skip_special_tokens=True))
- Mamba stacks `mixer` layers which are equivalent to `Attention` layers. You can find the main logic of Mamba in the `MambaMixer` class.
- The example below demonstrates how to fine-tune Mamba with [PEFT](https://huggingface.co/docs/peft).
```py
from datasets import load_dataset
from trl import SFTTrainer
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
model_id = "state-spaces/mamba-130m-hf"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
dataset = load_dataset("Abirate/english_quotes", split="train")
training_args = TrainingArguments(
output_dir="./results",
num_train_epochs=3,
per_device_train_batch_size=4,
logging_dir='./logs',
logging_steps=10,
learning_rate=2e-3
)
lora_config = LoraConfig(
r=8,
target_modules=["x_proj", "embeddings", "in_proj", "out_proj"],
task_type="CAUSAL_LM",
bias="none"
)
trainer = SFTTrainer(
model=model,
processing_class=tokenizer,
```py
from datasets import load_dataset
from trl import SFTConfig, SFTTrainer
from peft import LoraConfig
model_id = "state-spaces/mamba-130m-hf"
dataset = load_dataset("Abirate/english_quotes", split="train")
training_args = SFTConfig(dataset_text_field="quote")
lora_config = LoraConfig(target_modules=["x_proj", "embeddings", "in_proj", "out_proj"])
trainer = SFTTrainer(
model=model_id,
args=training_args,
peft_config=lora_config,
train_dataset=dataset,
dataset_text_field="quote",
)
trainer.train()
train_dataset=dataset,
peft_config=lora_config,
)
trainer.train()
```
## MambaConfig

View File

@@ -103,40 +103,19 @@ print(tokenizer.decode(output[0], skip_special_tokens=True))
- The example below demonstrates how to fine-tune Mamba 2 with [PEFT](https://huggingface.co/docs/peft).
```python
from trl import SFTTrainer
from datasets import load_dataset
from peft import LoraConfig
from transformers import AutoTokenizer, Mamba2ForCausalLM, TrainingArguments
model_id = 'mistralai/Mamba-Codestral-7B-v0.1'
tokenizer = AutoTokenizer.from_pretrained(model_id, revision='refs/pr/9', from_slow=True, legacy=False)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left" #enforce padding side left
from trl import SFTConfig, SFTTrainer
model = Mamba2ForCausalLM.from_pretrained(model_id, revision='refs/pr/9')
model_id = "mistralai/Mamba-Codestral-7B-v0.1"
dataset = load_dataset("Abirate/english_quotes", split="train")
# Without CUDA kernels, batch size of 2 occupies one 80GB device
# but precision can be reduced.
# Experiments and trials welcome!
training_args = TrainingArguments(
output_dir="./results",
num_train_epochs=3,
per_device_train_batch_size=2,
logging_dir='./logs',
logging_steps=10,
learning_rate=2e-3
)
lora_config = LoraConfig(
r=8,
target_modules=["embeddings", "in_proj", "out_proj"],
task_type="CAUSAL_LM",
bias="none"
)
training_args = SFTConfig(dataset_text_field="quote", gradient_checkpointing=True, per_device_train_batch_size=4)
lora_config = LoraConfig(target_modules=["x_proj", "embeddings", "in_proj", "out_proj"])
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
model=model_id,
args=training_args,
peft_config=lora_config,
train_dataset=dataset,
dataset_text_field="quote",
peft_config=lora_config,
)
trainer.train()
```

View File

@@ -392,15 +392,15 @@ training_args = TrainingArguments(
[Gradient Low-Rank Projection (GaLore)](https://hf.co/papers/2403.03507) significantly reduces memory usage when training large language models (LLMs). One of GaLores key benefits is *full-parameter* learning, unlike low-rank adaptation methods like [LoRA](https://hf.co/papers/2106.09685), which produces better model performance.
Install the [GaLore](https://github.com/jiaweizzhao/GaLore) library, [TRL](https://hf.co/docs/trl/index), and [Datasets](https://hf.co/docs/datasets/index).
Install the [GaLore](https://github.com/jiaweizzhao/GaLore) and [TRL](https://hf.co/docs/trl/index) libraries.
```bash
pip install galore-torch trl datasets
pip install galore-torch trl
```
Pick a GaLore optimizer (`"galore_adamw"`, `"galore_adafactor"`, `"galore_adamw_8bit`") and pass it to the `optim` parameter in [`TrainingArguments`]. Use the `optim_target_modules` parameter to specify which modules to adapt (can be a list of strings, regex, or a full path).
Pick a GaLore optimizer (`"galore_adamw"`, `"galore_adafactor"`, `"galore_adamw_8bit`") and pass it to the `optim` parameter in [`trl.SFTConfig`]. Use the `optim_target_modules` parameter to specify which modules to adapt (can be a list of strings, regex, or a full path).
Extra parameters supported by GaLore, `rank`, `update_proj_gap`, and `scale`, should be passed to the `optim_args` parameter in [`TrainingArguments`].
Extra parameters supported by GaLore, `rank`, `update_proj_gap`, and `scale`, should be passed to the `optim_args` parameter in [`trl.SFTConfig`].
The example below enables GaLore with [`~trl.SFTTrainer`] that targets the `attn` and `mlp` layers with regex.
@@ -411,29 +411,22 @@ The example below enables GaLore with [`~trl.SFTTrainer`] that targets the `attn
<hfoption id="GaLore optimizer">
```py
import torch
import datasets
import trl
from transformers import TrainingArguments, AutoConfig, AutoTokenizer, AutoModelForCausalLM
from trl import SFTConfig, SFTTrainer
train_dataset = datasets.load_dataset('imdb', split='train')
args = TrainingArguments(
args = SFTConfig(
output_dir="./test-galore",
max_steps=100,
per_device_train_batch_size=2,
optim="galore_adamw",
optim_target_modules=[r".*.attn.*", r".*.mlp.*"],
optim_args="rank=64, update_proj_gap=100, scale=0.10",
gradient_checkpointing=True,
)
config = AutoConfig.from_pretrained("google/gemma-2b")
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
model = AutoModelForCausalLM.from_config("google/gemma-2b").to(0)
trainer = trl.SFTTrainer(
model=model,
trainer = SFTTrainer(
model="google/gemma-2b",
args=args,
train_dataset=train_dataset,
dataset_text_field='text',
max_seq_length=512,
)
trainer.train()
```
@@ -444,29 +437,22 @@ trainer.train()
Append `layerwise` to the optimizer name to enable layerwise optimization. For example, `"galore_adamw"` becomes `"galore_adamw_layerwise"`. This feature is still experimental and does not support Distributed Data Parallel (DDP). The code below can only be run on a [single GPU](https://github.com/jiaweizzhao/GaLore?tab=readme-ov-file#train-7b-model-with-a-single-gpu-with-24gb-memory). Other features like gradient clipping and DeepSpeed may not be available out of the box. Feel free to open an [issue](https://github.com/huggingface/transformers/issues) if you encounter any problems!
```py
import torch
import datasets
import trl
from transformers import TrainingArguments, AutoConfig, AutoTokenizer, AutoModelForCausalLM
from trl import SFTConfig, SFTTrainer
train_dataset = datasets.load_dataset('imdb', split='train')
args = TrainingArguments(
args = SFTConfig(
output_dir="./test-galore",
max_steps=100,
per_device_train_batch_size=2,
optim="galore_adamw_layerwise",
optim_target_modules=[r".*.attn.*", r".*.mlp.*"],
optim_args="rank=64, update_proj_gap=100, scale=0.10",
gradient_checkpointing=True,
)
config = AutoConfig.from_pretrained("google/gemma-2b")
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
model = AutoModelForCausalLM.from_config("google/gemma-2b").to(0)
trainer = trl.SFTTrainer(
model=model,
trainer = SFTTrainer(
model="google/gemma-2b",
args=args,
train_dataset=train_dataset,
dataset_text_field='text',
max_seq_length=512,
)
trainer.train()
```