From 82e61f34451dbea2de8d2220d51b0609d605dfd7 Mon Sep 17 00:00:00 2001 From: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Date: Wed, 22 Feb 2023 11:16:56 +0100 Subject: [PATCH] [SpeechT5HifiGan] Handle batched inputs (#21702) * [SpeechT5HifiGan] Handle batched inputs * fix docstring * rebase and new ruff style --- .../models/speecht5/modeling_speecht5.py | 26 ++++++++++++++----- .../models/speecht5/test_modeling_speecht5.py | 20 ++++++++++++++ 2 files changed, 40 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/speecht5/modeling_speecht5.py b/src/transformers/models/speecht5/modeling_speecht5.py index 86f7b81ca0..61910c345e 100644 --- a/src/transformers/models/speecht5/modeling_speecht5.py +++ b/src/transformers/models/speecht5/modeling_speecht5.py @@ -3030,19 +3030,27 @@ class SpeechT5HifiGan(PreTrainedModel): def forward(self, spectrogram): r""" - Converts a single log-mel spectogram into a speech waveform. + Converts a log-mel spectrogram into a speech waveform. Passing a batch of log-mel spectrograms returns a batch + of speech waveforms. Passing a single, un-batched log-mel spectrogram returns a single, un-batched speech + waveform. Args: - spectrogram (`torch.FloatTensor` of shape `(sequence_length, config.model_in_dim)`): - Tensor containing the log-mel spectrogram. + spectrogram (`torch.FloatTensor`): + Tensor containing the log-mel spectrograms. Can be batched and of shape `(batch_size, sequence_length, + config.model_in_dim)`, or un-batched and of shape `(sequence_length, config.model_in_dim)`. Returns: - `torch.FloatTensor`: Tensor of shape `(num_frames,)` containing the speech waveform. + `torch.FloatTensor`: Tensor containing the speech waveform. If the input spectrogram is batched, will be of + shape `(batch_size, num_frames,)`. If un-batched, will be of shape `(num_frames,)`. """ if self.config.normalize_before: spectrogram = (spectrogram - self.mean) / self.scale - hidden_states = spectrogram.transpose(1, 0).unsqueeze(0) + is_batched = spectrogram.dim() == 3 + if not is_batched: + spectrogram = spectrogram.unsqueeze(0) + + hidden_states = spectrogram.transpose(2, 1) hidden_states = self.conv_pre(hidden_states) for i in range(self.num_upsamples): @@ -3058,5 +3066,11 @@ class SpeechT5HifiGan(PreTrainedModel): hidden_states = self.conv_post(hidden_states) hidden_states = torch.tanh(hidden_states) - waveform = hidden_states.squeeze(0).transpose(1, 0).view(-1) + if not is_batched: + # remove batch dim and collapse tensor to 1-d audio waveform + waveform = hidden_states.squeeze(0).transpose(1, 0).view(-1) + else: + # remove seq-len dim since this collapses to 1 + waveform = hidden_states.squeeze(1) + return waveform diff --git a/tests/models/speecht5/test_modeling_speecht5.py b/tests/models/speecht5/test_modeling_speecht5.py index a8dd0ec7c1..d8f0332ae1 100644 --- a/tests/models/speecht5/test_modeling_speecht5.py +++ b/tests/models/speecht5/test_modeling_speecht5.py @@ -1545,3 +1545,23 @@ class SpeechT5HifiGanTest(ModelTesterMixin, unittest.TestCase): # skip because it fails on automapping of SpeechT5HifiGanConfig def test_save_load_fast_init_to_base(self): pass + + def test_batched_inputs_outputs(self): + config, inputs = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + batched_inputs = inputs["spectrogram"].unsqueeze(0).repeat(2, 1, 1) + + batched_outputs = model(batched_inputs) + self.assertEqual( + batched_inputs.shape[0], batched_outputs.shape[0], msg="Got different batch dims for input and output" + ) + + def test_unbatched_inputs_outputs(self): + config, inputs = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + outputs = model(inputs["spectrogram"]) + self.assertTrue(outputs.dim() == 1, msg="Got un-batched inputs but batched output")