diff --git a/examples/seq2seq/test_finetune_trainer.py b/examples/seq2seq/test_finetune_trainer.py index 2860eeec17..18ebd1e695 100644 --- a/examples/seq2seq/test_finetune_trainer.py +++ b/examples/seq2seq/test_finetune_trainer.py @@ -47,58 +47,38 @@ def test_finetune_trainer_slow(): def run_trainer(eval_steps: int, max_len: str, model_name: str, num_train_epochs: int): data_dir = "examples/seq2seq/test_data/wmt_en_ro" output_dir = tempfile.mkdtemp(prefix="test_output") - argv = [ - "--model_name_or_path", - model_name, - "--data_dir", - data_dir, - "--output_dir", - output_dir, - "--overwrite_output_dir", - "--n_train", - "8", - "--n_val", - "8", - "--max_source_length", - max_len, - "--max_target_length", - max_len, - "--val_max_target_length", - max_len, - "--do_train", - "--do_eval", - "--do_predict", - "--num_train_epochs", - str(num_train_epochs), - "--per_device_train_batch_size", - "4", - "--per_device_eval_batch_size", - "4", - "--learning_rate", - "3e-4", - "--warmup_steps", - "8", - "--evaluate_during_training", - "--predict_with_generate", - "--logging_steps", - 0, - "--save_steps", - str(eval_steps), - "--eval_steps", - str(eval_steps), - "--sortish_sampler", - "--label_smoothing", - "0.1", - # "--eval_beams", - # "2", - "--adafactor", - "--task", - "translation", - "--tgt_lang", - "ro_RO", - "--src_lang", - "en_XX", - ] + argv = f""" + --model_name_or_path {model_name} + --data_dir {data_dir} + --output_dir {output_dir} + --overwrite_output_dir + --n_train 8 + --n_val 8 + --max_source_length {max_len} + --max_target_length {max_len} + --val_max_target_length {max_len} + --do_train + --do_eval + --do_predict + --num_train_epochs {str(num_train_epochs)} + --per_device_train_batch_size 4 + --per_device_eval_batch_size 4 + --learning_rate 3e-4 + --warmup_steps 8 + --evaluate_during_training + --predict_with_generate + --logging_steps 0 + --save_steps {str(eval_steps)} + --eval_steps {str(eval_steps)} + --sortish_sampler + --label_smoothing 0.1 + --adafactor + --task translation + --tgt_lang ro_RO + --src_lang en_XX + """.split() + # --eval_beams 2 + testargs = ["finetune_trainer.py"] + argv with patch.object(sys, "argv", testargs): main()