Fix result saving errors of pytorch examples (#20276)
This commit is contained in:
@@ -85,7 +85,7 @@ def parse_args():
|
||||
"--validation_file", type=str, default=None, help="A csv or a json file containing the validation data."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_length",
|
||||
"--max_seq_length",
|
||||
type=int,
|
||||
default=128,
|
||||
help=(
|
||||
@@ -424,7 +424,7 @@ def main():
|
||||
tokenized_examples = tokenizer(
|
||||
first_sentences,
|
||||
second_sentences,
|
||||
max_length=args.max_length,
|
||||
max_length=args.max_seq_length,
|
||||
padding=padding,
|
||||
truncation=True,
|
||||
)
|
||||
@@ -654,8 +654,10 @@ def main():
|
||||
tokenizer.save_pretrained(args.output_dir)
|
||||
if args.push_to_hub:
|
||||
repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True)
|
||||
with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
|
||||
json.dump({"eval_accuracy": eval_metric["accuracy"]}, f)
|
||||
|
||||
all_results = {f"eval_{k}": v for k, v in eval_metric.items()}
|
||||
with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
|
||||
json.dump(all_results, f)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user