diff --git a/docs/source/en/perf_train_gpu_one.md b/docs/source/en/perf_train_gpu_one.md index b526d7a9e1..5ad0046f18 100644 --- a/docs/source/en/perf_train_gpu_one.md +++ b/docs/source/en/perf_train_gpu_one.md @@ -203,6 +203,7 @@ improvement. All you need to do is to add the following to your code: ``` import torch torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True ``` CUDA will automatically switch to using tf32 instead of fp32 where possible, assuming that the used GPU is from the Ampere series. diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 313caf47e9..b409a84bed 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1432,6 +1432,7 @@ class TrainingArguments: " otherwise." ) torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True else: logger.warning( "The speedups for torchdynamo mostly come wih GPU Ampere or higher and which is not detected here." @@ -1440,11 +1441,13 @@ class TrainingArguments: if self.tf32: if is_torch_tf32_available(): torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True else: raise ValueError("--tf32 requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7") else: if is_torch_tf32_available(): torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.allow_tf32 = False # no need to assert on else if self.report_to is None: diff --git a/tests/models/jukebox/test_modeling_jukebox.py b/tests/models/jukebox/test_modeling_jukebox.py index 93de887156..c0c78a25f8 100644 --- a/tests/models/jukebox/test_modeling_jukebox.py +++ b/tests/models/jukebox/test_modeling_jukebox.py @@ -167,6 +167,7 @@ class Jukebox1bModelTester(unittest.TestCase): @slow def test_conditioning(self): torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.allow_tf32 = False model = JukeboxModel.from_pretrained(self.model_id, min_duration=0).eval() labels = self.prepare_inputs() @@ -195,6 +196,7 @@ class Jukebox1bModelTester(unittest.TestCase): @slow def test_primed_sampling(self): torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.allow_tf32 = False model = JukeboxModel.from_pretrained(self.model_id, min_duration=0).eval() set_seed(0)