fix run_seq2seq.py; porting trainer tests to it (#10162)
* fix run_seq2seq.py; porting DeepSpeed tests to it * unrefactor * defensive programming * defensive programming 2 * port the rest of the trainer tests * style * a cleaner scripts dir finder * cleanup
This commit is contained in:
@@ -18,6 +18,7 @@ Fine-tuning the library models for sequence to sequence.
|
||||
"""
|
||||
# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
@@ -38,6 +39,7 @@ from transformers import (
|
||||
DataCollatorForSeq2Seq,
|
||||
HfArgumentParser,
|
||||
MBartTokenizer,
|
||||
MBartTokenizerFast,
|
||||
Seq2SeqTrainer,
|
||||
Seq2SeqTrainingArguments,
|
||||
default_data_collator,
|
||||
@@ -53,6 +55,11 @@ with FileLock(".lock") as lock:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def save_json(content, path, indent=4, **json_dump_kwargs):
|
||||
with open(path, "w") as f:
|
||||
json.dump(content, f, indent=indent, sort_keys=True, **json_dump_kwargs)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
"""
|
||||
@@ -351,8 +358,15 @@ def main():
|
||||
)
|
||||
|
||||
# Set decoder_start_token_id
|
||||
if model.config.decoder_start_token_id is None and isinstance(tokenizer, MBartTokenizer):
|
||||
model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.target_lang]
|
||||
if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)):
|
||||
assert (
|
||||
data_args.target_lang is not None and data_args.source_lang is not None
|
||||
), "mBart requires --target_lang and --source_lang"
|
||||
if isinstance(tokenizer, MBartTokenizer):
|
||||
model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.target_lang]
|
||||
else:
|
||||
model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(data_args.target_lang)
|
||||
|
||||
if model.config.decoder_start_token_id is None:
|
||||
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
|
||||
|
||||
@@ -448,6 +462,8 @@ def main():
|
||||
|
||||
if training_args.do_train:
|
||||
train_dataset = datasets["train"]
|
||||
if "train" not in datasets:
|
||||
raise ValueError("--do_train requires a train dataset")
|
||||
if data_args.max_train_samples is not None:
|
||||
train_dataset = train_dataset.select(range(data_args.max_train_samples))
|
||||
train_dataset = train_dataset.map(
|
||||
@@ -460,6 +476,8 @@ def main():
|
||||
|
||||
if training_args.do_eval:
|
||||
max_target_length = data_args.val_max_target_length
|
||||
if "validation" not in datasets:
|
||||
raise ValueError("--do_eval requires a validation dataset")
|
||||
eval_dataset = datasets["validation"]
|
||||
if data_args.max_val_samples is not None:
|
||||
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
|
||||
@@ -473,6 +491,8 @@ def main():
|
||||
|
||||
if training_args.do_predict:
|
||||
max_target_length = data_args.val_max_target_length
|
||||
if "test" not in datasets:
|
||||
raise ValueError("--do_predict requires a test dataset")
|
||||
test_dataset = datasets["test"]
|
||||
if data_args.max_test_samples is not None:
|
||||
test_dataset = test_dataset.select(range(data_args.max_test_samples))
|
||||
@@ -550,6 +570,7 @@ def main():
|
||||
compute_metrics=compute_metrics if training_args.predict_with_generate else None,
|
||||
)
|
||||
|
||||
all_metrics = {}
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
if last_checkpoint is not None:
|
||||
@@ -561,13 +582,17 @@ def main():
|
||||
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
||||
trainer.save_model() # Saves the tokenizer too for easy upload
|
||||
|
||||
output_train_file = os.path.join(training_args.output_dir, "train_results.txt")
|
||||
metrics = train_result.metrics
|
||||
max_train_samples = (
|
||||
data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
|
||||
)
|
||||
metrics["train_samples"] = min(max_train_samples, len(train_dataset))
|
||||
if trainer.is_world_process_zero():
|
||||
with open(output_train_file, "w") as writer:
|
||||
logger.info("***** Train results *****")
|
||||
for key, value in sorted(train_result.metrics.items()):
|
||||
logger.info(f" {key} = {value}")
|
||||
writer.write(f"{key} = {value}\n")
|
||||
logger.info("***** train metrics *****")
|
||||
for key in sorted(metrics.keys()):
|
||||
logger.info(f" {key} = {metrics[key]}")
|
||||
save_json(metrics, os.path.join(training_args.output_dir, "train_results.json"))
|
||||
all_metrics.update(metrics)
|
||||
|
||||
# Need to save the state, since Trainer.save_model saves only the tokenizer with the model
|
||||
trainer.state.save_to_json(os.path.join(training_args.output_dir, "trainer_state.json"))
|
||||
@@ -577,16 +602,19 @@ def main():
|
||||
if training_args.do_eval:
|
||||
logger.info("*** Evaluate ***")
|
||||
|
||||
results = trainer.evaluate(max_length=data_args.val_max_target_length, num_beams=data_args.num_beams)
|
||||
results = {k: round(v, 4) for k, v in results.items()}
|
||||
metrics = trainer.evaluate(
|
||||
max_length=data_args.val_max_target_length, num_beams=data_args.num_beams, metric_key_prefix="val"
|
||||
)
|
||||
metrics = {k: round(v, 4) for k, v in metrics.items()}
|
||||
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset)
|
||||
metrics["val_samples"] = min(max_val_samples, len(eval_dataset))
|
||||
|
||||
output_eval_file = os.path.join(training_args.output_dir, "eval_results_seq2seq.txt")
|
||||
if trainer.is_world_process_zero():
|
||||
with open(output_eval_file, "w") as writer:
|
||||
logger.info("***** Eval results *****")
|
||||
for key, value in sorted(results.items()):
|
||||
logger.info(f" {key} = {value}")
|
||||
writer.write(f"{key} = {value}\n")
|
||||
logger.info("***** val metrics *****")
|
||||
for key in sorted(metrics.keys()):
|
||||
logger.info(f" {key} = {metrics[key]}")
|
||||
save_json(metrics, os.path.join(training_args.output_dir, "val_results.json"))
|
||||
all_metrics.update(metrics)
|
||||
|
||||
if training_args.do_predict:
|
||||
logger.info("*** Test ***")
|
||||
@@ -597,16 +625,17 @@ def main():
|
||||
max_length=data_args.val_max_target_length,
|
||||
num_beams=data_args.num_beams,
|
||||
)
|
||||
test_metrics = test_results.metrics
|
||||
test_metrics["test_loss"] = round(test_metrics["test_loss"], 4)
|
||||
metrics = test_results.metrics
|
||||
max_test_samples = data_args.max_test_samples if data_args.max_test_samples is not None else len(test_dataset)
|
||||
metrics["test_samples"] = min(max_test_samples, len(test_dataset))
|
||||
metrics = {k: round(v, 4) for k, v in metrics.items()}
|
||||
|
||||
output_test_result_file = os.path.join(training_args.output_dir, "test_results_seq2seq.txt")
|
||||
if trainer.is_world_process_zero():
|
||||
with open(output_test_result_file, "w") as writer:
|
||||
logger.info("***** Test results *****")
|
||||
for key, value in sorted(test_metrics.items()):
|
||||
logger.info(f" {key} = {value}")
|
||||
writer.write(f"{key} = {value}\n")
|
||||
logger.info("***** test metrics *****")
|
||||
for key in sorted(metrics.keys()):
|
||||
logger.info(f" {key} = {metrics[key]}")
|
||||
save_json(metrics, os.path.join(training_args.output_dir, "test_results.json"))
|
||||
all_metrics.update(metrics)
|
||||
|
||||
if training_args.predict_with_generate:
|
||||
test_preds = tokenizer.batch_decode(
|
||||
@@ -617,6 +646,9 @@ def main():
|
||||
with open(output_test_preds_file, "w") as writer:
|
||||
writer.write("\n".join(test_preds))
|
||||
|
||||
if trainer.is_world_process_zero():
|
||||
save_json(all_metrics, os.path.join(training_args.output_dir, "all_results.json"))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user