Fix evaluation with label smoothing in Trainer (#10338)
This commit is contained in:
@@ -1888,6 +1888,14 @@ class Trainer:
|
||||
else:
|
||||
ignore_keys = []
|
||||
|
||||
# labels may be popped when computing the loss (label smoothing for instance) so we grab them first.
|
||||
if has_labels:
|
||||
labels = nested_detach(tuple(inputs.get(name) for name in self.label_names))
|
||||
if len(labels) == 1:
|
||||
labels = labels[0]
|
||||
else:
|
||||
labels = None
|
||||
|
||||
with torch.no_grad():
|
||||
if has_labels:
|
||||
loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
|
||||
@@ -1918,13 +1926,6 @@ class Trainer:
|
||||
if len(logits) == 1:
|
||||
logits = logits[0]
|
||||
|
||||
if has_labels:
|
||||
labels = nested_detach(tuple(inputs.get(name) for name in self.label_names))
|
||||
if len(labels) == 1:
|
||||
labels = labels[0]
|
||||
else:
|
||||
labels = None
|
||||
|
||||
return (loss, logits, labels)
|
||||
|
||||
def floating_point_ops(self, inputs: Dict[str, Union[torch.Tensor, Any]]):
|
||||
|
||||
Reference in New Issue
Block a user