Fix trainer seq2seq qa.py evaluate log and ft script (#19208)
* fix args option * fix trainer eval log * fix out of memory qa script * do isort, black, flake * fix tokenize target * take it back. * fix: comment
This commit is contained in:
@@ -327,21 +327,28 @@ def main():
|
||||
if data_args.dataset_name is not None:
|
||||
# Downloading and loading a dataset from the hub.
|
||||
raw_datasets = load_dataset(
|
||||
data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir
|
||||
data_args.dataset_name,
|
||||
data_args.dataset_config_name,
|
||||
cache_dir=model_args.cache_dir,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
)
|
||||
else:
|
||||
data_files = {}
|
||||
if data_args.train_file is not None:
|
||||
data_files["train"] = data_args.train_file
|
||||
extension = data_args.train_file.split(".")[-1]
|
||||
|
||||
if data_args.validation_file is not None:
|
||||
data_files["validation"] = data_args.validation_file
|
||||
extension = data_args.validation_file.split(".")[-1]
|
||||
if data_args.test_file is not None:
|
||||
data_files["test"] = data_args.test_file
|
||||
extension = data_args.test_file.split(".")[-1]
|
||||
raw_datasets = load_dataset(extension, data_files=data_files, field="data", cache_dir=model_args.cache_dir)
|
||||
raw_datasets = load_dataset(
|
||||
extension,
|
||||
data_files=data_files,
|
||||
cache_dir=model_args.cache_dir,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
)
|
||||
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
||||
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
||||
|
||||
@@ -359,7 +366,7 @@ def main():
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
|
||||
cache_dir=model_args.cache_dir,
|
||||
use_fast=True,
|
||||
use_fast=model_args.use_fast_tokenizer,
|
||||
revision=model_args.model_revision,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
)
|
||||
@@ -476,9 +483,10 @@ def main():
|
||||
max_length=max_seq_length,
|
||||
padding=padding,
|
||||
truncation=True,
|
||||
return_overflowing_tokens=True,
|
||||
return_offsets_mapping=True,
|
||||
return_overflowing_tokens=True,
|
||||
)
|
||||
|
||||
# Tokenize targets with the `text_target` keyword argument
|
||||
labels = tokenizer(text_target=targets, max_length=max_answer_length, padding=padding, truncation=True)
|
||||
|
||||
@@ -503,6 +511,7 @@ def main():
|
||||
]
|
||||
|
||||
model_inputs["labels"] = labels["input_ids"]
|
||||
|
||||
return model_inputs
|
||||
|
||||
if training_args.do_train:
|
||||
@@ -627,7 +636,7 @@ def main():
|
||||
eval_examples=eval_examples if training_args.do_eval else None,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
compute_metrics=compute_metrics,
|
||||
compute_metrics=compute_metrics if training_args.predict_with_generate else None,
|
||||
post_process_function=post_processing_function,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user