[SpeechT5] Fix HiFiGAN tests (#21788)
This commit is contained in:
@@ -1551,9 +1551,13 @@ class SpeechT5HifiGanTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
batched_inputs = inputs["spectrogram"].unsqueeze(0).repeat(2, 1, 1)
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
batched_inputs = inputs["spectrogram"].unsqueeze(0).repeat(2, 1, 1)
|
||||||
|
with torch.no_grad():
|
||||||
|
batched_outputs = model(batched_inputs.to(torch_device))
|
||||||
|
|
||||||
batched_outputs = model(batched_inputs)
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
batched_inputs.shape[0], batched_outputs.shape[0], msg="Got different batch dims for input and output"
|
batched_inputs.shape[0], batched_outputs.shape[0], msg="Got different batch dims for input and output"
|
||||||
)
|
)
|
||||||
@@ -1563,5 +1567,9 @@ class SpeechT5HifiGanTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
outputs = model(inputs["spectrogram"])
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(inputs["spectrogram"].to(torch_device))
|
||||||
self.assertTrue(outputs.dim() == 1, msg="Got un-batched inputs but batched output")
|
self.assertTrue(outputs.dim() == 1, msg="Got un-batched inputs but batched output")
|
||||||
|
|||||||
Reference in New Issue
Block a user