Set TF32 flag for PyTorch cuDNN backend (#25075)

This commit is contained in:
Xuehai Pan
2023-07-25 20:04:48 +08:00
committed by GitHub
parent 5dba88b2d2
commit 6bc61aa7af
3 changed files with 6 additions and 0 deletions

View File

@@ -203,6 +203,7 @@ improvement. All you need to do is to add the following to your code:
``` ```
import torch import torch
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
``` ```
CUDA will automatically switch to using tf32 instead of fp32 where possible, assuming that the used GPU is from the Ampere series. CUDA will automatically switch to using tf32 instead of fp32 where possible, assuming that the used GPU is from the Ampere series.

View File

@@ -1432,6 +1432,7 @@ class TrainingArguments:
" otherwise." " otherwise."
) )
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
else: else:
logger.warning( logger.warning(
"The speedups for torchdynamo mostly come wih GPU Ampere or higher and which is not detected here." "The speedups for torchdynamo mostly come wih GPU Ampere or higher and which is not detected here."
@@ -1440,11 +1441,13 @@ class TrainingArguments:
if self.tf32: if self.tf32:
if is_torch_tf32_available(): if is_torch_tf32_available():
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
else: else:
raise ValueError("--tf32 requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7") raise ValueError("--tf32 requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7")
else: else:
if is_torch_tf32_available(): if is_torch_tf32_available():
torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
# no need to assert on else # no need to assert on else
if self.report_to is None: if self.report_to is None:

View File

@@ -167,6 +167,7 @@ class Jukebox1bModelTester(unittest.TestCase):
@slow @slow
def test_conditioning(self): def test_conditioning(self):
torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
model = JukeboxModel.from_pretrained(self.model_id, min_duration=0).eval() model = JukeboxModel.from_pretrained(self.model_id, min_duration=0).eval()
labels = self.prepare_inputs() labels = self.prepare_inputs()
@@ -195,6 +196,7 @@ class Jukebox1bModelTester(unittest.TestCase):
@slow @slow
def test_primed_sampling(self): def test_primed_sampling(self):
torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
model = JukeboxModel.from_pretrained(self.model_id, min_duration=0).eval() model = JukeboxModel.from_pretrained(self.model_id, min_duration=0).eval()
set_seed(0) set_seed(0)