Add early stopping for Bark generation via logits processor (#26675)

* add early stopping logits processor

* black formmated

* indent

* follow method signature

* actual logic

* check for None

* address comments on docstrings and method signature

* add unit test under `LogitsProcessorTest` wip

* unit test passing

* black formatted

* condition per sample

* add to BarkModelIntegrationTests

* wip BarkSemanticModelTest

* rename and add to kwargs handling

* not add to BarkSemanticModelTest

* correct logic and assert last outputs tokens different in test

* doc-builder style

* read from kwargs as well

* assert len of with less than that of without

* ruff

* add back seed and test case

* add original impl default suggestion

* doc-builder

* rename and use softmax

* switch back to LogitsProcessor and update docs wording

* camelCase and spelling and saving compute

* assert strictly less than

* assert less than

* expand test_generate_semantic_early_stop instead
This commit is contained in:
Isaac Chung
2023-10-27 13:07:33 +03:00
committed by GitHub
parent 90ee9cea19
commit e2bffcfafd
5 changed files with 125 additions and 12 deletions

View File

@@ -917,7 +917,51 @@ class BarkModelIntegrationTests(unittest.TestCase):
temperature=1.0,
semantic_generation_config=self.semantic_generation_config,
)
self.assertListEqual(output_ids[0, : len(expected_output_ids)].tolist(), expected_output_ids)
@slow
def test_generate_semantic_early_stop(self):
input_ids = self.inputs
min_eos_p = 0.01
# fmt: off
# check first ids
expected_output_ids = [7363, 321, 41, 1461, 6915, 952, 326, 41, 41, 927,]
# fmt: on
# Should be able to read min_eos_p from kwargs
with torch.no_grad():
torch.manual_seed(0)
output_ids_without_min_eos_p = self.model.semantic.generate(
**input_ids,
do_sample=False,
temperature=0.9,
semantic_generation_config=self.semantic_generation_config,
)
torch.manual_seed(0)
output_ids_kwargs = self.model.semantic.generate(
**input_ids,
do_sample=False,
temperature=0.9,
semantic_generation_config=self.semantic_generation_config,
min_eos_p=min_eos_p,
)
self.assertListEqual(output_ids_without_min_eos_p[0, : len(expected_output_ids)].tolist(), expected_output_ids)
self.assertLess(len(output_ids_kwargs[0, :].tolist()), len(output_ids_without_min_eos_p[0, :].tolist()))
# Should be able to read min_eos_p from the semantic generation config
self.semantic_generation_config.min_eos_p = min_eos_p
with torch.no_grad():
torch.manual_seed(0)
output_ids = self.model.semantic.generate(
**input_ids,
do_sample=False,
temperature=0.9,
semantic_generation_config=self.semantic_generation_config,
)
self.assertEqual(output_ids.shape, output_ids_kwargs.shape)
self.assertLess(len(output_ids[0, :].tolist()), len(output_ids_without_min_eos_p[0, :].tolist()))
self.assertListEqual(output_ids[0, : len(expected_output_ids)].tolist(), expected_output_ids)
@slow
@@ -1022,26 +1066,30 @@ class BarkModelIntegrationTests(unittest.TestCase):
input_ids = self.inputs
with torch.no_grad():
torch.manual_seed(0)
self.model.generate(
**input_ids, do_sample=False, temperature=1.0, coarse_do_sample=True, coarse_temperature=0.7
)
self.model.generate(
output_ids_without_min_eos_p = self.model.generate(
**input_ids,
do_sample=False,
temperature=1.0,
do_sample=True,
temperature=0.9,
coarse_do_sample=True,
coarse_temperature=0.7,
fine_temperature=0.3,
)
self.model.generate(
output_ids_with_min_eos_p = self.model.generate(
**input_ids,
do_sample=True,
temperature=0.6,
penalty_alpha=0.6,
semantic_temperature=0.9,
coarse_temperature=0.2,
fine_temperature=0.1,
temperature=0.9,
coarse_temperature=0.7,
fine_temperature=0.3,
min_eos_p=0.1,
)
self.assertLess(
len(output_ids_with_min_eos_p[0, :].tolist()), len(output_ids_without_min_eos_p[0, :].tolist())
)
@require_torch_gpu
@slow