From 7729b7747872eabe461a8d2d4ce1068ae0e716a8 Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Thu, 6 Jun 2024 13:37:29 +0500 Subject: [PATCH] Make mamba use cache (#31116) * make mamba use cache * uss cache naming as in mamba * fix musicgen --- src/transformers/generation/utils.py | 13 +++++++++---- src/transformers/models/mamba/modeling_mamba.py | 14 ++++++++++++-- .../musicgen_melody/modeling_musicgen_melody.py | 4 +++- tests/models/mamba/test_modeling_mamba.py | 3 +-- 4 files changed, 25 insertions(+), 9 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 967980c714..535e82b8e0 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -627,18 +627,22 @@ class GenerationMixin: def _extract_past_from_model_output(self, outputs: ModelOutput, standardize_cache_format: bool = False): past_key_values = None + cache_name = "past_key_values" if "past_key_values" in outputs: past_key_values = outputs.past_key_values elif "mems" in outputs: past_key_values = outputs.mems elif "past_buckets_states" in outputs: past_key_values = outputs.past_buckets_states + elif "cache_params" in outputs: + past_key_values = outputs.cache_params + cache_name = "cache_params" # Bloom fix: standardizes the cache format when requested if standardize_cache_format and hasattr(self, "_convert_to_standard_cache"): batch_size = outputs.logits.shape[0] past_key_values = self._convert_to_standard_cache(past_key_values, batch_size=batch_size) - return past_key_values + return cache_name, past_key_values def _update_model_kwargs_for_generation( self, @@ -648,10 +652,11 @@ class GenerationMixin: standardize_cache_format: bool = False, num_new_tokens: int = 1, ) -> Dict[str, Any]: - # update past_key_values - model_kwargs["past_key_values"] = self._extract_past_from_model_output( + # update past_key_values keeping its naming used in model code + cache_name, cache = self._extract_past_from_model_output( outputs, standardize_cache_format=standardize_cache_format ) + model_kwargs[cache_name] = cache if getattr(outputs, "state", None) is not None: model_kwargs["state"] = outputs.state @@ -2401,7 +2406,7 @@ class GenerationMixin: next_past_key_values = selected_outputs["past_key_values"] else: - next_past_key_values = self._extract_past_from_model_output(outputs, standardize_cache_format=True) + _, next_past_key_values = self._extract_past_from_model_output(outputs, standardize_cache_format=True) # Do it in-place layer per layer to save memory if isinstance(next_past_key_values, DynamicCache): next_past_key_values.batch_select_indices(augmented_idx) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index be42ba2330..82cbef3033 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -633,7 +633,12 @@ class MambaForCausalLM(MambaPreTrainedModel): return model_kwargs def prepare_inputs_for_generation( - self, input_ids, cache_params: Optional[MambaCache] = None, inputs_embeds=None, attention_mask=None, **kwargs + self, + input_ids, + inputs_embeds=None, + use_cache=None, + cache_params: Optional[MambaCache] = None, + **kwargs, ): # only last token for inputs_ids if the state is passed along. if cache_params is not None: @@ -644,7 +649,12 @@ class MambaForCausalLM(MambaPreTrainedModel): else: model_inputs = {"input_ids": input_ids} - model_inputs["cache_params"] = cache_params + model_inputs.update( + { + "cache_params": cache_params, + "use_cache": use_cache, + } + ) return model_inputs @add_start_docstrings_to_model_forward(MAMBA_INPUTS_DOCSTRING) diff --git a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py index 6861349edc..2b6ff1b6be 100644 --- a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py @@ -2787,9 +2787,11 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel): model_inputs: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: # update past_key_values - model_kwargs["past_key_values"] = self._extract_past_from_model_output( + cache_name, cache = self._extract_past_from_model_output( outputs, standardize_cache_format=standardize_cache_format ) + model_kwargs[cache_name] = cache + if getattr(outputs, "state", None) is not None: model_kwargs["state"] = outputs.state diff --git a/tests/models/mamba/test_modeling_mamba.py b/tests/models/mamba/test_modeling_mamba.py index 3b77e26dcc..7aec7add11 100644 --- a/tests/models/mamba/test_modeling_mamba.py +++ b/tests/models/mamba/test_modeling_mamba.py @@ -447,10 +447,9 @@ class MambaIntegrationTests(unittest.TestCase): model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf", torch_dtype=torch.float16) model.to(device) - model.config.use_cache = True input_ids = tokenizer("Hey how are you doing?", return_tensors="pt")["input_ids"].to(device) - out = model.generate(input_ids, do_sample=False, max_new_tokens=10) + out = model.generate(input_ids, do_sample=False, use_cache=True, max_new_tokens=10) output_sentence = tokenizer.decode(out[0, :]) self.assertEqual(output_sentence, "Hey how are you doing?\n\nI'm so glad you're here.")