Cache: new Cache format in decoder-only models (#31421)

* draft bart with new cache

* add cache for decoder-only models

* revert utils

* modify docstring

* revert bart

* minor fixes

* fix copies (not related)

* revert tests

* remove enc-dec related code

* remove bloom

* remove opt (enc-dec)

* update docstring

* git, codegen, gpt_neo, gpt_neox, gpj

* clean up

* copied from statements

* revert

* tmp

* update warning msg

* forgot git

* add more flags

* run-slow git,codegen,gpt_neo,gpt_neox,gpj

* add cache flag to VLMs

* remove files

* style

* video LLMs also need a flag

* style

* llava will go in another PR

* style

* [run-slow] codegen, falcon, git, gpt_neo, gpt_neox, gptj, idefics

* Update src/transformers/models/gpt_neo/modeling_gpt_neo.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* copy from

* deprecate until v4.45 and warn if not training

* nit

* fix test

* test static cache

* add more tests and fix models

* fix copies

* return sliding window mask

* run slow tests & fix + codestyle

* one more falcon fix for alibi

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Raushan Turganbay
2024-08-07 10:02:16 +05:00
committed by GitHub
parent 6af0854efa
commit a30c865f99
11 changed files with 1915 additions and 781 deletions

View File

@@ -59,7 +59,7 @@ if is_torch_available():
ImageGPTForCausalImageModeling,
SpeechEncoderDecoderModel,
)
from transformers.cache_utils import DynamicCache, EncoderDecoderCache, QuantoQuantizedCache
from transformers.cache_utils import DynamicCache, EncoderDecoderCache, QuantoQuantizedCache, StaticCache
from transformers.generation import (
BeamSampleDecoderOnlyOutput,
BeamSampleEncoderDecoderOutput,
@@ -1769,6 +1769,53 @@ class GenerationTesterMixin:
)
)
def test_generate_with_static_cache(self):
"""
Tests if StaticCache works if we set attn_implementation=static when generation.
This doesn't test if generation quality is good, but tests that models with
self._supports_static_cache don't throw an error when generating and return
a StaticCache object at the end.
"""
for model_class in self.all_generative_model_classes:
if not model_class._supports_static_cache:
self.skipTest(reason="This model does not support the static cache format")
config, input_ids, attention_mask = self._get_input_ids_and_config()
if config.is_encoder_decoder:
self.skipTest(reason="This model is encoder-decoder and has Encoder-Decoder Cache")
config.use_cache = True
config.is_decoder = True
batch_size, seq_length = input_ids.shape
max_new_tokens = 20
model = model_class(config).to(torch_device).eval()
generation_kwargs = {
"max_length": None,
"max_new_tokens": max_new_tokens,
"cache_implementation": "static",
"return_dict_in_generate": True, # Required to return `past_key_values`
}
max_cache_len = seq_length + max_new_tokens
head_dim = (
model.config.head_dim
if hasattr(model.config, "head_dim")
else model.config.hidden_size // model.config.num_attention_heads
)
num_key_value_heads = (
model.config.num_attention_heads
if getattr(config, "num_key_value_heads", None) is None
else model.config.num_key_value_heads
)
num_hidden_layers = config.num_hidden_layers
results = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
cache_shape = (batch_size, num_key_value_heads, max_cache_len, head_dim)
self.assertTrue(isinstance(results.past_key_values, StaticCache))
self.assertTrue(len(results.past_key_values.key_cache) == num_hidden_layers)
self.assertTrue(results.past_key_values.key_cache[0].shape == cache_shape)
@require_quanto
def test_generate_with_quant_cache(self):
for model_class in self.all_generative_model_classes: