Add early stopping for Bark generation via logits processor (#26675)

* add early stopping logits processor

* black formmated

* indent

* follow method signature

* actual logic

* check for None

* address comments on docstrings and method signature

* add unit test under `LogitsProcessorTest` wip

* unit test passing

* black formatted

* condition per sample

* add to BarkModelIntegrationTests

* wip BarkSemanticModelTest

* rename and add to kwargs handling

* not add to BarkSemanticModelTest

* correct logic and assert last outputs tokens different in test

* doc-builder style

* read from kwargs as well

* assert len of with less than that of without

* ruff

* add back seed and test case

* add original impl default suggestion

* doc-builder

* rename and use softmax

* switch back to LogitsProcessor and update docs wording

* camelCase and spelling and saving compute

* assert strictly less than

* assert less than

* expand test_generate_semantic_early_stop instead
This commit is contained in:
Isaac Chung
2023-10-27 13:07:33 +03:00
committed by GitHub
parent 90ee9cea19
commit e2bffcfafd
5 changed files with 125 additions and 12 deletions

View File

@@ -53,6 +53,7 @@ if is_torch_available():
TypicalLogitsWarper,
UnbatchedClassifierFreeGuidanceLogitsProcessor,
)
from transformers.generation.logits_process import BarkEosPrioritizerLogitsProcessor
@require_torch
@@ -800,3 +801,19 @@ class LogitsProcessorTest(unittest.TestCase):
self.assertAlmostEqual(out[0].item(), res[0].item())
self.assertAlmostEqual(out[1].item(), res[1].item())
self.assertAlmostEqual(out[2].item(), res[2].item())
def test_early_stop_processor(self):
input_ids = None
eos_token_id = 2
min_eos_p = 0.1 ## some small float
scores = self._get_uniform_logits(2, 4)
scores[0][eos_token_id] = -6 ## less than log(min_eos_p)
esp = BarkEosPrioritizerLogitsProcessor(eos_token_id=eos_token_id, min_eos_p=min_eos_p)
actual_scores = esp(input_ids, scores)
expected_scores_list = [
scores[0].tolist(),
[float("-inf"), float("-inf"), scores[0][0], float("-inf")],
]
self.assertListEqual(actual_scores.tolist(), expected_scores_list)