Fix: StaticCache & inputs_embeds (#32932)

squash commit
This commit is contained in:
Raushan Turganbay
2024-09-06 09:56:59 +02:00
committed by GitHub
parent 5792c459ed
commit 1759bb9126
3 changed files with 169 additions and 8 deletions

View File

@@ -1481,6 +1481,7 @@ class GenerationMixin:
model_kwargs: Dict, model_kwargs: Dict,
assistant_model: "PreTrainedModel", assistant_model: "PreTrainedModel",
batch_size: int, batch_size: int,
max_cache_length: int,
device: torch.device, device: torch.device,
) -> bool: ) -> bool:
""" """
@@ -1547,8 +1548,8 @@ class GenerationMixin:
) )
model_kwargs[cache_name] = self._get_cache( model_kwargs[cache_name] = self._get_cache(
cache_implementation=generation_config.cache_implementation, cache_implementation=generation_config.cache_implementation,
batch_size=generation_config.num_beams * generation_config.num_return_sequences * batch_size, batch_size=max(generation_config.num_beams, generation_config.num_return_sequences) * batch_size,
max_cache_len=generation_config.max_length, max_cache_len=max_cache_length,
device=device, device=device,
model_kwargs=model_kwargs, model_kwargs=model_kwargs,
) )
@@ -1888,7 +1889,16 @@ class GenerationMixin:
# TODO (joao): remove `user_defined_cache` after v4.47 (remove default conversion to legacy format) # TODO (joao): remove `user_defined_cache` after v4.47 (remove default conversion to legacy format)
cache_name = "past_key_values" if "mamba" not in self.__class__.__name__.lower() else "cache_params" cache_name = "past_key_values" if "mamba" not in self.__class__.__name__.lower() else "cache_params"
user_defined_cache = model_kwargs.get(cache_name) user_defined_cache = model_kwargs.get(cache_name)
self._prepare_cache_for_generation(generation_config, model_kwargs, assistant_model, batch_size, device) max_cache_length = generation_config.max_length
if (
inputs_tensor.shape[1] != input_ids_length
and model_input_name == "inputs_embeds"
and not self.config.is_encoder_decoder
):
max_cache_length += inputs_tensor.shape[1]
self._prepare_cache_for_generation(
generation_config, model_kwargs, assistant_model, batch_size, max_cache_length, device
)
# 8. determine generation mode # 8. determine generation mode
generation_mode = generation_config.get_generation_mode(assistant_model) generation_mode = generation_config.get_generation_mode(assistant_model)
@@ -1936,8 +1946,8 @@ class GenerationMixin:
raise ValueError("assisted generate is only supported for batch_size = 1") raise ValueError("assisted generate is only supported for batch_size = 1")
if not model_kwargs["use_cache"]: if not model_kwargs["use_cache"]:
raise ValueError("assisted generate requires `use_cache=True`") raise ValueError("assisted generate requires `use_cache=True`")
if generation_config.cache_implementation == "static": if generation_config.cache_implementation in ["static", "hybrid", "sliding_window"]:
raise ValueError("assisted generate is not supported with `static_cache`") raise ValueError("assisted generate is not supported with Static cache classes`")
if self._is_stateful: if self._is_stateful:
# In assisted generation we need the ability to confirm whether the model would pick certain tokens, # In assisted generation we need the ability to confirm whether the model would pick certain tokens,
# which is not possible with stateful models (they can't reset to a previous subset of generated text) # which is not possible with stateful models (they can't reset to a previous subset of generated text)

View File

