[seq2seq testing] improve readability (#7845)
This commit is contained in:
@@ -47,58 +47,38 @@ def test_finetune_trainer_slow():
|
|||||||
def run_trainer(eval_steps: int, max_len: str, model_name: str, num_train_epochs: int):
|
def run_trainer(eval_steps: int, max_len: str, model_name: str, num_train_epochs: int):
|
||||||
data_dir = "examples/seq2seq/test_data/wmt_en_ro"
|
data_dir = "examples/seq2seq/test_data/wmt_en_ro"
|
||||||
output_dir = tempfile.mkdtemp(prefix="test_output")
|
output_dir = tempfile.mkdtemp(prefix="test_output")
|
||||||
argv = [
|
argv = f"""
|
||||||
"--model_name_or_path",
|
--model_name_or_path {model_name}
|
||||||
model_name,
|
--data_dir {data_dir}
|
||||||
"--data_dir",
|
--output_dir {output_dir}
|
||||||
data_dir,
|
--overwrite_output_dir
|
||||||
"--output_dir",
|
--n_train 8
|
||||||
output_dir,
|
--n_val 8
|
||||||
"--overwrite_output_dir",
|
--max_source_length {max_len}
|
||||||
"--n_train",
|
--max_target_length {max_len}
|
||||||
"8",
|
--val_max_target_length {max_len}
|
||||||
"--n_val",
|
--do_train
|
||||||
"8",
|
--do_eval
|
||||||
"--max_source_length",
|
--do_predict
|
||||||
max_len,
|
--num_train_epochs {str(num_train_epochs)}
|
||||||
"--max_target_length",
|
--per_device_train_batch_size 4
|
||||||
max_len,
|
--per_device_eval_batch_size 4
|
||||||
"--val_max_target_length",
|
--learning_rate 3e-4
|
||||||
max_len,
|
--warmup_steps 8
|
||||||
"--do_train",
|
--evaluate_during_training
|
||||||
"--do_eval",
|
--predict_with_generate
|
||||||
"--do_predict",
|
--logging_steps 0
|
||||||
"--num_train_epochs",
|
--save_steps {str(eval_steps)}
|
||||||
str(num_train_epochs),
|
--eval_steps {str(eval_steps)}
|
||||||
"--per_device_train_batch_size",
|
--sortish_sampler
|
||||||
"4",
|
--label_smoothing 0.1
|
||||||
"--per_device_eval_batch_size",
|
--adafactor
|
||||||
"4",
|
--task translation
|
||||||
"--learning_rate",
|
--tgt_lang ro_RO
|
||||||
"3e-4",
|
--src_lang en_XX
|
||||||
"--warmup_steps",
|
""".split()
|
||||||
"8",
|
# --eval_beams 2
|
||||||
"--evaluate_during_training",
|
|
||||||
"--predict_with_generate",
|
|
||||||
"--logging_steps",
|
|
||||||
0,
|
|
||||||
"--save_steps",
|
|
||||||
str(eval_steps),
|
|
||||||
"--eval_steps",
|
|
||||||
str(eval_steps),
|
|
||||||
"--sortish_sampler",
|
|
||||||
"--label_smoothing",
|
|
||||||
"0.1",
|
|
||||||
# "--eval_beams",
|
|
||||||
# "2",
|
|
||||||
"--adafactor",
|
|
||||||
"--task",
|
|
||||||
"translation",
|
|
||||||
"--tgt_lang",
|
|
||||||
"ro_RO",
|
|
||||||
"--src_lang",
|
|
||||||
"en_XX",
|
|
||||||
]
|
|
||||||
testargs = ["finetune_trainer.py"] + argv
|
testargs = ["finetune_trainer.py"] + argv
|
||||||
with patch.object(sys, "argv", testargs):
|
with patch.object(sys, "argv", testargs):
|
||||||
main()
|
main()
|
||||||
|
|||||||
Reference in New Issue
Block a user