diff --git a/docs/source/en/main_classes/trainer.md b/docs/source/en/main_classes/trainer.md index ad3ea57f13..4a767ee076 100644 --- a/docs/source/en/main_classes/trainer.md +++ b/docs/source/en/main_classes/trainer.md @@ -60,7 +60,7 @@ from transformers import Trainer class CustomTrainer(Trainer): def compute_loss(self, model, inputs, return_outputs=False): - labels = inputs.get("labels") + labels = inputs.pop("labels") # forward pass outputs = model(**inputs) logits = outputs.get("logits")