Remove FSDP wrapping from sub-models. (#34452)

* Remove FSDP wrapping from sub-models.

* solve conflict trainer.py

* make fixup

* add unit test for fsdp_auto_wrap_policy when using auto_find_batch_size

* put back extract_model_from_parallel

* use transformers unwrap_model
This commit is contained in:
AbdelKarim ELJANDOUBI
2024-11-15 23:00:03 +01:00
committed by GitHub
parent b0c0ba7b4d
commit 8d50fda644
2 changed files with 33 additions and 3 deletions

View File

@@ -117,6 +117,33 @@ class TestFSDPTrainerFP8(TestCasePlus):
execute_subprocess_async(cmd, env=self.get_env())
# successful return here == success - any errors would have caused an error in the sub-call
class TestFSDPTrainerWrap(TestCasePlus):
@require_accelerate
@require_torch_multi_gpu
@require_fsdp
def test_trainer(self):
output_dir = self.get_auto_remove_tmp_dir()
cmd = [
"accelerate",
"launch",
"--use_fsdp",
"--main_process_port",
f"{get_torch_dist_unique_port()}",
"--num_processes",
f"{torch.cuda.device_count()}",
"--fsdp_transformer_layer_cls_to_wrap",
"GPT2Block",
f"{self.test_file_dir}/test_trainer_fsdp.py",
"--output_dir",
f"{output_dir}",
"--report_to",
"none",
"--auto_find_batch_size",
"True",
]
execute_subprocess_async(cmd, env=self.get_env())
# successful return here == success - any errors would have caused an error in the sub-call
if __name__ == "__main__":
parser = HfArgumentParser((Seq2SeqTrainingArguments,))