Add early stopping for Bark generation via logits processor (#26675)
* add early stopping logits processor * black formmated * indent * follow method signature * actual logic * check for None * address comments on docstrings and method signature * add unit test under `LogitsProcessorTest` wip * unit test passing * black formatted * condition per sample * add to BarkModelIntegrationTests * wip BarkSemanticModelTest * rename and add to kwargs handling * not add to BarkSemanticModelTest * correct logic and assert last outputs tokens different in test * doc-builder style * read from kwargs as well * assert len of with less than that of without * ruff * add back seed and test case * add original impl default suggestion * doc-builder * rename and use softmax * switch back to LogitsProcessor and update docs wording * camelCase and spelling and saving compute * assert strictly less than * assert less than * expand test_generate_semantic_early_stop instead
This commit is contained in:
@@ -1749,3 +1749,35 @@ class UnbatchedClassifierFreeGuidanceLogitsProcessor(LogitsProcessor):
|
|||||||
unconditional_logits = torch.nn.functional.log_softmax(logits[:, -1], dim=-1)
|
unconditional_logits = torch.nn.functional.log_softmax(logits[:, -1], dim=-1)
|
||||||
out = self.guidance_scale * (scores - unconditional_logits) + unconditional_logits
|
out = self.guidance_scale * (scores - unconditional_logits) + unconditional_logits
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class BarkEosPrioritizerLogitsProcessor(LogitsProcessor):
|
||||||
|
r"""This processor ensures that the EOS token is selected if its probability is greater than the `min_eos_p`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
eos_token_id (`Union[int, List[int]]`):
|
||||||
|
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
|
||||||
|
min_eos_p (`float`, *optional*):
|
||||||
|
Minimum end of speech threshold.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, eos_token_id: Union[int, List[int]], min_eos_p: float):
|
||||||
|
if isinstance(eos_token_id, int):
|
||||||
|
eos_token_id = [eos_token_id]
|
||||||
|
self.eos_token_id = eos_token_id
|
||||||
|
if min_eos_p is not None and min_eos_p <= 0:
|
||||||
|
raise ValueError(f"`min_eos_p` has to be a positive float, but is {min_eos_p}")
|
||||||
|
self.min_eos_p = min_eos_p
|
||||||
|
|
||||||
|
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||||
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
|
if self.min_eos_p:
|
||||||
|
probs = torch.nn.functional.softmax(scores.float(), dim=-1)
|
||||||
|
# create scores full of -inf except for the eos_token_id
|
||||||
|
early_stop_scores = torch.ones_like(scores) * -float("inf")
|
||||||
|
early_stop_scores[:, self.eos_token_id] = scores[:, self.eos_token_id]
|
||||||
|
|
||||||
|
do_early_stop = probs[:, self.eos_token_id] > self.min_eos_p
|
||||||
|
scores = torch.where(do_early_stop, early_stop_scores, scores)
|
||||||
|
|
||||||
|
return scores
|
||||||
|
|||||||
@@ -44,6 +44,7 @@ class BarkSemanticGenerationConfig(GenerationConfig):
|
|||||||
semantic_vocab_size=10_000,
|
semantic_vocab_size=10_000,
|
||||||
max_input_semantic_length=256,
|
max_input_semantic_length=256,
|
||||||
semantic_rate_hz=49.9,
|
semantic_rate_hz=49.9,
|
||||||
|
min_eos_p=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""Class that holds a generation configuration for [`BarkSemanticModel`].
|
"""Class that holds a generation configuration for [`BarkSemanticModel`].
|
||||||
@@ -86,6 +87,10 @@ class BarkSemanticGenerationConfig(GenerationConfig):
|
|||||||
Max length of semantic input vector.
|
Max length of semantic input vector.
|
||||||
semantic_rate_hz (`float`, *optional*, defaults to 49.9):
|
semantic_rate_hz (`float`, *optional*, defaults to 49.9):
|
||||||
Semantic rate in Hertz.
|
Semantic rate in Hertz.
|
||||||
|
min_eos_p (`float`, *optional*):
|
||||||
|
Minimum threshold of the probability of the EOS token for it to be sampled. This is an early stopping
|
||||||
|
strategy to mitigate potential unwanted generations at the end of a prompt. The original implementation
|
||||||
|
suggests a default value of 0.2.
|
||||||
"""
|
"""
|
||||||
super().__init__(
|
super().__init__(
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
@@ -107,6 +112,7 @@ class BarkSemanticGenerationConfig(GenerationConfig):
|
|||||||
self.semantic_vocab_size = semantic_vocab_size
|
self.semantic_vocab_size = semantic_vocab_size
|
||||||
self.max_input_semantic_length = max_input_semantic_length
|
self.max_input_semantic_length = max_input_semantic_length
|
||||||
self.semantic_rate_hz = semantic_rate_hz
|
self.semantic_rate_hz = semantic_rate_hz
|
||||||
|
self.min_eos_p = min_eos_p
|
||||||
|
|
||||||
|
|
||||||
class BarkCoarseGenerationConfig(GenerationConfig):
|
class BarkCoarseGenerationConfig(GenerationConfig):
|
||||||
|
|||||||
@@ -21,7 +21,11 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
from ...generation.logits_process import AlternatingCodebooksLogitsProcessor, SuppressTokensLogitsProcessor
|
from ...generation.logits_process import (
|
||||||
|
AlternatingCodebooksLogitsProcessor,
|
||||||
|
BarkEosPrioritizerLogitsProcessor,
|
||||||
|
SuppressTokensLogitsProcessor,
|
||||||
|
)
|
||||||
from ...modeling_outputs import CausalLMOutputWithPast, MaskedLMOutput
|
from ...modeling_outputs import CausalLMOutputWithPast, MaskedLMOutput
|
||||||
from ...modeling_utils import PreTrainedModel, get_parameter_device
|
from ...modeling_utils import PreTrainedModel, get_parameter_device
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
@@ -798,12 +802,17 @@ class BarkSemanticModel(BarkCausalModel):
|
|||||||
|
|
||||||
suppress_tokens_logits_processor = SuppressTokensLogitsProcessor(tokens_to_suppress)
|
suppress_tokens_logits_processor = SuppressTokensLogitsProcessor(tokens_to_suppress)
|
||||||
|
|
||||||
|
min_eos_p = kwargs.get("min_eos_p", semantic_generation_config.min_eos_p)
|
||||||
|
early_stopping_logits_processor = BarkEosPrioritizerLogitsProcessor(
|
||||||
|
eos_token_id=semantic_generation_config.eos_token_id, min_eos_p=min_eos_p
|
||||||
|
)
|
||||||
|
|
||||||
# pass input_ids in order to stay consistent with the transformers generate method even though it is not used
|
# pass input_ids in order to stay consistent with the transformers generate method even though it is not used
|
||||||
# (except to get the input seq_len - that's why we keep the first 257 tokens)
|
# (except to get the input seq_len - that's why we keep the first 257 tokens)
|
||||||
semantic_output = super().generate(
|
semantic_output = super().generate(
|
||||||
torch.ones((batch_size, max_input_semantic_length + 1), dtype=torch.int).to(self.device),
|
torch.ones((batch_size, max_input_semantic_length + 1), dtype=torch.int).to(self.device),
|
||||||
input_embeds=input_embeds,
|
input_embeds=input_embeds,
|
||||||
logits_processor=[suppress_tokens_logits_processor],
|
logits_processor=[suppress_tokens_logits_processor, early_stopping_logits_processor],
|
||||||
generation_config=semantic_generation_config,
|
generation_config=semantic_generation_config,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) # size: 10048
|
) # size: 10048
|
||||||
@@ -1559,7 +1568,8 @@ class BarkModel(BarkPreTrainedModel):
|
|||||||
|
|
||||||
kwargs_semantic = {
|
kwargs_semantic = {
|
||||||
# if "attention_mask" is set, it should not be passed to CoarseModel and FineModel
|
# if "attention_mask" is set, it should not be passed to CoarseModel and FineModel
|
||||||
"attention_mask": kwargs.pop("attention_mask", None)
|
"attention_mask": kwargs.pop("attention_mask", None),
|
||||||
|
"min_eos_p": kwargs.pop("min_eos_p", None),
|
||||||
}
|
}
|
||||||
kwargs_coarse = {}
|
kwargs_coarse = {}
|
||||||
kwargs_fine = {}
|
kwargs_fine = {}
|
||||||
|
|||||||
@@ -53,6 +53,7 @@ if is_torch_available():
|
|||||||
TypicalLogitsWarper,
|
TypicalLogitsWarper,
|
||||||
UnbatchedClassifierFreeGuidanceLogitsProcessor,
|
UnbatchedClassifierFreeGuidanceLogitsProcessor,
|
||||||
)
|
)
|
||||||
|
from transformers.generation.logits_process import BarkEosPrioritizerLogitsProcessor
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@@ -800,3 +801,19 @@ class LogitsProcessorTest(unittest.TestCase):
|
|||||||
self.assertAlmostEqual(out[0].item(), res[0].item())
|
self.assertAlmostEqual(out[0].item(), res[0].item())
|
||||||
self.assertAlmostEqual(out[1].item(), res[1].item())
|
self.assertAlmostEqual(out[1].item(), res[1].item())
|
||||||
self.assertAlmostEqual(out[2].item(), res[2].item())
|
self.assertAlmostEqual(out[2].item(), res[2].item())
|
||||||
|
|
||||||
|
def test_early_stop_processor(self):
|
||||||
|
input_ids = None
|
||||||
|
eos_token_id = 2
|
||||||
|
min_eos_p = 0.1 ## some small float
|
||||||
|
|
||||||
|
scores = self._get_uniform_logits(2, 4)
|
||||||
|
scores[0][eos_token_id] = -6 ## less than log(min_eos_p)
|
||||||
|
|
||||||
|
esp = BarkEosPrioritizerLogitsProcessor(eos_token_id=eos_token_id, min_eos_p=min_eos_p)
|
||||||
|
actual_scores = esp(input_ids, scores)
|
||||||
|
expected_scores_list = [
|
||||||
|
scores[0].tolist(),
|
||||||
|
[float("-inf"), float("-inf"), scores[0][0], float("-inf")],
|
||||||
|
]
|
||||||
|
self.assertListEqual(actual_scores.tolist(), expected_scores_list)
|
||||||
|
|||||||
@@ -917,7 +917,51 @@ class BarkModelIntegrationTests(unittest.TestCase):
|
|||||||
temperature=1.0,
|
temperature=1.0,
|
||||||
semantic_generation_config=self.semantic_generation_config,
|
semantic_generation_config=self.semantic_generation_config,
|
||||||
)
|
)
|
||||||
|
self.assertListEqual(output_ids[0, : len(expected_output_ids)].tolist(), expected_output_ids)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_generate_semantic_early_stop(self):
|
||||||
|
input_ids = self.inputs
|
||||||
|
min_eos_p = 0.01
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
# check first ids
|
||||||
|
expected_output_ids = [7363, 321, 41, 1461, 6915, 952, 326, 41, 41, 927,]
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
# Should be able to read min_eos_p from kwargs
|
||||||
|
with torch.no_grad():
|
||||||
|
torch.manual_seed(0)
|
||||||
|
output_ids_without_min_eos_p = self.model.semantic.generate(
|
||||||
|
**input_ids,
|
||||||
|
do_sample=False,
|
||||||
|
temperature=0.9,
|
||||||
|
semantic_generation_config=self.semantic_generation_config,
|
||||||
|
)
|
||||||
|
torch.manual_seed(0)
|
||||||
|
output_ids_kwargs = self.model.semantic.generate(
|
||||||
|
**input_ids,
|
||||||
|
do_sample=False,
|
||||||
|
temperature=0.9,
|
||||||
|
semantic_generation_config=self.semantic_generation_config,
|
||||||
|
min_eos_p=min_eos_p,
|
||||||
|
)
|
||||||
|
self.assertListEqual(output_ids_without_min_eos_p[0, : len(expected_output_ids)].tolist(), expected_output_ids)
|
||||||
|
self.assertLess(len(output_ids_kwargs[0, :].tolist()), len(output_ids_without_min_eos_p[0, :].tolist()))
|
||||||
|
|
||||||
|
# Should be able to read min_eos_p from the semantic generation config
|
||||||
|
self.semantic_generation_config.min_eos_p = min_eos_p
|
||||||
|
with torch.no_grad():
|
||||||
|
torch.manual_seed(0)
|
||||||
|
output_ids = self.model.semantic.generate(
|
||||||
|
**input_ids,
|
||||||
|
do_sample=False,
|
||||||
|
temperature=0.9,
|
||||||
|
semantic_generation_config=self.semantic_generation_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(output_ids.shape, output_ids_kwargs.shape)
|
||||||
|
self.assertLess(len(output_ids[0, :].tolist()), len(output_ids_without_min_eos_p[0, :].tolist()))
|
||||||
self.assertListEqual(output_ids[0, : len(expected_output_ids)].tolist(), expected_output_ids)
|
self.assertListEqual(output_ids[0, : len(expected_output_ids)].tolist(), expected_output_ids)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@@ -1022,25 +1066,29 @@ class BarkModelIntegrationTests(unittest.TestCase):
|
|||||||
input_ids = self.inputs
|
input_ids = self.inputs
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
torch.manual_seed(0)
|
||||||
self.model.generate(
|
self.model.generate(
|
||||||
**input_ids, do_sample=False, temperature=1.0, coarse_do_sample=True, coarse_temperature=0.7
|
**input_ids, do_sample=False, temperature=1.0, coarse_do_sample=True, coarse_temperature=0.7
|
||||||
)
|
)
|
||||||
self.model.generate(
|
output_ids_without_min_eos_p = self.model.generate(
|
||||||
**input_ids,
|
**input_ids,
|
||||||
do_sample=False,
|
do_sample=True,
|
||||||
temperature=1.0,
|
temperature=0.9,
|
||||||
coarse_do_sample=True,
|
coarse_do_sample=True,
|
||||||
coarse_temperature=0.7,
|
coarse_temperature=0.7,
|
||||||
fine_temperature=0.3,
|
fine_temperature=0.3,
|
||||||
)
|
)
|
||||||
self.model.generate(
|
|
||||||
|
output_ids_with_min_eos_p = self.model.generate(
|
||||||
**input_ids,
|
**input_ids,
|
||||||
do_sample=True,
|
do_sample=True,
|
||||||
temperature=0.6,
|
temperature=0.9,
|
||||||
penalty_alpha=0.6,
|
coarse_temperature=0.7,
|
||||||
semantic_temperature=0.9,
|
fine_temperature=0.3,
|
||||||
coarse_temperature=0.2,
|
min_eos_p=0.1,
|
||||||
fine_temperature=0.1,
|
)
|
||||||
|
self.assertLess(
|
||||||
|
len(output_ids_with_min_eos_p[0, :].tolist()), len(output_ids_without_min_eos_p[0, :].tolist())
|
||||||
)
|
)
|
||||||
|
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
|
|||||||
Reference in New Issue
Block a user