[SpeechT5] Fix HiFiGAN tests (#21788)

This commit is contained in:
Sanchit Gandhi
2023-02-24 16:55:38 +01:00
committed by GitHub
parent 59c1d5b96b
commit 3dae0d7b4f

View File

@@ -1551,9 +1551,13 @@ class SpeechT5HifiGanTest(ModelTesterMixin, unittest.TestCase):
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
model = model_class(config) model = model_class(config)
batched_inputs = inputs["spectrogram"].unsqueeze(0).repeat(2, 1, 1) model.to(torch_device)
model.eval()
batched_inputs = inputs["spectrogram"].unsqueeze(0).repeat(2, 1, 1)
with torch.no_grad():
batched_outputs = model(batched_inputs.to(torch_device))
batched_outputs = model(batched_inputs)
self.assertEqual( self.assertEqual(
batched_inputs.shape[0], batched_outputs.shape[0], msg="Got different batch dims for input and output" batched_inputs.shape[0], batched_outputs.shape[0], msg="Got different batch dims for input and output"
) )
@@ -1563,5 +1567,9 @@ class SpeechT5HifiGanTest(ModelTesterMixin, unittest.TestCase):
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
model = model_class(config) model = model_class(config)
outputs = model(inputs["spectrogram"]) model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(inputs["spectrogram"].to(torch_device))
self.assertTrue(outputs.dim() == 1, msg="Got un-batched inputs but batched output") self.assertTrue(outputs.dim() == 1, msg="Got un-batched inputs but batched output")