fix FSDP + torch.compile bug when saving pretrained model (#37725)

* args keep_torch_compile=False in _save and _wwrap_method

* Fix FSDP execution on evaluation  for torch_compile mode

* add test trainer FSDP + Torch Compile

* fix quality code

* make style

* Revert " make style"

This reverts commit 77e797f8829c50992cc21496be3d9a3e480e1c97.

* make style
This commit is contained in:
Joaquin Caballero
2025-05-06 18:51:28 +03:00
committed by GitHub
parent 5534b80b7f
commit 031ef8802c
2 changed files with 33 additions and 4 deletions

View File

@@ -147,6 +147,34 @@ class TestFSDPTrainerWrap(TestCasePlus):
# successful return here == success - any errors would have caused an error in the sub-call
class TestFSDPTrainerTorchCompile(TestCasePlus):
@require_torch_multi_accelerator
@require_accelerate
@run_first
def test_trainer(self):
output_dir = self.get_auto_remove_tmp_dir()
cmd = [
"accelerate",
"launch",
"--use_fsdp",
"--main_process_port",
f"{get_torch_dist_unique_port()}",
"--num_processes",
f"{backend_device_count(torch_device)}",
"--fsdp_transformer_layer_cls_to_wrap",
"GPT2Block",
f"{self.test_file_dir}/test_trainer_fsdp.py",
"--torch_compile_mode",
"default",
"--output_dir",
f"{output_dir}",
"--report_to",
"none",
]
execute_subprocess_async(cmd, env=self.get_env())
# successful return here == success - any errors would have caused an error in the sub-call
if __name__ == "__main__":
parser = HfArgumentParser((Seq2SeqTrainingArguments,))
training_args = parser.parse_args_into_dataclasses()[0]