Generate: prepare assisted generation for release (#23052)
This commit is contained in:
@@ -4177,7 +4177,6 @@ class GenerationMixin:
|
|||||||
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||||
["It might be possible to get a better understanding of the nature of the problem, but it's not"]
|
["It might be possible to get a better understanding of the nature of the problem, but it's not"]
|
||||||
```"""
|
```"""
|
||||||
# NOTE: the code here is copy/paste from greedy search/sample, except when clearly stated in the comments
|
|
||||||
# Assistant: initialize assistant-related variables
|
# Assistant: initialize assistant-related variables
|
||||||
if not hasattr(assistant_model, "max_assistant_tokens"):
|
if not hasattr(assistant_model, "max_assistant_tokens"):
|
||||||
assistant_model.max_assistant_tokens = 5 # this value, which will be updated, persists across calls
|
assistant_model.max_assistant_tokens = 5 # this value, which will be updated, persists across calls
|
||||||
@@ -4248,20 +4247,20 @@ class GenerationMixin:
|
|||||||
prev_seq_len = model_kwargs["assistant_past_key_values"][0][0].shape[2]
|
prev_seq_len = model_kwargs["assistant_past_key_values"][0][0].shape[2]
|
||||||
# `new_token_len` can be 1 or 2 (next token in assistant + last token picked by the larger model)
|
# `new_token_len` can be 1 or 2 (next token in assistant + last token picked by the larger model)
|
||||||
new_token_len = candidate_input_ids.shape[1] - prev_seq_len
|
new_token_len = candidate_input_ids.shape[1] - prev_seq_len
|
||||||
tmp_inputs = candidate_input_ids[:, -new_token_len:]
|
assist_inputs = candidate_input_ids[:, -new_token_len:]
|
||||||
tmp_attn = torch.ones_like(candidate_input_ids)
|
assist_attn = torch.ones_like(candidate_input_ids)
|
||||||
# TODO (joao): make it compatible with models that use unconventional fwd pass logic, like blip2
|
# TODO (joao): make it compatible with models that use unconventional fwd pass logic, like blip2
|
||||||
if assistant_model.config.is_encoder_decoder:
|
if assistant_model.config.is_encoder_decoder:
|
||||||
assistant_model_outputs = assistant_model(
|
assistant_model_outputs = assistant_model(
|
||||||
decoder_input_ids=tmp_inputs,
|
decoder_input_ids=assist_inputs,
|
||||||
decoder_attention_mask=tmp_attn,
|
decoder_attention_mask=assist_attn,
|
||||||
past_key_values=model_kwargs["assistant_past_key_values"],
|
past_key_values=model_kwargs["assistant_past_key_values"],
|
||||||
encoder_outputs=model_kwargs["assistant_encoder_outputs"],
|
encoder_outputs=model_kwargs["assistant_encoder_outputs"],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assistant_model_outputs = assistant_model(
|
assistant_model_outputs = assistant_model(
|
||||||
tmp_inputs,
|
assist_inputs,
|
||||||
attention_mask=tmp_attn,
|
attention_mask=assist_attn,
|
||||||
past_key_values=model_kwargs["assistant_past_key_values"],
|
past_key_values=model_kwargs["assistant_past_key_values"],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -4296,16 +4295,17 @@ class GenerationMixin:
|
|||||||
candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1]
|
candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1]
|
||||||
|
|
||||||
# 2. Use the original model to obtain the next token logits given the candidate sequence. We obtain
|
# 2. Use the original model to obtain the next token logits given the candidate sequence. We obtain
|
||||||
# `candidate_length + 1` relevant logits from this process (see step 7 on why the +1)
|
# `candidate_length + 1` relevant logits from this process: in the event that all candidates are correct,
|
||||||
|
# we use this forward pass to also pick the subsequent logits in the original model.
|
||||||
|
|
||||||
# 2.1. Run a forward pass on the candidate sequence
|
# 2.1. Run a forward pass on the candidate sequence
|
||||||
if "past_key_values" in model_kwargs:
|
if "past_key_values" in model_kwargs:
|
||||||
og_model_attn = torch.ones_like(candidate_input_ids)
|
model_attn = torch.ones_like(candidate_input_ids)
|
||||||
og_model_input_ids = candidate_input_ids[:, -candidate_length - 1 :]
|
model_input_ids = candidate_input_ids[:, -candidate_length - 1 :]
|
||||||
if self.config.is_encoder_decoder:
|
if self.config.is_encoder_decoder:
|
||||||
outputs = self(
|
outputs = self(
|
||||||
decoder_input_ids=og_model_input_ids,
|
decoder_input_ids=model_input_ids,
|
||||||
decoder_attention_mask=og_model_attn,
|
decoder_attention_mask=model_attn,
|
||||||
past_key_values=model_kwargs["past_key_values"],
|
past_key_values=model_kwargs["past_key_values"],
|
||||||
encoder_outputs=model_kwargs["encoder_outputs"],
|
encoder_outputs=model_kwargs["encoder_outputs"],
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
@@ -4313,8 +4313,8 @@ class GenerationMixin:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
outputs = self(
|
outputs = self(
|
||||||
og_model_input_ids,
|
model_input_ids,
|
||||||
attention_mask=og_model_attn,
|
attention_mask=model_attn,
|
||||||
past_key_values=model_kwargs["past_key_values"],
|
past_key_values=model_kwargs["past_key_values"],
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
@@ -4343,21 +4343,43 @@ class GenerationMixin:
|
|||||||
for i in range(candidate_length):
|
for i in range(candidate_length):
|
||||||
new_logits[:, i, :] = logits_warper(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :])
|
new_logits[:, i, :] = logits_warper(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :])
|
||||||
|
|
||||||
# 3. Obtain the next tokens from the original model logits. If `do_sample` is True, use multinomial
|
# 3. Obtain the next tokens from the original model logits.
|
||||||
# sampling, otherwise use argmax.
|
|
||||||
if do_sample:
|
if do_sample:
|
||||||
probs = new_logits[:, -candidate_length - 1 :, :].softmax(dim=-1)
|
probs = new_logits[:, -candidate_length - 1 :, :].softmax(dim=-1)
|
||||||
sampled_tokens = torch.multinomial(probs[0, :, :], num_samples=1).squeeze(1)[None, :]
|
selected_tokens = torch.multinomial(probs[0, :, :], num_samples=1).squeeze(1)[None, :]
|
||||||
next_tokens = sampled_tokens[:, :-1]
|
|
||||||
else:
|
else:
|
||||||
next_tokens = new_logits[:, -candidate_length - 1 : -1, :].argmax(dim=-1)
|
selected_tokens = new_logits[:, -candidate_length - 1 :, :].argmax(dim=-1)
|
||||||
|
|
||||||
# 4. Compare the argmax from the original model logits with the assistant forecasted tokens. We can keep
|
# 4. Compare the argmax from the original model logits with the assistant forecasted tokens. We can keep
|
||||||
# the assistant forecasted tokens until the first mismatch, or until the max length is reached.
|
# the assistant forecasted tokens until the first mismatch, or until the max length is reached.
|
||||||
candidate_new_tokens = candidate_input_ids[:, -candidate_length:]
|
candidate_new_tokens = candidate_input_ids[:, -candidate_length:]
|
||||||
n_matches = ((~(candidate_new_tokens == next_tokens)).cumsum(dim=-1) < 1).sum()
|
n_matches = ((~(candidate_new_tokens == selected_tokens[:, :-1])).cumsum(dim=-1) < 1).sum()
|
||||||
|
|
||||||
# 5. Adjust the max number of assistant tokens to use in the next iteration. This is a simple heuristic,
|
# 5. Update variables according to the number of matching assistant tokens. Remember: the token generated
|
||||||
|
# by the model after the last candidate match is also valid, as it is generated from a correct sequence.
|
||||||
|
# Because of this last token, assisted generation search reduces to a normal greedy search/sample if there
|
||||||
|
# is no match.
|
||||||
|
|
||||||
|
# 5.1. Ensure we don't generate beyond max_len or an EOS token
|
||||||
|
if last_assistant_token_is_eos and n_matches == candidate_length:
|
||||||
|
n_matches -= 1
|
||||||
|
n_matches = min(n_matches, max_len - cur_len - 1)
|
||||||
|
|
||||||
|
# 5.2. Get the valid continuation, after the matching tokens
|
||||||
|
valid_tokens = selected_tokens[:, : n_matches + 1]
|
||||||
|
input_ids = torch.cat((input_ids, valid_tokens), dim=-1)
|
||||||
|
if streamer is not None:
|
||||||
|
streamer.put(valid_tokens.cpu())
|
||||||
|
new_cur_len = input_ids.shape[-1]
|
||||||
|
|
||||||
|
# 5.3. Discard past key values relative to unused assistant tokens
|
||||||
|
new_cache_size = new_cur_len - 1
|
||||||
|
outputs.past_key_values = _crop_past_key_values(self, outputs.past_key_values, new_cache_size)
|
||||||
|
model_kwargs["assistant_past_key_values"] = _crop_past_key_values(
|
||||||
|
assistant_model, model_kwargs["assistant_past_key_values"], new_cache_size - 1
|
||||||
|
) # the assistant does not have the token after the last match, hence the -1
|
||||||
|
|
||||||
|
# 6. Adjust the max number of assistant tokens to use in the next iteration. This is a simple heuristic,
|
||||||
# probably can be improved -- we want to balance the benefits of getting assistant tokens correct with the
|
# probably can be improved -- we want to balance the benefits of getting assistant tokens correct with the
|
||||||
# cost of forecasting incorrect assistant tokens.
|
# cost of forecasting incorrect assistant tokens.
|
||||||
if n_matches == int(assistant_model.max_assistant_tokens):
|
if n_matches == int(assistant_model.max_assistant_tokens):
|
||||||
@@ -4365,37 +4387,7 @@ class GenerationMixin:
|
|||||||
else:
|
else:
|
||||||
assistant_model.max_assistant_tokens = max(1.0, assistant_model.max_assistant_tokens - 1.0)
|
assistant_model.max_assistant_tokens = max(1.0, assistant_model.max_assistant_tokens - 1.0)
|
||||||
|
|
||||||
# 6. Update variables according to the number of matching assistant tokens.
|
# Assistant: main logic end
|
||||||
# 6.1. Ensure we don't generate beyond max_len or an EOS token (remember: one token will be added below)
|
|
||||||
n_matches = min(n_matches, max_len - cur_len - 1)
|
|
||||||
if last_assistant_token_is_eos and n_matches == candidate_length:
|
|
||||||
n_matches -= 1
|
|
||||||
input_ids = candidate_input_ids[:, 0 : cur_len + n_matches]
|
|
||||||
new_cur_len = input_ids.shape[-1]
|
|
||||||
if streamer is not None:
|
|
||||||
streamer.put(candidate_input_ids[:, cur_len : cur_len + n_matches])
|
|
||||||
|
|
||||||
# 6.2. Discard past key values relative to unused assistant tokens
|
|
||||||
outputs.past_key_values = _crop_past_key_values(self, outputs.past_key_values, new_cur_len)
|
|
||||||
model_kwargs["assistant_past_key_values"] = _crop_past_key_values(
|
|
||||||
assistant_model, model_kwargs["assistant_past_key_values"], new_cur_len
|
|
||||||
)
|
|
||||||
|
|
||||||
# 6.3. Extract the logits for the next token
|
|
||||||
next_token_scores = new_logits[:, n_matches, :]
|
|
||||||
|
|
||||||
# 7. Use the set of logits after the last matching assistant token to obtain the next token. Note that,
|
|
||||||
# because of this step, assisted generation search reduces to a normal greedy search/sample if there is no
|
|
||||||
# match.
|
|
||||||
if do_sample:
|
|
||||||
probs = probs[:, n_matches, :]
|
|
||||||
next_tokens = sampled_tokens[:, n_matches]
|
|
||||||
else:
|
|
||||||
next_tokens = torch.argmax(next_token_scores, dim=-1)
|
|
||||||
|
|
||||||
# Assistant: main logic end; Compared to greedy search/sample, the following (redundant) blocks were
|
|
||||||
# removed below: (1) model input preparation; (2) model forward pass; (3) score preparation; (4) model
|
|
||||||
# cache update.
|
|
||||||
|
|
||||||
if synced_gpus and this_peer_finished:
|
if synced_gpus and this_peer_finished:
|
||||||
continue # don't waste resources running the code we don't need
|
continue # don't waste resources running the code we don't need
|
||||||
@@ -4407,20 +4399,20 @@ class GenerationMixin:
|
|||||||
scores += tuple(new_logits[:, i, :] for i in range(n_matches + 1))
|
scores += tuple(new_logits[:, i, :] for i in range(n_matches + 1))
|
||||||
|
|
||||||
if "past_key_values" not in model_kwargs:
|
if "past_key_values" not in model_kwargs:
|
||||||
last_matching_idx = new_cur_len - 1
|
added_len = new_cur_len
|
||||||
else:
|
else:
|
||||||
last_matching_idx = n_matches
|
added_len = n_matches + 1
|
||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
if self.config.is_encoder_decoder:
|
if self.config.is_encoder_decoder:
|
||||||
cross_attentions = _split_model_outputs(
|
cross_attentions = _split_model_outputs(
|
||||||
cross_attentions, outputs.cross_attentions, cur_len, last_matching_idx
|
cross_attentions, outputs.cross_attentions, cur_len, added_len
|
||||||
)
|
)
|
||||||
decoder_attentions = _split_model_outputs(
|
decoder_attentions = _split_model_outputs(
|
||||||
decoder_attentions,
|
decoder_attentions,
|
||||||
outputs.decoder_attentions,
|
outputs.decoder_attentions,
|
||||||
cur_len,
|
cur_len,
|
||||||
last_matching_idx,
|
added_len,
|
||||||
is_decoder_attention=True,
|
is_decoder_attention=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -4428,28 +4420,19 @@ class GenerationMixin:
|
|||||||
decoder_attentions,
|
decoder_attentions,
|
||||||
outputs.attentions,
|
outputs.attentions,
|
||||||
cur_len,
|
cur_len,
|
||||||
last_matching_idx,
|
added_len,
|
||||||
is_decoder_attention=True,
|
is_decoder_attention=True,
|
||||||
)
|
)
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
if self.config.is_encoder_decoder:
|
if self.config.is_encoder_decoder:
|
||||||
decoder_hidden_states = _split_model_outputs(
|
decoder_hidden_states = _split_model_outputs(
|
||||||
decoder_hidden_states, outputs.decoder_hidden_states, cur_len, last_matching_idx
|
decoder_hidden_states, outputs.decoder_hidden_states, cur_len, added_len
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
decoder_hidden_states = _split_model_outputs(
|
decoder_hidden_states = _split_model_outputs(
|
||||||
decoder_hidden_states, outputs.hidden_states, cur_len, last_matching_idx
|
decoder_hidden_states, outputs.hidden_states, cur_len, added_len
|
||||||
)
|
)
|
||||||
|
|
||||||
# finished sentences should have their next token be a padding token
|
|
||||||
if eos_token_id is not None:
|
|
||||||
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
|
|
||||||
|
|
||||||
# update generated ids, model inputs, and length for next step
|
|
||||||
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
|
||||||
if streamer is not None:
|
|
||||||
streamer.put(next_tokens.cpu())
|
|
||||||
|
|
||||||
model_kwargs = self._update_model_kwargs_for_generation(
|
model_kwargs = self._update_model_kwargs_for_generation(
|
||||||
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
||||||
)
|
)
|
||||||
@@ -4457,7 +4440,10 @@ class GenerationMixin:
|
|||||||
# if eos_token was found in one sentence, set sentence to finished
|
# if eos_token was found in one sentence, set sentence to finished
|
||||||
if eos_token_id_tensor is not None:
|
if eos_token_id_tensor is not None:
|
||||||
unfinished_sequences = unfinished_sequences.mul(
|
unfinished_sequences = unfinished_sequences.mul(
|
||||||
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
|
input_ids[:, -1]
|
||||||
|
.tile(eos_token_id_tensor.shape[0], 1)
|
||||||
|
.ne(eos_token_id_tensor.unsqueeze(1))
|
||||||
|
.prod(dim=0)
|
||||||
)
|
)
|
||||||
|
|
||||||
# stop when each sentence is finished
|
# stop when each sentence is finished
|
||||||
@@ -4531,7 +4517,7 @@ def _crop_past_key_values(model, past_key_values, maximum_length):
|
|||||||
return past_key_values
|
return past_key_values
|
||||||
|
|
||||||
|
|
||||||
def _split_model_outputs(outputs, new_outputs, previous_cur_len, last_matching_idx, is_decoder_attention=False):
|
def _split_model_outputs(outputs, new_outputs, cur_len, added_len, is_decoder_attention=False):
|
||||||
"""
|
"""
|
||||||
Given the (decoder/cross attentions)/(decoder hidden states) for multiple generated tokens, splits it into a tuple
|
Given the (decoder/cross attentions)/(decoder hidden states) for multiple generated tokens, splits it into a tuple
|
||||||
where each member corresponds to a single generated token.
|
where each member corresponds to a single generated token.
|
||||||
@@ -4541,16 +4527,17 @@ def _split_model_outputs(outputs, new_outputs, previous_cur_len, last_matching_i
|
|||||||
if len(outputs) == 0:
|
if len(outputs) == 0:
|
||||||
new_tuple = ()
|
new_tuple = ()
|
||||||
for layer in new_outputs:
|
for layer in new_outputs:
|
||||||
last_dim_size = previous_cur_len if is_decoder_attention else layer.shape[-1]
|
last_dim_size = cur_len if is_decoder_attention else layer.shape[-1]
|
||||||
new_tuple += (layer[..., :previous_cur_len, :last_dim_size],)
|
new_tuple += (layer[..., :cur_len, :last_dim_size],)
|
||||||
outputs += (new_tuple,)
|
outputs += (new_tuple,)
|
||||||
last_matching_idx -= previous_cur_len
|
# The first iteration contains the prompt + 1 generated token, let's update the length variables accordingly
|
||||||
previous_cur_len += 1
|
cur_len += 1
|
||||||
|
added_len -= cur_len
|
||||||
|
|
||||||
for i in range(last_matching_idx + 1):
|
for i in range(added_len):
|
||||||
new_tuple = ()
|
new_tuple = ()
|
||||||
for layer in new_outputs:
|
for layer in new_outputs:
|
||||||
last_dim_size = previous_cur_len + i if is_decoder_attention else layer.shape[-1]
|
last_dim_size = cur_len + i if is_decoder_attention else layer.shape[-1]
|
||||||
new_tuple += (layer[..., i : i + 1, :last_dim_size],)
|
new_tuple += (layer[..., i : i + 1, :last_dim_size],)
|
||||||
outputs += (new_tuple,)
|
outputs += (new_tuple,)
|
||||||
return outputs
|
return outputs
|
||||||
|
|||||||
@@ -1518,8 +1518,8 @@ class GenerationTesterMixin:
|
|||||||
self._check_outputs(output, input_ids, model.config, use_cache=True)
|
self._check_outputs(output, input_ids, model.config, use_cache=True)
|
||||||
|
|
||||||
def test_assisted_decoding_sample(self):
|
def test_assisted_decoding_sample(self):
|
||||||
# Seeded assisted decoding will not match sample for the same seed, as there are >1 sampling steps per output
|
# Seeded assisted decoding will not match sample for the same seed, as the forward pass does not return the
|
||||||
# token. As such, this test only checks that the output format is correct.
|
# exact same logits (the forward pass of the main model, now with several tokens at once, has causal masking).
|
||||||
|
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
# won't fix: FSMT and Reformer have a different cache variable type (and format).
|
# won't fix: FSMT and Reformer have a different cache variable type (and format).
|
||||||
|
|||||||
Reference in New Issue
Block a user