diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 491d37cbc2..1531a65e46 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1559,11 +1559,6 @@ class Trainer: self.accelerator.ddp_handler = DistributedDataParallelKwargs(**kwargs) - # torch.compile() needs to be called after wrapping the model with FSDP or DDP - # to ensure that it accounts for the graph breaks required by those wrappers - if self.args.torch_compile: - model = torch.compile(model, backend=self.args.torch_compile_backend, mode=self.args.torch_compile_mode) - return model def train( diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 001c8c5eeb..403d437843 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1371,6 +1371,15 @@ class TrainingArguments: self.torch_compile = True if self.torch_compile and self.torch_compile_backend is None: self.torch_compile_backend = "inductor" + + # accelerate integration for torch compile + if self.torch_compile: + # set env vars for accelerate + prefix = "ACCELERATE_DYNAMO_" + os.environ[prefix + "BACKEND"] = self.torch_compile_backend + if self.torch_compile_mode is not None: + os.environ[prefix + "MODE"] = self.torch_compile_mode + if self.framework == "pt" and is_torch_available() and self.torch_compile: if is_torch_tf32_available(): if self.tf32 is None and not self.fp16 or self.bf16: