Smangrul/accelerate mp integrate (#23148)

* mixed precision support via accelerate

* fix issues

* fix for the sharded ddp case

* fix flax and tf failing tests

* `refactor the place to create `Accelerator` object

* address comments by removing debugging print statements
This commit is contained in:
Sourab Mangrulkar
2023-05-31 12:27:51 +05:30
committed by GitHub
parent de9255de27
commit 9f0646a555
2 changed files with 51 additions and 26 deletions

View File

@@ -1562,6 +1562,15 @@ class TrainingArguments:
FutureWarning,
)
# if training args is specified, it will override the one specified in the accelerate config
if self.half_precision_backend != "apex" and len(self.sharded_ddp) == 0:
mixed_precision_dtype = os.environ.get("ACCELERATE_MIXED_PRECISION", "no")
if self.fp16:
mixed_precision_dtype = "fp16"
elif self.bf16:
mixed_precision_dtype = "bf16"
os.environ["ACCELERATE_MIXED_PRECISION"] = mixed_precision_dtype
def __str__(self):
self_as_dict = asdict(self)