Use compute_loss in prediction_step (#9935)
This commit is contained in:
@@ -1312,7 +1312,7 @@ class Trainer:
|
|||||||
|
|
||||||
return loss.detach()
|
return loss.detach()
|
||||||
|
|
||||||
def compute_loss(self, model, inputs):
|
def compute_loss(self, model, inputs, return_outputs=False):
|
||||||
"""
|
"""
|
||||||
How the loss is computed by Trainer. By default, all models return the loss in the first element.
|
How the loss is computed by Trainer. By default, all models return the loss in the first element.
|
||||||
|
|
||||||
@@ -1329,10 +1329,12 @@ class Trainer:
|
|||||||
self._past = outputs[self.args.past_index]
|
self._past = outputs[self.args.past_index]
|
||||||
|
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
return self.label_smoother(outputs, labels)
|
loss = self.label_smoother(outputs, labels)
|
||||||
else:
|
else:
|
||||||
# We don't use .loss here since the model may return tuples instead of ModelOutput.
|
# We don't use .loss here since the model may return tuples instead of ModelOutput.
|
||||||
return outputs["loss"] if isinstance(outputs, dict) else outputs[0]
|
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
|
||||||
|
|
||||||
|
return (loss, outputs) if return_outputs else loss
|
||||||
|
|
||||||
def is_local_process_zero(self) -> bool:
|
def is_local_process_zero(self) -> bool:
|
||||||
"""
|
"""
|
||||||
@@ -1718,29 +1720,27 @@ class Trainer:
|
|||||||
ignore_keys = []
|
ignore_keys = []
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if self.use_amp:
|
|
||||||
with autocast():
|
|
||||||
outputs = model(**inputs)
|
|
||||||
else:
|
|
||||||
outputs = model(**inputs)
|
|
||||||
if has_labels:
|
if has_labels:
|
||||||
if self.label_smoother is not None and "labels" in inputs:
|
loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
|
||||||
loss = self.label_smoother(outputs, inputs["labels"]).mean().detach()
|
loss = loss.mean().detach()
|
||||||
else:
|
|
||||||
loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).mean().detach()
|
|
||||||
if isinstance(outputs, dict):
|
if isinstance(outputs, dict):
|
||||||
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"])
|
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"])
|
||||||
else:
|
else:
|
||||||
logits = outputs[1:]
|
logits = outputs[1:]
|
||||||
else:
|
else:
|
||||||
loss = None
|
loss = None
|
||||||
|
if self.use_amp:
|
||||||
|
with autocast():
|
||||||
|
outputs = model(**inputs)
|
||||||
|
else:
|
||||||
|
outputs = model(**inputs)
|
||||||
if isinstance(outputs, dict):
|
if isinstance(outputs, dict):
|
||||||
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)
|
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)
|
||||||
else:
|
else:
|
||||||
logits = outputs
|
logits = outputs
|
||||||
# TODO: this needs to be fixed and made cleaner later.
|
# TODO: this needs to be fixed and made cleaner later.
|
||||||
if self.args.past_index >= 0:
|
if self.args.past_index >= 0:
|
||||||
self._past = outputs[self.args.past_index if has_labels else self.args.past_index - 1]
|
self._past = outputs[self.args.past_index - 1]
|
||||||
|
|
||||||
if prediction_loss_only:
|
if prediction_loss_only:
|
||||||
return (loss, None, None)
|
return (loss, None, None)
|
||||||
|
|||||||
Reference in New Issue
Block a user