[seq2seq testing] multigpu test run via subprocess (#7281)
Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
@@ -25,6 +25,7 @@ from utils import (
|
||||
assert_all_frozen,
|
||||
calculate_bleu,
|
||||
calculate_rouge,
|
||||
check_output_dir,
|
||||
flatten_list,
|
||||
freeze_embeds,
|
||||
freeze_params,
|
||||
@@ -329,6 +330,7 @@ class SummarizationModule(BaseTransformer):
|
||||
parser.add_argument("--freeze_encoder", action="store_true")
|
||||
parser.add_argument("--freeze_embeds", action="store_true")
|
||||
parser.add_argument("--sortish_sampler", action="store_true", default=False)
|
||||
parser.add_argument("--overwrite_output_dir", action="store_true", default=False)
|
||||
parser.add_argument("--max_tokens_per_batch", type=int, default=None)
|
||||
parser.add_argument("--logger_name", type=str, choices=["default", "wandb", "wandb_shared"], default="default")
|
||||
parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.")
|
||||
@@ -373,8 +375,8 @@ class TranslationModule(SummarizationModule):
|
||||
|
||||
def main(args, model=None) -> SummarizationModule:
|
||||
Path(args.output_dir).mkdir(exist_ok=True)
|
||||
if len(os.listdir(args.output_dir)) > 3 and args.do_train:
|
||||
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
|
||||
check_output_dir(args, expected_items=3)
|
||||
|
||||
if model is None:
|
||||
if "summarization" in args.task:
|
||||
model: SummarizationModule = SummarizationModule(args)
|
||||
|
||||
Reference in New Issue
Block a user