[SpeechT5HifiGan] Handle batched inputs (#21702)
* [SpeechT5HifiGan] Handle batched inputs * fix docstring * rebase and new ruff style
This commit is contained in:
@@ -1545,3 +1545,23 @@ class SpeechT5HifiGanTest(ModelTesterMixin, unittest.TestCase):
|
||||
# skip because it fails on automapping of SpeechT5HifiGanConfig
|
||||
def test_save_load_fast_init_to_base(self):
|
||||
pass
|
||||
|
||||
def test_batched_inputs_outputs(self):
|
||||
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
batched_inputs = inputs["spectrogram"].unsqueeze(0).repeat(2, 1, 1)
|
||||
|
||||
batched_outputs = model(batched_inputs)
|
||||
self.assertEqual(
|
||||
batched_inputs.shape[0], batched_outputs.shape[0], msg="Got different batch dims for input and output"
|
||||
)
|
||||
|
||||
def test_unbatched_inputs_outputs(self):
|
||||
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
outputs = model(inputs["spectrogram"])
|
||||
self.assertTrue(outputs.dim() == 1, msg="Got un-batched inputs but batched output")
|
||||
|
||||
Reference in New Issue
Block a user