Make mamba use cache (#31116)
* make mamba use cache * uss cache naming as in mamba * fix musicgen
This commit is contained in:
committed by
GitHub
parent
f5c0fa9f6f
commit
7729b77478
@@ -627,18 +627,22 @@ class GenerationMixin:
|
|||||||
|
|
||||||
def _extract_past_from_model_output(self, outputs: ModelOutput, standardize_cache_format: bool = False):
|
def _extract_past_from_model_output(self, outputs: ModelOutput, standardize_cache_format: bool = False):
|
||||||
past_key_values = None
|
past_key_values = None
|
||||||
|
cache_name = "past_key_values"
|
||||||
if "past_key_values" in outputs:
|
if "past_key_values" in outputs:
|
||||||
past_key_values = outputs.past_key_values
|
past_key_values = outputs.past_key_values
|
||||||
elif "mems" in outputs:
|
elif "mems" in outputs:
|
||||||
past_key_values = outputs.mems
|
past_key_values = outputs.mems
|
||||||
elif "past_buckets_states" in outputs:
|
elif "past_buckets_states" in outputs:
|
||||||
past_key_values = outputs.past_buckets_states
|
past_key_values = outputs.past_buckets_states
|
||||||
|
elif "cache_params" in outputs:
|
||||||
|
past_key_values = outputs.cache_params
|
||||||
|
cache_name = "cache_params"
|
||||||
|
|
||||||
# Bloom fix: standardizes the cache format when requested
|
# Bloom fix: standardizes the cache format when requested
|
||||||
if standardize_cache_format and hasattr(self, "_convert_to_standard_cache"):
|
if standardize_cache_format and hasattr(self, "_convert_to_standard_cache"):
|
||||||
batch_size = outputs.logits.shape[0]
|
batch_size = outputs.logits.shape[0]
|
||||||
past_key_values = self._convert_to_standard_cache(past_key_values, batch_size=batch_size)
|
past_key_values = self._convert_to_standard_cache(past_key_values, batch_size=batch_size)
|
||||||
return past_key_values
|
return cache_name, past_key_values
|
||||||
|
|
||||||
def _update_model_kwargs_for_generation(
|
def _update_model_kwargs_for_generation(
|
||||||
self,
|
self,
|
||||||
@@ -648,10 +652,11 @@ class GenerationMixin:
|
|||||||
standardize_cache_format: bool = False,
|
standardize_cache_format: bool = False,
|
||||||
num_new_tokens: int = 1,
|
num_new_tokens: int = 1,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
# update past_key_values
|
# update past_key_values keeping its naming used in model code
|
||||||
model_kwargs["past_key_values"] = self._extract_past_from_model_output(
|
cache_name, cache = self._extract_past_from_model_output(
|
||||||
outputs, standardize_cache_format=standardize_cache_format
|
outputs, standardize_cache_format=standardize_cache_format
|
||||||
)
|
)
|
||||||
|
model_kwargs[cache_name] = cache
|
||||||
if getattr(outputs, "state", None) is not None:
|
if getattr(outputs, "state", None) is not None:
|
||||||
model_kwargs["state"] = outputs.state
|
model_kwargs["state"] = outputs.state
|
||||||
|
|
||||||
@@ -2401,7 +2406,7 @@ class GenerationMixin:
|
|||||||
next_past_key_values = selected_outputs["past_key_values"]
|
next_past_key_values = selected_outputs["past_key_values"]
|
||||||
|
|
||||||
else:
|
else:
|
||||||
next_past_key_values = self._extract_past_from_model_output(outputs, standardize_cache_format=True)
|
_, next_past_key_values = self._extract_past_from_model_output(outputs, standardize_cache_format=True)
|
||||||
# Do it in-place layer per layer to save memory
|
# Do it in-place layer per layer to save memory
|
||||||
if isinstance(next_past_key_values, DynamicCache):
|
if isinstance(next_past_key_values, DynamicCache):
|
||||||
next_past_key_values.batch_select_indices(augmented_idx)
|
next_past_key_values.batch_select_indices(augmented_idx)
|
||||||
|
|||||||
@@ -633,7 +633,12 @@ class MambaForCausalLM(MambaPreTrainedModel):
|
|||||||
return model_kwargs
|
return model_kwargs
|
||||||
|
|
||||||
def prepare_inputs_for_generation(
|
def prepare_inputs_for_generation(
|
||||||
self, input_ids, cache_params: Optional[MambaCache] = None, inputs_embeds=None, attention_mask=None, **kwargs
|
self,
|
||||||
|
input_ids,
|
||||||
|
inputs_embeds=None,
|
||||||
|
use_cache=None,
|
||||||
|
cache_params: Optional[MambaCache] = None,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
# only last token for inputs_ids if the state is passed along.
|
# only last token for inputs_ids if the state is passed along.
|
||||||
if cache_params is not None:
|
if cache_params is not None:
|
||||||
@@ -644,7 +649,12 @@ class MambaForCausalLM(MambaPreTrainedModel):
|
|||||||
else:
|
else:
|
||||||
model_inputs = {"input_ids": input_ids}
|
model_inputs = {"input_ids": input_ids}
|
||||||
|
|
||||||
model_inputs["cache_params"] = cache_params
|
model_inputs.update(
|
||||||
|
{
|
||||||
|
"cache_params": cache_params,
|
||||||
|
"use_cache": use_cache,
|
||||||
|
}
|
||||||
|
)
|
||||||
return model_inputs
|
return model_inputs
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(MAMBA_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(MAMBA_INPUTS_DOCSTRING)
|
||||||
|
|||||||
@@ -2787,9 +2787,11 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
|
|||||||
model_inputs: Optional[Dict[str, Any]] = None,
|
model_inputs: Optional[Dict[str, Any]] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
# update past_key_values
|
# update past_key_values
|
||||||
model_kwargs["past_key_values"] = self._extract_past_from_model_output(
|
cache_name, cache = self._extract_past_from_model_output(
|
||||||
outputs, standardize_cache_format=standardize_cache_format
|
outputs, standardize_cache_format=standardize_cache_format
|
||||||
)
|
)
|
||||||
|
model_kwargs[cache_name] = cache
|
||||||
|
|
||||||
if getattr(outputs, "state", None) is not None:
|
if getattr(outputs, "state", None) is not None:
|
||||||
model_kwargs["state"] = outputs.state
|
model_kwargs["state"] = outputs.state
|
||||||
|
|
||||||
|
|||||||
@@ -447,10 +447,9 @@ class MambaIntegrationTests(unittest.TestCase):
|
|||||||
|
|
||||||
model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf", torch_dtype=torch.float16)
|
model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf", torch_dtype=torch.float16)
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.config.use_cache = True
|
|
||||||
input_ids = tokenizer("Hey how are you doing?", return_tensors="pt")["input_ids"].to(device)
|
input_ids = tokenizer("Hey how are you doing?", return_tensors="pt")["input_ids"].to(device)
|
||||||
|
|
||||||
out = model.generate(input_ids, do_sample=False, max_new_tokens=10)
|
out = model.generate(input_ids, do_sample=False, use_cache=True, max_new_tokens=10)
|
||||||
output_sentence = tokenizer.decode(out[0, :])
|
output_sentence = tokenizer.decode(out[0, :])
|
||||||
self.assertEqual(output_sentence, "Hey how are you doing?\n\nI'm so glad you're here.")
|
self.assertEqual(output_sentence, "Hey how are you doing?\n\nI'm so glad you're here.")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user