From 41cd52a768a222a13da0c6aaae877a92fc6c783c Mon Sep 17 00:00:00 2001 From: Mohan Zhang Date: Wed, 8 Sep 2021 11:48:00 -0400 Subject: [PATCH] fixed document (#13414) --- docs/source/main_classes/trainer.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/main_classes/trainer.rst b/docs/source/main_classes/trainer.rst index 850af5eb99..4c3a947743 100644 --- a/docs/source/main_classes/trainer.rst +++ b/docs/source/main_classes/trainer.rst @@ -64,9 +64,9 @@ classification: class MultilabelTrainer(Trainer): def compute_loss(self, model, inputs, return_outputs=False): - labels = inputs.pop("labels") + labels = inputs.get("labels") outputs = model(**inputs) - logits = outputs.logits + logits = outputs.get('logits') loss_fct = nn.BCEWithLogitsLoss() loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.float().view(-1, self.model.config.num_labels))