[whisper] static kv cache (#31166)
* make work with cache abstraction * correct for static cache * hacks for compile * make fast * fix * fix pos ids * generate * fix sdpa * fix sdpa cache pos * fix fa2 * clean fa2 * integrate cache into generate * make style * copies * more copies * update eager * update sdpa * update fa2 * simplify * use cache pos * always compute cross-cache for debug * avoid recompiles Co-authored-by: Arthur Zucker <arthur@huggingface.co> * fix fix * fix fix fix * more fix * try encoder-decoder cache (too messy) * revert encoder-decoder cache * check cross-attn cache * use enc-dec dataclass * use richer enc-dec dataclass * clean-up * revert static cache changes * small fixes * revert to cpu flag * fix copies * add static slow test * past k/v docstring * more docstrings * cache_position docstrings * add to docs * add enc-dec cache to docs * make style * fix after rebase * fix beam * style * fix generation strategies * fix most decoder-only tests * style * skip test * more clean up * small docstrings * Apply suggestions from code review Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * add todo * only crop self-attn * check cache in mixin * style * fix re-compile after rebase * move `is_updated` logic to enc-dec wrapper * revert back * revert cache back * finalise design * fix * fix fix * style * Update src/transformers/cache_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * deprecate * updates * final updates * style * style --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -57,7 +57,7 @@ if is_torch_available():
|
||||
ImageGPTForCausalImageModeling,
|
||||
SpeechEncoderDecoderModel,
|
||||
)
|
||||
from transformers.cache_utils import DynamicCache, QuantoQuantizedCache
|
||||
from transformers.cache_utils import DynamicCache, EncoderDecoderCache, QuantoQuantizedCache
|
||||
from transformers.generation import (
|
||||
BeamSampleDecoderOnlyOutput,
|
||||
BeamSampleEncoderDecoderOutput,
|
||||
@@ -1636,7 +1636,6 @@ class GenerationTesterMixin:
|
||||
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
config.use_cache = True
|
||||
config.is_decoder = True
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
generation_kwargs = {
|
||||
@@ -1652,15 +1651,21 @@ class GenerationTesterMixin:
|
||||
set_seed(seed)
|
||||
legacy_results = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
|
||||
set_seed(seed)
|
||||
if config.is_encoder_decoder:
|
||||
cache_cls = EncoderDecoderCache
|
||||
past_key_values = cache_cls(DynamicCache(), DynamicCache())
|
||||
else:
|
||||
cache_cls = DynamicCache
|
||||
past_key_values = cache_cls()
|
||||
new_results = model.generate(
|
||||
input_ids, attention_mask=attention_mask, past_key_values=DynamicCache(), **generation_kwargs
|
||||
input_ids, attention_mask=attention_mask, past_key_values=past_key_values, **generation_kwargs
|
||||
)
|
||||
|
||||
# The two sets of generated sequences must match, despite the cache format between forward passes being
|
||||
# different
|
||||
self.assertListEqual(legacy_results.sequences.tolist(), new_results.sequences.tolist())
|
||||
self.assertTrue(isinstance(legacy_results.past_key_values, tuple))
|
||||
self.assertTrue(isinstance(new_results.past_key_values, DynamicCache))
|
||||
self.assertTrue(isinstance(new_results.past_key_values, cache_cls))
|
||||
|
||||
# The contents of the two caches, when converted to the same format (in both directions!), must match
|
||||
legacy_cache = legacy_results.past_key_values
|
||||
@@ -1675,7 +1680,7 @@ class GenerationTesterMixin:
|
||||
)
|
||||
|
||||
new_cache = new_results.past_key_values
|
||||
legacy_cache_converted = DynamicCache.from_legacy_cache(legacy_results.past_key_values)
|
||||
legacy_cache_converted = cache_cls.from_legacy_cache(legacy_results.past_key_values)
|
||||
for layer_idx in range(len(new_cache)):
|
||||
for kv_idx in range(len(new_cache[layer_idx])):
|
||||
self.assertTrue(
|
||||
|
||||
Reference in New Issue
Block a user