Don't compute metrics in LM examples on TPU (#16029)

This commit is contained in:
Sylvain Gugger
2022-03-10 07:44:51 -05:00
committed by GitHub
parent 10591399d6
commit 19597998f6
2 changed files with 10 additions and 4 deletions

View File

@@ -43,6 +43,7 @@ from transformers import (
Trainer,
TrainingArguments,
default_data_collator,
is_torch_tpu_available,
set_seed,
)
from transformers.testing_utils import CaptureLogger
@@ -479,8 +480,10 @@ def main():
tokenizer=tokenizer,
# Data collator will default to DataCollatorWithPadding, so we change it.
data_collator=default_data_collator,
compute_metrics=compute_metrics if training_args.do_eval else None,
preprocess_logits_for_metrics=preprocess_logits_for_metrics if training_args.do_eval else None,
compute_metrics=compute_metrics if training_args.do_eval and not is_torch_tpu_available() else None,
preprocess_logits_for_metrics=preprocess_logits_for_metrics
if training_args.do_eval and not is_torch_tpu_available()
else None,
)
# Training

View File

@@ -43,6 +43,7 @@ from transformers import (
HfArgumentParser,
Trainer,
TrainingArguments,
is_torch_tpu_available,
set_seed,
)
from transformers.trainer_utils import get_last_checkpoint
@@ -513,8 +514,10 @@ def main():
eval_dataset=eval_dataset if training_args.do_eval else None,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics if training_args.do_eval else None,
preprocess_logits_for_metrics=preprocess_logits_for_metrics if training_args.do_eval else None,
compute_metrics=compute_metrics if training_args.do_eval and not is_torch_tpu_available() else None,
preprocess_logits_for_metrics=preprocess_logits_for_metrics
if training_args.do_eval and not is_torch_tpu_available()
else None,
)
# Training