FIx Bark batching feature (#27271)

* fix bark batching

* make style

* add tests and make style
This commit is contained in:
Yoach Lacombe
2023-11-07 18:32:00 +00:00
committed by GitHub
parent 8f840edd31
commit ac5d4cf6de
2 changed files with 82 additions and 14 deletions

View File

@@ -1067,6 +1067,37 @@ class BarkModelIntegrationTests(unittest.TestCase):
self.model.generate(**input_ids, do_sample=True, temperature=0.6, penalty_alpha=0.6)
self.model.generate(**input_ids, do_sample=True, temperature=0.6, num_beams=4)
@slow
def test_generate_batching(self):
args = {"do_sample": False, "temperature": None}
s1 = "I love HuggingFace"
s2 = "In the light of the moon, a little egg lay on a leaf"
voice_preset = "en_speaker_6"
input_ids = self.processor([s1, s2], voice_preset=voice_preset).to(torch_device)
# generate in batch
outputs, audio_lengths = self.model.generate(**input_ids, **args, return_output_lengths=True)
# generate one-by-one
s1 = self.processor(s1, voice_preset=voice_preset).to(torch_device)
s2 = self.processor(s2, voice_preset=voice_preset).to(torch_device)
output1 = self.model.generate(**s1, **args)
output2 = self.model.generate(**s2, **args)
# up until the coarse acoustic model (included), results are the same
# the fine acoustic model introduces small differences
# first verify if same length (should be the same because it's decided in the coarse model)
self.assertEqual(tuple(audio_lengths), (output1.shape[1], output2.shape[1]))
# then assert almost equal
self.assertTrue(torch.allclose(outputs[0, : audio_lengths[0]], output1.squeeze(), atol=2e-3))
self.assertTrue(torch.allclose(outputs[1, : audio_lengths[1]], output2.squeeze(), atol=2e-3))
# now test single input with return_output_lengths = True
outputs, _ = self.model.generate(**s1, **args, return_output_lengths=True)
self.assertTrue((outputs == output1).all().item())
@slow
def test_generate_end_to_end_with_sub_models_args(self):
input_ids = self.inputs