[tests] expand flex-attn test for vision models (#38434)

* expand the test for VLMs

* typo

* mark models `supports_flex` + expand test for additional kwargs

* flex attn for refactored vision models

* fix copies

* fix

* unskip

* style

* address comments
This commit is contained in:
Raushan Turganbay
2025-06-03 09:40:44 +02:00
committed by GitHub
parent de4cf5a38e
commit bf68dd9e6e
45 changed files with 429 additions and 195 deletions

View File

@@ -65,7 +65,8 @@ class NewTaskModelForNewTask(PaliGemmaForConditionalGeneration):
# L2 normalization
embeddings = proj / proj.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim)
embeddings = embeddings * attention_mask.unsqueeze(-1) # (batch_size, sequence_length, dim)
if attention_mask is not None:
embeddings = embeddings * attention_mask.unsqueeze(-1) # (batch_size, sequence_length, dim)
return (embeddings,) + vlm_outputs