[tests] expand flex-attn test for vision models (#38434)
* expand the test for VLMs * typo * mark models `supports_flex` + expand test for additional kwargs * flex attn for refactored vision models * fix copies * fix * unskip * style * address comments
This commit is contained in:
committed by
GitHub
parent
de4cf5a38e
commit
bf68dd9e6e
@@ -65,7 +65,8 @@ class NewTaskModelForNewTask(PaliGemmaForConditionalGeneration):
|
||||
# L2 normalization
|
||||
embeddings = proj / proj.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim)
|
||||
|
||||
embeddings = embeddings * attention_mask.unsqueeze(-1) # (batch_size, sequence_length, dim)
|
||||
if attention_mask is not None:
|
||||
embeddings = embeddings * attention_mask.unsqueeze(-1) # (batch_size, sequence_length, dim)
|
||||
|
||||
return (embeddings,) + vlm_outputs
|
||||
|
||||
|
||||
Reference in New Issue
Block a user