[VLMs] support attention backends (#37576)
* update models * why rename * return attn weights when sdpa * fixes * fix attn implementation composite * fix moshi * add message * add typings * use explicitly all flags for each attn type * fix some tests * import what is needed * kosmos on main has ew attention already, yay * new models in main, run fixup * won't fix kosmos yet * fix-copies * clean up after rebasing * fix tests * style * dont cast attns to fp32 * did we update ruff? oke, let's just do what it asks * fix pixtral after rebase
This commit is contained in:
committed by
GitHub
parent
e296c63cd4
commit
d23aae2b8c
@@ -3765,6 +3765,10 @@ class ModelTesterMixin:
|
||||
key = "decoder_hidden_states"
|
||||
elif "logits" in outputs_eager and "Classification" in model_class.__name__:
|
||||
key = "logits"
|
||||
elif "language_model_outputs" in outputs_eager and "blip" in model_class.__name__.lower():
|
||||
outputs_eager = outputs_eager["language_model_outputs"]
|
||||
outputs_sdpa = outputs_sdpa["language_model_outputs"]
|
||||
key = "hidden_states" if "hidden_states" in outputs_eager else "decoder_hidden_states"
|
||||
else:
|
||||
key = "hidden_states"
|
||||
|
||||
@@ -4002,14 +4006,14 @@ class ModelTesterMixin:
|
||||
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype)
|
||||
|
||||
sub_models_supporting_fa2 = [
|
||||
module._supports_flash_attn_2
|
||||
(module._supports_flash_attn_2 or module._supports_attention_backend)
|
||||
for name, module in model.named_modules()
|
||||
if isinstance(module, PreTrainedModel) and name != ""
|
||||
]
|
||||
supports_fa2_all_modules = (
|
||||
all(sub_models_supporting_fa2)
|
||||
if len(sub_models_supporting_fa2) > 0
|
||||
else model._supports_flash_attn_2
|
||||
else (model._supports_flash_attn_2 or model._supports_attention_backend)
|
||||
)
|
||||
if not supports_fa2_all_modules:
|
||||
with self.assertRaises(ValueError):
|
||||
|
||||
Reference in New Issue
Block a user