From 7d2001aa44ac2ac9410d75d71845dbbd87f910e2 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Tue, 18 Jun 2019 22:13:30 +0200 Subject: [PATCH] overwrite_output_dir --- examples/run_classifier.py | 5 ++++- examples/run_squad.py | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/examples/run_classifier.py b/examples/run_classifier.py index 7c00e4833d..c3a16f593d 100644 --- a/examples/run_classifier.py +++ b/examples/run_classifier.py @@ -127,6 +127,9 @@ def main(): parser.add_argument("--no_cuda", action='store_true', help="Whether not to use CUDA when available") + parser.add_argument('--overwrite_output_dir', + action='store_true', + help="Overwrite the content of the output directory") parser.add_argument("--local_rank", type=int, default=-1, @@ -191,7 +194,7 @@ def main(): if not args.do_train and not args.do_eval: raise ValueError("At least one of `do_train` or `do_eval` must be True.") - if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train: + if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir: raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir)) if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) diff --git a/examples/run_squad.py b/examples/run_squad.py index 32e20f9c94..f20dd9d356 100644 --- a/examples/run_squad.py +++ b/examples/run_squad.py @@ -111,6 +111,9 @@ def main(): parser.add_argument('--fp16', action='store_true', help="Whether to use 16-bit float precision instead of 32-bit") + parser.add_argument('--overwrite_output_dir', + action='store_true', + help="Overwrite the content of the output directory") parser.add_argument('--loss_scale', type=float, default=0, help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" @@ -175,7 +178,7 @@ def main(): raise ValueError( "If `do_predict` is True, then `predict_file` must be specified.") - if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train: + if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir: raise ValueError("Output directory () already exists and is not empty.") if not os.path.exists(args.output_dir): os.makedirs(args.output_dir)