Set TF32 flag for PyTorch cuDNN backend (#25075)

This commit is contained in:
Xuehai Pan
2023-07-25 20:04:48 +08:00
committed by GitHub
parent 5dba88b2d2
commit 6bc61aa7af
3 changed files with 6 additions and 0 deletions

View File

@@ -167,6 +167,7 @@ class Jukebox1bModelTester(unittest.TestCase):
@slow
def test_conditioning(self):
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
model = JukeboxModel.from_pretrained(self.model_id, min_duration=0).eval()
labels = self.prepare_inputs()
@@ -195,6 +196,7 @@ class Jukebox1bModelTester(unittest.TestCase):
@slow
def test_primed_sampling(self):
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
model = JukeboxModel.from_pretrained(self.model_id, min_duration=0).eval()
set_seed(0)