[whisper] static kv cache (#31166)

* make work with cache abstraction

* correct for static cache

* hacks for compile

* make fast

* fix

* fix pos ids

* generate

* fix sdpa

* fix sdpa cache pos

* fix fa2

* clean fa2

* integrate cache into generate

* make style

* copies

* more copies

* update eager

* update sdpa

* update fa2

* simplify

* use cache pos

* always compute cross-cache for debug

* avoid recompiles
Co-authored-by: Arthur Zucker <arthur@huggingface.co>

* fix fix

* fix fix fix

* more fix

* try encoder-decoder cache (too messy)

* revert encoder-decoder cache

* check cross-attn cache

* use enc-dec dataclass

* use richer enc-dec dataclass

* clean-up

* revert static cache changes

* small fixes

* revert to cpu flag

* fix copies

* add static slow test

* past k/v docstring

* more docstrings

* cache_position docstrings

* add to docs

* add enc-dec cache to docs

* make style

* fix after rebase

* fix beam

* style

* fix generation strategies

* fix most decoder-only tests

* style

* skip test

* more clean up

* small docstrings

* Apply suggestions from code review

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* add todo

* only crop self-attn

* check cache in mixin

* style

* fix re-compile after rebase

* move `is_updated` logic to enc-dec wrapper

* revert back

* revert cache back

* finalise design

* fix

* fix fix

* style

* Update src/transformers/cache_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* deprecate

* updates

* final updates

* style

* style

---------

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Sanchit Gandhi
2024-07-02 13:24:15 +01:00
committed by GitHub
parent 57d7594a79
commit a9701953ff
10 changed files with 704 additions and 257 deletions

View File

@@ -57,7 +57,7 @@ if is_torch_available():
ImageGPTForCausalImageModeling,
SpeechEncoderDecoderModel,
)
from transformers.cache_utils import DynamicCache, QuantoQuantizedCache
from transformers.cache_utils import DynamicCache, EncoderDecoderCache, QuantoQuantizedCache
from transformers.generation import (
BeamSampleDecoderOnlyOutput,
BeamSampleEncoderDecoderOutput,
@@ -1636,7 +1636,6 @@ class GenerationTesterMixin:
config, input_ids, attention_mask = self._get_input_ids_and_config()
config.use_cache = True
config.is_decoder = True
model = model_class(config).to(torch_device).eval()
generation_kwargs = {
@@ -1652,15 +1651,21 @@ class GenerationTesterMixin:
set_seed(seed)
legacy_results = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
set_seed(seed)
if config.is_encoder_decoder:
cache_cls = EncoderDecoderCache
past_key_values = cache_cls(DynamicCache(), DynamicCache())
else:
cache_cls = DynamicCache
past_key_values = cache_cls()
new_results = model.generate(
input_ids, attention_mask=attention_mask, past_key_values=DynamicCache(), **generation_kwargs
input_ids, attention_mask=attention_mask, past_key_values=past_key_values, **generation_kwargs
)
# The two sets of generated sequences must match, despite the cache format between forward passes being
# different
self.assertListEqual(legacy_results.sequences.tolist(), new_results.sequences.tolist())
self.assertTrue(isinstance(legacy_results.past_key_values, tuple))
self.assertTrue(isinstance(new_results.past_key_values, DynamicCache))
self.assertTrue(isinstance(new_results.past_key_values, cache_cls))
# The contents of the two caches, when converted to the same format (in both directions!), must match
legacy_cache = legacy_results.past_key_values
@@ -1675,7 +1680,7 @@ class GenerationTesterMixin:
)
new_cache = new_results.past_key_values
legacy_cache_converted = DynamicCache.from_legacy_cache(legacy_results.past_key_values)
legacy_cache_converted = cache_cls.from_legacy_cache(legacy_results.past_key_values)
for layer_idx in range(len(new_cache)):
for kv_idx in range(len(new_cache[layer_idx])):
self.assertTrue(