Don't compute metrics in LM examples on TPU (#16029)
This commit is contained in:
@@ -43,6 +43,7 @@ from transformers import (
|
|||||||
Trainer,
|
Trainer,
|
||||||
TrainingArguments,
|
TrainingArguments,
|
||||||
default_data_collator,
|
default_data_collator,
|
||||||
|
is_torch_tpu_available,
|
||||||
set_seed,
|
set_seed,
|
||||||
)
|
)
|
||||||
from transformers.testing_utils import CaptureLogger
|
from transformers.testing_utils import CaptureLogger
|
||||||
@@ -479,8 +480,10 @@ def main():
|
|||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
# Data collator will default to DataCollatorWithPadding, so we change it.
|
# Data collator will default to DataCollatorWithPadding, so we change it.
|
||||||
data_collator=default_data_collator,
|
data_collator=default_data_collator,
|
||||||
compute_metrics=compute_metrics if training_args.do_eval else None,
|
compute_metrics=compute_metrics if training_args.do_eval and not is_torch_tpu_available() else None,
|
||||||
preprocess_logits_for_metrics=preprocess_logits_for_metrics if training_args.do_eval else None,
|
preprocess_logits_for_metrics=preprocess_logits_for_metrics
|
||||||
|
if training_args.do_eval and not is_torch_tpu_available()
|
||||||
|
else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Training
|
# Training
|
||||||
|
|||||||
@@ -43,6 +43,7 @@ from transformers import (
|
|||||||
HfArgumentParser,
|
HfArgumentParser,
|
||||||
Trainer,
|
Trainer,
|
||||||
TrainingArguments,
|
TrainingArguments,
|
||||||
|
is_torch_tpu_available,
|
||||||
set_seed,
|
set_seed,
|
||||||
)
|
)
|
||||||
from transformers.trainer_utils import get_last_checkpoint
|
from transformers.trainer_utils import get_last_checkpoint
|
||||||
@@ -513,8 +514,10 @@ def main():
|
|||||||
eval_dataset=eval_dataset if training_args.do_eval else None,
|
eval_dataset=eval_dataset if training_args.do_eval else None,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
data_collator=data_collator,
|
data_collator=data_collator,
|
||||||
compute_metrics=compute_metrics if training_args.do_eval else None,
|
compute_metrics=compute_metrics if training_args.do_eval and not is_torch_tpu_available() else None,
|
||||||
preprocess_logits_for_metrics=preprocess_logits_for_metrics if training_args.do_eval else None,
|
preprocess_logits_for_metrics=preprocess_logits_for_metrics
|
||||||
|
if training_args.do_eval and not is_torch_tpu_available()
|
||||||
|
else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Training
|
# Training
|
||||||
|
|||||||
Reference in New Issue
Block a user