Set TF32 flag for PyTorch cuDNN backend (#25075)

This commit is contained in:
Xuehai Pan
2023-07-25 20:04:48 +08:00
committed by GitHub
parent 5dba88b2d2
commit 6bc61aa7af
3 changed files with 6 additions and 0 deletions

View File

@@ -1432,6 +1432,7 @@ class TrainingArguments:
" otherwise."
)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
else:
logger.warning(
"The speedups for torchdynamo mostly come wih GPU Ampere or higher and which is not detected here."
@@ -1440,11 +1441,13 @@ class TrainingArguments:
if self.tf32:
if is_torch_tf32_available():
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
else:
raise ValueError("--tf32 requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7")
else:
if is_torch_tf32_available():
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
# no need to assert on else
if self.report_to is None: