[seq2seq testing] multigpu test run via subprocess (#7281)

Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
Stas Bekman
2020-10-21 14:20:53 -07:00
committed by GitHub
parent f8d3695e8c
commit 8b38173398
6 changed files with 294 additions and 15 deletions

View File

@@ -25,6 +25,7 @@ from utils import (
assert_all_frozen,
calculate_bleu,
calculate_rouge,
check_output_dir,
flatten_list,
freeze_embeds,
freeze_params,
@@ -329,6 +330,7 @@ class SummarizationModule(BaseTransformer):
parser.add_argument("--freeze_encoder", action="store_true")
parser.add_argument("--freeze_embeds", action="store_true")
parser.add_argument("--sortish_sampler", action="store_true", default=False)
parser.add_argument("--overwrite_output_dir", action="store_true", default=False)
parser.add_argument("--max_tokens_per_batch", type=int, default=None)
parser.add_argument("--logger_name", type=str, choices=["default", "wandb", "wandb_shared"], default="default")
parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.")
@@ -373,8 +375,8 @@ class TranslationModule(SummarizationModule):
def main(args, model=None) -> SummarizationModule:
Path(args.output_dir).mkdir(exist_ok=True)
if len(os.listdir(args.output_dir)) > 3 and args.do_train:
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
check_output_dir(args, expected_items=3)
if model is None:
if "summarization" in args.task:
model: SummarizationModule = SummarizationModule(args)