@@ -1453,6 +1453,9 @@ class GenerationTesterMixin:
model = model_class(config).to(torch_device).eval() model = model_class(config).to(torch_device).eval()
signature = inspect.signature(model.forward).parameters.keys() signature = inspect.signature(model.forward).parameters.keys()
# no cache as some models require special cache classes to be init outside forward
model.generation_config.use_cache = False
# Without padding # Without padding
model_kwargs = _prepare_model_kwargs(input_ids, attention_mask, signature) model_kwargs = _prepare_model_kwargs(input_ids, attention_mask, signature)
next_logits_wo_padding = model(**model_kwargs).logits[:, -1, :] next_logits_wo_padding = model(**model_kwargs).logits[:, -1, :]
@@ -1593,6 +1596,59 @@ class GenerationTesterMixin:
outputs_from_embeds_wo_ids.tolist(), outputs_from_embeds_wo_ids.tolist(),
) )
@pytest.mark.generate
def test_generate_from_inputs_embeds_with_static_cache(self):
"""
Test that StaticCache can generate from inputs_embeds and calculates max_cache_length
correctly in `generate()`. We force the model to not stop generation until max-length is reached
to verify that the cache length is indeed set correctly and we don't run out of index when slicing the cache.
"""
for model_class in self.all_generative_model_classes:
if not model_class._supports_static_cache:
self.skipTest(reason="This model does not support the static cache format")
config, input_ids, attention_mask = self._get_input_ids_and_config()
if config.is_encoder_decoder:
self.skipTest(reason="This model is encoder-decoder and has Encoder-Decoder Cache")
model = model_class(config).to(torch_device).eval()
if "inputs_embeds" not in inspect.signature(model.prepare_inputs_for_generation).parameters.keys():
self.skipTest(reason="This model does not support `inputs_embeds` in generation")
model.config.use_cache = True
model.config.is_decoder = True
batch_size, seq_length = input_ids.shape
max_cache_len = 30
# here we force to not stop at eos and go until max-length
model.generation_config.eos_token_id = model.config.eos_token_id = -1
generation_kwargs = {
"max_length": max_cache_len,
"cache_implementation": "static",
"return_dict_in_generate": True, # Required to return `past_key_values`
}
head_dim = (
model.config.head_dim
if hasattr(model.config, "head_dim")
else model.config.hidden_size // model.config.num_attention_heads
)
num_key_value_heads = (
model.config.num_attention_heads
if getattr(config, "num_key_value_heads", None) is None
else model.config.num_key_value_heads
)
num_hidden_layers = config.num_hidden_layers
inputs_embeds = model.get_input_embeddings()(input_ids)
outputs = model.generate(inputs_embeds=inputs_embeds, attention_mask=attention_mask, **generation_kwargs)
# we should get `max_length` in shape, not `max_length - embeds_length`
cache_shape = (batch_size, num_key_value_heads, max_cache_len, head_dim)
self.assertTrue(isinstance(outputs.past_key_values, StaticCache))
self.assertTrue(len(outputs.past_key_values.key_cache) == num_hidden_layers)
self.assertTrue(outputs.past_key_values.key_cache[0].shape == cache_shape)
@pytest.mark.generate @pytest.mark.generate
def test_generate_continue_from_past_key_values(self): def test_generate_continue_from_past_key_values(self):
# Tests that we can continue generating from past key values, returned from a previous `generate` call # Tests that we can continue generating from past key values, returned from a previous `generate` call

View File

