From 849367ccf741d8c58aa88ccfe1d52d8636eaf2b7 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Sat, 29 Apr 2023 10:53:30 +0100 Subject: [PATCH] Generate: prepare assisted generation for release (#23052) --- src/transformers/generation/utils.py | 139 ++++++++++++--------------- tests/generation/test_utils.py | 4 +- 2 files changed, 65 insertions(+), 78 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 4a0e621de1..06836d4d4a 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -4177,7 +4177,6 @@ class GenerationMixin: >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) ["It might be possible to get a better understanding of the nature of the problem, but it's not"] ```""" - # NOTE: the code here is copy/paste from greedy search/sample, except when clearly stated in the comments # Assistant: initialize assistant-related variables if not hasattr(assistant_model, "max_assistant_tokens"): assistant_model.max_assistant_tokens = 5 # this value, which will be updated, persists across calls @@ -4248,20 +4247,20 @@ class GenerationMixin: prev_seq_len = model_kwargs["assistant_past_key_values"][0][0].shape[2] # `new_token_len` can be 1 or 2 (next token in assistant + last token picked by the larger model) new_token_len = candidate_input_ids.shape[1] - prev_seq_len - tmp_inputs = candidate_input_ids[:, -new_token_len:] - tmp_attn = torch.ones_like(candidate_input_ids) + assist_inputs = candidate_input_ids[:, -new_token_len:] + assist_attn = torch.ones_like(candidate_input_ids) # TODO (joao): make it compatible with models that use unconventional fwd pass logic, like blip2 if assistant_model.config.is_encoder_decoder: assistant_model_outputs = assistant_model( - decoder_input_ids=tmp_inputs, - decoder_attention_mask=tmp_attn, + decoder_input_ids=assist_inputs, + decoder_attention_mask=assist_attn, past_key_values=model_kwargs["assistant_past_key_values"], encoder_outputs=model_kwargs["assistant_encoder_outputs"], ) else: assistant_model_outputs = assistant_model( - tmp_inputs, - attention_mask=tmp_attn, + assist_inputs, + attention_mask=assist_attn, past_key_values=model_kwargs["assistant_past_key_values"], ) else: @@ -4296,16 +4295,17 @@ class GenerationMixin: candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1] # 2. Use the original model to obtain the next token logits given the candidate sequence. We obtain - # `candidate_length + 1` relevant logits from this process (see step 7 on why the +1) + # `candidate_length + 1` relevant logits from this process: in the event that all candidates are correct, + # we use this forward pass to also pick the subsequent logits in the original model. # 2.1. Run a forward pass on the candidate sequence if "past_key_values" in model_kwargs: - og_model_attn = torch.ones_like(candidate_input_ids) - og_model_input_ids = candidate_input_ids[:, -candidate_length - 1 :] + model_attn = torch.ones_like(candidate_input_ids) + model_input_ids = candidate_input_ids[:, -candidate_length - 1 :] if self.config.is_encoder_decoder: outputs = self( - decoder_input_ids=og_model_input_ids, - decoder_attention_mask=og_model_attn, + decoder_input_ids=model_input_ids, + decoder_attention_mask=model_attn, past_key_values=model_kwargs["past_key_values"], encoder_outputs=model_kwargs["encoder_outputs"], output_attentions=output_attentions, @@ -4313,8 +4313,8 @@ class GenerationMixin: ) else: outputs = self( - og_model_input_ids, - attention_mask=og_model_attn, + model_input_ids, + attention_mask=model_attn, past_key_values=model_kwargs["past_key_values"], output_attentions=output_attentions, output_hidden_states=output_hidden_states, @@ -4343,21 +4343,43 @@ class GenerationMixin: for i in range(candidate_length): new_logits[:, i, :] = logits_warper(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :]) - # 3. Obtain the next tokens from the original model logits. If `do_sample` is True, use multinomial - # sampling, otherwise use argmax. + # 3. Obtain the next tokens from the original model logits. if do_sample: probs = new_logits[:, -candidate_length - 1 :, :].softmax(dim=-1) - sampled_tokens = torch.multinomial(probs[0, :, :], num_samples=1).squeeze(1)[None, :] - next_tokens = sampled_tokens[:, :-1] + selected_tokens = torch.multinomial(probs[0, :, :], num_samples=1).squeeze(1)[None, :] else: - next_tokens = new_logits[:, -candidate_length - 1 : -1, :].argmax(dim=-1) + selected_tokens = new_logits[:, -candidate_length - 1 :, :].argmax(dim=-1) # 4. Compare the argmax from the original model logits with the assistant forecasted tokens. We can keep # the assistant forecasted tokens until the first mismatch, or until the max length is reached. candidate_new_tokens = candidate_input_ids[:, -candidate_length:] - n_matches = ((~(candidate_new_tokens == next_tokens)).cumsum(dim=-1) < 1).sum() + n_matches = ((~(candidate_new_tokens == selected_tokens[:, :-1])).cumsum(dim=-1) < 1).sum() - # 5. Adjust the max number of assistant tokens to use in the next iteration. This is a simple heuristic, + # 5. Update variables according to the number of matching assistant tokens. Remember: the token generated + # by the model after the last candidate match is also valid, as it is generated from a correct sequence. + # Because of this last token, assisted generation search reduces to a normal greedy search/sample if there + # is no match. + + # 5.1. Ensure we don't generate beyond max_len or an EOS token + if last_assistant_token_is_eos and n_matches == candidate_length: + n_matches -= 1 + n_matches = min(n_matches, max_len - cur_len - 1) + + # 5.2. Get the valid continuation, after the matching tokens + valid_tokens = selected_tokens[:, : n_matches + 1] + input_ids = torch.cat((input_ids, valid_tokens), dim=-1) + if streamer is not None: + streamer.put(valid_tokens.cpu()) + new_cur_len = input_ids.shape[-1] + + # 5.3. Discard past key values relative to unused assistant tokens + new_cache_size = new_cur_len - 1 + outputs.past_key_values = _crop_past_key_values(self, outputs.past_key_values, new_cache_size) + model_kwargs["assistant_past_key_values"] = _crop_past_key_values( + assistant_model, model_kwargs["assistant_past_key_values"], new_cache_size - 1 + ) # the assistant does not have the token after the last match, hence the -1 + + # 6. Adjust the max number of assistant tokens to use in the next iteration. This is a simple heuristic, # probably can be improved -- we want to balance the benefits of getting assistant tokens correct with the # cost of forecasting incorrect assistant tokens. if n_matches == int(assistant_model.max_assistant_tokens): @@ -4365,37 +4387,7 @@ class GenerationMixin: else: assistant_model.max_assistant_tokens = max(1.0, assistant_model.max_assistant_tokens - 1.0) - # 6. Update variables according to the number of matching assistant tokens. - # 6.1. Ensure we don't generate beyond max_len or an EOS token (remember: one token will be added below) - n_matches = min(n_matches, max_len - cur_len - 1) - if last_assistant_token_is_eos and n_matches == candidate_length: - n_matches -= 1 - input_ids = candidate_input_ids[:, 0 : cur_len + n_matches] - new_cur_len = input_ids.shape[-1] - if streamer is not None: - streamer.put(candidate_input_ids[:, cur_len : cur_len + n_matches]) - - # 6.2. Discard past key values relative to unused assistant tokens - outputs.past_key_values = _crop_past_key_values(self, outputs.past_key_values, new_cur_len) - model_kwargs["assistant_past_key_values"] = _crop_past_key_values( - assistant_model, model_kwargs["assistant_past_key_values"], new_cur_len - ) - - # 6.3. Extract the logits for the next token - next_token_scores = new_logits[:, n_matches, :] - - # 7. Use the set of logits after the last matching assistant token to obtain the next token. Note that, - # because of this step, assisted generation search reduces to a normal greedy search/sample if there is no - # match. - if do_sample: - probs = probs[:, n_matches, :] - next_tokens = sampled_tokens[:, n_matches] - else: - next_tokens = torch.argmax(next_token_scores, dim=-1) - - # Assistant: main logic end; Compared to greedy search/sample, the following (redundant) blocks were - # removed below: (1) model input preparation; (2) model forward pass; (3) score preparation; (4) model - # cache update. + # Assistant: main logic end if synced_gpus and this_peer_finished: continue # don't waste resources running the code we don't need @@ -4407,20 +4399,20 @@ class GenerationMixin: scores += tuple(new_logits[:, i, :] for i in range(n_matches + 1)) if "past_key_values" not in model_kwargs: - last_matching_idx = new_cur_len - 1 + added_len = new_cur_len else: - last_matching_idx = n_matches + added_len = n_matches + 1 if output_attentions: if self.config.is_encoder_decoder: cross_attentions = _split_model_outputs( - cross_attentions, outputs.cross_attentions, cur_len, last_matching_idx + cross_attentions, outputs.cross_attentions, cur_len, added_len ) decoder_attentions = _split_model_outputs( decoder_attentions, outputs.decoder_attentions, cur_len, - last_matching_idx, + added_len, is_decoder_attention=True, ) else: @@ -4428,28 +4420,19 @@ class GenerationMixin: decoder_attentions, outputs.attentions, cur_len, - last_matching_idx, + added_len, is_decoder_attention=True, ) if output_hidden_states: if self.config.is_encoder_decoder: decoder_hidden_states = _split_model_outputs( - decoder_hidden_states, outputs.decoder_hidden_states, cur_len, last_matching_idx + decoder_hidden_states, outputs.decoder_hidden_states, cur_len, added_len ) else: decoder_hidden_states = _split_model_outputs( - decoder_hidden_states, outputs.hidden_states, cur_len, last_matching_idx + decoder_hidden_states, outputs.hidden_states, cur_len, added_len ) - # finished sentences should have their next token be a padding token - if eos_token_id is not None: - next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) - - # update generated ids, model inputs, and length for next step - input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) - if streamer is not None: - streamer.put(next_tokens.cpu()) - model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder ) @@ -4457,7 +4440,10 @@ class GenerationMixin: # if eos_token was found in one sentence, set sentence to finished if eos_token_id_tensor is not None: unfinished_sequences = unfinished_sequences.mul( - next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) + input_ids[:, -1] + .tile(eos_token_id_tensor.shape[0], 1) + .ne(eos_token_id_tensor.unsqueeze(1)) + .prod(dim=0) ) # stop when each sentence is finished @@ -4531,7 +4517,7 @@ def _crop_past_key_values(model, past_key_values, maximum_length): return past_key_values -def _split_model_outputs(outputs, new_outputs, previous_cur_len, last_matching_idx, is_decoder_attention=False): +def _split_model_outputs(outputs, new_outputs, cur_len, added_len, is_decoder_attention=False): """ Given the (decoder/cross attentions)/(decoder hidden states) for multiple generated tokens, splits it into a tuple where each member corresponds to a single generated token. @@ -4541,16 +4527,17 @@ def _split_model_outputs(outputs, new_outputs, previous_cur_len, last_matching_i if len(outputs) == 0: new_tuple = () for layer in new_outputs: - last_dim_size = previous_cur_len if is_decoder_attention else layer.shape[-1] - new_tuple += (layer[..., :previous_cur_len, :last_dim_size],) + last_dim_size = cur_len if is_decoder_attention else layer.shape[-1] + new_tuple += (layer[..., :cur_len, :last_dim_size],) outputs += (new_tuple,) - last_matching_idx -= previous_cur_len - previous_cur_len += 1 + # The first iteration contains the prompt + 1 generated token, let's update the length variables accordingly + cur_len += 1 + added_len -= cur_len - for i in range(last_matching_idx + 1): + for i in range(added_len): new_tuple = () for layer in new_outputs: - last_dim_size = previous_cur_len + i if is_decoder_attention else layer.shape[-1] + last_dim_size = cur_len + i if is_decoder_attention else layer.shape[-1] new_tuple += (layer[..., i : i + 1, :last_dim_size],) outputs += (new_tuple,) return outputs diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 825fa28993..0dfb4368d7 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1518,8 +1518,8 @@ class GenerationTesterMixin: self._check_outputs(output, input_ids, model.config, use_cache=True) def test_assisted_decoding_sample(self): - # Seeded assisted decoding will not match sample for the same seed, as there are >1 sampling steps per output - # token. As such, this test only checks that the output format is correct. + # Seeded assisted decoding will not match sample for the same seed, as the forward pass does not return the + # exact same logits (the forward pass of the main model, now with several tokens at once, has causal masking). for model_class in self.all_generative_model_classes: # won't fix: FSMT and Reformer have a different cache variable type (and format).