From 77b59dce9fb804c72e3e9f3eaa01d53a905e6ada Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Tue, 23 Apr 2024 16:23:36 +0500 Subject: [PATCH] Fix on "cache position" for assisted generation (#30068) * clean commit history I hope * get kv seq length correctly * PR suggestions * Update src/transformers/testing_utils.py Co-authored-by: Joao Gante * add comment * give gpt bigcode it's own overriden method * remove code --------- Co-authored-by: Joao Gante --- src/transformers/generation/utils.py | 76 ++++++++++--------- .../gpt_bigcode/modeling_gpt_bigcode.py | 18 +++++ .../models/jamba/modeling_jamba.py | 10 +++ tests/generation/test_utils.py | 11 ++- 4 files changed, 77 insertions(+), 38 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index bf718932a4..9e6a58d3e5 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -641,6 +641,7 @@ class GenerationMixin: model_kwargs: Dict[str, Any], is_encoder_decoder: bool = False, standardize_cache_format: bool = False, + num_new_tokens: int = 1, ) -> Dict[str, Any]: # update past_key_values model_kwargs["past_key_values"] = self._extract_past_from_model_output( @@ -671,7 +672,7 @@ class GenerationMixin: ) if "cache_position" in model_kwargs and model_kwargs["cache_position"] is not None: - model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1 + model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens return model_kwargs @@ -1294,6 +1295,21 @@ class GenerationMixin: return generation_config, model_kwargs + def _get_initial_cache_position(self, input_ids, model_kwargs): + """Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length""" + past_length = 0 + if "past_key_values" in model_kwargs: + if isinstance(model_kwargs["past_key_values"], Cache): + past_length = model_kwargs["past_key_values"].get_seq_length() + else: + past_length = model_kwargs["past_key_values"][0][0].shape[2] + if "inputs_embeds" in model_kwargs: + cur_len = model_kwargs["inputs_embeds"].shape[1] + else: + cur_len = input_ids.shape[-1] + model_kwargs["cache_position"] = torch.arange(past_length, cur_len, device=input_ids.device) + return model_kwargs + @torch.no_grad() def generate( self, @@ -1560,6 +1576,8 @@ class GenerationMixin: raise ValueError("assisted generate is only supported for batch_size = 1") if not model_kwargs["use_cache"]: raise ValueError("assisted generate requires `use_cache=True`") + if generation_config.cache_implementation == "static": + raise ValueError("assisted generate is not supported with `static_cache`") # 11. Get the candidate generator, given the parameterization candidate_generator = self._get_candidate_generator( @@ -2024,11 +2042,9 @@ class GenerationMixin: ) # keep track of which sequences are already finished - batch_size, cur_len = input_ids.shape - if "inputs_embeds" in model_kwargs: - cur_len = model_kwargs["inputs_embeds"].shape[1] + batch_size = input_ids.shape[0] unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) - model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) + model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) this_peer_finished = False @@ -2495,12 +2511,10 @@ class GenerationMixin: ) # keep track of which sequences are already finished - batch_size, cur_len = input_ids.shape - if "inputs_embeds" in model_kwargs: - cur_len = model_kwargs["inputs_embeds"].shape[1] + batch_size = input_ids.shape[0] this_peer_finished = False unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) - model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) + model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): # prepare model inputs @@ -2792,12 +2806,10 @@ class GenerationMixin: ) # keep track of which sequences are already finished - batch_size, cur_len = input_ids.shape - if "inputs_embeds" in model_kwargs: - cur_len = model_kwargs["inputs_embeds"].shape[1] + batch_size = input_ids.shape[0] this_peer_finished = False unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) - model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) + model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): # prepare model inputs @@ -3108,9 +3120,7 @@ class GenerationMixin: num_beams = beam_scorer.num_beams batch_beam_size, cur_len = input_ids.shape - if "inputs_embeds" in model_kwargs: - cur_len = model_kwargs["inputs_embeds"].shape[1] - model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) + model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) if num_beams * batch_size != batch_beam_size: raise ValueError( @@ -3514,9 +3524,7 @@ class GenerationMixin: num_beams = beam_scorer.num_beams batch_beam_size, cur_len = input_ids.shape - if "inputs_embeds" in model_kwargs: - cur_len = model_kwargs["inputs_embeds"].shape[1] - model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) + model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) # init attention / hidden states / scores tuples scores = () if (return_dict_in_generate and output_scores) else None @@ -3874,9 +3882,7 @@ class GenerationMixin: device = input_ids.device batch_beam_size, cur_len = input_ids.shape - if "inputs_embeds" in model_kwargs: - cur_len = model_kwargs["inputs_embeds"].shape[1] - model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) + model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) if return_dict_in_generate and output_scores: beam_indices = [tuple(() for _ in range(num_sub_beams * batch_size)) for _ in range(num_beam_groups)] @@ -4292,9 +4298,7 @@ class GenerationMixin: num_beams = constrained_beam_scorer.num_beams batch_beam_size, cur_len = input_ids.shape - if "inputs_embeds" in model_kwargs: - cur_len = model_kwargs["inputs_embeds"].shape[1] - model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) + model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) if num_beams * batch_size != batch_beam_size: raise ValueError( @@ -4655,11 +4659,9 @@ class GenerationMixin: ) # keep track of which sequences are already finished - batch_size, cur_len = input_ids.shape - if "inputs_embeds" in model_kwargs: - cur_len = model_kwargs["inputs_embeds"].shape[1] + batch_size = input_ids.shape[0] unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) - model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) + model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) this_peer_finished = False while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): @@ -4679,20 +4681,21 @@ class GenerationMixin: # we use this forward pass to also pick the subsequent logits in the original model. # 2.1. Prepare the model inputs - model_kwargs = _prepare_attention_mask( - model_kwargs, candidate_input_ids.shape[1], self.config.is_encoder_decoder + candidate_kwargs = copy.copy(model_kwargs) + candidate_kwargs = _prepare_attention_mask( + candidate_kwargs, candidate_input_ids.shape[1], self.config.is_encoder_decoder ) - model_kwargs = _prepare_token_type_ids(model_kwargs, candidate_input_ids.shape[1]) - if "cache_position" in model_kwargs: - model_kwargs["cache_position"] = torch.cat( + candidate_kwargs = _prepare_token_type_ids(candidate_kwargs, candidate_input_ids.shape[1]) + if "cache_position" in candidate_kwargs: + candidate_kwargs["cache_position"] = torch.cat( ( - model_kwargs["cache_position"], + candidate_kwargs["cache_position"], torch.arange(cur_len, cur_len + candidate_length, device=input_ids.device, dtype=torch.long), ), dim=0, ) - model_inputs = self.prepare_inputs_for_generation(candidate_input_ids, **model_kwargs) + model_inputs = self.prepare_inputs_for_generation(candidate_input_ids, **candidate_kwargs) if "num_logits_to_keep" in model_inputs: model_inputs["num_logits_to_keep"] = candidate_length + 1 @@ -4811,6 +4814,7 @@ class GenerationMixin: outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, + num_new_tokens=n_matches + 1, ) unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 4e3b849848..d61877cb1f 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -1209,6 +1209,24 @@ class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel): ) return model_inputs + def _get_initial_cache_position(self, input_ids, model_kwargs): + """ + Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length. + Since gpt bigcode is special, the method is overridden here, other models use it from `generation.utils.py`. + """ + past_length = 0 + if "past_key_values" in model_kwargs: + if self.config.multi_query: + past_length = model_kwargs["past_key_values"][0].shape[1] + else: + past_length = model_kwargs["past_key_values"][0].shape[2] + if "inputs_embeds" in model_kwargs: + cur_len = model_kwargs["inputs_embeds"].shape[1] + else: + cur_len = input_ids.shape[-1] + model_kwargs["cache_position"] = torch.arange(past_length, cur_len, device=input_ids.device) + return model_kwargs + @add_start_docstrings_to_model_forward(GPT_BIGCODE_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index 9780d95d4e..dd4e3af1a0 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -231,6 +231,7 @@ class HybridMambaAttentionDynamicCache(DynamicCache): conv_kernel_size = config.mamba_d_conv self.conv_states = [] self.ssm_states = [] + self.transformer_layers = [] for i in range(config.num_hidden_layers): if self.layers_block_type[i] == "mamba": self.conv_states += [ @@ -242,6 +243,7 @@ class HybridMambaAttentionDynamicCache(DynamicCache): else: self.conv_states += [torch.tensor([[]] * batch_size, device=device)] self.ssm_states += [torch.tensor([[]] * batch_size, device=device)] + self.transformer_layers.append(i) self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] @@ -276,6 +278,14 @@ class HybridMambaAttentionDynamicCache(DynamicCache): device = self.ssm_states[layer_idx].device self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device)) + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # take any layer that contains cache and not empty tensor + layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx + if len(self.key_cache) <= layer_idx: + return 0 + return self.key_cache[layer_idx].shape[-2] + def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 215b258230..eacba9ebc6 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1091,8 +1091,9 @@ class GenerationTesterMixin: ) self.assertListEqual(low_output.tolist(), high_output.tolist()) + @parameterized.expand([("random",), ("same",)]) @is_flaky() # Read NOTE (1) below. If there are API issues, all attempts will fail. - def test_assisted_decoding_matches_greedy_search(self): + def test_assisted_decoding_matches_greedy_search(self, assistant_type): # This test ensures that the assisted generation does not introduce output changes over greedy search. # NOTE (1): The sentence above is true most of the time, there is a tiny difference in the logits due to matmul # shape differences -- and it may result in a different output. The input shape difference happens in the @@ -1151,7 +1152,13 @@ class GenerationTesterMixin: } output_greedy = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs) - assistant_model = model + # test with the same assistant model or randomly init one + # in the first case all candidate tokens are accepted, in the second none is accepted + # case when some are accepted and some not is hard to reproduce, so let's hope this catches most errors :) + if assistant_type == "random": + assistant_model = model_class(config).to(torch_device).eval() + else: + assistant_model = model assistant_model.generation_config.num_assistant_tokens = 2 # see b) assistant_model.generation_config.num_assistant_tokens_schedule = "constant" # see b) generation_kwargs.update({"assistant_model": assistant_model})