[SpeechT5HifiGan] Handle batched inputs (#21702)
* [SpeechT5HifiGan] Handle batched inputs * fix docstring * rebase and new ruff style
This commit is contained in:
@@ -3030,19 +3030,27 @@ class SpeechT5HifiGan(PreTrainedModel):
|
||||
|
||||
def forward(self, spectrogram):
|
||||
r"""
|
||||
Converts a single log-mel spectogram into a speech waveform.
|
||||
Converts a log-mel spectrogram into a speech waveform. Passing a batch of log-mel spectrograms returns a batch
|
||||
of speech waveforms. Passing a single, un-batched log-mel spectrogram returns a single, un-batched speech
|
||||
waveform.
|
||||
|
||||
Args:
|
||||
spectrogram (`torch.FloatTensor` of shape `(sequence_length, config.model_in_dim)`):
|
||||
Tensor containing the log-mel spectrogram.
|
||||
spectrogram (`torch.FloatTensor`):
|
||||
Tensor containing the log-mel spectrograms. Can be batched and of shape `(batch_size, sequence_length,
|
||||
config.model_in_dim)`, or un-batched and of shape `(sequence_length, config.model_in_dim)`.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: Tensor of shape `(num_frames,)` containing the speech waveform.
|
||||
`torch.FloatTensor`: Tensor containing the speech waveform. If the input spectrogram is batched, will be of
|
||||
shape `(batch_size, num_frames,)`. If un-batched, will be of shape `(num_frames,)`.
|
||||
"""
|
||||
if self.config.normalize_before:
|
||||
spectrogram = (spectrogram - self.mean) / self.scale
|
||||
|
||||
hidden_states = spectrogram.transpose(1, 0).unsqueeze(0)
|
||||
is_batched = spectrogram.dim() == 3
|
||||
if not is_batched:
|
||||
spectrogram = spectrogram.unsqueeze(0)
|
||||
|
||||
hidden_states = spectrogram.transpose(2, 1)
|
||||
|
||||
hidden_states = self.conv_pre(hidden_states)
|
||||
for i in range(self.num_upsamples):
|
||||
@@ -3058,5 +3066,11 @@ class SpeechT5HifiGan(PreTrainedModel):
|
||||
hidden_states = self.conv_post(hidden_states)
|
||||
hidden_states = torch.tanh(hidden_states)
|
||||
|
||||
if not is_batched:
|
||||
# remove batch dim and collapse tensor to 1-d audio waveform
|
||||
waveform = hidden_states.squeeze(0).transpose(1, 0).view(-1)
|
||||
else:
|
||||
# remove seq-len dim since this collapses to 1
|
||||
waveform = hidden_states.squeeze(1)
|
||||
|
||||
return waveform
|
||||
|
||||
@@ -1545,3 +1545,23 @@ class SpeechT5HifiGanTest(ModelTesterMixin, unittest.TestCase):
|
||||
# skip because it fails on automapping of SpeechT5HifiGanConfig
|
||||
def test_save_load_fast_init_to_base(self):
|
||||
pass
|
||||
|
||||
def test_batched_inputs_outputs(self):
|
||||
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
batched_inputs = inputs["spectrogram"].unsqueeze(0).repeat(2, 1, 1)
|
||||
|
||||
batched_outputs = model(batched_inputs)
|
||||
self.assertEqual(
|
||||
batched_inputs.shape[0], batched_outputs.shape[0], msg="Got different batch dims for input and output"
|
||||
)
|
||||
|
||||
def test_unbatched_inputs_outputs(self):
|
||||
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
outputs = model(inputs["spectrogram"])
|
||||
self.assertTrue(outputs.dim() == 1, msg="Got un-batched inputs but batched output")
|
||||
|
||||
Reference in New Issue
Block a user