[docs] fix outdated example code in trainer.md (#36066)

fix bugs
This commit is contained in:
Fanli Lin
2025-02-07 02:54:22 +08:00
committed by GitHub
parent 4563ba2c6f
commit 6246c03260

View File

@@ -130,7 +130,7 @@ from torch import nn
from transformers import Trainer
class CustomTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False):
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
labels = inputs.pop("labels")
# forward pass
outputs = model(**inputs)
@@ -156,9 +156,7 @@ class EarlyStoppingCallback(TrainerCallback):
def on_step_end(self, args, state, control, **kwargs):
if state.global_step >= self.num_steps:
return {"should_training_stop": True}
else:
return {}
control.should_training_stop = True
```
Then pass it to the [`Trainer`]'s `callback` parameter.
@@ -737,7 +735,7 @@ accelerate launch --num_processes=2 \
--fsdp_transformer_layer_cls_to_wrap="BertLayer" \
--fsdp_sharding_strategy=1 \
--fsdp_state_dict_type=FULL_STATE_DICT \
./examples/pytorch/text-classification/run_glue.py
./examples/pytorch/text-classification/run_glue.py \
--model_name_or_path google-bert/bert-base-cased \
--task_name $TASK_NAME \
--do_train \