Add early stopping for Bark generation via logits processor (#26675)
* add early stopping logits processor * black formmated * indent * follow method signature * actual logic * check for None * address comments on docstrings and method signature * add unit test under `LogitsProcessorTest` wip * unit test passing * black formatted * condition per sample * add to BarkModelIntegrationTests * wip BarkSemanticModelTest * rename and add to kwargs handling * not add to BarkSemanticModelTest * correct logic and assert last outputs tokens different in test * doc-builder style * read from kwargs as well * assert len of with less than that of without * ruff * add back seed and test case * add original impl default suggestion * doc-builder * rename and use softmax * switch back to LogitsProcessor and update docs wording * camelCase and spelling and saving compute * assert strictly less than * assert less than * expand test_generate_semantic_early_stop instead
This commit is contained in:
@@ -53,6 +53,7 @@ if is_torch_available():
|
||||
TypicalLogitsWarper,
|
||||
UnbatchedClassifierFreeGuidanceLogitsProcessor,
|
||||
)
|
||||
from transformers.generation.logits_process import BarkEosPrioritizerLogitsProcessor
|
||||
|
||||
|
||||
@require_torch
|
||||
@@ -800,3 +801,19 @@ class LogitsProcessorTest(unittest.TestCase):
|
||||
self.assertAlmostEqual(out[0].item(), res[0].item())
|
||||
self.assertAlmostEqual(out[1].item(), res[1].item())
|
||||
self.assertAlmostEqual(out[2].item(), res[2].item())
|
||||
|
||||
def test_early_stop_processor(self):
|
||||
input_ids = None
|
||||
eos_token_id = 2
|
||||
min_eos_p = 0.1 ## some small float
|
||||
|
||||
scores = self._get_uniform_logits(2, 4)
|
||||
scores[0][eos_token_id] = -6 ## less than log(min_eos_p)
|
||||
|
||||
esp = BarkEosPrioritizerLogitsProcessor(eos_token_id=eos_token_id, min_eos_p=min_eos_p)
|
||||
actual_scores = esp(input_ids, scores)
|
||||
expected_scores_list = [
|
||||
scores[0].tolist(),
|
||||
[float("-inf"), float("-inf"), scores[0][0], float("-inf")],
|
||||
]
|
||||
self.assertListEqual(actual_scores.tolist(), expected_scores_list)
|
||||
|
||||
Reference in New Issue
Block a user