Cache: add new flag to distinguish models that Cache but not static cache (#30800)
* jamba cache * new flag * generate exception
This commit is contained in:
@@ -1616,6 +1616,11 @@ class GenerationMixin:
|
|||||||
"issue: https://github.com/huggingface/transformers/issues/28981."
|
"issue: https://github.com/huggingface/transformers/issues/28981."
|
||||||
)
|
)
|
||||||
if generation_config.cache_implementation == "static":
|
if generation_config.cache_implementation == "static":
|
||||||
|
if not self._supports_static_cache:
|
||||||
|
raise ValueError(
|
||||||
|
"This model does not support `cache_implementation='static'`. Please check the following "
|
||||||
|
"issue: https://github.com/huggingface/transformers/issues/28981"
|
||||||
|
)
|
||||||
model_kwargs["past_key_values"] = self._get_static_cache(batch_size, generation_config.max_length)
|
model_kwargs["past_key_values"] = self._get_static_cache(batch_size, generation_config.max_length)
|
||||||
|
|
||||||
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
|
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
|
||||||
|
|||||||
@@ -1280,8 +1280,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
# SDPA support
|
# SDPA support
|
||||||
_supports_sdpa = False
|
_supports_sdpa = False
|
||||||
|
|
||||||
# Has support for a `Cache` instance as `past_key_values`
|
# Has support for a `Cache` instance as `past_key_values`? Does it support a `StaticCache`?
|
||||||
_supports_cache_class = False
|
_supports_cache_class = False
|
||||||
|
_supports_static_cache = False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dummy_inputs(self) -> Dict[str, torch.Tensor]:
|
def dummy_inputs(self) -> Dict[str, torch.Tensor]:
|
||||||
|
|||||||
@@ -720,6 +720,7 @@ class CoherePreTrainedModel(PreTrainedModel):
|
|||||||
_supports_flash_attn_2 = True
|
_supports_flash_attn_2 = True
|
||||||
_supports_sdpa = True
|
_supports_sdpa = True
|
||||||
_supports_cache_class = True
|
_supports_cache_class = True
|
||||||
|
_supports_static_cache = True
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
std = self.config.initializer_range
|
std = self.config.initializer_range
|
||||||
|
|||||||
@@ -938,6 +938,7 @@ class DbrxPreTrainedModel(PreTrainedModel):
|
|||||||
_supports_flash_attn_2 = True
|
_supports_flash_attn_2 = True
|
||||||
_supports_sdpa = True
|
_supports_sdpa = True
|
||||||
_supports_cache_class = True
|
_supports_cache_class = True
|
||||||
|
_supports_static_cache = True
|
||||||
|
|
||||||
def _init_weights(self, module: nn.Module):
|
def _init_weights(self, module: nn.Module):
|
||||||
std = self.config.initializer_range
|
std = self.config.initializer_range
|
||||||
|
|||||||
@@ -703,6 +703,7 @@ class GemmaPreTrainedModel(PreTrainedModel):
|
|||||||
_supports_flash_attn_2 = True
|
_supports_flash_attn_2 = True
|
||||||
_supports_sdpa = True
|
_supports_sdpa = True
|
||||||
_supports_cache_class = True
|
_supports_cache_class = True
|
||||||
|
_supports_static_cache = True
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
std = self.config.initializer_range
|
std = self.config.initializer_range
|
||||||
|
|||||||
@@ -1341,6 +1341,7 @@ class Idefics2PreTrainedModel(PreTrainedModel):
|
|||||||
_no_split_modules = ["Idefics2VisionAttention", "Idefics2MLP", "Idefics2PerceiverLayer", "Idefics2DecoderLayer"]
|
_no_split_modules = ["Idefics2VisionAttention", "Idefics2MLP", "Idefics2PerceiverLayer", "Idefics2DecoderLayer"]
|
||||||
_skip_keys_device_placement = "past_key_values"
|
_skip_keys_device_placement = "past_key_values"
|
||||||
_supports_flash_attn_2 = True
|
_supports_flash_attn_2 = True
|
||||||
|
_supports_cache_class = True
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
# important: this ported version of Idefics2 isn't meant for training from scratch - only
|
# important: this ported version of Idefics2 isn't meant for training from scratch - only
|
||||||
|
|||||||
@@ -1261,6 +1261,7 @@ class JambaPreTrainedModel(PreTrainedModel):
|
|||||||
_skip_keys_device_placement = "past_key_values"
|
_skip_keys_device_placement = "past_key_values"
|
||||||
_supports_flash_attn_2 = True
|
_supports_flash_attn_2 = True
|
||||||
_supports_sdpa = True
|
_supports_sdpa = True
|
||||||
|
_supports_cache_class = True # Note: only supports HybridMambaAttentionDynamicCache
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
std = self.config.initializer_range
|
std = self.config.initializer_range
|
||||||
|
|||||||
@@ -799,6 +799,7 @@ class LlamaPreTrainedModel(PreTrainedModel):
|
|||||||
_supports_flash_attn_2 = True
|
_supports_flash_attn_2 = True
|
||||||
_supports_sdpa = True
|
_supports_sdpa = True
|
||||||
_supports_cache_class = True
|
_supports_cache_class = True
|
||||||
|
_supports_static_cache = True
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
std = self.config.initializer_range
|
std = self.config.initializer_range
|
||||||
|
|||||||
@@ -810,6 +810,7 @@ class MistralPreTrainedModel(PreTrainedModel):
|
|||||||
_skip_keys_device_placement = "past_key_values"
|
_skip_keys_device_placement = "past_key_values"
|
||||||
_supports_flash_attn_2 = True
|
_supports_flash_attn_2 = True
|
||||||
_supports_sdpa = True
|
_supports_sdpa = True
|
||||||
|
_supports_cache_class = True
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
std = self.config.initializer_range
|
std = self.config.initializer_range
|
||||||
|
|||||||
@@ -989,6 +989,7 @@ class MixtralPreTrainedModel(PreTrainedModel):
|
|||||||
_skip_keys_device_placement = "past_key_values"
|
_skip_keys_device_placement = "past_key_values"
|
||||||
_supports_flash_attn_2 = True
|
_supports_flash_attn_2 = True
|
||||||
_supports_sdpa = True
|
_supports_sdpa = True
|
||||||
|
_supports_cache_class = True
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
std = self.config.initializer_range
|
std = self.config.initializer_range
|
||||||
|
|||||||
@@ -776,6 +776,7 @@ class OlmoPreTrainedModel(PreTrainedModel):
|
|||||||
_supports_flash_attn_2 = True
|
_supports_flash_attn_2 = True
|
||||||
_supports_sdpa = True
|
_supports_sdpa = True
|
||||||
_supports_cache_class = True
|
_supports_cache_class = True
|
||||||
|
_supports_static_cache = True
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
std = self.config.initializer_range
|
std = self.config.initializer_range
|
||||||
|
|||||||
@@ -457,6 +457,7 @@ class PersimmonPreTrainedModel(PreTrainedModel):
|
|||||||
supports_gradient_checkpointing = True
|
supports_gradient_checkpointing = True
|
||||||
_no_split_modules = ["PersimmonDecoderLayer"]
|
_no_split_modules = ["PersimmonDecoderLayer"]
|
||||||
_skip_keys_device_placement = "past_key_values"
|
_skip_keys_device_placement = "past_key_values"
|
||||||
|
_supports_cache_class = True
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
std = self.config.initializer_range
|
std = self.config.initializer_range
|
||||||
|
|||||||
@@ -825,6 +825,7 @@ class PhiPreTrainedModel(PreTrainedModel):
|
|||||||
_skip_keys_device_placement = "past_key_values"
|
_skip_keys_device_placement = "past_key_values"
|
||||||
_supports_flash_attn_2 = True
|
_supports_flash_attn_2 = True
|
||||||
_supports_sdpa = True
|
_supports_sdpa = True
|
||||||
|
_supports_cache_class = True
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
std = self.config.initializer_range
|
std = self.config.initializer_range
|
||||||
|
|||||||
@@ -921,6 +921,7 @@ class Phi3PreTrainedModel(PreTrainedModel):
|
|||||||
_skip_keys_device_placement = "past_key_values"
|
_skip_keys_device_placement = "past_key_values"
|
||||||
_supports_flash_attn_2 = True
|
_supports_flash_attn_2 = True
|
||||||
_supports_sdpa = False
|
_supports_sdpa = False
|
||||||
|
_supports_cache_class = True
|
||||||
|
|
||||||
_version = "0.0.5"
|
_version = "0.0.5"
|
||||||
|
|
||||||
|
|||||||
@@ -821,6 +821,7 @@ class Qwen2PreTrainedModel(PreTrainedModel):
|
|||||||
_skip_keys_device_placement = "past_key_values"
|
_skip_keys_device_placement = "past_key_values"
|
||||||
_supports_flash_attn_2 = True
|
_supports_flash_attn_2 = True
|
||||||
_supports_sdpa = True
|
_supports_sdpa = True
|
||||||
|
_supports_cache_class = True
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
std = self.config.initializer_range
|
std = self.config.initializer_range
|
||||||
|
|||||||
@@ -975,6 +975,7 @@ class Qwen2MoePreTrainedModel(PreTrainedModel):
|
|||||||
_skip_keys_device_placement = "past_key_values"
|
_skip_keys_device_placement = "past_key_values"
|
||||||
_supports_flash_attn_2 = True
|
_supports_flash_attn_2 = True
|
||||||
_supports_sdpa = True
|
_supports_sdpa = True
|
||||||
|
_supports_cache_class = True
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
std = self.config.initializer_range
|
std = self.config.initializer_range
|
||||||
|
|||||||
@@ -541,7 +541,6 @@ class RecurrentGemmaPreTrainedModel(PreTrainedModel):
|
|||||||
_skip_keys_device_placement = ["cache"]
|
_skip_keys_device_placement = ["cache"]
|
||||||
_supports_flash_attn_2 = False
|
_supports_flash_attn_2 = False
|
||||||
_supports_sdpa = False # we can't compare with eager for now
|
_supports_sdpa = False # we can't compare with eager for now
|
||||||
_supports_cache_class = True
|
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
std = math.sqrt(self.config.w_init_variance_scale / self.config.conv1d_width)
|
std = math.sqrt(self.config.w_init_variance_scale / self.config.conv1d_width)
|
||||||
|
|||||||
@@ -799,6 +799,7 @@ class Starcoder2PreTrainedModel(PreTrainedModel):
|
|||||||
_skip_keys_device_placement = "past_key_values"
|
_skip_keys_device_placement = "past_key_values"
|
||||||
_supports_flash_attn_2 = True
|
_supports_flash_attn_2 = True
|
||||||
_supports_sdpa = True
|
_supports_sdpa = True
|
||||||
|
_supports_cache_class = True
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
std = self.config.initializer_range
|
std = self.config.initializer_range
|
||||||
|
|||||||
@@ -4365,7 +4365,7 @@ class ModelTesterMixin:
|
|||||||
self.skipTest("Model architecture has no generative classes, and thus not necessarily supporting 4D masks")
|
self.skipTest("Model architecture has no generative classes, and thus not necessarily supporting 4D masks")
|
||||||
|
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
if not model_class._supports_cache_class:
|
if not model_class._supports_static_cache:
|
||||||
self.skipTest(f"{model_class.__name__} is not guaranteed to work with custom 4D attention masks")
|
self.skipTest(f"{model_class.__name__} is not guaranteed to work with custom 4D attention masks")
|
||||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
model = model_class(config).to(device=torch_device, dtype=torch.float32)
|
model = model_class(config).to(device=torch_device, dtype=torch.float32)
|
||||||
|
|||||||
Reference in New Issue
Block a user