From 9d889f870ed202de2859b0236895fe7f985a0c72 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 16 May 2024 12:08:35 +0100 Subject: [PATCH] Cache: add new flag to distinguish models that `Cache` but not static cache (#30800) * jamba cache * new flag * generate exception --- src/transformers/generation/utils.py | 5 +++++ src/transformers/modeling_utils.py | 3 ++- src/transformers/models/cohere/modeling_cohere.py | 1 + src/transformers/models/dbrx/modeling_dbrx.py | 1 + src/transformers/models/gemma/modeling_gemma.py | 1 + src/transformers/models/idefics2/modeling_idefics2.py | 1 + src/transformers/models/jamba/modeling_jamba.py | 1 + src/transformers/models/llama/modeling_llama.py | 1 + src/transformers/models/mistral/modeling_mistral.py | 1 + src/transformers/models/mixtral/modeling_mixtral.py | 1 + src/transformers/models/olmo/modeling_olmo.py | 1 + src/transformers/models/persimmon/modeling_persimmon.py | 1 + src/transformers/models/phi/modeling_phi.py | 1 + src/transformers/models/phi3/modeling_phi3.py | 1 + src/transformers/models/qwen2/modeling_qwen2.py | 1 + src/transformers/models/qwen2_moe/modeling_qwen2_moe.py | 1 + .../models/recurrent_gemma/modeling_recurrent_gemma.py | 1 - src/transformers/models/starcoder2/modeling_starcoder2.py | 1 + tests/test_modeling_common.py | 2 +- 19 files changed, 23 insertions(+), 3 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 2022096aaf..4b71fe9519 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1616,6 +1616,11 @@ class GenerationMixin: "issue: https://github.com/huggingface/transformers/issues/28981." ) if generation_config.cache_implementation == "static": + if not self._supports_static_cache: + raise ValueError( + "This model does not support `cache_implementation='static'`. Please check the following " + "issue: https://github.com/huggingface/transformers/issues/28981" + ) model_kwargs["past_key_values"] = self._get_static_cache(batch_size, generation_config.max_length) self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 37f35a3433..599147e6cc 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1280,8 +1280,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # SDPA support _supports_sdpa = False - # Has support for a `Cache` instance as `past_key_values` + # Has support for a `Cache` instance as `past_key_values`? Does it support a `StaticCache`? _supports_cache_class = False + _supports_static_cache = False @property def dummy_inputs(self) -> Dict[str, torch.Tensor]: diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 1f08bba620..eb6be4911b 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -720,6 +720,7 @@ class CoherePreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_cache_class = True + _supports_static_cache = True def _init_weights(self, module): std = self.config.initializer_range diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index eaaad0097e..62b571b292 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -938,6 +938,7 @@ class DbrxPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_cache_class = True + _supports_static_cache = True def _init_weights(self, module: nn.Module): std = self.config.initializer_range diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 4d9c0aaa38..8477dc2b5d 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -703,6 +703,7 @@ class GemmaPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_cache_class = True + _supports_static_cache = True def _init_weights(self, module): std = self.config.initializer_range diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py index 7ca963808f..17ed18f6b9 100644 --- a/src/transformers/models/idefics2/modeling_idefics2.py +++ b/src/transformers/models/idefics2/modeling_idefics2.py @@ -1341,6 +1341,7 @@ class Idefics2PreTrainedModel(PreTrainedModel): _no_split_modules = ["Idefics2VisionAttention", "Idefics2MLP", "Idefics2PerceiverLayer", "Idefics2DecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_cache_class = True def _init_weights(self, module): # important: this ported version of Idefics2 isn't meant for training from scratch - only diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index 8a5cdcbc2e..cfa80d2ce3 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -1261,6 +1261,7 @@ class JambaPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True _supports_sdpa = True + _supports_cache_class = True # Note: only supports HybridMambaAttentionDynamicCache def _init_weights(self, module): std = self.config.initializer_range diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 2cf0979d90..e639eac3f5 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -799,6 +799,7 @@ class LlamaPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_cache_class = True + _supports_static_cache = True def _init_weights(self, module): std = self.config.initializer_range diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 15ed401df9..665e95a8fd 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -810,6 +810,7 @@ class MistralPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True _supports_sdpa = True + _supports_cache_class = True def _init_weights(self, module): std = self.config.initializer_range diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 4b9a6db584..e5a81c4c90 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -989,6 +989,7 @@ class MixtralPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True _supports_sdpa = True + _supports_cache_class = True def _init_weights(self, module): std = self.config.initializer_range diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 063f78e5db..5f04ce61ba 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -776,6 +776,7 @@ class OlmoPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_cache_class = True + _supports_static_cache = True def _init_weights(self, module): std = self.config.initializer_range diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index 3abd896079..8d4ad53207 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -457,6 +457,7 @@ class PersimmonPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["PersimmonDecoderLayer"] _skip_keys_device_placement = "past_key_values" + _supports_cache_class = True def _init_weights(self, module): std = self.config.initializer_range diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 4877e84b86..795ff18e5b 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -825,6 +825,7 @@ class PhiPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True _supports_sdpa = True + _supports_cache_class = True def _init_weights(self, module): std = self.config.initializer_range diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index 6e17e8695e..db88d607a3 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -921,6 +921,7 @@ class Phi3PreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True _supports_sdpa = False + _supports_cache_class = True _version = "0.0.5" diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 709504aba7..b5a1370ae1 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -821,6 +821,7 @@ class Qwen2PreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True _supports_sdpa = True + _supports_cache_class = True def _init_weights(self, module): std = self.config.initializer_range diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 5fc4a9071c..838425505b 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -975,6 +975,7 @@ class Qwen2MoePreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True _supports_sdpa = True + _supports_cache_class = True def _init_weights(self, module): std = self.config.initializer_range diff --git a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py index c21f99ce48..a115dc687c 100644 --- a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +++ b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py @@ -541,7 +541,6 @@ class RecurrentGemmaPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = ["cache"] _supports_flash_attn_2 = False _supports_sdpa = False # we can't compare with eager for now - _supports_cache_class = True def _init_weights(self, module): std = math.sqrt(self.config.w_init_variance_scale / self.config.conv1d_width) diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 6d404bd634..61e8518d65 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -799,6 +799,7 @@ class Starcoder2PreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True _supports_sdpa = True + _supports_cache_class = True def _init_weights(self, module): std = self.config.initializer_range diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index d6102dc2ef..b2f96c7269 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -4365,7 +4365,7 @@ class ModelTesterMixin: self.skipTest("Model architecture has no generative classes, and thus not necessarily supporting 4D masks") for model_class in self.all_generative_model_classes: - if not model_class._supports_cache_class: + if not model_class._supports_static_cache: self.skipTest(f"{model_class.__name__} is not guaranteed to work with custom 4D attention masks") config, _ = self.model_tester.prepare_config_and_inputs_for_common() model = model_class(config).to(device=torch_device, dtype=torch.float32)