[cache] make all classes cache compatible finally (#38635)
* dump * push other models * fix simple greedy generation * xmod * add fmst and clean up some mentions of old cache format * gpt-bigcode now follows standards * delete tuple cache reference in generation * fix some models * fix some models * fix mambas and support cache in tapas * fix some more tests * fix copies * delete `_reorder_cache` * another fix copies * fix typos and delete unnecessary test * fix rag generate, needs special cache reordering * fix tapas and superglue * reformer create special cache * recurrent gemma `reorder_cache` was a no-op, delete * fix-copies * fix blio and musicgen pipeline tests * fix reformer * fix reformer, again... * delete `_supports_cache_class` * delete `supports_quantized_cache` * fix failing tests * fix copies * some minor clean up * style * style * fix copies * fix tests * fix copies * create causal mask now needs positions? * fixc copies * style * Update tests/test_modeling_common.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * clean-up of non-generative model after merging main * check `is_decoder` for cache * delete transpose for scores * remove tuple cache from docs everywhere * fix tests * fix copies * fix copies once more * properly deprecate `encoder_attention_mask` in Bert-like models * import `deprecate_kwarg` where needed * fix copies again * fix copies * delete `nex_decoder_cache` * fix copies asks to update for PLM * fix copies * rebasing had a few new models, fix them and merge asap! * fix copies once more * fix slow tests * fix tests and updare PLM checkpoint * add read token and revert accidentally removed line * oh com -on, style * just skip it, read token has no access to PLM yet --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
committed by
GitHub
parent
6cb43defd0
commit
c8524aeb07
@@ -1001,7 +1001,7 @@ class GenerationTesterMixin:
|
||||
self.skipTest(reason="Stateful models don't support contrastive search generation")
|
||||
|
||||
# won't fix: FSMT and Reformer have a different cache variable type (and format).
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["reformer"]):
|
||||
self.skipTest(reason="Won't fix: old model with different cache format")
|
||||
|
||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||
@@ -1030,7 +1030,7 @@ class GenerationTesterMixin:
|
||||
self.skipTest(reason="Stateful models don't support contrastive search generation")
|
||||
|
||||
# won't fix: FSMT and Reformer have a different cache variable type (and format).
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["reformer"]):
|
||||
self.skipTest(reason="Won't fix: old model with different cache format")
|
||||
|
||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||
@@ -1070,10 +1070,8 @@ class GenerationTesterMixin:
|
||||
if model_class._is_stateful:
|
||||
self.skipTest(reason="Stateful models don't support contrastive search generation")
|
||||
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer", "speech2text"]):
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["reformer"]):
|
||||
self.skipTest(reason="Won't fix: old model with different cache format")
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["gptbigcode"]):
|
||||
self.skipTest(reason="TODO: fix me")
|
||||
|
||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
|
||||
|
||||
@@ -1112,22 +1110,16 @@ class GenerationTesterMixin:
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if model_class._is_stateful:
|
||||
self.skipTest(reason="Stateful models don't support assisted generation")
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["reformer"]):
|
||||
self.skipTest(reason="Won't fix: old model with different cache format")
|
||||
if any(
|
||||
model_name in model_class.__name__.lower()
|
||||
for model_name in [
|
||||
"bigbirdpegasus",
|
||||
"led",
|
||||
"mega",
|
||||
"moshi",
|
||||
"speech2text",
|
||||
"git",
|
||||
"prophetnet",
|
||||
"seamlessm4t",
|
||||
"clvp",
|
||||
"mllama", # special cache sizes
|
||||
"blip2", # overridden `generate()`
|
||||
"blip2", # overridden `generate()` all BLIP models
|
||||
"instructblip",
|
||||
"instructblipvideo",
|
||||
]
|
||||
@@ -1196,23 +1188,16 @@ class GenerationTesterMixin:
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if model_class._is_stateful:
|
||||
self.skipTest(reason="Stateful models don't support assisted generation")
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["reformer"]):
|
||||
self.skipTest(reason="Won't fix: old model with different cache format")
|
||||
if any(
|
||||
model_name in model_class.__name__.lower()
|
||||
for model_name in [
|
||||
"bigbirdpegasus",
|
||||
"led",
|
||||
"mega",
|
||||
"moshi",
|
||||
"speech2text",
|
||||
"git",
|
||||
"prophetnet",
|
||||
"seamlessm4t",
|
||||
"clvp",
|
||||
"fuyu",
|
||||
"mllama", # special cache sizes
|
||||
"blip2", # overridden `generate()`
|
||||
"blip2", # overridden `generate()` for all BLIP models
|
||||
"instructblip",
|
||||
"instructblipvideo",
|
||||
# All models below: shouldn't suggest image tokens. Can be fixed by passing `suppress_ids` to candidate generator: @joaa @raushan
|
||||
@@ -1340,22 +1325,16 @@ class GenerationTesterMixin:
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if model_class._is_stateful:
|
||||
self.skipTest(reason="Stateful models don't support assisted generation")
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["reformer"]):
|
||||
self.skipTest(reason="Won't fix: old model with different cache format")
|
||||
if any(
|
||||
model_name in model_class.__name__.lower()
|
||||
for model_name in [
|
||||
"bigbirdpegasus",
|
||||
"led",
|
||||
"mega",
|
||||
"moshi",
|
||||
"speech2text",
|
||||
"git",
|
||||
"prophetnet",
|
||||
"seamlessm4t",
|
||||
"clvp",
|
||||
"mllama", # special cache sizes
|
||||
"blip2", # overridden `generate()`
|
||||
"blip2", # overridden `generate()` for all BLIP models
|
||||
"instructblip",
|
||||
"instructblipvideo",
|
||||
]
|
||||
@@ -2059,12 +2038,15 @@ class GenerationTesterMixin:
|
||||
@pytest.mark.generate
|
||||
def test_generate_with_quant_cache(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if not model_class._supports_quantized_cache:
|
||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||
|
||||
if (
|
||||
config.get_text_config(decoder=True).is_encoder_decoder
|
||||
or not model_class._supports_default_dynamic_cache()
|
||||
):
|
||||
self.skipTest(reason="This model does not support the quantized cache format")
|
||||
|
||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||
config.is_decoder = True
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
generation_kwargs = {
|
||||
"max_new_tokens": 5,
|
||||
@@ -2509,14 +2491,10 @@ class GenerationTesterMixin:
|
||||
# Past Key Value States -- a few notes here:
|
||||
# 1. Its inner sequence length is with respect to the inputs of the latest forward pass, hence the "-1"
|
||||
# 2. We ignore models that have unique cache structures (e.g. mamba) or are in need of refatoring to match the
|
||||
# standard cache format (e.g.gptbigcode )
|
||||
# standard cache format (e.g.mamba architecture )
|
||||
models_without_standard_cache = (
|
||||
"bamba",
|
||||
"ctrl",
|
||||
"fsmt",
|
||||
"granitemoehybrid",
|
||||
"gptbigcode",
|
||||
"mega",
|
||||
"reformer",
|
||||
"jamba",
|
||||
"mamba",
|
||||
|
||||
@@ -737,9 +737,11 @@ class BertModelIntegrationTest(unittest.TestCase):
|
||||
torch.allclose(res_eager.last_hidden_state, res_sdpa.last_hidden_state, atol=1e-5, rtol=1e-4)
|
||||
)
|
||||
|
||||
# Case where query length != kv_length.
|
||||
res_eager = model(**inp, past_key_values=pkv)
|
||||
res_sdpa = model_sdpa(**inp, past_key_values=pkv)
|
||||
# Case where query length != kv_length. Note that model needs to be a decoder so we can use cache
|
||||
model.config.is_decoder = True
|
||||
model_sdpa.config.is_decoder = True
|
||||
res_eager = model(**inp, past_key_values=pkv, use_cache=True)
|
||||
res_sdpa = model_sdpa(**inp, past_key_values=pkv, use_cache=True)
|
||||
self.assertTrue(
|
||||
torch.allclose(res_eager.last_hidden_state, res_sdpa.last_hidden_state, atol=1e-5, rtol=1e-4)
|
||||
)
|
||||
|
||||
@@ -284,6 +284,7 @@ class BigBirdModelTester:
|
||||
attention_mask=next_attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
use_cache=False,
|
||||
output_hidden_states=True,
|
||||
)["hidden_states"][0]
|
||||
output_from_past = model(
|
||||
@@ -292,6 +293,7 @@ class BigBirdModelTester:
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=True,
|
||||
output_hidden_states=True,
|
||||
)["hidden_states"][0]
|
||||
|
||||
|
||||
@@ -34,6 +34,7 @@ from transformers.testing_utils import (
|
||||
Expectations,
|
||||
cleanup,
|
||||
require_bitsandbytes,
|
||||
require_optimum_quanto,
|
||||
require_read_token,
|
||||
require_torch,
|
||||
require_torch_accelerator,
|
||||
@@ -344,6 +345,12 @@ class MllamaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTester
|
||||
|
||||
self.assertListEqual([layer_attention.shape for layer_attention in iter_attentions], expected_shapes)
|
||||
|
||||
@require_optimum_quanto
|
||||
@pytest.mark.generate
|
||||
@unittest.skip("Mllama is actually an encoder decoder cache and thus can't supports quant cache")
|
||||
def test_generate_with_quant_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("For some unknown reasons the tests fails in CrossAttention layer when doing torch.sdpa(). ")
|
||||
def test_sdpa_can_compile_dynamic(self):
|
||||
pass
|
||||
|
||||
@@ -770,9 +770,9 @@ class MvpStandaloneDecoderModelTester:
|
||||
|
||||
# get two different outputs
|
||||
output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"]
|
||||
output_from_past = model(next_tokens, attention_mask=attn_mask, past_key_values=past_key_values)[
|
||||
"last_hidden_state"
|
||||
]
|
||||
output_from_past = model(
|
||||
next_tokens, attention_mask=attn_mask, past_key_values=past_key_values, use_cache=True
|
||||
)["last_hidden_state"]
|
||||
|
||||
# select random slice
|
||||
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
||||
|
||||
@@ -21,7 +21,7 @@ from transformers import (
|
||||
AutoTokenizer,
|
||||
PerceptionLMProcessor,
|
||||
)
|
||||
from transformers.testing_utils import require_vision
|
||||
from transformers.testing_utils import require_read_token, require_vision
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
|
||||
from ...test_processing_common import ProcessorTesterMixin
|
||||
@@ -34,11 +34,12 @@ if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
# TEST_MODEL_PATH = "facebook/Perception-LM-1B"
|
||||
TEST_MODEL_PATH = "shumingh/plm_1b_hf" # should be replaced by the above once checkpoints are merged
|
||||
TEST_MODEL_PATH = "facebook/Perception-LM-1B"
|
||||
|
||||
|
||||
@require_vision
|
||||
@require_read_token
|
||||
@unittest.skip("Fequires read token and we didn't requests access yet. FIXME @ydshieh when you are back :)")
|
||||
class PerceptionLMProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
processor_class = PerceptionLMProcessor
|
||||
|
||||
|
||||
@@ -737,7 +737,7 @@ class ProphetNetStandaloneDecoderModelTester:
|
||||
|
||||
# get two different outputs
|
||||
output_from_no_past = model(next_input_ids)["last_hidden_state"]
|
||||
output_from_past = model(next_tokens, past_key_values=past_key_values)["last_hidden_state"]
|
||||
output_from_past = model(next_tokens, past_key_values=past_key_values, use_cache=True)["last_hidden_state"]
|
||||
|
||||
# select random slice
|
||||
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
||||
|
||||
@@ -354,9 +354,9 @@ class SpeechT5ForSpeechToTextTester:
|
||||
next_attention_mask = torch.cat([attention_mask, next_attn_mask], dim=-1)
|
||||
|
||||
output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)["last_hidden_state"]
|
||||
output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values)[
|
||||
"last_hidden_state"
|
||||
]
|
||||
output_from_past = model(
|
||||
next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values, use_cache=True
|
||||
)["last_hidden_state"]
|
||||
|
||||
# select random slice
|
||||
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
||||
|
||||
@@ -853,10 +853,10 @@ class ModelTesterMixin:
|
||||
addition_year = int(match_object.group(1))
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
# For now, skip everything older than 2025 and "important models" (too much models to patch otherwise)
|
||||
# For now, skip everything older than 2024 and "important models" (too much models to patch otherwise)
|
||||
# Use `supports_cache_class` as a proxy to judge "important" models in order to prioritize them
|
||||
# TODO: relax this as we patch more and more models
|
||||
if addition_year < 2024 and not model_class._supports_cache_class:
|
||||
if addition_year < 2024:
|
||||
self.skipTest(reason=f"{model_class} is not a priorited model for now.")
|
||||
|
||||
# Monkey patch the method to add a seed (we do it on PreTrainedModel._initialize_weights, which wraps
|
||||
@@ -1590,18 +1590,7 @@ class ModelTesterMixin:
|
||||
head_dim = model.config.hidden_size // model.config.num_attention_heads
|
||||
|
||||
cache_shape = (batch_size, num_heads, 0, head_dim)
|
||||
empty_pkv = tuple(
|
||||
(
|
||||
torch.rand(cache_shape, dtype=torch.float, device=torch_device),
|
||||
torch.rand(cache_shape, dtype=torch.float, device=torch_device),
|
||||
)
|
||||
for i in range(model.config.num_hidden_layers)
|
||||
)
|
||||
empty_pkv = (
|
||||
DynamicCache.from_legacy_cache(empty_pkv)
|
||||
if model_class._supports_cache_class
|
||||
else empty_pkv
|
||||
)
|
||||
empty_pkv = DynamicCache()
|
||||
|
||||
cache_length = 9
|
||||
cache_shape = (batch_size, num_heads, cache_length, head_dim)
|
||||
@@ -1612,11 +1601,7 @@ class ModelTesterMixin:
|
||||
)
|
||||
for i in range(model.config.num_hidden_layers)
|
||||
)
|
||||
non_empty_pkv = (
|
||||
DynamicCache.from_legacy_cache(non_empty_pkv)
|
||||
if model_class._supports_cache_class
|
||||
else non_empty_pkv
|
||||
)
|
||||
non_empty_pkv = DynamicCache.from_legacy_cache(non_empty_pkv)
|
||||
|
||||
inps = copy.deepcopy(inputs_to_test[0])
|
||||
|
||||
|
||||
@@ -46,7 +46,6 @@ if is_torch_available():
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
Cache,
|
||||
ClvpForCausalLM,
|
||||
DynamicCache,
|
||||
Gemma2Config,
|
||||
GenerationConfig,
|
||||
@@ -122,36 +121,6 @@ class CacheTest(unittest.TestCase):
|
||||
torch.allclose(to_legacy[layer_idx][key_value_idx], new_cache[layer_idx][key_value_idx])
|
||||
)
|
||||
|
||||
def test_reorder_cache_retrocompatibility(self):
|
||||
"""Tests that Cache.reorder_cache is retrocompatible with the legacy code path"""
|
||||
legacy_reorder_fn = ClvpForCausalLM._reorder_cache # An example of a legacy `_reorder_cache` function
|
||||
|
||||
legacy_cache = ()
|
||||
new_cache = DynamicCache()
|
||||
|
||||
# Creates a new cache with 10 layers in both formats
|
||||
for layer_idx in range(10):
|
||||
new_key = torch.rand((4, 4, 8, 16))
|
||||
new_value = torch.rand((4, 4, 8, 16))
|
||||
new_cache.update(new_key, new_value, layer_idx)
|
||||
legacy_cache += ((new_key, new_value),)
|
||||
|
||||
# Let's create some dummy beam indices. From the shape above, it is equivalent to the case where num_beams=4
|
||||
# and batch_size=1
|
||||
beam_idx = torch.randint(low=0, high=4, size=(4,))
|
||||
|
||||
legacy_cache_reordered = legacy_reorder_fn(legacy_cache, beam_idx)
|
||||
new_cache.reorder_cache(beam_idx)
|
||||
|
||||
# Let's check that the results are the same
|
||||
for layer_idx in range(10):
|
||||
for key_value_idx in range(2):
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
new_cache[layer_idx][key_value_idx], legacy_cache_reordered[layer_idx][key_value_idx]
|
||||
)
|
||||
)
|
||||
|
||||
def test_static_cache_mha_mqa_gqa(self):
|
||||
"""
|
||||
Tests that static cache works with multi-head attention (MHA), grouped query attention (GQA), and multi-query
|
||||
|
||||
Reference in New Issue
Block a user