From 461e8cacf94d1f76367cc9ba2cfd5b9bd3641c81 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Mon, 22 Feb 2021 16:39:02 -0500 Subject: [PATCH] Fix evaluation with label smoothing in Trainer (#10338) --- src/transformers/trainer.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index f18563fa4f..962a6fb1ba 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1888,6 +1888,14 @@ class Trainer: else: ignore_keys = [] + # labels may be popped when computing the loss (label smoothing for instance) so we grab them first. + if has_labels: + labels = nested_detach(tuple(inputs.get(name) for name in self.label_names)) + if len(labels) == 1: + labels = labels[0] + else: + labels = None + with torch.no_grad(): if has_labels: loss, outputs = self.compute_loss(model, inputs, return_outputs=True) @@ -1918,13 +1926,6 @@ class Trainer: if len(logits) == 1: logits = logits[0] - if has_labels: - labels = nested_detach(tuple(inputs.get(name) for name in self.label_names)) - if len(labels) == 1: - labels = labels[0] - else: - labels = None - return (loss, logits, labels) def floating_point_ops(self, inputs: Dict[str, Union[torch.Tensor, Any]]):