[cache] make all classes cache compatible finally (#38635)

* dump

* push other models

* fix simple greedy generation

* xmod

* add fmst and clean up some mentions of old cache format

* gpt-bigcode now follows standards

* delete tuple cache reference in generation

* fix some models

* fix some models

* fix mambas and support cache in tapas

* fix some more tests

* fix copies

* delete `_reorder_cache`

* another fix copies

* fix typos and delete unnecessary test

* fix rag generate, needs special cache reordering

* fix tapas and superglue

* reformer create special cache

* recurrent gemma `reorder_cache` was a no-op, delete

* fix-copies

* fix blio and musicgen pipeline tests

* fix reformer

* fix reformer, again...

* delete `_supports_cache_class`

* delete `supports_quantized_cache`

* fix failing tests

* fix copies

* some minor clean up

* style

* style

* fix copies

* fix tests

* fix copies

* create causal mask now needs positions?

* fixc copies

* style

* Update tests/test_modeling_common.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* clean-up of non-generative model after merging main

* check `is_decoder` for cache

* delete transpose for scores

* remove tuple cache from docs everywhere

* fix tests

* fix copies

* fix copies once more

* properly deprecate `encoder_attention_mask` in Bert-like models

* import `deprecate_kwarg` where needed

* fix copies again

* fix copies

* delete `nex_decoder_cache`

* fix copies asks to update for PLM

* fix copies

* rebasing had a few new models, fix them and merge asap!

* fix copies once more

* fix slow tests

* fix tests and updare PLM checkpoint

* add read token and revert accidentally removed line

* oh com -on, style

* just skip it, read token has no access to PLM yet

---------

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
Raushan Turganbay
2025-07-16 17:00:17 +05:00
committed by GitHub
parent 6cb43defd0
commit c8524aeb07
268 changed files with 5707 additions and 6831 deletions

View File

@@ -1001,7 +1001,7 @@ class GenerationTesterMixin:
self.skipTest(reason="Stateful models don't support contrastive search generation")
# won't fix: FSMT and Reformer have a different cache variable type (and format).
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
if any(model_name in model_class.__name__.lower() for model_name in ["reformer"]):
self.skipTest(reason="Won't fix: old model with different cache format")
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
@@ -1030,7 +1030,7 @@ class GenerationTesterMixin:
self.skipTest(reason="Stateful models don't support contrastive search generation")
# won't fix: FSMT and Reformer have a different cache variable type (and format).
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
if any(model_name in model_class.__name__.lower() for model_name in ["reformer"]):
self.skipTest(reason="Won't fix: old model with different cache format")
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
@@ -1070,10 +1070,8 @@ class GenerationTesterMixin:
if model_class._is_stateful:
self.skipTest(reason="Stateful models don't support contrastive search generation")
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer", "speech2text"]):
if any(model_name in model_class.__name__.lower() for model_name in ["reformer"]):
self.skipTest(reason="Won't fix: old model with different cache format")
if any(model_name in model_class.__name__.lower() for model_name in ["gptbigcode"]):
self.skipTest(reason="TODO: fix me")
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
@@ -1112,22 +1110,16 @@ class GenerationTesterMixin:
for model_class in self.all_generative_model_classes:
if model_class._is_stateful:
self.skipTest(reason="Stateful models don't support assisted generation")
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
if any(model_name in model_class.__name__.lower() for model_name in ["reformer"]):
self.skipTest(reason="Won't fix: old model with different cache format")
if any(
model_name in model_class.__name__.lower()
for model_name in [
"bigbirdpegasus",
"led",
"mega",
"moshi",
"speech2text",
"git",
"prophetnet",
"seamlessm4t",
"clvp",
"mllama", # special cache sizes
"blip2", # overridden `generate()`
"blip2", # overridden `generate()` all BLIP models
"instructblip",
"instructblipvideo",
]
@@ -1196,23 +1188,16 @@ class GenerationTesterMixin:
for model_class in self.all_generative_model_classes:
if model_class._is_stateful:
self.skipTest(reason="Stateful models don't support assisted generation")
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
if any(model_name in model_class.__name__.lower() for model_name in ["reformer"]):
self.skipTest(reason="Won't fix: old model with different cache format")
if any(
model_name in model_class.__name__.lower()
for model_name in [
"bigbirdpegasus",
"led",
"mega",
"moshi",
"speech2text",
"git",
"prophetnet",
"seamlessm4t",
"clvp",
"fuyu",
"mllama", # special cache sizes
"blip2", # overridden `generate()`
"blip2", # overridden `generate()` for all BLIP models
"instructblip",
"instructblipvideo",
# All models below: shouldn't suggest image tokens. Can be fixed by passing `suppress_ids` to candidate generator: @joaa @raushan
@@ -1340,22 +1325,16 @@ class GenerationTesterMixin:
for model_class in self.all_generative_model_classes:
if model_class._is_stateful:
self.skipTest(reason="Stateful models don't support assisted generation")
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
if any(model_name in model_class.__name__.lower() for model_name in ["reformer"]):
self.skipTest(reason="Won't fix: old model with different cache format")
if any(
model_name in model_class.__name__.lower()
for model_name in [
"bigbirdpegasus",
"led",
"mega",
"moshi",
"speech2text",
"git",
"prophetnet",
"seamlessm4t",
"clvp",
"mllama", # special cache sizes
"blip2", # overridden `generate()`
"blip2", # overridden `generate()` for all BLIP models
"instructblip",
"instructblipvideo",
]
@@ -2059,12 +2038,15 @@ class GenerationTesterMixin:
@pytest.mark.generate
def test_generate_with_quant_cache(self):
for model_class in self.all_generative_model_classes:
if not model_class._supports_quantized_cache:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
if (
config.get_text_config(decoder=True).is_encoder_decoder
or not model_class._supports_default_dynamic_cache()
):
self.skipTest(reason="This model does not support the quantized cache format")
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
config.is_decoder = True
model = model_class(config).to(torch_device).eval()
generation_kwargs = {
"max_new_tokens": 5,
@@ -2509,14 +2491,10 @@ class GenerationTesterMixin:
# Past Key Value States -- a few notes here:
# 1. Its inner sequence length is with respect to the inputs of the latest forward pass, hence the "-1"
# 2. We ignore models that have unique cache structures (e.g. mamba) or are in need of refatoring to match the
# standard cache format (e.g.gptbigcode )
# standard cache format (e.g.mamba architecture )
models_without_standard_cache = (
"bamba",
"ctrl",
"fsmt",
"granitemoehybrid",
"gptbigcode",
"mega",
"reformer",
"jamba",
"mamba",

View File

@@ -737,9 +737,11 @@ class BertModelIntegrationTest(unittest.TestCase):
torch.allclose(res_eager.last_hidden_state, res_sdpa.last_hidden_state, atol=1e-5, rtol=1e-4)
)
# Case where query length != kv_length.
res_eager = model(**inp, past_key_values=pkv)
res_sdpa = model_sdpa(**inp, past_key_values=pkv)
# Case where query length != kv_length. Note that model needs to be a decoder so we can use cache
model.config.is_decoder = True
model_sdpa.config.is_decoder = True
res_eager = model(**inp, past_key_values=pkv, use_cache=True)
res_sdpa = model_sdpa(**inp, past_key_values=pkv, use_cache=True)
self.assertTrue(
torch.allclose(res_eager.last_hidden_state, res_sdpa.last_hidden_state, atol=1e-5, rtol=1e-4)
)

View File

@@ -284,6 +284,7 @@ class BigBirdModelTester:
attention_mask=next_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=False,
output_hidden_states=True,
)["hidden_states"][0]
output_from_past = model(
@@ -292,6 +293,7 @@ class BigBirdModelTester:
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
use_cache=True,
output_hidden_states=True,
)["hidden_states"][0]

View File

@@ -34,6 +34,7 @@ from transformers.testing_utils import (
Expectations,
cleanup,
require_bitsandbytes,
require_optimum_quanto,
require_read_token,
require_torch,
require_torch_accelerator,
@@ -344,6 +345,12 @@ class MllamaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTester
self.assertListEqual([layer_attention.shape for layer_attention in iter_attentions], expected_shapes)
@require_optimum_quanto
@pytest.mark.generate
@unittest.skip("Mllama is actually an encoder decoder cache and thus can't supports quant cache")
def test_generate_with_quant_cache(self):
pass
@unittest.skip("For some unknown reasons the tests fails in CrossAttention layer when doing torch.sdpa(). ")
def test_sdpa_can_compile_dynamic(self):
pass

View File

@@ -770,9 +770,9 @@ class MvpStandaloneDecoderModelTester:
# get two different outputs
output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"]
output_from_past = model(next_tokens, attention_mask=attn_mask, past_key_values=past_key_values)[
"last_hidden_state"
]
output_from_past = model(
next_tokens, attention_mask=attn_mask, past_key_values=past_key_values, use_cache=True
)["last_hidden_state"]
# select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()

View File

@@ -21,7 +21,7 @@ from transformers import (
AutoTokenizer,
PerceptionLMProcessor,
)
from transformers.testing_utils import require_vision
from transformers.testing_utils import require_read_token, require_vision
from transformers.utils import is_torch_available, is_vision_available
from ...test_processing_common import ProcessorTesterMixin
@@ -34,11 +34,12 @@ if is_torch_available():
import torch
# TEST_MODEL_PATH = "facebook/Perception-LM-1B"
TEST_MODEL_PATH = "shumingh/plm_1b_hf" # should be replaced by the above once checkpoints are merged
TEST_MODEL_PATH = "facebook/Perception-LM-1B"
@require_vision
@require_read_token
@unittest.skip("Fequires read token and we didn't requests access yet. FIXME @ydshieh when you are back :)")
class PerceptionLMProcessorTest(ProcessorTesterMixin, unittest.TestCase):
processor_class = PerceptionLMProcessor

View File

@@ -737,7 +737,7 @@ class ProphetNetStandaloneDecoderModelTester:
# get two different outputs
output_from_no_past = model(next_input_ids)["last_hidden_state"]
output_from_past = model(next_tokens, past_key_values=past_key_values)["last_hidden_state"]
output_from_past = model(next_tokens, past_key_values=past_key_values, use_cache=True)["last_hidden_state"]
# select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()

View File

@@ -354,9 +354,9 @@ class SpeechT5ForSpeechToTextTester:
next_attention_mask = torch.cat([attention_mask, next_attn_mask], dim=-1)
output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)["last_hidden_state"]
output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values)[
"last_hidden_state"
]
output_from_past = model(
next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values, use_cache=True
)["last_hidden_state"]
# select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()

View File

@@ -853,10 +853,10 @@ class ModelTesterMixin:
addition_year = int(match_object.group(1))
for model_class in self.all_model_classes:
# For now, skip everything older than 2025 and "important models" (too much models to patch otherwise)
# For now, skip everything older than 2024 and "important models" (too much models to patch otherwise)
# Use `supports_cache_class` as a proxy to judge "important" models in order to prioritize them
# TODO: relax this as we patch more and more models
if addition_year < 2024 and not model_class._supports_cache_class:
if addition_year < 2024:
self.skipTest(reason=f"{model_class} is not a priorited model for now.")
# Monkey patch the method to add a seed (we do it on PreTrainedModel._initialize_weights, which wraps
@@ -1590,18 +1590,7 @@ class ModelTesterMixin:
head_dim = model.config.hidden_size // model.config.num_attention_heads
cache_shape = (batch_size, num_heads, 0, head_dim)
empty_pkv = tuple(
(
torch.rand(cache_shape, dtype=torch.float, device=torch_device),
torch.rand(cache_shape, dtype=torch.float, device=torch_device),
)
for i in range(model.config.num_hidden_layers)
)
empty_pkv = (
DynamicCache.from_legacy_cache(empty_pkv)
if model_class._supports_cache_class
else empty_pkv
)
empty_pkv = DynamicCache()
cache_length = 9
cache_shape = (batch_size, num_heads, cache_length, head_dim)
@@ -1612,11 +1601,7 @@ class ModelTesterMixin:
)
for i in range(model.config.num_hidden_layers)
)
non_empty_pkv = (
DynamicCache.from_legacy_cache(non_empty_pkv)
if model_class._supports_cache_class
else non_empty_pkv
)
non_empty_pkv = DynamicCache.from_legacy_cache(non_empty_pkv)
inps = copy.deepcopy(inputs_to_test[0])

View File

@@ -46,7 +46,6 @@ if is_torch_available():
AutoModelForCausalLM,
AutoTokenizer,
Cache,
ClvpForCausalLM,
DynamicCache,
Gemma2Config,
GenerationConfig,
@@ -122,36 +121,6 @@ class CacheTest(unittest.TestCase):
torch.allclose(to_legacy[layer_idx][key_value_idx], new_cache[layer_idx][key_value_idx])
)
def test_reorder_cache_retrocompatibility(self):
"""Tests that Cache.reorder_cache is retrocompatible with the legacy code path"""
legacy_reorder_fn = ClvpForCausalLM._reorder_cache # An example of a legacy `_reorder_cache` function
legacy_cache = ()
new_cache = DynamicCache()
# Creates a new cache with 10 layers in both formats
for layer_idx in range(10):
new_key = torch.rand((4, 4, 8, 16))
new_value = torch.rand((4, 4, 8, 16))
new_cache.update(new_key, new_value, layer_idx)
legacy_cache += ((new_key, new_value),)
# Let's create some dummy beam indices. From the shape above, it is equivalent to the case where num_beams=4
# and batch_size=1
beam_idx = torch.randint(low=0, high=4, size=(4,))
legacy_cache_reordered = legacy_reorder_fn(legacy_cache, beam_idx)
new_cache.reorder_cache(beam_idx)
# Let's check that the results are the same
for layer_idx in range(10):
for key_value_idx in range(2):
self.assertTrue(
torch.allclose(
new_cache[layer_idx][key_value_idx], legacy_cache_reordered[layer_idx][key_value_idx]
)
)
def test_static_cache_mha_mqa_gqa(self):
"""
Tests that static cache works with multi-head attention (MHA), grouped query attention (GQA), and multi-query