[SpeechT5HifiGan] Handle batched inputs (#21702)

* [SpeechT5HifiGan] Handle batched inputs

* fix docstring

* rebase and new ruff style
This commit is contained in:
Sanchit Gandhi
2023-02-22 11:16:56 +01:00
committed by GitHub
parent 09127c5713
commit 82e61f3445
2 changed files with 40 additions and 6 deletions

View File

@@ -1545,3 +1545,23 @@ class SpeechT5HifiGanTest(ModelTesterMixin, unittest.TestCase):
# skip because it fails on automapping of SpeechT5HifiGanConfig
def test_save_load_fast_init_to_base(self):
pass
def test_batched_inputs_outputs(self):
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
batched_inputs = inputs["spectrogram"].unsqueeze(0).repeat(2, 1, 1)
batched_outputs = model(batched_inputs)
self.assertEqual(
batched_inputs.shape[0], batched_outputs.shape[0], msg="Got different batch dims for input and output"
)
def test_unbatched_inputs_outputs(self):
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
outputs = model(inputs["spectrogram"])
self.assertTrue(outputs.dim() == 1, msg="Got un-batched inputs but batched output")