if output is tuple like facebook/hf-seamless-m4t-medium, waveform is … (#29722)
* if output is tuple like facebook/hf-seamless-m4t-medium, waveform is the first element Signed-off-by: Wang, Yi <yi.a.wang@intel.com> * add test and fix batch issue Signed-off-by: Wang, Yi <yi.a.wang@intel.com> * add dict output support for seamless_m4t Signed-off-by: Wang, Yi <yi.a.wang@intel.com> --------- Signed-off-by: Wang, Yi <yi.a.wang@intel.com>
This commit is contained in:
@@ -3496,7 +3496,6 @@ class SeamlessM4TForTextToSpeech(SeamlessM4TPreTrainedModel):
|
|||||||
self.device
|
self.device
|
||||||
)
|
)
|
||||||
kwargs_speech["decoder_input_ids"] = t2u_decoder_input_ids
|
kwargs_speech["decoder_input_ids"] = t2u_decoder_input_ids
|
||||||
|
|
||||||
# second generation
|
# second generation
|
||||||
unit_ids = self.t2u_model.generate(inputs_embeds=t2u_input_embeds, **kwargs_speech)
|
unit_ids = self.t2u_model.generate(inputs_embeds=t2u_input_embeds, **kwargs_speech)
|
||||||
output_unit_ids = unit_ids.detach().clone()
|
output_unit_ids = unit_ids.detach().clone()
|
||||||
|
|||||||
@@ -128,9 +128,12 @@ class PipelineIterator(IterableDataset):
|
|||||||
# Try to infer the size of the batch
|
# Try to infer the size of the batch
|
||||||
if isinstance(processed, torch.Tensor):
|
if isinstance(processed, torch.Tensor):
|
||||||
first_tensor = processed
|
first_tensor = processed
|
||||||
|
elif isinstance(processed, tuple):
|
||||||
|
first_tensor = processed[0]
|
||||||
else:
|
else:
|
||||||
key = list(processed.keys())[0]
|
key = list(processed.keys())[0]
|
||||||
first_tensor = processed[key]
|
first_tensor = processed[key]
|
||||||
|
|
||||||
if isinstance(first_tensor, list):
|
if isinstance(first_tensor, list):
|
||||||
observed_batch_size = len(first_tensor)
|
observed_batch_size = len(first_tensor)
|
||||||
else:
|
else:
|
||||||
@@ -140,7 +143,7 @@ class PipelineIterator(IterableDataset):
|
|||||||
# elements.
|
# elements.
|
||||||
self.loader_batch_size = observed_batch_size
|
self.loader_batch_size = observed_batch_size
|
||||||
# Setting internal index to unwrap the batch
|
# Setting internal index to unwrap the batch
|
||||||
self._loader_batch_data = processed
|
self._loader_batch_data = processed[0] if isinstance(processed, tuple) else processed
|
||||||
self._loader_batch_index = 0
|
self._loader_batch_index = 0
|
||||||
return self.loader_batch_item()
|
return self.loader_batch_item()
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -200,7 +200,10 @@ class TextToAudioPipeline(Pipeline):
|
|||||||
|
|
||||||
def postprocess(self, waveform):
|
def postprocess(self, waveform):
|
||||||
output_dict = {}
|
output_dict = {}
|
||||||
|
if isinstance(waveform, dict):
|
||||||
|
waveform = waveform["waveform"]
|
||||||
|
elif isinstance(waveform, tuple):
|
||||||
|
waveform = waveform[0]
|
||||||
output_dict["audio"] = waveform.cpu().float().numpy()
|
output_dict["audio"] = waveform.cpu().float().numpy()
|
||||||
output_dict["sampling_rate"] = self.sampling_rate
|
output_dict["sampling_rate"] = self.sampling_rate
|
||||||
|
|
||||||
|
|||||||
@@ -66,6 +66,27 @@ class TextToAudioPipelineTests(unittest.TestCase):
|
|||||||
audio = [output["audio"] for output in outputs]
|
audio = [output["audio"] for output in outputs]
|
||||||
self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio)
|
self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_torch
|
||||||
|
def test_medium_seamless_m4t_pt(self):
|
||||||
|
speech_generator = pipeline(task="text-to-audio", model="facebook/hf-seamless-m4t-medium", framework="pt")
|
||||||
|
|
||||||
|
for forward_params in [{"tgt_lang": "eng"}, {"return_intermediate_token_ids": True, "tgt_lang": "eng"}]:
|
||||||
|
outputs = speech_generator("This is a test", forward_params=forward_params)
|
||||||
|
self.assertEqual({"audio": ANY(np.ndarray), "sampling_rate": 16000}, outputs)
|
||||||
|
|
||||||
|
# test two examples side-by-side
|
||||||
|
outputs = speech_generator(["This is a test", "This is a second test"], forward_params=forward_params)
|
||||||
|
audio = [output["audio"] for output in outputs]
|
||||||
|
self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio)
|
||||||
|
|
||||||
|
# test batching
|
||||||
|
outputs = speech_generator(
|
||||||
|
["This is a test", "This is a second test"], forward_params=forward_params, batch_size=2
|
||||||
|
)
|
||||||
|
audio = [output["audio"] for output in outputs]
|
||||||
|
self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_small_bark_pt(self):
|
def test_small_bark_pt(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user