Cache: add new flag to distinguish models that Cache but not static cache (#30800)

* jamba cache

* new flag

* generate exception
This commit is contained in:
Joao Gante
2024-05-16 12:08:35 +01:00
committed by GitHub
parent 17cc71e149
commit 9d889f870e
19 changed files with 23 additions and 3 deletions

View File

@@ -4365,7 +4365,7 @@ class ModelTesterMixin:
self.skipTest("Model architecture has no generative classes, and thus not necessarily supporting 4D masks")
for model_class in self.all_generative_model_classes:
if not model_class._supports_cache_class:
if not model_class._supports_static_cache:
self.skipTest(f"{model_class.__name__} is not guaranteed to work with custom 4D attention masks")
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config).to(device=torch_device, dtype=torch.float32)