From 0270256b275d75432d61633dfa39da1f20f987ca Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Mon, 5 Oct 2020 06:33:15 -0400 Subject: [PATCH] Allow nested tensors in predicted logits (#7542) --- src/transformers/trainer.py | 7 +++++-- src/transformers/trainer_utils.py | 7 +++++++ 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 12409338f4..262ed6df59 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -48,6 +48,7 @@ from .trainer_utils import ( distributed_broadcast_scalars, distributed_concat, nested_concat, + nested_detach, nested_numpify, nested_xla_mesh_reduce, set_seed, @@ -1466,16 +1467,18 @@ class Trainer: logits = outputs[:] if self.args.past_index >= 0: self._past = outputs[self.args.past_index if has_labels else self.args.past_index - 1] + # Remove the past from the logits. + logits = logits[: self.args.past_index - 1] + logits[self.args.past_index :] if prediction_loss_only: return (loss, None, None) - logits = tuple(logit.detach() for logit in logits) + logits = nested_detach(logits) if len(logits) == 1: logits = logits[0] if has_labels: - labels = tuple(inputs.get(name).detach() for name in self.label_names) + labels = nested_detach(tuple(inputs.get(name) for name in self.label_names)) if len(labels) == 1: labels = labels[0] else: diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index b3207ec359..e816b0772a 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -154,6 +154,13 @@ def nested_concat(tensors, new_tensors, dim=0): raise ImportError("Torch must be installed to use `nested_concat`") +def nested_deatch(tensors): + "Detach `tensors` (even if it's a nested list/tuple of tensors)." + if isinstance(tensors, (list, tuple)): + return type(tensors)(nested_detach(t) for t in tensors) + return tensors.detach() + + def nested_numpify(tensors): "Numpify `tensors` (even if it's a nested list/tuple of tensors)." if isinstance(tensors, (list, tuple)):