fix run_seq2seq.py; porting trainer tests to it (#10162)
* fix run_seq2seq.py; porting DeepSpeed tests to it * unrefactor * defensive programming * defensive programming 2 * port the rest of the trainer tests * style * a cleaner scripts dir finder * cleanup
This commit is contained in:
@@ -115,15 +115,16 @@ class TestDeepSpeed(TestCasePlus):
|
||||
extra_args_str: str = None,
|
||||
remove_args_str: str = None,
|
||||
):
|
||||
data_dir = self.examples_dir / "seq2seq/test_data/wmt_en_ro"
|
||||
data_dir = self.examples_dir / "test_data/wmt_en_ro"
|
||||
output_dir = self.get_auto_remove_tmp_dir()
|
||||
args = f"""
|
||||
--model_name_or_path {model_name}
|
||||
--data_dir {data_dir}
|
||||
--train_file {data_dir}/train.json
|
||||
--validation_file {data_dir}/val.json
|
||||
--output_dir {output_dir}
|
||||
--overwrite_output_dir
|
||||
--n_train 8
|
||||
--n_val 8
|
||||
--max_train_samples 8
|
||||
--max_val_samples 8
|
||||
--max_source_length {max_len}
|
||||
--max_target_length {max_len}
|
||||
--val_max_target_length {max_len}
|
||||
@@ -139,8 +140,8 @@ class TestDeepSpeed(TestCasePlus):
|
||||
--label_smoothing_factor 0.1
|
||||
--adafactor
|
||||
--task translation
|
||||
--tgt_lang ro_RO
|
||||
--src_lang en_XX
|
||||
--target_lang ro_RO
|
||||
--source_lang en_XX
|
||||
""".split()
|
||||
|
||||
if extra_args_str is not None:
|
||||
@@ -151,7 +152,7 @@ class TestDeepSpeed(TestCasePlus):
|
||||
args = [x for x in args if x not in remove_args]
|
||||
|
||||
ds_args = f"--deepspeed {self.test_file_dir_str}/ds_config.json".split()
|
||||
script = [f"{self.examples_dir_str}/seq2seq/finetune_trainer.py"]
|
||||
script = [f"{self.examples_dir_str}/seq2seq/run_seq2seq.py"]
|
||||
num_gpus = get_gpu_count() if distributed else 1
|
||||
launcher = f"deepspeed --num_gpus {num_gpus}".split()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user