Set TF32 flag for PyTorch cuDNN backend (#25075)
This commit is contained in:
@@ -1432,6 +1432,7 @@ class TrainingArguments:
|
||||
" otherwise."
|
||||
)
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
else:
|
||||
logger.warning(
|
||||
"The speedups for torchdynamo mostly come wih GPU Ampere or higher and which is not detected here."
|
||||
@@ -1440,11 +1441,13 @@ class TrainingArguments:
|
||||
if self.tf32:
|
||||
if is_torch_tf32_available():
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
else:
|
||||
raise ValueError("--tf32 requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7")
|
||||
else:
|
||||
if is_torch_tf32_available():
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
torch.backends.cudnn.allow_tf32 = False
|
||||
# no need to assert on else
|
||||
|
||||
if self.report_to is None:
|
||||
|
||||
Reference in New Issue
Block a user