Simplify and update trl examples (#38772)
* Simplify and update trl examples * Remove optim_args from SFTConfig in Trainer documentation * Update docs/source/en/trainer.md * Apply suggestions from code review * Update docs/source/en/trainer.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --------- Co-authored-by: Quentin Gallouédec <qgallouedec@Quentins-MacBook-Pro.local> Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
de24fb63ed
commit
c989ddd294
@@ -58,34 +58,18 @@ print(tokenizer.batch_decode(out))
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
from trl import SFTTrainer
|
||||
from trl import SFTConfig, SFTTrainer
|
||||
from peft import LoraConfig
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
|
||||
|
||||
model_id = "state-spaces/mamba-130m-hf"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id)
|
||||
dataset = load_dataset("Abirate/english_quotes", split="train")
|
||||
training_args = TrainingArguments(
|
||||
output_dir="./results",
|
||||
num_train_epochs=3,
|
||||
per_device_train_batch_size=4,
|
||||
logging_dir='./logs',
|
||||
logging_steps=10,
|
||||
learning_rate=2e-3
|
||||
)
|
||||
lora_config = LoraConfig(
|
||||
r=8,
|
||||
target_modules=["x_proj", "embeddings", "in_proj", "out_proj"],
|
||||
task_type="CAUSAL_LM",
|
||||
bias="none"
|
||||
)
|
||||
training_args = SFTConfig(dataset_text_field="quote")
|
||||
lora_config = LoraConfig(target_modules=["x_proj", "embeddings", "in_proj", "out_proj"])
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
model=model_id,
|
||||
args=training_args,
|
||||
peft_config=lora_config,
|
||||
train_dataset=dataset,
|
||||
dataset_text_field="quote",
|
||||
peft_config=lora_config,
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
Reference in New Issue
Block a user