From c0281feb506b3cd8e9cfe19aa931ad05e295cffa Mon Sep 17 00:00:00 2001 From: davidleonfdez <45669232+davidleonfdez@users.noreply.github.com> Date: Thu, 3 Mar 2022 19:41:03 +0000 Subject: [PATCH] Fix #15898 (#15928) --- examples/pytorch/language-modeling/run_clm.py | 4 ++++ examples/pytorch/language-modeling/run_mlm.py | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/examples/pytorch/language-modeling/run_clm.py b/examples/pytorch/language-modeling/run_clm.py index 2ba861ab6f..ae50bd2ce9 100755 --- a/examples/pytorch/language-modeling/run_clm.py +++ b/examples/pytorch/language-modeling/run_clm.py @@ -454,6 +454,10 @@ def main(): eval_dataset = eval_dataset.select(range(data_args.max_eval_samples)) def preprocess_logits_for_metrics(logits, labels): + if isinstance(logits, tuple): + # Depending on the model and config, logits may contain extra tensors, + # like past_key_values, but logits always come first + logits = logits[0] return logits.argmax(dim=-1) metric = load_metric("accuracy") diff --git a/examples/pytorch/language-modeling/run_mlm.py b/examples/pytorch/language-modeling/run_mlm.py index 862166f736..9926cccfae 100755 --- a/examples/pytorch/language-modeling/run_mlm.py +++ b/examples/pytorch/language-modeling/run_mlm.py @@ -477,6 +477,10 @@ def main(): eval_dataset = eval_dataset.select(range(data_args.max_eval_samples)) def preprocess_logits_for_metrics(logits, labels): + if isinstance(logits, tuple): + # Depending on the model and config, logits may contain extra tensors, + # like past_key_values, but logits always come first + logits = logits[0] return logits.argmax(dim=-1) metric = load_metric("accuracy")