From 79d62b2da227b39619afa7f3a86d8aeb95e0f4fa Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Fri, 5 Apr 2024 15:26:44 +0800 Subject: [PATCH] =?UTF-8?q?if=20output=20is=20tuple=20like=20facebook/hf-s?= =?UTF-8?q?eamless-m4t-medium,=20waveform=20is=20=E2=80=A6=20(#29722)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * if output is tuple like facebook/hf-seamless-m4t-medium, waveform is the first element Signed-off-by: Wang, Yi * add test and fix batch issue Signed-off-by: Wang, Yi * add dict output support for seamless_m4t Signed-off-by: Wang, Yi --------- Signed-off-by: Wang, Yi --- .../seamless_m4t/modeling_seamless_m4t.py | 1 - src/transformers/pipelines/pt_utils.py | 5 ++++- src/transformers/pipelines/text_to_audio.py | 5 ++++- .../pipelines/test_pipelines_text_to_audio.py | 21 +++++++++++++++++++ 4 files changed, 29 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py index f619dd9e79..c0fe60a643 100755 --- a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py +++ b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py @@ -3496,7 +3496,6 @@ class SeamlessM4TForTextToSpeech(SeamlessM4TPreTrainedModel): self.device ) kwargs_speech["decoder_input_ids"] = t2u_decoder_input_ids - # second generation unit_ids = self.t2u_model.generate(inputs_embeds=t2u_input_embeds, **kwargs_speech) output_unit_ids = unit_ids.detach().clone() diff --git a/src/transformers/pipelines/pt_utils.py b/src/transformers/pipelines/pt_utils.py index c39f906f64..652d1eb544 100644 --- a/src/transformers/pipelines/pt_utils.py +++ b/src/transformers/pipelines/pt_utils.py @@ -128,9 +128,12 @@ class PipelineIterator(IterableDataset): # Try to infer the size of the batch if isinstance(processed, torch.Tensor): first_tensor = processed + elif isinstance(processed, tuple): + first_tensor = processed[0] else: key = list(processed.keys())[0] first_tensor = processed[key] + if isinstance(first_tensor, list): observed_batch_size = len(first_tensor) else: @@ -140,7 +143,7 @@ class PipelineIterator(IterableDataset): # elements. self.loader_batch_size = observed_batch_size # Setting internal index to unwrap the batch - self._loader_batch_data = processed + self._loader_batch_data = processed[0] if isinstance(processed, tuple) else processed self._loader_batch_index = 0 return self.loader_batch_item() else: diff --git a/src/transformers/pipelines/text_to_audio.py b/src/transformers/pipelines/text_to_audio.py index 58c21cc121..81653f14d6 100644 --- a/src/transformers/pipelines/text_to_audio.py +++ b/src/transformers/pipelines/text_to_audio.py @@ -200,7 +200,10 @@ class TextToAudioPipeline(Pipeline): def postprocess(self, waveform): output_dict = {} - + if isinstance(waveform, dict): + waveform = waveform["waveform"] + elif isinstance(waveform, tuple): + waveform = waveform[0] output_dict["audio"] = waveform.cpu().float().numpy() output_dict["sampling_rate"] = self.sampling_rate diff --git a/tests/pipelines/test_pipelines_text_to_audio.py b/tests/pipelines/test_pipelines_text_to_audio.py index a9f1eccae5..b780d26d79 100644 --- a/tests/pipelines/test_pipelines_text_to_audio.py +++ b/tests/pipelines/test_pipelines_text_to_audio.py @@ -66,6 +66,27 @@ class TextToAudioPipelineTests(unittest.TestCase): audio = [output["audio"] for output in outputs] self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio) + @slow + @require_torch + def test_medium_seamless_m4t_pt(self): + speech_generator = pipeline(task="text-to-audio", model="facebook/hf-seamless-m4t-medium", framework="pt") + + for forward_params in [{"tgt_lang": "eng"}, {"return_intermediate_token_ids": True, "tgt_lang": "eng"}]: + outputs = speech_generator("This is a test", forward_params=forward_params) + self.assertEqual({"audio": ANY(np.ndarray), "sampling_rate": 16000}, outputs) + + # test two examples side-by-side + outputs = speech_generator(["This is a test", "This is a second test"], forward_params=forward_params) + audio = [output["audio"] for output in outputs] + self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio) + + # test batching + outputs = speech_generator( + ["This is a test", "This is a second test"], forward_params=forward_params, batch_size=2 + ) + audio = [output["audio"] for output in outputs] + self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio) + @slow @require_torch def test_small_bark_pt(self):