if output is tuple like facebook/hf-seamless-m4t-medium, waveform is … (#29722)

* if output is tuple like facebook/hf-seamless-m4t-medium, waveform is the first element

Signed-off-by: Wang, Yi <yi.a.wang@intel.com>

* add test and fix batch issue

Signed-off-by: Wang, Yi <yi.a.wang@intel.com>

* add dict output support for seamless_m4t

Signed-off-by: Wang, Yi <yi.a.wang@intel.com>

---------

Signed-off-by: Wang, Yi <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi
2024-04-05 15:26:44 +08:00
committed by GitHub
parent 8b52fa6b42
commit 79d62b2da2
4 changed files with 29 additions and 3 deletions

View File

@@ -3496,7 +3496,6 @@ class SeamlessM4TForTextToSpeech(SeamlessM4TPreTrainedModel):
self.device self.device
) )
kwargs_speech["decoder_input_ids"] = t2u_decoder_input_ids kwargs_speech["decoder_input_ids"] = t2u_decoder_input_ids
# second generation # second generation
unit_ids = self.t2u_model.generate(inputs_embeds=t2u_input_embeds, **kwargs_speech) unit_ids = self.t2u_model.generate(inputs_embeds=t2u_input_embeds, **kwargs_speech)
output_unit_ids = unit_ids.detach().clone() output_unit_ids = unit_ids.detach().clone()

View File

@@ -128,9 +128,12 @@ class PipelineIterator(IterableDataset):
# Try to infer the size of the batch # Try to infer the size of the batch
if isinstance(processed, torch.Tensor): if isinstance(processed, torch.Tensor):
first_tensor = processed first_tensor = processed
elif isinstance(processed, tuple):
first_tensor = processed[0]
else: else:
key = list(processed.keys())[0] key = list(processed.keys())[0]
first_tensor = processed[key] first_tensor = processed[key]
if isinstance(first_tensor, list): if isinstance(first_tensor, list):
observed_batch_size = len(first_tensor) observed_batch_size = len(first_tensor)
else: else:
@@ -140,7 +143,7 @@ class PipelineIterator(IterableDataset):
# elements. # elements.
self.loader_batch_size = observed_batch_size self.loader_batch_size = observed_batch_size
# Setting internal index to unwrap the batch # Setting internal index to unwrap the batch
self._loader_batch_data = processed self._loader_batch_data = processed[0] if isinstance(processed, tuple) else processed
self._loader_batch_index = 0 self._loader_batch_index = 0
return self.loader_batch_item() return self.loader_batch_item()
else: else:

View File

@@ -200,7 +200,10 @@ class TextToAudioPipeline(Pipeline):
def postprocess(self, waveform): def postprocess(self, waveform):
output_dict = {} output_dict = {}
if isinstance(waveform, dict):
waveform = waveform["waveform"]
elif isinstance(waveform, tuple):
waveform = waveform[0]
output_dict["audio"] = waveform.cpu().float().numpy() output_dict["audio"] = waveform.cpu().float().numpy()
output_dict["sampling_rate"] = self.sampling_rate output_dict["sampling_rate"] = self.sampling_rate

View File

@@ -66,6 +66,27 @@ class TextToAudioPipelineTests(unittest.TestCase):
audio = [output["audio"] for output in outputs] audio = [output["audio"] for output in outputs]
self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio) self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio)
@slow
@require_torch
def test_medium_seamless_m4t_pt(self):
speech_generator = pipeline(task="text-to-audio", model="facebook/hf-seamless-m4t-medium", framework="pt")
for forward_params in [{"tgt_lang": "eng"}, {"return_intermediate_token_ids": True, "tgt_lang": "eng"}]:
outputs = speech_generator("This is a test", forward_params=forward_params)
self.assertEqual({"audio": ANY(np.ndarray), "sampling_rate": 16000}, outputs)
# test two examples side-by-side
outputs = speech_generator(["This is a test", "This is a second test"], forward_params=forward_params)
audio = [output["audio"] for output in outputs]
self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio)
# test batching
outputs = speech_generator(
["This is a test", "This is a second test"], forward_params=forward_params, batch_size=2
)
audio = [output["audio"] for output in outputs]
self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio)
@slow @slow
@require_torch @require_torch
def test_small_bark_pt(self): def test_small_bark_pt(self):