Set TF32 flag for PyTorch cuDNN backend (#25075)
This commit is contained in:
@@ -203,6 +203,7 @@ improvement. All you need to do is to add the following to your code:
|
||||
```
|
||||
import torch
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
```
|
||||
|
||||
CUDA will automatically switch to using tf32 instead of fp32 where possible, assuming that the used GPU is from the Ampere series.
|
||||
|
||||
@@ -1432,6 +1432,7 @@ class TrainingArguments:
|
||||
" otherwise."
|
||||
)
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
else:
|
||||
logger.warning(
|
||||
"The speedups for torchdynamo mostly come wih GPU Ampere or higher and which is not detected here."
|
||||
@@ -1440,11 +1441,13 @@ class TrainingArguments:
|
||||
if self.tf32:
|
||||
if is_torch_tf32_available():
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
else:
|
||||
raise ValueError("--tf32 requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7")
|
||||
else:
|
||||
if is_torch_tf32_available():
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
torch.backends.cudnn.allow_tf32 = False
|
||||
# no need to assert on else
|
||||
|
||||
if self.report_to is None:
|
||||
|
||||
@@ -167,6 +167,7 @@ class Jukebox1bModelTester(unittest.TestCase):
|
||||
@slow
|
||||
def test_conditioning(self):
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
torch.backends.cudnn.allow_tf32 = False
|
||||
model = JukeboxModel.from_pretrained(self.model_id, min_duration=0).eval()
|
||||
|
||||
labels = self.prepare_inputs()
|
||||
@@ -195,6 +196,7 @@ class Jukebox1bModelTester(unittest.TestCase):
|
||||
@slow
|
||||
def test_primed_sampling(self):
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
torch.backends.cudnn.allow_tf32 = False
|
||||
|
||||
model = JukeboxModel.from_pretrained(self.model_id, min_duration=0).eval()
|
||||
set_seed(0)
|
||||
|
||||
Reference in New Issue
Block a user