Add early stopping for Bark generation via logits processor (#26675)
* add early stopping logits processor * black formmated * indent * follow method signature * actual logic * check for None * address comments on docstrings and method signature * add unit test under `LogitsProcessorTest` wip * unit test passing * black formatted * condition per sample * add to BarkModelIntegrationTests * wip BarkSemanticModelTest * rename and add to kwargs handling * not add to BarkSemanticModelTest * correct logic and assert last outputs tokens different in test * doc-builder style * read from kwargs as well * assert len of with less than that of without * ruff * add back seed and test case * add original impl default suggestion * doc-builder * rename and use softmax * switch back to LogitsProcessor and update docs wording * camelCase and spelling and saving compute * assert strictly less than * assert less than * expand test_generate_semantic_early_stop instead
This commit is contained in:
@@ -917,7 +917,51 @@ class BarkModelIntegrationTests(unittest.TestCase):
|
||||
temperature=1.0,
|
||||
semantic_generation_config=self.semantic_generation_config,
|
||||
)
|
||||
self.assertListEqual(output_ids[0, : len(expected_output_ids)].tolist(), expected_output_ids)
|
||||
|
||||
@slow
|
||||
def test_generate_semantic_early_stop(self):
|
||||
input_ids = self.inputs
|
||||
min_eos_p = 0.01
|
||||
|
||||
# fmt: off
|
||||
# check first ids
|
||||
expected_output_ids = [7363, 321, 41, 1461, 6915, 952, 326, 41, 41, 927,]
|
||||
# fmt: on
|
||||
|
||||
# Should be able to read min_eos_p from kwargs
|
||||
with torch.no_grad():
|
||||
torch.manual_seed(0)
|
||||
output_ids_without_min_eos_p = self.model.semantic.generate(
|
||||
**input_ids,
|
||||
do_sample=False,
|
||||
temperature=0.9,
|
||||
semantic_generation_config=self.semantic_generation_config,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
output_ids_kwargs = self.model.semantic.generate(
|
||||
**input_ids,
|
||||
do_sample=False,
|
||||
temperature=0.9,
|
||||
semantic_generation_config=self.semantic_generation_config,
|
||||
min_eos_p=min_eos_p,
|
||||
)
|
||||
self.assertListEqual(output_ids_without_min_eos_p[0, : len(expected_output_ids)].tolist(), expected_output_ids)
|
||||
self.assertLess(len(output_ids_kwargs[0, :].tolist()), len(output_ids_without_min_eos_p[0, :].tolist()))
|
||||
|
||||
# Should be able to read min_eos_p from the semantic generation config
|
||||
self.semantic_generation_config.min_eos_p = min_eos_p
|
||||
with torch.no_grad():
|
||||
torch.manual_seed(0)
|
||||
output_ids = self.model.semantic.generate(
|
||||
**input_ids,
|
||||
do_sample=False,
|
||||
temperature=0.9,
|
||||
semantic_generation_config=self.semantic_generation_config,
|
||||
)
|
||||
|
||||
self.assertEqual(output_ids.shape, output_ids_kwargs.shape)
|
||||
self.assertLess(len(output_ids[0, :].tolist()), len(output_ids_without_min_eos_p[0, :].tolist()))
|
||||
self.assertListEqual(output_ids[0, : len(expected_output_ids)].tolist(), expected_output_ids)
|
||||
|
||||
@slow
|
||||
@@ -1022,26 +1066,30 @@ class BarkModelIntegrationTests(unittest.TestCase):
|
||||
input_ids = self.inputs
|
||||
|
||||
with torch.no_grad():
|
||||
torch.manual_seed(0)
|
||||
self.model.generate(
|
||||
**input_ids, do_sample=False, temperature=1.0, coarse_do_sample=True, coarse_temperature=0.7
|
||||
)
|
||||
self.model.generate(
|
||||
output_ids_without_min_eos_p = self.model.generate(
|
||||
**input_ids,
|
||||
do_sample=False,
|
||||
temperature=1.0,
|
||||
do_sample=True,
|
||||
temperature=0.9,
|
||||
coarse_do_sample=True,
|
||||
coarse_temperature=0.7,
|
||||
fine_temperature=0.3,
|
||||
)
|
||||
self.model.generate(
|
||||
|
||||
output_ids_with_min_eos_p = self.model.generate(
|
||||
**input_ids,
|
||||
do_sample=True,
|
||||
temperature=0.6,
|
||||
penalty_alpha=0.6,
|
||||
semantic_temperature=0.9,
|
||||
coarse_temperature=0.2,
|
||||
fine_temperature=0.1,
|
||||
temperature=0.9,
|
||||
coarse_temperature=0.7,
|
||||
fine_temperature=0.3,
|
||||
min_eos_p=0.1,
|
||||
)
|
||||
self.assertLess(
|
||||
len(output_ids_with_min_eos_p[0, :].tolist()), len(output_ids_without_min_eos_p[0, :].tolist())
|
||||
)
|
||||
|
||||
@require_torch_gpu
|
||||
@slow
|
||||
|
||||
Reference in New Issue
Block a user