[whisper] static kv cache (#31166)
* make work with cache abstraction * correct for static cache * hacks for compile * make fast * fix * fix pos ids * generate * fix sdpa * fix sdpa cache pos * fix fa2 * clean fa2 * integrate cache into generate * make style * copies * more copies * update eager * update sdpa * update fa2 * simplify * use cache pos * always compute cross-cache for debug * avoid recompiles Co-authored-by: Arthur Zucker <arthur@huggingface.co> * fix fix * fix fix fix * more fix * try encoder-decoder cache (too messy) * revert encoder-decoder cache * check cross-attn cache * use enc-dec dataclass * use richer enc-dec dataclass * clean-up * revert static cache changes * small fixes * revert to cpu flag * fix copies * add static slow test * past k/v docstring * more docstrings * cache_position docstrings * add to docs * add enc-dec cache to docs * make style * fix after rebase * fix beam * style * fix generation strategies * fix most decoder-only tests * style * skip test * more clean up * small docstrings * Apply suggestions from code review Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * add todo * only crop self-attn * check cache in mixin * style * fix re-compile after rebase * move `is_updated` logic to enc-dec wrapper * revert back * revert cache back * finalise design * fix * fix fix * style * Update src/transformers/cache_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * deprecate * updates * final updates * style * style --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -1539,6 +1539,46 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
def test_longform_generate_multi_batch_cond_prev(self):
|
||||
self._check_longform_generate_multi_batch(condition_on_prev_tokens=True)
|
||||
|
||||
def test_custom_4d_attention_mask(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = WhisperForConditionalGeneration(config).to(device=torch_device, dtype=torch.float32)
|
||||
model.eval()
|
||||
|
||||
(
|
||||
input_ids,
|
||||
position_ids,
|
||||
input_ids_shared_prefix,
|
||||
mask_shared_prefix,
|
||||
position_ids_shared_prefix,
|
||||
) = self._get_custom_4d_mask_test_data()
|
||||
|
||||
with torch.no_grad():
|
||||
logits = model.forward(
|
||||
decoder_input_ids=input_ids,
|
||||
input_features=input_dict["input_features"],
|
||||
decoder_position_ids=position_ids,
|
||||
).logits
|
||||
# logits.shape == torch.Size([3, 4, ...])
|
||||
|
||||
logits_shared_prefix = model(
|
||||
decoder_input_ids=input_ids_shared_prefix,
|
||||
input_features=input_dict["input_features"],
|
||||
decoder_attention_mask=mask_shared_prefix,
|
||||
decoder_position_ids=position_ids_shared_prefix,
|
||||
)[0]
|
||||
# logits_shared_prefix.shape == torch.Size([1, 6, ...])
|
||||
|
||||
out_last_tokens = logits[:, -1, :] # last tokens in each batch line
|
||||
out_shared_prefix_last_tokens = logits_shared_prefix[0, -3:, :] # last three tokens
|
||||
|
||||
# comparing greedily-chosen tokens:
|
||||
assert torch.equal(out_last_tokens.max(axis=1).indices, out_shared_prefix_last_tokens.max(axis=1).indices)
|
||||
|
||||
# comparing softmax-normalized logits:
|
||||
normalized_0 = torch.nn.functional.softmax(out_last_tokens)
|
||||
normalized_1 = torch.nn.functional.softmax(out_shared_prefix_last_tokens)
|
||||
torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4)
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_torchaudio
|
||||
@@ -2961,6 +3001,34 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
torch.manual_seed(0)
|
||||
model.generate(**inputs, **gen_kwargs)
|
||||
|
||||
@slow
|
||||
def test_tiny_static_generation(self):
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
||||
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
|
||||
model.to(torch_device)
|
||||
|
||||
input_speech = self._load_datasamples(4)
|
||||
input_features = processor(input_speech, return_tensors="pt", sampling_rate=16_000).input_features
|
||||
input_features = input_features.to(torch_device)
|
||||
eager_generated_ids = model.generate(input_features, max_new_tokens=64)
|
||||
|
||||
model.generation_config.cache_implementation = "static"
|
||||
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
|
||||
|
||||
# compile the forward pass and assert equivalence
|
||||
static_generated_ids = model.generate(input_features, max_new_tokens=64)
|
||||
assert (eager_generated_ids == static_generated_ids).all()
|
||||
|
||||
# check the compiled graph can be re-used and that the cache is correctly reset
|
||||
# reverse the ordering of the input features
|
||||
permutation_idx = (
|
||||
torch.arange(input_features.shape[0], 0, step=-1, dtype=torch.long, device=input_features.device) - 1
|
||||
)
|
||||
input_features = input_features[permutation_idx, ...]
|
||||
static_generated_ids = model.generate(input_features, max_new_tokens=64)
|
||||
# assert re-ordered generations match those from eager
|
||||
assert (eager_generated_ids[permutation_idx, :] == static_generated_ids).all()
|
||||
|
||||
|
||||
def prepare_whisper_encoder_inputs_dict(config, input_features, head_mask=None):
|
||||
if head_mask is None:
|
||||
@@ -3564,6 +3632,10 @@ class WhisperStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin,
|
||||
config=config, input_ids=inputs_dict["input_ids"]
|
||||
)
|
||||
|
||||
@unittest.skip(reason="Tested implicitly through the encoder-decoder tests")
|
||||
def test_custom_4d_attention_mask(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Generate needs input ids")
|
||||
def test_generate_without_input_ids(self):
|
||||
# generate only works with input ids for whisper
|
||||
|
||||
Reference in New Issue
Block a user