Set TF32 flag for PyTorch cuDNN backend (#25075)
This commit is contained in:
@@ -167,6 +167,7 @@ class Jukebox1bModelTester(unittest.TestCase):
|
||||
@slow
|
||||
def test_conditioning(self):
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
torch.backends.cudnn.allow_tf32 = False
|
||||
model = JukeboxModel.from_pretrained(self.model_id, min_duration=0).eval()
|
||||
|
||||
labels = self.prepare_inputs()
|
||||
@@ -195,6 +196,7 @@ class Jukebox1bModelTester(unittest.TestCase):
|
||||
@slow
|
||||
def test_primed_sampling(self):
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
torch.backends.cudnn.allow_tf32 = False
|
||||
|
||||
model = JukeboxModel.from_pretrained(self.model_id, min_duration=0).eval()
|
||||
set_seed(0)
|
||||
|
||||
Reference in New Issue
Block a user