[qwen2-vl] fix FA2 inference (#39121)

* fix FA2

* update is causal flag and remove mask for FA2

* update for FA2 with varlen path

* how the tests were passing with different devices?

* add comment and ref to the PR

* move mask preparation to base pretrained model

* seq len is the first dim, not second

* fix copies to fix GLM4V
This commit is contained in:
Raushan Turganbay
2025-07-01 12:18:37 +02:00
committed by GitHub
parent def9663239
commit 7a25f8dfdb
10 changed files with 363 additions and 199 deletions

View File

@@ -184,7 +184,7 @@ class Qwen2_5_VLVisionText2TextModelTester:
input_ids[:, self.num_image_tokens - 1] = self.vision_start_token_id
inputs_dict = {
"pixel_values": pixel_values,
"image_grid_thw": torch.tensor([[1, 1, 1]] * self.batch_size),
"image_grid_thw": torch.tensor([[1, 1, 1]] * self.batch_size, device=torch_device),
"input_ids": input_ids,
"attention_mask": attention_mask,
}

View File

@@ -176,7 +176,7 @@ class Qwen2VLVisionText2TextModelTester:
inputs_dict = {
"pixel_values": pixel_values,
"image_grid_thw": torch.tensor([[1, 1, 1]] * self.batch_size),
"image_grid_thw": torch.tensor([[1, 1, 1]] * self.batch_size, device=torch_device),
"input_ids": input_ids,
"attention_mask": attention_mask,
}