[tests] Test all cache implementations (#37873)

This commit is contained in:
Joao Gante
2025-04-30 15:37:00 +01:00
committed by GitHub
parent 2c1155519f
commit 1b222903c3
45 changed files with 338 additions and 438 deletions

View File

@@ -22,7 +22,7 @@ from dataclasses import dataclass
from typing import Callable, List, Optional, Tuple, Union
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
from ...modeling_attn_mask_utils import AttentionMaskConverter
@@ -991,10 +991,10 @@ class AriaTextModel(AriaTextPreTrainedModel):
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)
using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
@@ -1005,7 +1005,7 @@ class AriaTextModel(AriaTextPreTrainedModel):
dtype = input_tensor.dtype
sequence_length = input_tensor.shape[1]
if using_static_cache:
if using_compilable_cache:
target_length = past_key_values.get_max_cache_shape()
else:
target_length = (