Fix: StaticCache & inputs_embeds (#32932)

squash commit
This commit is contained in:
Raushan Turganbay
2024-09-06 09:56:59 +02:00
committed by GitHub
parent 5792c459ed
commit 1759bb9126
3 changed files with 169 additions and 8 deletions

View File

@@ -1453,6 +1453,9 @@ class GenerationTesterMixin:
model = model_class(config).to(torch_device).eval()
signature = inspect.signature(model.forward).parameters.keys()
# no cache as some models require special cache classes to be init outside forward
model.generation_config.use_cache = False
# Without padding
model_kwargs = _prepare_model_kwargs(input_ids, attention_mask, signature)
next_logits_wo_padding = model(**model_kwargs).logits[:, -1, :]
@@ -1593,6 +1596,59 @@ class GenerationTesterMixin:
outputs_from_embeds_wo_ids.tolist(),
)
@pytest.mark.generate
def test_generate_from_inputs_embeds_with_static_cache(self):
"""
Test that StaticCache can generate from inputs_embeds and calculates max_cache_length
correctly in `generate()`. We force the model to not stop generation until max-length is reached
to verify that the cache length is indeed set correctly and we don't run out of index when slicing the cache.
"""
for model_class in self.all_generative_model_classes:
if not model_class._supports_static_cache:
self.skipTest(reason="This model does not support the static cache format")
config, input_ids, attention_mask = self._get_input_ids_and_config()
if config.is_encoder_decoder:
self.skipTest(reason="This model is encoder-decoder and has Encoder-Decoder Cache")
model = model_class(config).to(torch_device).eval()
if "inputs_embeds" not in inspect.signature(model.prepare_inputs_for_generation).parameters.keys():
self.skipTest(reason="This model does not support `inputs_embeds` in generation")
model.config.use_cache = True
model.config.is_decoder = True
batch_size, seq_length = input_ids.shape
max_cache_len = 30
# here we force to not stop at eos and go until max-length
model.generation_config.eos_token_id = model.config.eos_token_id = -1
generation_kwargs = {
"max_length": max_cache_len,
"cache_implementation": "static",
"return_dict_in_generate": True, # Required to return `past_key_values`
}
head_dim = (
model.config.head_dim
if hasattr(model.config, "head_dim")
else model.config.hidden_size // model.config.num_attention_heads
)
num_key_value_heads = (
model.config.num_attention_heads
if getattr(config, "num_key_value_heads", None) is None
else model.config.num_key_value_heads
)
num_hidden_layers = config.num_hidden_layers
inputs_embeds = model.get_input_embeddings()(input_ids)
outputs = model.generate(inputs_embeds=inputs_embeds, attention_mask=attention_mask, **generation_kwargs)
# we should get `max_length` in shape, not `max_length - embeds_length`
cache_shape = (batch_size, num_key_value_heads, max_cache_len, head_dim)
self.assertTrue(isinstance(outputs.past_key_values, StaticCache))
self.assertTrue(len(outputs.past_key_values.key_cache) == num_hidden_layers)
self.assertTrue(outputs.past_key_values.key_cache[0].shape == cache_shape)
@pytest.mark.generate
def test_generate_continue_from_past_key_values(self):
# Tests that we can continue generating from past key values, returned from a previous `generate` call