Fix result saving errors of pytorch examples (#20276)
This commit is contained in:
@@ -766,10 +766,11 @@ def main():
|
||||
if args.push_to_hub:
|
||||
repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True)
|
||||
|
||||
with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
|
||||
json.dump(
|
||||
{"eval_accuracy": eval_metric["accuracy"], "train_loss": total_loss.item() / len(train_dataloader)}, f
|
||||
)
|
||||
all_results = {f"eval_{k}": v for k, v in eval_metric.items()}
|
||||
if args.with_tracking:
|
||||
all_results.update({"train_loss": total_loss.item() / len(train_dataloader)})
|
||||
with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
|
||||
json.dump(all_results, f)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user