From 031ef8802c2d8bf6fec977fd274e1596d8a5ef21 Mon Sep 17 00:00:00 2001 From: Joaquin Caballero Date: Tue, 6 May 2025 18:51:28 +0300 Subject: [PATCH] fix FSDP + torch.compile bug when saving pretrained model (#37725) * args keep_torch_compile=False in _save and _wwrap_method * Fix FSDP execution on evaluation for torch_compile mode * add test trainer FSDP + Torch Compile * fix quality code * make style * Revert " make style" This reverts commit 77e797f8829c50992cc21496be3d9a3e480e1c97. * make style --- src/transformers/trainer.py | 9 +++++---- tests/trainer/test_trainer_fsdp.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 33 insertions(+), 4 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 3701df85b1..743cc353ab 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1986,7 +1986,7 @@ class Trainer: return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps) # train/eval could be run multiple-times - if already wrapped, don't re-wrap it again - if self.accelerator.unwrap_model(model) is not model: + if self.accelerator.unwrap_model(model, keep_torch_compile=False) is not model: return model # Mixed precision training with apex @@ -3998,8 +3998,8 @@ class Trainer: if state_dict is None: state_dict = self.model.state_dict() - if isinstance(self.accelerator.unwrap_model(self.model), supported_classes): - self.accelerator.unwrap_model(self.model).save_pretrained( + if isinstance(self.accelerator.unwrap_model(self.model, keep_torch_compile=False), supported_classes): + self.accelerator.unwrap_model(self.model, keep_torch_compile=False).save_pretrained( output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors ) else: @@ -4296,7 +4296,8 @@ class Trainer: start_time = time.time() model = ( self.accelerator.prepare(model) - if self.is_deepspeed_enabled or (self.is_fsdp_enabled and self.accelerator.mixed_precision != "fp8") + if self.is_deepspeed_enabled + or (self.is_fsdp_enabled and self.accelerator.mixed_precision != "fp8" and not self.args.torch_compile) else self.accelerator.prepare_model(model, evaluation_mode=True) ) self.model_preparation_time = round(time.time() - start_time, 4) diff --git a/tests/trainer/test_trainer_fsdp.py b/tests/trainer/test_trainer_fsdp.py index 690ebd9d80..1e30f9f454 100644 --- a/tests/trainer/test_trainer_fsdp.py +++ b/tests/trainer/test_trainer_fsdp.py @@ -147,6 +147,34 @@ class TestFSDPTrainerWrap(TestCasePlus): # successful return here == success - any errors would have caused an error in the sub-call +class TestFSDPTrainerTorchCompile(TestCasePlus): + @require_torch_multi_accelerator + @require_accelerate + @run_first + def test_trainer(self): + output_dir = self.get_auto_remove_tmp_dir() + cmd = [ + "accelerate", + "launch", + "--use_fsdp", + "--main_process_port", + f"{get_torch_dist_unique_port()}", + "--num_processes", + f"{backend_device_count(torch_device)}", + "--fsdp_transformer_layer_cls_to_wrap", + "GPT2Block", + f"{self.test_file_dir}/test_trainer_fsdp.py", + "--torch_compile_mode", + "default", + "--output_dir", + f"{output_dir}", + "--report_to", + "none", + ] + execute_subprocess_async(cmd, env=self.get_env()) + # successful return here == success - any errors would have caused an error in the sub-call + + if __name__ == "__main__": parser = HfArgumentParser((Seq2SeqTrainingArguments,)) training_args = parser.parse_args_into_dataclasses()[0]