@@ -16,9 +16,10 @@
import unittest import unittest
from parameterized import parameterized
from pytest import mark from pytest import mark
from transformers import AutoModelForCausalLM, AutoTokenizer, Gemma2Config, is_torch_available, pipeline from transformers import AutoModelForCausalLM, AutoTokenizer, Gemma2Config, HybridCache, is_torch_available, pipeline
from transformers.testing_utils import ( from transformers.testing_utils import (
require_flash_attn, require_flash_attn,
require_read_token, require_read_token,
@@ -59,7 +60,7 @@ class Gemma2ModelTest(GemmaModelTest, unittest.TestCase):
if is_torch_available() if is_torch_available()
else () else ()
) )
all_generative_model_classes = () all_generative_model_classes = (Gemma2ForCausalLM,) if is_torch_available() else ()
pipeline_model_mapping = ( pipeline_model_mapping = (
{ {
"feature-extraction": Gemma2Model, "feature-extraction": Gemma2Model,
@@ -89,6 +90,101 @@ class Gemma2ModelTest(GemmaModelTest, unittest.TestCase):
def test_eager_matches_sdpa_inference(self): def test_eager_matches_sdpa_inference(self):
pass pass
@parameterized.expand([("random",), ("same",)])
@unittest.skip("Gemma2 has HybridCache which is not compatible with assisted decoding")
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
pass
@unittest.skip("Gemma2 has HybridCache which is not compatible with assisted decoding")
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
pass
@unittest.skip("Gemma2 has HybridCache which is not compatible with assisted decoding")
def test_assisted_decoding_sample(self):
pass
@unittest.skip("Gemma2 has HybridCache which is not compatible with dola decoding")
def test_dola_decoding_sample(self):
pass
@parameterized.expand([(1, False), (1, True), (4, False)])
@unittest.skip("Gemma2 has HybridCache and doesn't support old tuple format at all")
def test_new_cache_format(self, num_beams, do_sample):
pass
@unittest.skip("Gemma2 has HybridCache and doesn't support continue from past kv")
def test_generate_continue_from_past_key_values(self):
pass
@unittest.skip("Gemma2 has HybridCache and doesn't support low_memory generation")
def test_beam_search_low_memory(self):
pass
@unittest.skip("Gemma2 has HybridCache and doesn't support contrastive generation")
def test_contrastive_generate(self):
pass
@unittest.skip("Gemma2 has HybridCache and doesn't support contrastive generation")
def test_contrastive_generate_dict_outputs_use_cache(self):
pass
@unittest.skip("Gemma2 has HybridCache and doesn't support contrastive generation")
def test_contrastive_generate_low_memory(self):
pass
@unittest.skip("Gemma2 has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.")
def test_generate_with_static_cache(self):
pass
@unittest.skip("Gemma2 has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.")
def test_generate_from_inputs_embeds_with_static_cache(self):
pass
# overwrite because HybridCache has fixed length for key/values
def _check_attentions_for_generate(
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
):
self.assertIsInstance(attentions, tuple)
self.assertListEqual(
[isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions)
)
self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups)
for idx, iter_attentions in enumerate(attentions):
tgt_len = min_length + idx if not use_cache else 1
src_len = min_length + idx if not use_cache else max_length
expected_shape = (
batch_size * num_beam_groups,
config.num_attention_heads,
tgt_len,
src_len,
)
# check attn size
self.assertListEqual(
[layer_attention.shape for layer_attention in iter_attentions], [expected_shape] * len(iter_attentions)
)
# overwrite because HybridCache has fixed length for key/values
def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config, num_beam_groups=1):
self.assertIsInstance(past_key_values, HybridCache)
# check shape key, value (batch, head, max_seq_length, head_features)
head_dim = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
num_key_value_heads = (
config.num_attention_heads
if getattr(config, "num_key_value_heads", None) is None
else config.num_key_value_heads
)
num_hidden_layers = config.num_hidden_layers
# we should get `max_length` in shape, not `max_length - embeds_length`
# `+1` because the test in Mixin subtracts 1 which is needed for tuple cache
static_cache_shape = (batch_size, num_key_value_heads, seq_length + 1, head_dim)
static_layers = [layer_idx for layer_idx, boolean in enumerate(past_key_values.is_sliding) if not boolean]
self.assertTrue(len(past_key_values.key_cache) == num_hidden_layers)
self.assertTrue(past_key_values.key_cache[static_layers[0]].shape == static_cache_shape)
@unittest.skip("Gemma2's eager attn/sdpa attn outputs are expected to be different") @unittest.skip("Gemma2's eager attn/sdpa attn outputs are expected to be different")
def test_sdpa_equivalence(self): def test_sdpa_equivalence(self):
pass pass
@@ -203,6 +299,5 @@ class Gemma2IntegrationTest(unittest.TestCase):
output = model.generate(**inputs, max_new_tokens=100, do_sample=False) output = model.generate(**inputs, max_new_tokens=100, do_sample=False)
output_text = tokenizer.batch_decode(output, skip_special_tokens=False) output_text = tokenizer.batch_decode(output, skip_special_tokens=False)
print(output_text)
self.assertEqual(output_text, EXPECTED_TEXTS) self.assertEqual(output_text, EXPECTED_TEXTS)