shift torch dynamo handling to accelerate (#23168)
* mixed precision support via accelerate
* fix issues
* fix for the sharded ddp case
* fix flax and tf failing tests
* `refactor the place to create `Accelerator` object
* move ddp prep to accelerate
* fix 😅
* resolving comments
* move fsdp handling to accelerate
* fixex
* fix saving
* shift torch dynamo handling to accelerate
This commit is contained in:
committed by
GitHub
parent
0b774074a5
commit
03db591047
@@ -1559,11 +1559,6 @@ class Trainer:
|
|||||||
|
|
||||||
self.accelerator.ddp_handler = DistributedDataParallelKwargs(**kwargs)
|
self.accelerator.ddp_handler = DistributedDataParallelKwargs(**kwargs)
|
||||||
|
|
||||||
# torch.compile() needs to be called after wrapping the model with FSDP or DDP
|
|
||||||
# to ensure that it accounts for the graph breaks required by those wrappers
|
|
||||||
if self.args.torch_compile:
|
|
||||||
model = torch.compile(model, backend=self.args.torch_compile_backend, mode=self.args.torch_compile_mode)
|
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
|
|||||||
@@ -1371,6 +1371,15 @@ class TrainingArguments:
|
|||||||
self.torch_compile = True
|
self.torch_compile = True
|
||||||
if self.torch_compile and self.torch_compile_backend is None:
|
if self.torch_compile and self.torch_compile_backend is None:
|
||||||
self.torch_compile_backend = "inductor"
|
self.torch_compile_backend = "inductor"
|
||||||
|
|
||||||
|
# accelerate integration for torch compile
|
||||||
|
if self.torch_compile:
|
||||||
|
# set env vars for accelerate
|
||||||
|
prefix = "ACCELERATE_DYNAMO_"
|
||||||
|
os.environ[prefix + "BACKEND"] = self.torch_compile_backend
|
||||||
|
if self.torch_compile_mode is not None:
|
||||||
|
os.environ[prefix + "MODE"] = self.torch_compile_mode
|
||||||
|
|
||||||
if self.framework == "pt" and is_torch_available() and self.torch_compile:
|
if self.framework == "pt" and is_torch_available() and self.torch_compile:
|
||||||
if is_torch_tf32_available():
|
if is_torch_tf32_available():
|
||||||
if self.tf32 is None and not self.fp16 or self.bf16:
|
if self.tf32 is None and not self.fp16 or self.bf16:
|
||||||
|
|||||||
Reference in New Issue
Block a user