[whisper] static kv cache (#31166)

* make work with cache abstraction

* correct for static cache

* hacks for compile

* make fast

* fix

* fix pos ids

* generate

* fix sdpa

* fix sdpa cache pos

* fix fa2

* clean fa2

* integrate cache into generate

* make style

* copies

* more copies

* update eager

* update sdpa

* update fa2

* simplify

* use cache pos

* always compute cross-cache for debug

* avoid recompiles
Co-authored-by: Arthur Zucker <arthur@huggingface.co>

* fix fix

* fix fix fix

* more fix

* try encoder-decoder cache (too messy)

* revert encoder-decoder cache

* check cross-attn cache

* use enc-dec dataclass

* use richer enc-dec dataclass

* clean-up

* revert static cache changes

* small fixes

* revert to cpu flag

* fix copies

* add static slow test

* past k/v docstring

* more docstrings

* cache_position docstrings

* add to docs

* add enc-dec cache to docs

* make style

* fix after rebase

* fix beam

* style

* fix generation strategies

* fix most decoder-only tests

* style

* skip test

* more clean up

* small docstrings

* Apply suggestions from code review

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* add todo

* only crop self-attn

* check cache in mixin

* style

* fix re-compile after rebase

* move `is_updated` logic to enc-dec wrapper

* revert back

* revert cache back

* finalise design

* fix

* fix fix

* style

* Update src/transformers/cache_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* deprecate

* updates

* final updates

* style

* style

---------

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Sanchit Gandhi
2024-07-02 13:24:15 +01:00
committed by GitHub
parent 57d7594a79
commit a9701953ff
10 changed files with 704 additions and 257 deletions

View File

@@ -57,7 +57,7 @@ if is_torch_available():
ImageGPTForCausalImageModeling,
SpeechEncoderDecoderModel,
)
from transformers.cache_utils import DynamicCache, QuantoQuantizedCache
from transformers.cache_utils import DynamicCache, EncoderDecoderCache, QuantoQuantizedCache
from transformers.generation import (
BeamSampleDecoderOnlyOutput,
BeamSampleEncoderDecoderOutput,
@@ -1636,7 +1636,6 @@ class GenerationTesterMixin:
config, input_ids, attention_mask = self._get_input_ids_and_config()
config.use_cache = True
config.is_decoder = True
model = model_class(config).to(torch_device).eval()
generation_kwargs = {
@@ -1652,15 +1651,21 @@ class GenerationTesterMixin:
set_seed(seed)
legacy_results = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
set_seed(seed)
if config.is_encoder_decoder:
cache_cls = EncoderDecoderCache
past_key_values = cache_cls(DynamicCache(), DynamicCache())
else:
cache_cls = DynamicCache
past_key_values = cache_cls()
new_results = model.generate(
input_ids, attention_mask=attention_mask, past_key_values=DynamicCache(), **generation_kwargs
input_ids, attention_mask=attention_mask, past_key_values=past_key_values, **generation_kwargs
)
# The two sets of generated sequences must match, despite the cache format between forward passes being
# different
self.assertListEqual(legacy_results.sequences.tolist(), new_results.sequences.tolist())
self.assertTrue(isinstance(legacy_results.past_key_values, tuple))
self.assertTrue(isinstance(new_results.past_key_values, DynamicCache))
self.assertTrue(isinstance(new_results.past_key_values, cache_cls))
# The contents of the two caches, when converted to the same format (in both directions!), must match
legacy_cache = legacy_results.past_key_values
@@ -1675,7 +1680,7 @@ class GenerationTesterMixin:
)
new_cache = new_results.past_key_values
legacy_cache_converted = DynamicCache.from_legacy_cache(legacy_results.past_key_values)
legacy_cache_converted = cache_cls.from_legacy_cache(legacy_results.past_key_values)
for layer_idx in range(len(new_cache)):
for kv_idx in range(len(new_cache[layer_idx])):
self.assertTrue(

View File

@@ -1539,6 +1539,46 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
def test_longform_generate_multi_batch_cond_prev(self):
self._check_longform_generate_multi_batch(condition_on_prev_tokens=True)
def test_custom_4d_attention_mask(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = WhisperForConditionalGeneration(config).to(device=torch_device, dtype=torch.float32)
model.eval()
(
input_ids,
position_ids,
input_ids_shared_prefix,
mask_shared_prefix,
position_ids_shared_prefix,
) = self._get_custom_4d_mask_test_data()
with torch.no_grad():
logits = model.forward(
decoder_input_ids=input_ids,
input_features=input_dict["input_features"],
decoder_position_ids=position_ids,
).logits
# logits.shape == torch.Size([3, 4, ...])
logits_shared_prefix = model(
decoder_input_ids=input_ids_shared_prefix,
input_features=input_dict["input_features"],
decoder_attention_mask=mask_shared_prefix,
decoder_position_ids=position_ids_shared_prefix,
)[0]
# logits_shared_prefix.shape == torch.Size([1, 6, ...])
out_last_tokens = logits[:, -1, :] # last tokens in each batch line
out_shared_prefix_last_tokens = logits_shared_prefix[0, -3:, :] # last three tokens
# comparing greedily-chosen tokens:
assert torch.equal(out_last_tokens.max(axis=1).indices, out_shared_prefix_last_tokens.max(axis=1).indices)
# comparing softmax-normalized logits:
normalized_0 = torch.nn.functional.softmax(out_last_tokens)
normalized_1 = torch.nn.functional.softmax(out_shared_prefix_last_tokens)
torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4)
@require_torch
@require_torchaudio
@@ -2961,6 +3001,34 @@ class WhisperModelIntegrationTests(unittest.TestCase):
torch.manual_seed(0)
model.generate(**inputs, **gen_kwargs)
@slow
def test_tiny_static_generation(self):
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
model.to(torch_device)
input_speech = self._load_datasamples(4)
input_features = processor(input_speech, return_tensors="pt", sampling_rate=16_000).input_features
input_features = input_features.to(torch_device)
eager_generated_ids = model.generate(input_features, max_new_tokens=64)
model.generation_config.cache_implementation = "static"
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
# compile the forward pass and assert equivalence
static_generated_ids = model.generate(input_features, max_new_tokens=64)
assert (eager_generated_ids == static_generated_ids).all()
# check the compiled graph can be re-used and that the cache is correctly reset
# reverse the ordering of the input features
permutation_idx = (
torch.arange(input_features.shape[0], 0, step=-1, dtype=torch.long, device=input_features.device) - 1
)
input_features = input_features[permutation_idx, ...]
static_generated_ids = model.generate(input_features, max_new_tokens=64)
# assert re-ordered generations match those from eager
assert (eager_generated_ids[permutation_idx, :] == static_generated_ids).all()
def prepare_whisper_encoder_inputs_dict(config, input_features, head_mask=None):
if head_mask is None:
@@ -3564,6 +3632,10 @@ class WhisperStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin,
config=config, input_ids=inputs_dict["input_ids"]
)
@unittest.skip(reason="Tested implicitly through the encoder-decoder tests")
def test_custom_4d_attention_mask(self):
pass
@unittest.skip(reason="Generate needs input ids")
def test_generate_without_input_ids(self):
# generate only works with input ids for whisper