From d996024af7a94b0f11a5ad351217b648ecaed72a Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Tue, 2 Feb 2021 07:00:17 -0500 Subject: [PATCH] Use compute_loss in prediction_step (#9935) --- src/transformers/trainer.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index c25c7cb42d..f73d7d0007 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1312,7 +1312,7 @@ class Trainer: return loss.detach() - def compute_loss(self, model, inputs): + def compute_loss(self, model, inputs, return_outputs=False): """ How the loss is computed by Trainer. By default, all models return the loss in the first element. @@ -1329,10 +1329,12 @@ class Trainer: self._past = outputs[self.args.past_index] if labels is not None: - return self.label_smoother(outputs, labels) + loss = self.label_smoother(outputs, labels) else: # We don't use .loss here since the model may return tuples instead of ModelOutput. - return outputs["loss"] if isinstance(outputs, dict) else outputs[0] + loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0] + + return (loss, outputs) if return_outputs else loss def is_local_process_zero(self) -> bool: """ @@ -1718,29 +1720,27 @@ class Trainer: ignore_keys = [] with torch.no_grad(): - if self.use_amp: - with autocast(): - outputs = model(**inputs) - else: - outputs = model(**inputs) if has_labels: - if self.label_smoother is not None and "labels" in inputs: - loss = self.label_smoother(outputs, inputs["labels"]).mean().detach() - else: - loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).mean().detach() + loss, outputs = self.compute_loss(model, inputs, return_outputs=True) + loss = loss.mean().detach() if isinstance(outputs, dict): logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"]) else: logits = outputs[1:] else: loss = None + if self.use_amp: + with autocast(): + outputs = model(**inputs) + else: + outputs = model(**inputs) if isinstance(outputs, dict): logits = tuple(v for k, v in outputs.items() if k not in ignore_keys) else: logits = outputs - # TODO: this needs to be fixed and made cleaner later. - if self.args.past_index >= 0: - self._past = outputs[self.args.past_index if has_labels else self.args.past_index - 1] + # TODO: this needs to be fixed and made cleaner later. + if self.args.past_index >= 0: + self._past = outputs[self.args.past_index - 1] if prediction_loss_only: return (loss, None, None)