Smangrul/accelerate mp integrate (#23148)
* mixed precision support via accelerate * fix issues * fix for the sharded ddp case * fix flax and tf failing tests * `refactor the place to create `Accelerator` object * address comments by removing debugging print statements
This commit is contained in:
committed by
GitHub
parent
de9255de27
commit
9f0646a555
@@ -1562,6 +1562,15 @@ class TrainingArguments:
|
||||
FutureWarning,
|
||||
)
|
||||
|
||||
# if training args is specified, it will override the one specified in the accelerate config
|
||||
if self.half_precision_backend != "apex" and len(self.sharded_ddp) == 0:
|
||||
mixed_precision_dtype = os.environ.get("ACCELERATE_MIXED_PRECISION", "no")
|
||||
if self.fp16:
|
||||
mixed_precision_dtype = "fp16"
|
||||
elif self.bf16:
|
||||
mixed_precision_dtype = "bf16"
|
||||
os.environ["ACCELERATE_MIXED_PRECISION"] = mixed_precision_dtype
|
||||
|
||||
def __str__(self):
|
||||
self_as_dict = asdict(self)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user