FIx Bark batching feature (#27271)
* fix bark batching * make style * add tests and make style
This commit is contained in:
@@ -1067,6 +1067,37 @@ class BarkModelIntegrationTests(unittest.TestCase):
|
||||
self.model.generate(**input_ids, do_sample=True, temperature=0.6, penalty_alpha=0.6)
|
||||
self.model.generate(**input_ids, do_sample=True, temperature=0.6, num_beams=4)
|
||||
|
||||
@slow
|
||||
def test_generate_batching(self):
|
||||
args = {"do_sample": False, "temperature": None}
|
||||
|
||||
s1 = "I love HuggingFace"
|
||||
s2 = "In the light of the moon, a little egg lay on a leaf"
|
||||
voice_preset = "en_speaker_6"
|
||||
input_ids = self.processor([s1, s2], voice_preset=voice_preset).to(torch_device)
|
||||
|
||||
# generate in batch
|
||||
outputs, audio_lengths = self.model.generate(**input_ids, **args, return_output_lengths=True)
|
||||
|
||||
# generate one-by-one
|
||||
s1 = self.processor(s1, voice_preset=voice_preset).to(torch_device)
|
||||
s2 = self.processor(s2, voice_preset=voice_preset).to(torch_device)
|
||||
output1 = self.model.generate(**s1, **args)
|
||||
output2 = self.model.generate(**s2, **args)
|
||||
|
||||
# up until the coarse acoustic model (included), results are the same
|
||||
# the fine acoustic model introduces small differences
|
||||
# first verify if same length (should be the same because it's decided in the coarse model)
|
||||
self.assertEqual(tuple(audio_lengths), (output1.shape[1], output2.shape[1]))
|
||||
|
||||
# then assert almost equal
|
||||
self.assertTrue(torch.allclose(outputs[0, : audio_lengths[0]], output1.squeeze(), atol=2e-3))
|
||||
self.assertTrue(torch.allclose(outputs[1, : audio_lengths[1]], output2.squeeze(), atol=2e-3))
|
||||
|
||||
# now test single input with return_output_lengths = True
|
||||
outputs, _ = self.model.generate(**s1, **args, return_output_lengths=True)
|
||||
self.assertTrue((outputs == output1).all().item())
|
||||
|
||||
@slow
|
||||
def test_generate_end_to_end_with_sub_models_args(self):
|
||||
input_ids = self.inputs
|
||||
|
||||
Reference in New Issue
Block a user