Fix trainer seq2seq qa.py evaluate log and ft script (#19208)
* fix args option * fix trainer eval log * fix out of memory qa script * do isort, black, flake * fix tokenize target * take it back. * fix: comment
This commit is contained in:
@@ -327,21 +327,28 @@ def main():
|
|||||||
if data_args.dataset_name is not None:
|
if data_args.dataset_name is not None:
|
||||||
# Downloading and loading a dataset from the hub.
|
# Downloading and loading a dataset from the hub.
|
||||||
raw_datasets = load_dataset(
|
raw_datasets = load_dataset(
|
||||||
data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir
|
data_args.dataset_name,
|
||||||
|
data_args.dataset_config_name,
|
||||||
|
cache_dir=model_args.cache_dir,
|
||||||
|
use_auth_token=True if model_args.use_auth_token else None,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
data_files = {}
|
data_files = {}
|
||||||
if data_args.train_file is not None:
|
if data_args.train_file is not None:
|
||||||
data_files["train"] = data_args.train_file
|
data_files["train"] = data_args.train_file
|
||||||
extension = data_args.train_file.split(".")[-1]
|
extension = data_args.train_file.split(".")[-1]
|
||||||
|
|
||||||
if data_args.validation_file is not None:
|
if data_args.validation_file is not None:
|
||||||
data_files["validation"] = data_args.validation_file
|
data_files["validation"] = data_args.validation_file
|
||||||
extension = data_args.validation_file.split(".")[-1]
|
extension = data_args.validation_file.split(".")[-1]
|
||||||
if data_args.test_file is not None:
|
if data_args.test_file is not None:
|
||||||
data_files["test"] = data_args.test_file
|
data_files["test"] = data_args.test_file
|
||||||
extension = data_args.test_file.split(".")[-1]
|
extension = data_args.test_file.split(".")[-1]
|
||||||
raw_datasets = load_dataset(extension, data_files=data_files, field="data", cache_dir=model_args.cache_dir)
|
raw_datasets = load_dataset(
|
||||||
|
extension,
|
||||||
|
data_files=data_files,
|
||||||
|
cache_dir=model_args.cache_dir,
|
||||||
|
use_auth_token=True if model_args.use_auth_token else None,
|
||||||
|
)
|
||||||
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
||||||
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
||||||
|
|
||||||
@@ -359,7 +366,7 @@ def main():
|
|||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
|
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
|
||||||
cache_dir=model_args.cache_dir,
|
cache_dir=model_args.cache_dir,
|
||||||
use_fast=True,
|
use_fast=model_args.use_fast_tokenizer,
|
||||||
revision=model_args.model_revision,
|
revision=model_args.model_revision,
|
||||||
use_auth_token=True if model_args.use_auth_token else None,
|
use_auth_token=True if model_args.use_auth_token else None,
|
||||||
)
|
)
|
||||||
@@ -476,9 +483,10 @@ def main():
|
|||||||
max_length=max_seq_length,
|
max_length=max_seq_length,
|
||||||
padding=padding,
|
padding=padding,
|
||||||
truncation=True,
|
truncation=True,
|
||||||
return_overflowing_tokens=True,
|
|
||||||
return_offsets_mapping=True,
|
return_offsets_mapping=True,
|
||||||
|
return_overflowing_tokens=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Tokenize targets with the `text_target` keyword argument
|
# Tokenize targets with the `text_target` keyword argument
|
||||||
labels = tokenizer(text_target=targets, max_length=max_answer_length, padding=padding, truncation=True)
|
labels = tokenizer(text_target=targets, max_length=max_answer_length, padding=padding, truncation=True)
|
||||||
|
|
||||||
@@ -503,6 +511,7 @@ def main():
|
|||||||
]
|
]
|
||||||
|
|
||||||
model_inputs["labels"] = labels["input_ids"]
|
model_inputs["labels"] = labels["input_ids"]
|
||||||
|
|
||||||
return model_inputs
|
return model_inputs
|
||||||
|
|
||||||
if training_args.do_train:
|
if training_args.do_train:
|
||||||
@@ -627,7 +636,7 @@ def main():
|
|||||||
eval_examples=eval_examples if training_args.do_eval else None,
|
eval_examples=eval_examples if training_args.do_eval else None,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
data_collator=data_collator,
|
data_collator=data_collator,
|
||||||
compute_metrics=compute_metrics,
|
compute_metrics=compute_metrics if training_args.predict_with_generate else None,
|
||||||
post_process_function=post_processing_function,
|
post_process_function=post_processing_function,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -15,12 +15,14 @@
|
|||||||
"""
|
"""
|
||||||
A subclass of `Trainer` specific to Question-Answering tasks
|
A subclass of `Trainer` specific to Question-Answering tasks
|
||||||
"""
|
"""
|
||||||
|
import math
|
||||||
|
import time
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
from transformers import Seq2SeqTrainer, is_torch_tpu_available
|
from transformers import Seq2SeqTrainer, is_torch_tpu_available
|
||||||
from transformers.trainer_utils import PredictionOutput
|
from transformers.trainer_utils import PredictionOutput, speed_metrics
|
||||||
|
|
||||||
|
|
||||||
if is_torch_tpu_available(check_device=False):
|
if is_torch_tpu_available(check_device=False):
|
||||||
@@ -59,6 +61,7 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer):
|
|||||||
# Temporarily disable metric computation, we will do it in the loop here.
|
# Temporarily disable metric computation, we will do it in the loop here.
|
||||||
compute_metrics = self.compute_metrics
|
compute_metrics = self.compute_metrics
|
||||||
self.compute_metrics = None
|
self.compute_metrics = None
|
||||||
|
start_time = time.time()
|
||||||
eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
|
eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
|
||||||
try:
|
try:
|
||||||
output = eval_loop(
|
output = eval_loop(
|
||||||
@@ -71,6 +74,15 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer):
|
|||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
self.compute_metrics = compute_metrics
|
self.compute_metrics = compute_metrics
|
||||||
|
total_batch_size = self.args.eval_batch_size * self.args.world_size
|
||||||
|
output.metrics.update(
|
||||||
|
speed_metrics(
|
||||||
|
metric_key_prefix,
|
||||||
|
start_time,
|
||||||
|
num_samples=output.num_samples,
|
||||||
|
num_steps=math.ceil(output.num_samples / total_batch_size),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if self.post_process_function is not None and self.compute_metrics is not None:
|
if self.post_process_function is not None and self.compute_metrics is not None:
|
||||||
eval_preds = self.post_process_function(eval_examples, eval_dataset, output)
|
eval_preds = self.post_process_function(eval_examples, eval_dataset, output)
|
||||||
@@ -81,15 +93,15 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer):
|
|||||||
if not key.startswith(f"{metric_key_prefix}_"):
|
if not key.startswith(f"{metric_key_prefix}_"):
|
||||||
metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
|
metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
|
||||||
|
|
||||||
|
output.metrics.update(metrics)
|
||||||
|
|
||||||
self.log(metrics)
|
self.log(metrics)
|
||||||
else:
|
|
||||||
metrics = {}
|
|
||||||
|
|
||||||
if self.args.tpu_metrics_debug or self.args.debug:
|
if self.args.tpu_metrics_debug or self.args.debug:
|
||||||
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
|
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
|
||||||
xm.master_print(met.metrics_report())
|
xm.master_print(met.metrics_report())
|
||||||
|
|
||||||
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics)
|
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics)
|
||||||
return metrics
|
return metrics
|
||||||
|
|
||||||
def predict(
|
def predict(
|
||||||
|
|||||||
Reference in New Issue
Block a user