[seq2seq testing] multigpu test run via subprocess (#7281)
Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
@@ -17,7 +17,7 @@ from finetune import main as ft_main
|
||||
from make_student import create_student_by_copying_alternating_layers, get_layers_to_supervise
|
||||
from transformers import AutoModelForSeq2SeqLM, MBartTokenizer, T5ForConditionalGeneration
|
||||
from transformers.modeling_bart import shift_tokens_right
|
||||
from utils import calculate_bleu, freeze_params, label_smoothed_nll_loss, use_task_specific_params
|
||||
from utils import calculate_bleu, check_output_dir, freeze_params, label_smoothed_nll_loss, use_task_specific_params
|
||||
|
||||
|
||||
# need the parent dir module
|
||||
@@ -266,8 +266,7 @@ def create_module(args):
|
||||
|
||||
def distill_main(args):
|
||||
Path(args.output_dir).mkdir(exist_ok=True)
|
||||
if len(os.listdir(args.output_dir)) > 3 and args.do_train:
|
||||
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
|
||||
check_output_dir(args, expected_items=3)
|
||||
|
||||
model = create_module(args)
|
||||
return ft_main(args, model=model)
|
||||
|
||||
Reference in New Issue
Block a user