Cache: new Cache format in decoder-only models (#31421)
* draft bart with new cache * add cache for decoder-only models * revert utils * modify docstring * revert bart * minor fixes * fix copies (not related) * revert tests * remove enc-dec related code * remove bloom * remove opt (enc-dec) * update docstring * git, codegen, gpt_neo, gpt_neox, gpj * clean up * copied from statements * revert * tmp * update warning msg * forgot git * add more flags * run-slow git,codegen,gpt_neo,gpt_neox,gpj * add cache flag to VLMs * remove files * style * video LLMs also need a flag * style * llava will go in another PR * style * [run-slow] codegen, falcon, git, gpt_neo, gpt_neox, gptj, idefics * Update src/transformers/models/gpt_neo/modeling_gpt_neo.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * copy from * deprecate until v4.45 and warn if not training * nit * fix test * test static cache * add more tests and fix models * fix copies * return sliding window mask * run slow tests & fix + codestyle * one more falcon fix for alibi --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
6af0854efa
commit
a30c865f99
@@ -59,7 +59,7 @@ if is_torch_available():
|
||||
ImageGPTForCausalImageModeling,
|
||||
SpeechEncoderDecoderModel,
|
||||
)
|
||||
from transformers.cache_utils import DynamicCache, EncoderDecoderCache, QuantoQuantizedCache
|
||||
from transformers.cache_utils import DynamicCache, EncoderDecoderCache, QuantoQuantizedCache, StaticCache
|
||||
from transformers.generation import (
|
||||
BeamSampleDecoderOnlyOutput,
|
||||
BeamSampleEncoderDecoderOutput,
|
||||
@@ -1769,6 +1769,53 @@ class GenerationTesterMixin:
|
||||
)
|
||||
)
|
||||
|
||||
def test_generate_with_static_cache(self):
|
||||
"""
|
||||
Tests if StaticCache works if we set attn_implementation=static when generation.
|
||||
This doesn't test if generation quality is good, but tests that models with
|
||||
self._supports_static_cache don't throw an error when generating and return
|
||||
a StaticCache object at the end.
|
||||
"""
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if not model_class._supports_static_cache:
|
||||
self.skipTest(reason="This model does not support the static cache format")
|
||||
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
if config.is_encoder_decoder:
|
||||
self.skipTest(reason="This model is encoder-decoder and has Encoder-Decoder Cache")
|
||||
|
||||
config.use_cache = True
|
||||
config.is_decoder = True
|
||||
batch_size, seq_length = input_ids.shape
|
||||
max_new_tokens = 20
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
generation_kwargs = {
|
||||
"max_length": None,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"cache_implementation": "static",
|
||||
"return_dict_in_generate": True, # Required to return `past_key_values`
|
||||
}
|
||||
|
||||
max_cache_len = seq_length + max_new_tokens
|
||||
head_dim = (
|
||||
model.config.head_dim
|
||||
if hasattr(model.config, "head_dim")
|
||||
else model.config.hidden_size // model.config.num_attention_heads
|
||||
)
|
||||
num_key_value_heads = (
|
||||
model.config.num_attention_heads
|
||||
if getattr(config, "num_key_value_heads", None) is None
|
||||
else model.config.num_key_value_heads
|
||||
)
|
||||
num_hidden_layers = config.num_hidden_layers
|
||||
results = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
|
||||
|
||||
cache_shape = (batch_size, num_key_value_heads, max_cache_len, head_dim)
|
||||
self.assertTrue(isinstance(results.past_key_values, StaticCache))
|
||||
self.assertTrue(len(results.past_key_values.key_cache) == num_hidden_layers)
|
||||
self.assertTrue(results.past_key_values.key_cache[0].shape == cache_shape)
|
||||
|
||||
@require_quanto
|
||||
def test_generate_with_quant_cache(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
|
||||
Reference in New Issue
Block a user