[Assistant Generation] Improve Encoder Decoder (#26701)
* [Assistant Generation] Improve enc dec * save more * Fix logit processor checks * Clean * make style * fix deprecation * fix generation test * Apply suggestions from code review * fix biogpt * make style
This commit is contained in:
committed by
GitHub
parent
5334796d20
commit
da69de17e8
@@ -227,6 +227,20 @@ class GenerationConfig(PushToHubMixin):
|
||||
decoder_start_token_id (`int`, *optional*):
|
||||
If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token.
|
||||
|
||||
> Generation parameters exclusive to [assistant generation](https://arxiv.org/abs/2211.17192)
|
||||
|
||||
num_assistant_tokens (`int`, *optional*, defaults to 5):
|
||||
Defines the number of _speculative tokens_ that shall be generated by the assistant model before being
|
||||
checked by the target model at each iteration. Higher values for `num_assistant_tokens` make the generation
|
||||
more _speculative_ : If the assistant model is performant larger speed-ups can be reached, if the assistant
|
||||
model requires lots of corrections, lower speed-ups are reached.
|
||||
|
||||
num_assistant_tokens_schedule (`str`, *optional*, defaults to `"heuristic"`):
|
||||
Defines the schedule at which max assistant tokens shall be changed during inference.
|
||||
- `"_heuristic_`: When all _speculative_ tokens are correct, increase `num_assistant_tokens` by 2 else
|
||||
reduce by 1
|
||||
- `"constant"`: `num_assistant_tokens` stays unchanged during generation
|
||||
|
||||
> Wild card
|
||||
|
||||
generation_kwargs:
|
||||
@@ -294,6 +308,10 @@ class GenerationConfig(PushToHubMixin):
|
||||
self.encoder_no_repeat_ngram_size = kwargs.pop("encoder_no_repeat_ngram_size", 0)
|
||||
self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None)
|
||||
|
||||
# Assistant generation
|
||||
self.num_assistant_tokens = kwargs.pop("num_assistant_tokens", 5)
|
||||
self.num_assistant_tokens_schedule = kwargs.pop("num_assistant_tokens_schedule", "heuristic")
|
||||
|
||||
# Wild card
|
||||
self.generation_kwargs = kwargs.pop("generation_kwargs", {})
|
||||
|
||||
|
||||
@@ -1241,6 +1241,10 @@ class GenerationMixin:
|
||||
decoder_model_args = set(inspect.signature(decoder.forward).parameters)
|
||||
model_args |= {f"decoder_{x}" for x in decoder_model_args}
|
||||
|
||||
# allow assistant_encoder_outputs to be passed if we're doing assisted generating
|
||||
if "assistant_encoder_outputs" in model_kwargs:
|
||||
model_args |= {"assistant_encoder_outputs"}
|
||||
|
||||
for key, value in model_kwargs.items():
|
||||
if value is not None and key not in model_args:
|
||||
unused_model_args.append(key)
|
||||
@@ -1612,7 +1616,7 @@ class GenerationMixin:
|
||||
raise ValueError("assisted generate requires `use_cache=True`")
|
||||
|
||||
# 11. If the assistant model is an encoder-decoder, prepare its encoder outputs
|
||||
if assistant_model.config.is_encoder_decoder:
|
||||
if assistant_model.config.is_encoder_decoder and "assistant_encoder_outputs" not in model_kwargs:
|
||||
assistant_model_kwargs = copy.deepcopy(model_kwargs)
|
||||
inputs_tensor, model_input_name, assistant_model_kwargs = assistant_model._prepare_model_inputs(
|
||||
inputs_tensor, assistant_model.generation_config.bos_token_id, assistant_model_kwargs
|
||||
@@ -4347,8 +4351,14 @@ class GenerationMixin:
|
||||
["It might be possible to get a better understanding of the nature of the problem, but it's not"]
|
||||
```"""
|
||||
# Assistant: initialize assistant-related variables
|
||||
if not hasattr(assistant_model, "max_assistant_tokens"):
|
||||
assistant_model.max_assistant_tokens = 5 # this value, which will be updated, persists across calls
|
||||
if hasattr(assistant_model, "num_assistant_tokens"):
|
||||
warnings.warn(
|
||||
"Setting `num_assistant_tokens` via `assistant_model.num_assistant_tokens` is deprecated and will be removed in v.37. Make sure to set `num_assistant_tokens` via the generation_config instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
num_assistant_tokens = assistant_model.num_assistant_tokens
|
||||
else:
|
||||
num_assistant_tokens = assistant_model.generation_config.num_assistant_tokens
|
||||
|
||||
# init values
|
||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||
@@ -4421,26 +4431,23 @@ class GenerationMixin:
|
||||
# `.generate()` call if we decide to add `past_key_values` as a possible output of generate, as we
|
||||
# need access to the assistant cache to secure strong speedups.
|
||||
candidate_input_ids = input_ids
|
||||
for _ in range(int(assistant_model.max_assistant_tokens)):
|
||||
for _ in range(int(num_assistant_tokens)):
|
||||
# 1.1. use the assistant model to obtain the next candidate logits
|
||||
if "assistant_past_key_values" in model_kwargs:
|
||||
prev_seq_len = model_kwargs["assistant_past_key_values"][0][assistant_kv_indexing].shape[-2]
|
||||
# `new_token_len` can be 1 or 2 (next token in assistant + last token picked by the larger model)
|
||||
new_token_len = candidate_input_ids.shape[1] - prev_seq_len
|
||||
assist_inputs = candidate_input_ids[:, -new_token_len:]
|
||||
assist_attn = torch.ones_like(candidate_input_ids)
|
||||
# TODO (joao): make it compatible with models that use unconventional fwd pass logic, like blip2
|
||||
if assistant_model.config.is_encoder_decoder:
|
||||
assistant_model_outputs = assistant_model(
|
||||
decoder_input_ids=assist_inputs,
|
||||
decoder_attention_mask=assist_attn,
|
||||
past_key_values=model_kwargs["assistant_past_key_values"],
|
||||
encoder_outputs=model_kwargs["assistant_encoder_outputs"],
|
||||
)
|
||||
else:
|
||||
assistant_model_outputs = assistant_model(
|
||||
assist_inputs,
|
||||
attention_mask=assist_attn,
|
||||
past_key_values=model_kwargs["assistant_past_key_values"],
|
||||
)
|
||||
else:
|
||||
@@ -4495,18 +4502,18 @@ class GenerationMixin:
|
||||
# 2.3. Process the new logits
|
||||
new_logits = outputs.logits[:, -candidate_length - 1 :] # excludes the input prompt if present
|
||||
if len(logits_processor) > 0:
|
||||
for i in range(candidate_length):
|
||||
for i in range(candidate_length + 1):
|
||||
new_logits[:, i, :] = logits_processor(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :])
|
||||
if len(logits_warper) > 0:
|
||||
for i in range(candidate_length):
|
||||
for i in range(candidate_length + 1):
|
||||
new_logits[:, i, :] = logits_warper(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :])
|
||||
|
||||
# 3. Obtain the next tokens from the original model logits.
|
||||
if do_sample:
|
||||
probs = new_logits[:, -candidate_length - 1 :, :].softmax(dim=-1)
|
||||
probs = new_logits.softmax(dim=-1)
|
||||
selected_tokens = torch.multinomial(probs[0, :, :], num_samples=1).squeeze(1)[None, :]
|
||||
else:
|
||||
selected_tokens = new_logits[:, -candidate_length - 1 :, :].argmax(dim=-1)
|
||||
selected_tokens = new_logits.argmax(dim=-1)
|
||||
|
||||
# 4. Compare the argmax from the original model logits with the assistant forecasted tokens. We can keep
|
||||
# the assistant forecasted tokens until the first mismatch, or until the max length is reached.
|
||||
@@ -4540,13 +4547,13 @@ class GenerationMixin:
|
||||
# 6. Adjust the max number of assistant tokens to use in the next iteration. This is a simple heuristic,
|
||||
# probably can be improved -- we want to balance the benefits of getting assistant tokens correct with the
|
||||
# cost of forecasting incorrect assistant tokens.
|
||||
if n_matches == int(assistant_model.max_assistant_tokens):
|
||||
assistant_model.max_assistant_tokens += 2.0
|
||||
if assistant_model.generation_config.num_assistant_tokens_schedule == "heuristic":
|
||||
if n_matches == int(num_assistant_tokens):
|
||||
num_assistant_tokens += 2.0
|
||||
else:
|
||||
assistant_model.max_assistant_tokens = max(1.0, assistant_model.max_assistant_tokens - 1.0)
|
||||
num_assistant_tokens = max(1.0, num_assistant_tokens - 1.0)
|
||||
|
||||
# Assistant: main logic end
|
||||
|
||||
if synced_gpus and this_peer_finished:
|
||||
continue # don't waste resources running the code we don't need
|
||||
|
||||
|
||||
@@ -544,7 +544,11 @@ class BioGptModel(BioGptPreTrainedModel):
|
||||
inputs_embeds = self.embed_tokens(input) * self.embed_scale
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(inputs_embeds.shape[:2], dtype=torch.bool, device=inputs_embeds.device)
|
||||
attention_mask = torch.ones(
|
||||
(inputs_embeds.shape[0], inputs_embeds.shape[1] + past_key_values_length),
|
||||
dtype=torch.bool,
|
||||
device=inputs_embeds.device,
|
||||
)
|
||||
elif attention_mask.shape[1] != past_key_values_length + input_shape[1]:
|
||||
raise ValueError(
|
||||
f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be "
|
||||
|
||||
@@ -2953,7 +2953,8 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
|
||||
return outs
|
||||
|
||||
def prepare_inputs_for_generation(self, *args, foo=False, **kwargs):
|
||||
def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None, **kwargs):
|
||||
kwargs["encoder_outputs"] = encoder_outputs
|
||||
inputs = super().prepare_inputs_for_generation(*args, **kwargs)
|
||||
|
||||
inputs["foo"] = foo
|
||||
@@ -2992,3 +2993,14 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
assistant_model=assistant,
|
||||
)
|
||||
self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist())
|
||||
|
||||
# Check that passing encoder_outputs directly also works as expected
|
||||
encoder_outputs = assistant.get_encoder()(input_ids)
|
||||
|
||||
outputs_assisted = model.generate(
|
||||
foo=True,
|
||||
assistant_model=assistant,
|
||||
encoder_outputs=encoder_outputs,
|
||||
assistant_encoder_outputs=encoder_outputs,
|
||||
)
|
||||
self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist())
|
||||
|
||||
Reference in New Issue
Block a user