Fix trainer seq2seq qa.py evaluate log and ft script (#19208)

* fix args option

* fix trainer eval log

* fix out of memory qa script

* do isort, black, flake

* fix tokenize target

* take it back.

* fix: comment
This commit is contained in:
Tatsuki Okada
2022-09-28 23:55:46 +09:00
committed by GitHub
parent 9c6aeba353
commit 4a0b958d61
2 changed files with 32 additions and 11 deletions

View File

@@ -15,12 +15,14 @@
"""
A subclass of `Trainer` specific to Question-Answering tasks
"""
import math
import time
from typing import Dict, List, Optional
from torch.utils.data import Dataset
from transformers import Seq2SeqTrainer, is_torch_tpu_available
from transformers.trainer_utils import PredictionOutput
from transformers.trainer_utils import PredictionOutput, speed_metrics
if is_torch_tpu_available(check_device=False):
@@ -59,6 +61,7 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer):
# Temporarily disable metric computation, we will do it in the loop here.
compute_metrics = self.compute_metrics
self.compute_metrics = None
start_time = time.time()
eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
try:
output = eval_loop(
@@ -71,6 +74,15 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer):
)
finally:
self.compute_metrics = compute_metrics
total_batch_size = self.args.eval_batch_size * self.args.world_size
output.metrics.update(
speed_metrics(
metric_key_prefix,
start_time,
num_samples=output.num_samples,
num_steps=math.ceil(output.num_samples / total_batch_size),
)
)
if self.post_process_function is not None and self.compute_metrics is not None:
eval_preds = self.post_process_function(eval_examples, eval_dataset, output)
@@ -81,15 +93,15 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer):
if not key.startswith(f"{metric_key_prefix}_"):
metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
self.log(metrics)
else:
metrics = {}
output.metrics.update(metrics)
self.log(metrics)
if self.args.tpu_metrics_debug or self.args.debug:
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
xm.master_print(met.metrics_report())
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics)
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics)
return metrics
def predict(