From 19597998f61934104caa5ead361f09d0e9512336 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Thu, 10 Mar 2022 07:44:51 -0500 Subject: [PATCH] Don't compute metrics in LM examples on TPU (#16029) --- examples/pytorch/language-modeling/run_clm.py | 7 +++++-- examples/pytorch/language-modeling/run_mlm.py | 7 +++++-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/examples/pytorch/language-modeling/run_clm.py b/examples/pytorch/language-modeling/run_clm.py index ae50bd2ce9..5534e6901f 100755 --- a/examples/pytorch/language-modeling/run_clm.py +++ b/examples/pytorch/language-modeling/run_clm.py @@ -43,6 +43,7 @@ from transformers import ( Trainer, TrainingArguments, default_data_collator, + is_torch_tpu_available, set_seed, ) from transformers.testing_utils import CaptureLogger @@ -479,8 +480,10 @@ def main(): tokenizer=tokenizer, # Data collator will default to DataCollatorWithPadding, so we change it. data_collator=default_data_collator, - compute_metrics=compute_metrics if training_args.do_eval else None, - preprocess_logits_for_metrics=preprocess_logits_for_metrics if training_args.do_eval else None, + compute_metrics=compute_metrics if training_args.do_eval and not is_torch_tpu_available() else None, + preprocess_logits_for_metrics=preprocess_logits_for_metrics + if training_args.do_eval and not is_torch_tpu_available() + else None, ) # Training diff --git a/examples/pytorch/language-modeling/run_mlm.py b/examples/pytorch/language-modeling/run_mlm.py index 9926cccfae..7ceae8b17a 100755 --- a/examples/pytorch/language-modeling/run_mlm.py +++ b/examples/pytorch/language-modeling/run_mlm.py @@ -43,6 +43,7 @@ from transformers import ( HfArgumentParser, Trainer, TrainingArguments, + is_torch_tpu_available, set_seed, ) from transformers.trainer_utils import get_last_checkpoint @@ -513,8 +514,10 @@ def main(): eval_dataset=eval_dataset if training_args.do_eval else None, tokenizer=tokenizer, data_collator=data_collator, - compute_metrics=compute_metrics if training_args.do_eval else None, - preprocess_logits_for_metrics=preprocess_logits_for_metrics if training_args.do_eval else None, + compute_metrics=compute_metrics if training_args.do_eval and not is_torch_tpu_available() else None, + preprocess_logits_for_metrics=preprocess_logits_for_metrics + if training_args.do_eval and not is_torch_tpu_available() + else None, ) # Training