Set TF32 flag for PyTorch cuDNN backend (#25075)
This commit is contained in:
@@ -203,6 +203,7 @@ improvement. All you need to do is to add the following to your code:
|
|||||||
```
|
```
|
||||||
import torch
|
import torch
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
torch.backends.cudnn.allow_tf32 = True
|
||||||
```
|
```
|
||||||
|
|
||||||
CUDA will automatically switch to using tf32 instead of fp32 where possible, assuming that the used GPU is from the Ampere series.
|
CUDA will automatically switch to using tf32 instead of fp32 where possible, assuming that the used GPU is from the Ampere series.
|
||||||
|
|||||||
@@ -1432,6 +1432,7 @@ class TrainingArguments:
|
|||||||
" otherwise."
|
" otherwise."
|
||||||
)
|
)
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
torch.backends.cudnn.allow_tf32 = True
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"The speedups for torchdynamo mostly come wih GPU Ampere or higher and which is not detected here."
|
"The speedups for torchdynamo mostly come wih GPU Ampere or higher and which is not detected here."
|
||||||
@@ -1440,11 +1441,13 @@ class TrainingArguments:
|
|||||||
if self.tf32:
|
if self.tf32:
|
||||||
if is_torch_tf32_available():
|
if is_torch_tf32_available():
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
torch.backends.cudnn.allow_tf32 = True
|
||||||
else:
|
else:
|
||||||
raise ValueError("--tf32 requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7")
|
raise ValueError("--tf32 requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7")
|
||||||
else:
|
else:
|
||||||
if is_torch_tf32_available():
|
if is_torch_tf32_available():
|
||||||
torch.backends.cuda.matmul.allow_tf32 = False
|
torch.backends.cuda.matmul.allow_tf32 = False
|
||||||
|
torch.backends.cudnn.allow_tf32 = False
|
||||||
# no need to assert on else
|
# no need to assert on else
|
||||||
|
|
||||||
if self.report_to is None:
|
if self.report_to is None:
|
||||||
|
|||||||
@@ -167,6 +167,7 @@ class Jukebox1bModelTester(unittest.TestCase):
|
|||||||
@slow
|
@slow
|
||||||
def test_conditioning(self):
|
def test_conditioning(self):
|
||||||
torch.backends.cuda.matmul.allow_tf32 = False
|
torch.backends.cuda.matmul.allow_tf32 = False
|
||||||
|
torch.backends.cudnn.allow_tf32 = False
|
||||||
model = JukeboxModel.from_pretrained(self.model_id, min_duration=0).eval()
|
model = JukeboxModel.from_pretrained(self.model_id, min_duration=0).eval()
|
||||||
|
|
||||||
labels = self.prepare_inputs()
|
labels = self.prepare_inputs()
|
||||||
@@ -195,6 +196,7 @@ class Jukebox1bModelTester(unittest.TestCase):
|
|||||||
@slow
|
@slow
|
||||||
def test_primed_sampling(self):
|
def test_primed_sampling(self):
|
||||||
torch.backends.cuda.matmul.allow_tf32 = False
|
torch.backends.cuda.matmul.allow_tf32 = False
|
||||||
|
torch.backends.cudnn.allow_tf32 = False
|
||||||
|
|
||||||
model = JukeboxModel.from_pretrained(self.model_id, min_duration=0).eval()
|
model = JukeboxModel.from_pretrained(self.model_id, min_duration=0).eval()
|
||||||
set_seed(0)
|
set_seed(0)
|
||||||
|
|||||||
Reference in New Issue
Block a user