From 03db59104714f60b591ef48073840b288ee4cdc0 Mon Sep 17 00:00:00 2001 From: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> Date: Wed, 31 May 2023 14:42:07 +0530 Subject: [PATCH] shift torch dynamo handling to accelerate (#23168) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * mixed precision support via accelerate * fix issues * fix for the sharded ddp case * fix flax and tf failing tests * `refactor the place to create `Accelerator` object * move ddp prep to accelerate * fix 😅 * resolving comments * move fsdp handling to accelerate * fixex * fix saving * shift torch dynamo handling to accelerate --- src/transformers/trainer.py | 5 ----- src/transformers/training_args.py | 9 +++++++++ 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 491d37cbc2..1531a65e46 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1559,11 +1559,6 @@ class Trainer: self.accelerator.ddp_handler = DistributedDataParallelKwargs(**kwargs) - # torch.compile() needs to be called after wrapping the model with FSDP or DDP - # to ensure that it accounts for the graph breaks required by those wrappers - if self.args.torch_compile: - model = torch.compile(model, backend=self.args.torch_compile_backend, mode=self.args.torch_compile_mode) - return model def train( diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 001c8c5eeb..403d437843 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1371,6 +1371,15 @@ class TrainingArguments: self.torch_compile = True if self.torch_compile and self.torch_compile_backend is None: self.torch_compile_backend = "inductor" + + # accelerate integration for torch compile + if self.torch_compile: + # set env vars for accelerate + prefix = "ACCELERATE_DYNAMO_" + os.environ[prefix + "BACKEND"] = self.torch_compile_backend + if self.torch_compile_mode is not None: + os.environ[prefix + "MODE"] = self.torch_compile_mode + if self.framework == "pt" and is_torch_available() and self.torch_compile: if is_torch_tf32_available(): if self.tf32 is None and not self.fp16 or self.bf16: