Allow nested tensors in predicted logits (#7542)
This commit is contained in:
@@ -48,6 +48,7 @@ from .trainer_utils import (
|
|||||||
distributed_broadcast_scalars,
|
distributed_broadcast_scalars,
|
||||||
distributed_concat,
|
distributed_concat,
|
||||||
nested_concat,
|
nested_concat,
|
||||||
|
nested_detach,
|
||||||
nested_numpify,
|
nested_numpify,
|
||||||
nested_xla_mesh_reduce,
|
nested_xla_mesh_reduce,
|
||||||
set_seed,
|
set_seed,
|
||||||
@@ -1466,16 +1467,18 @@ class Trainer:
|
|||||||
logits = outputs[:]
|
logits = outputs[:]
|
||||||
if self.args.past_index >= 0:
|
if self.args.past_index >= 0:
|
||||||
self._past = outputs[self.args.past_index if has_labels else self.args.past_index - 1]
|
self._past = outputs[self.args.past_index if has_labels else self.args.past_index - 1]
|
||||||
|
# Remove the past from the logits.
|
||||||
|
logits = logits[: self.args.past_index - 1] + logits[self.args.past_index :]
|
||||||
|
|
||||||
if prediction_loss_only:
|
if prediction_loss_only:
|
||||||
return (loss, None, None)
|
return (loss, None, None)
|
||||||
|
|
||||||
logits = tuple(logit.detach() for logit in logits)
|
logits = nested_detach(logits)
|
||||||
if len(logits) == 1:
|
if len(logits) == 1:
|
||||||
logits = logits[0]
|
logits = logits[0]
|
||||||
|
|
||||||
if has_labels:
|
if has_labels:
|
||||||
labels = tuple(inputs.get(name).detach() for name in self.label_names)
|
labels = nested_detach(tuple(inputs.get(name) for name in self.label_names))
|
||||||
if len(labels) == 1:
|
if len(labels) == 1:
|
||||||
labels = labels[0]
|
labels = labels[0]
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -154,6 +154,13 @@ def nested_concat(tensors, new_tensors, dim=0):
|
|||||||
raise ImportError("Torch must be installed to use `nested_concat`")
|
raise ImportError("Torch must be installed to use `nested_concat`")
|
||||||
|
|
||||||
|
|
||||||
|
def nested_deatch(tensors):
|
||||||
|
"Detach `tensors` (even if it's a nested list/tuple of tensors)."
|
||||||
|
if isinstance(tensors, (list, tuple)):
|
||||||
|
return type(tensors)(nested_detach(t) for t in tensors)
|
||||||
|
return tensors.detach()
|
||||||
|
|
||||||
|
|
||||||
def nested_numpify(tensors):
|
def nested_numpify(tensors):
|
||||||
"Numpify `tensors` (even if it's a nested list/tuple of tensors)."
|
"Numpify `tensors` (even if it's a nested list/tuple of tensors)."
|
||||||
if isinstance(tensors, (list, tuple)):
|
if isinstance(tensors, (list, tuple)):
|
||||||
|
|||||||
Reference in New Issue
Block a user